評價此頁
fullgraph=True 識別和消除圖中斷>

使用 fullgraph=True 識別和消除圖中斷#

建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日

使用 torch.compile(fullgraph=False)(預設值)是開始使用 torch.compile 的好方法:它開箱即用地支援所有 Python 程式,方法是允許圖中斷,並在常見情況下提供良好的效能。

但是,如果您試圖從模型中獲得更多效能,您應該明確考慮哪些程式碼區域應該被編譯

  • 我們建議使用 torch.compile(fullgraph=True) 來查詢和消除程式碼中的圖中斷。

  • 如果您是庫開發者(或正在測試您的程式碼是否“可以”與 torch.compile 一起使用),我們建議使用 torch.compile(fullgraph=True) 進行測試。

torch.compile(fullgraph=True) 相比 fullgraph=False 提供了更強的保證:我們將始終捕獲一個單獨的 FX 圖進行編譯(如果由於圖中斷而無法編譯,則會出錯)。**特別是,您將被迫解決遇到的每一個圖中斷。**

解決圖中斷有多種策略。

策略 1:重寫不支援的程式碼,使其使用 Dynamo 支援的功能#

許多圖中斷錯誤訊息會提供一些關於如何重寫程式碼以避免圖中斷的建議。如果圖中斷仍然難以解決,請繼續進行下一策略,或在 PyTorch GitHub 倉庫 提交一個 issue。

更多圖中斷示例以及如何解決它們,請參見 常見的圖中斷

示例:Dynamo 不支援對被編譯函式的輸入 list_iterator 物件呼叫 next

@torch.compile(fullgraph=True)
def f(xs):
    a = next(xs)
    b = next(xs)
    return a + b

xs = [torch.tensor(1.), torch.tensor(2.)]
try:
    out = f(iter(xs))
except Exception as e:
    print(e)
Unsupported method call
  Explanation: Dynamo does not know how to trace method `__next__` of class `list_iterator`
  Hint: Avoid calling `list_iterator.__next__` in your code.
  Hint: Please report an issue to PyTorch.
  Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope.
  Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+.

  Developer debug context: call_method UserDefinedObjectVariable(list_iterator) __next__ [] {}

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html

