巢狀圖中斷#
建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日
摘要
巢狀函式中的圖中斷可能導致編譯器行為難以理解,我們將在下面進行文件說明。
巢狀圖中斷會導致 重複圖中斷行為。
回想一下,當 torch.compile 應用於一個函式時,任何巢狀的函式呼叫也會被跟蹤。巢狀圖中斷 指的是發生在巢狀函式呼叫中的任何圖中斷。
def inner(x):
...
torch._dynamo.graph_break() # nested graph break
...
@torch.compile
def outer(x):
...
y = inner(x)
...
巢狀圖中斷的恢復語義可能令人困惑,因此我們在此描述其行為。
回想一下,在 fullgraph=False 中,圖中斷會被處理,即編譯到目前為止確定的 FX 圖,以常規 Python 執行不支援的程式碼,然後在新 FX 圖中恢復跟蹤。恢復函式實際上是一項相當複雜的技術壯舉,因此恢復跟蹤僅支援頂級函式。
因此,我們可以按照以下方式在巢狀圖中斷後恢復跟蹤(在此限制下):
首先,考慮下面的示例,其中 torch.compile 從 f 開始跟蹤,並一直跟蹤直到遇到 inner1 中的圖中斷。
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def inner2(x):
x = x + 4
x = inner1(x)
x = x + 8
@torch.compile
def f(x):
# start tracing from here
x = x + 16
x = inner2(x)
x = x + 32
f(torch.randn(3))
由於我們只能從頂級函式恢復,因此我們在 f 中對 inner2 的呼叫進行圖中斷。
# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
inner2 然後會自動編譯為頂級函式。我們一直跟蹤直到再次遇到 inner1 中的圖中斷。
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
# this torch.compile is automatically applied
@torch.compile
def inner2(x):
# start tracing from here
x = x + 4
x = inner1(x)
x = x + 8
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
然後,我們在 inner2 中對 inner1 的呼叫進行圖中斷。
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
inner1 然後會自動編譯為頂級函式。圖中斷來自 inner1,因此我們正常處理該圖中斷。
# this torch.compile is automatically applied
@torch.compile
def inner1(x):
# start tracing from here
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
compiled_f_semantics(torch.randn(3))
inner1 被正常處理。
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
因此,初始程式碼在語義上等同於:
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = compiled_inner1_semantics(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
compiled_f_semantics(torch.randn(3))
請特別注意,我們跟蹤了 3 個頂級函式,並且跟蹤了相同的圖中斷 3 次。這就是為什麼在使用 torch.compile 時可能會遇到重複圖中斷的原因。
總而言之,巢狀圖中斷的處理方式如下:
從頂級函式一直跟蹤到巢狀的圖中斷。
在頂級函式中,在呼叫二級函式時進行圖中斷。
編譯到目前為止跟蹤到的 PyTorch 操作並執行編譯後的圖。
呼叫二級函式,該函式會被自動編譯為頂級函式。
在呼叫二級函式後恢復跟蹤。
請注意,處理此圖中斷的執行時為 ,其中 是巢狀深度, 是從頂級函式到圖中斷的指令數。我們最終會跟蹤 幀,並且我們跟蹤相同的圖中斷 次。