評價此頁
torch._dynamo.nonstrict_trace">

使用 torch._dynamo.nonstrict_trace#

建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日

摘要

  • 使用 nonstrict_tracetorch.compile 編譯區域內部使用非嚴格跟蹤來跟蹤函式。您可能希望這樣做,因為 Dynamo 圖在函式內部的某個地方斷裂了,而您確定該函式是可進行非嚴格跟蹤的。

考慮以下場景

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
try:
    func(torch.rand(10))
except Exception as e:
    print(e)
Call to `torch._dynamo.graph_break()`
  Explanation: User-inserted graph break. Message: None
  Hint: Remove the `torch._dynamo.graph_break()` call.

  Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html

from user code:
   File "/tmp/ipykernel_850/2253748958.py", line 9, in func
    n = get_magic_num()
  File "/tmp/ipykernel_850/2253748958.py", line 5, in get_magic_num
    torch._dynamo.graph_break()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

如果我們執行上面的程式碼,我們將收到一個來自 Dynamo 的錯誤,因為儘管使用者指定了 fullgraph=True,但它仍然看到一個圖斷裂。

在這些情況下,如果使用者仍然希望保持 fullgraph=True,他們通常有幾種選擇

  1. 圖斷裂是由於 Dynamo 尚不支援的語言特性。在這種情況下,使用者要麼重寫他們的程式碼,要麼在 GitHub 上提交一個 issue。

  2. 圖斷裂是由於呼叫了用 C 實現的函式。在這種情況下,使用者可以嘗試使用自定義操作。使用者也可以嘗試提供一個 polyfill(Python 中的引用實現),以便 Dynamo 可以跟蹤它。

  3. 最壞的情況——內部編譯器錯誤。在這種情況下,使用者很可能需要在 GitHub 上提交一個 issue。

除了所有這些選項之外,PyTorch 還提供了一個替代方案 torch._dynamo.nonstrict_trace,前提是引發圖斷裂的函式呼叫滿足某些要求

  • 通用非嚴格跟蹤 的要求。

  • 輸入和輸出必須包含基本型別(例如,intfloatlistdicttorch.Tensor),或者已註冊到 torch.utils._pytree 的使用者定義型別。

  • 該函式必須定義在 torch.compile 編譯區域之外。

  • 函式讀取的任何非輸入值將被視為常量(例如,全域性張量),並且不會對其進行保護。

在跟蹤對 torch._dynamo.nonstrict_trace 跟蹤的函式的呼叫時,torch.compile 會切換到非嚴格跟蹤,並且 FX 圖最終將包含該函式內部發生的所有相關張量操作。

對於上面的示例,我們可以使用 torch._dynamo.nonstrict_trace 消除 圖斷裂

@torch._dynamo.nonstrict_trace
def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.1627, 42.1384, 42.9075, 42.5242, 42.4176, 42.1747, 42.9599, 42.2383,
        42.5449, 42.0285])

請注意,也可以在 torch.compile 編譯區域內部使用它

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = torch._dynamo.nonstrict_trace(get_magic_num)()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.5935, 42.2370, 42.2154, 42.5488, 42.9691, 42.9799, 42.1668, 42.0909,
        42.4228, 42.2204])