from user code:
   File "/tmp/ipykernel_904/1195637716.py", line 3, in f
    a = next(xs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

相反,重寫編譯函式以接受列表。

@torch.compile(fullgraph=True)
def f_rewritten(xs):
    it = iter(xs)
    a = next(it)
    b = next(it)
    return a + b

f_rewritten(xs)
tensor(3.)

策略 2:純函式始終可以透過逃生艙進行編譯#

**摘要**:所有 Python 函式的空間都非常廣闊,因此 Dynamo 無法在沒有圖中斷的情況下跟蹤每個 Python 函式。對於 Dynamo 無法在沒有圖中斷的情況下跟蹤的“純”Python 函式,我們提供了一些逃生艙來嘗試跟蹤這些函式。

  1. 對純 Triton 核心使用 custom_optriton_op

  2. 對僅使用 PyTorch Tensor 運算的純函式使用 nonstrict_trace

  3. 對所有其他純函式使用 custom_op

“純函式”是具有以下屬性的函式

  • 確定性。給定相同的輸入,純函式總是返回相同的輸出。

  • 無外部副作用。純函式沒有任何外部可見的副作用,例如修改外部狀態或執行 I/O 操作。函式內部的副作用是允許的(例如,突變中間張量)。一個值得注意的例外是,函式輸入張量上的 torch.* 運算的突變通常是允許的。

  • 顯式輸入/輸出。所有輸入資料都必須透過函式引數傳遞,並且所有輸出都從函式返回。

有關示例,請參見 純函式

理論上,Dynamo 能夠處理各種各樣的非純函式,但可能缺少對特定 Python 語言功能的覆蓋。然而,純函式總是可以透過逃生艙進行編譯。

如果您有圖中斷,可以將圍繞它的程式碼重構為純函式,並使用繞過 Dynamo 跟蹤的逃生艙。

  1. 如果您希望函式中的 Tensor 運算顯示在 Dynamo 輸出圖中(從而可以最佳化),請使用 torch._dynamo.nonstrict_tracenonstrict_trace 告訴 Dynamo 使用**非嚴格跟蹤**。

  2. 如果您希望函式相對於 torch.compile(包括前端 Dynamo 和後端)是“不透明”的,請使用自定義運算子。

請注意,沒有任何東西阻止這些逃生艙應用於非純函式,但**我們不提供任何健全性保證**。

示例:如果 Dynamo 不支援某些 Python 功能或 API(例如,它使用 PyTorch 運算),並且該功能是嚴格可跟蹤的,請 使用 torch._dynamo.nonstrict_trace 來捕獲它

# this is a function that Dynamo doesn't support (due to the graph_break() call).
def g(x):
    y = x.sin()
    torch._dynamo.graph_break()
    z = y.sin()
    return z

@torch.compile(fullgraph=True)
def f(x):
    w = x.sin()
    return g(w)

x = torch.randn(3)
try:
    f(x)  # Graph Break: there was a call to torch._dynamo.graph_break()
except Exception as e:
    print(e)

@torch.compile(fullgraph=True)
def f_rewritten(x):
    w = x.sin()
    return torch._dynamo.nonstrict_trace(g)(w)
f_rewritten(x)  # works
Call to `torch._dynamo.graph_break()`
  Explanation: User-inserted graph break. Message: None
  Hint: Remove the `torch._dynamo.graph_break()` call.

  Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html

from user code:
   File "/tmp/ipykernel_904/2422769198.py", line 11, in f
    return g(w)
  File "/tmp/ipykernel_904/2422769198.py", line 4, in g
    torch._dynamo.graph_break()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
tensor([-0.0326, -0.7442, -0.2731])

示例:使用 自定義運算子 來建立相對於 torch.compile 不透明的函式。

from torch.utils.cpp_extension import load_inline

# C++ source code for the square operation
cpp_source = """
torch::Tensor square_cpu(torch::Tensor input) {
    // Check that input is a CPU tensor
    TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");

    // Create output tensor with same shape and dtype as input
    torch::Tensor output = torch::empty_like(input);

    // Get data pointers
    float* input_data = input.data_ptr<float>();
    float* output_data = output.data_ptr<float>();

    // Get total number of elements
    int64_t numel = input.numel();

    // For loop to compute square of each element
    for (int64_t i = 0; i < numel; i++) {
        output_data[i] = input_data[i] * input_data[i];
    }

    return output;
}
"""

# Load the extension inline
square_module = load_inline(
    name="square_cpu_kernel",
    cpp_sources=cpp_source,
    functions=["square_cpu"],
    verbose=True
)

def square(x):
    return square_module.square_cpu(x)

@torch.compile(fullgraph=True)
def f(x):
    return square(x)

try:
    f(torch.randn(3, 3))  # graph break
except Exception as e:
    print(e)
[1/2] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=square_cpu_kernel -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/include -isystem /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/py_3.10/include/python3.10 -fPIC -std=c++17 -c /var/lib/jenkins/.cache/torch_extensions/py310_cpu/square_cpu_kernel/main.cpp -o main.o 
[2/2] c++ main.o -shared -L/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o square_cpu_kernel.so
Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace the builtin `square_cpu_kernel.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
  Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
  Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.

  Developer debug context: module: square_cpu_kernel, qualname: pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu, skip reason: <missing reason>

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html

from user code:
   File "/tmp/ipykernel_904/2059008136.py", line 41, in f
    return square(x)
  File "/tmp/ipykernel_904/2059008136.py", line 37, in square
    return square_module.square_cpu(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:1598: UserWarning: Dynamo does not know how to trace the builtin `square_cpu_kernel.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.square_cpu.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
# Use torch.library.custom_op to define a new custom operator.
# Custom operators are opaque with respect to torch.compile:
# that is, torch.compile does not peek into them.

@torch.library.custom_op("mylib::square", mutates_args=())
def square(x: torch.Tensor) -> torch.Tensor:
    return square_module.square_cpu(x)

# Use register_fake to add a ``FakeTensor`` kernel for the operator
@square.register_fake
def _(x):
    return x.new_empty(x.size())

print(f(torch.randn(3, 3)))  # no graph break
tensor([[1.2862e-01, 8.5591e-02, 2.3450e-01],
        [1.5921e-01, 4.4706e-01, 1.5394e+00],
        [8.1086e-04, 4.4906e-01, 5.1608e-01]])

有關自定義 Triton 核心的 triton_op 的更多資訊,請參見 使用者定義的 Triton 核心教程

策略 3:不要編譯程式碼#

並非所有程式碼都適合編譯。torch.compile 是一個用於 Tensor 計算的編譯器;它無法最佳化磁碟 I/O 等內容。嘗試重構程式碼,使不受支援的程式碼不會在編譯區域內被呼叫。

@torch.compile(fullgraph=True)
def f(x):
   y = x ** 2  / 2
   torch.save(y, "foo.pt")
   z = y ** 3 / 6
   return z

x = torch.randn(3)
try:
    f(x)  # Graph Break: torch.save not supported
except Exception as e:
    print(e)
Attempted to call function marked as skipped
  Explanation: Dynamo developers have intentionally marked that the function `save` in file `/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/serialization.py` should not be traced.
  Hint: Avoid calling the function `save`.
  Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `save` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
  Hint: Please file an issue to PyTorch.

  Developer debug context: module: torch.serialization, qualname: save, skip reason: <missing reason>

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html

from user code:
   File "/tmp/ipykernel_904/150060719.py", line 4, in f
    torch.save(y, "foo.pt")

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
def f_rewritten(x):
   y = g(x)
   torch.save(y, "foo.pt")
   z = h(y)
   return z

@torch.compile(fullgraph=True)
def g(x):
   y = x ** 2  / 2
   return y

@torch.compile(fullgraph=True)
def h(y):
   z = y ** 3 / 6
   return z

f_rewritten(x)
tensor([1.3869e-06, 4.1821e-01, 9.4249e-03])