使用 torch._dynamo.nonstrict_trace#
建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日
摘要
使用
nonstrict_trace在torch.compile編譯區域內部使用非嚴格跟蹤來跟蹤函式。您可能希望這樣做,因為 Dynamo 圖在函式內部的某個地方斷裂了,而您確定該函式是可進行非嚴格跟蹤的。
考慮以下場景
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
try:
func(torch.rand(10))
except Exception as e:
print(e)
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "/tmp/ipykernel_850/2253748958.py", line 9, in func
n = get_magic_num()
File "/tmp/ipykernel_850/2253748958.py", line 5, in get_magic_num
torch._dynamo.graph_break()
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
如果我們執行上面的程式碼,我們將收到一個來自 Dynamo 的錯誤,因為儘管使用者指定了 fullgraph=True,但它仍然看到一個圖斷裂。
在這些情況下,如果使用者仍然希望保持 fullgraph=True,他們通常有幾種選擇
圖斷裂是由於 Dynamo 尚不支援的語言特性。在這種情況下,使用者要麼重寫他們的程式碼,要麼在 GitHub 上提交一個 issue。
圖斷裂是由於呼叫了用 C 實現的函式。在這種情況下,使用者可以嘗試使用自定義操作。使用者也可以嘗試提供一個 polyfill(Python 中的引用實現),以便 Dynamo 可以跟蹤它。
最壞的情況——內部編譯器錯誤。在這種情況下,使用者很可能需要在 GitHub 上提交一個 issue。
除了所有這些選項之外,PyTorch 還提供了一個替代方案 torch._dynamo.nonstrict_trace,前提是引發圖斷裂的函式呼叫滿足某些要求
通用非嚴格跟蹤 的要求。
輸入和輸出必須包含基本型別(例如,
int、float、list、dict、torch.Tensor),或者已註冊到torch.utils._pytree的使用者定義型別。該函式必須定義在
torch.compile編譯區域之外。函式讀取的任何非輸入值將被視為常量(例如,全域性張量),並且不會對其進行保護。
在跟蹤對 torch._dynamo.nonstrict_trace 跟蹤的函式的呼叫時,torch.compile 會切換到非嚴格跟蹤,並且 FX 圖最終將包含該函式內部發生的所有相關張量操作。
對於上面的示例,我們可以使用 torch._dynamo.nonstrict_trace 來 消除 圖斷裂
@torch._dynamo.nonstrict_trace
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.1627, 42.1384, 42.9075, 42.5242, 42.4176, 42.1747, 42.9599, 42.2383,
42.5449, 42.0285])
請注意,也可以在 torch.compile 編譯區域內部使用它
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = torch._dynamo.nonstrict_trace(get_magic_num)()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.5935, 42.2370, 42.2154, 42.5488, 42.9691, 42.9799, 42.1668, 42.0909,
42.4228, 42.2204])