Dynamo 核心概念#
建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日
摘要
Dynamo 是
torch.compile的前端,它執行跟蹤 (tracing) 以將 Python 函式(及其巢狀函式呼叫)的語義捕獲到一系列線性操作(即“(FX) 圖”)、剩餘位元組碼和“守衛”(一組圖和位元組碼有效的條件列表)。不支援的 Python 功能會導致圖中斷 (graph breaks),此時 Dynamo 會編譯從跟蹤中獲得的區域性圖,然後執行不支援的程式碼,之後在不支援的程式碼之後恢復跟蹤。
圖中斷可能導致 torch.compile 效能下降,並阻止後端最佳化機會。如果您未獲得預期效能,請檢查圖中斷。
Dynamo 跟蹤#
torch.compile 的前端 (Dynamo) 是一個自定義的 Python 位元組碼直譯器,旨在允許在 PyTorch 程式中進行圖編譯,同時保留 Python 的全部靈活性。給定一個要編譯的函式,Dynamo 會解釋 Python 位元組碼,將一系列 PyTorch 操作提取到 1 個或多個 FX 圖中,這些圖可以由後端進一步最佳化。

例如,對於上面圖示中的函式 f,Dynamo 會生成
一個接受原始輸入以及函式所需的其他一些輸入的FX 圖。
Python 位元組碼,可用作
f的直接替代。在我們的示例中,位元組碼會檢索其他輸入並將其傳遞給圖,並且還包含無法最佳化的 Python 副作用(列表追加)。守衛 (guards),它們指定圖和位元組碼有效的條件。除非另有說明,否則 Dynamo 生成的圖會針對輸入 Tensor 的形狀進行專門化。
圖中斷#
Dynamo 會跟蹤您的程式碼,並嘗試將您的 PyTorch 程式碼捕獲到一個 PyTorch 運算子的計算圖中(FX 圖)。然而,這並非總是可能的。當遇到無法跟蹤的程式碼時,會發生“圖中斷 (graph break)”。在預設的 torch.compile 設定中,圖中斷包括編譯到目前為止確定的 FX 圖,在常規 Python 中執行不支援的程式碼,然後在新 FX 圖中恢復跟蹤。
圖中斷是一項功能,它允許 Dynamo 執行任意 Python 程式碼,並切分出可以單獨最佳化的功能性子圖。
但是,圖中斷可能會導致 torch.compile 出現意想不到的效能下降。如果您未獲得預期的加速,我們建議您檢查圖中斷並將其移除。
圖中斷可能發生在以下情況:
依賴資料的 if 語句
許多 Python 內建函式
C 函式
下面是一個由於呼叫不支援的操作 torch.save 而導致的圖中斷示例。
@torch.compile
def f(x):
y = x ** 2 / 2
torch.save(y, "foo.pt") # torch.save is an unsupported operation
z = y ** 3 / 6
return z
x = torch.randn(3)
print(f(x))
tensor([6.4034e-07, 2.4362e-15, 3.6929e+00])
Graph break in user code at /tmp/ipykernel_699/215272159.py:4
Graph Break Reason: Attempted to call function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `save` in file `/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/serialization.py` should not be traced.
Hint: Avoid calling the function `save`.
Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `save` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: module: torch.serialization, qualname: save, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html
User code traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
app.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
self.io_loop.start()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
self.asyncio_loop.run_forever()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/envs/py_3.10/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
await self.process_one()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
await dispatch(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
await result
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
await super().execute_request(stream, ident, parent)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
reply_content = await reply_content
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
res = shell.run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3006, in run_cell
result = self._run_cell(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3061, in _run_cell
result = runner(coro)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3266, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3445, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_699/215272159.py", line 9, in <module>
print(f(x))
File "/tmp/ipykernel_699/215272159.py", line 4, in f
torch.save(y, "foo.pt") # torch.save is an unsupported operation
torch.compile(f)(x) 的語義大致如下:
def compiled_f_semantics(x):
y = torch.compile(g, fullgraph=True)(x)
torch.save(y, "foo.pt")
z = torch.compile(h, fullgraph=True)(x)
return z
def g(x):
return x ** 2 / 2
def h(x):
return y ** 3 / 6
守衛#
在跟蹤程式碼時,torch.compile 會對執行時值做一些假設。在跟蹤過程中,我們會生成“守衛”,這些守衛是用於檢查這些假設的執行時檢查。守衛會在後續呼叫已編譯函式時執行,以確定我們是否可以重用先前編譯的程式碼。執行時檢查的示例包括常量值、型別和物件 ID。
下面是一個生成的守衛示例。TENSOR_MATCH 守衛會檢查輸入的型別、裝置、dtype、形狀等。
@torch.compile
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
GUARDS:
TREE_GUARD_MANAGER:
+- RootGuardManager
| +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None # _dynamo/output_graph.py:688 in init_ambient_guards
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:676 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0), type=<class 'torch.Tensor'>, tag_safe=(True, False)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # mp/ipykernel_699/1068332425.py:3 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # mp/ipykernel_699/1068332425.py:3 in fn
Guard eval latency = 569.96 us
重新編譯#
如果先前編譯程式碼的每個例項的守衛都失敗,那麼 torch.compile 必須“重新編譯”該函式,這需要再次跟蹤原始程式碼。在下面的示例中,由於檢查張量引數形狀的守衛失敗,因此需要重新編譯。
@torch.compile
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
print(fn(torch.ones(4, 4)))
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
Recompiling function fn in /tmp/ipykernel_699/420870727.py:1
triggered by the following guard failure(s):
- 3/0: tensor 'x' size mismatch at index 0. expected 3, actual 4
動態形狀#
torch.compile 最初假設張量形狀是靜態/恆定的,並基於這些假設進行守衛。透過使用“動態形狀”,我們可以讓 torch.compile 生成可以接受不同形狀的張量輸入的已編譯程式碼 - 我們避免了每次形狀不同時都重新編譯。預設情況下,在 torch.compile(dynamic=None) 中啟用了自動動態形狀 - 如果由於形狀不匹配導致編譯失敗,則會嘗試使用動態形狀進行重新編譯。動態形狀也可以完全啟用(dynamic=True)或停用(dynamic=False)。
下面,我們啟用了動態形狀,並注意到我們不再需要重新編譯。
@torch.compile(dynamic=True)
def fn(x):
return x + 1
print(fn(torch.ones(3, 3)))
print(fn(torch.ones(4, 4)))
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
create_env
create_symbol s77 = 3 for L['x'].size()[0] [2, int_oo] return x + 1 # mp/ipykernel_699/1458103805.py:3 in fn (_dynamo/variables/builder.py:3508 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
create_symbol s77 duck sized L['x'].size()[1]
eval False == False [statically known]
eval False == False [statically known]
produce_guards
track_symint L['x'].size()[0] s77 None
track_symint L['x'].size()[1] s77 None
track_symint L['x'].stride()[0] s77 None
track_symint L['x'].stride()[1] 1 None
track_symint L['x'].storage_offset() 0 None
Skipping guard L['x'].stride()[1] == 1
Skipping guard L['x'].storage_offset() == 0
有關動態形狀的更多資訊,請參閱 動態形狀手冊。