評價此頁
fullgraph=False">

巢狀圖中斷#

建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日

摘要

  • 巢狀函式中的圖中斷可能導致編譯器行為難以理解,我們將在下面進行文件說明。

  • 巢狀圖中斷會導致 O(N)\mathcal O(N) 重複圖中斷行為。

回想一下,當 torch.compile 應用於一個函式時,任何巢狀的函式呼叫也會被跟蹤。巢狀圖中斷 指的是發生在巢狀函式呼叫中的任何圖中斷。

def inner(x):
    ...
    torch._dynamo.graph_break()  # nested graph break
    ...

@torch.compile
def outer(x):
    ...
    y = inner(x)
    ...

巢狀圖中斷的恢復語義可能令人困惑,因此我們在此描述其行為。

回想一下,在 fullgraph=False 中,圖中斷會被處理,即編譯到目前為止確定的 FX 圖,以常規 Python 執行不支援的程式碼,然後在新 FX 圖中恢復跟蹤。恢復函式實際上是一項相當複雜的技術壯舉,因此恢復跟蹤僅支援頂級函式。

因此,我們可以按照以下方式在巢狀圖中斷後恢復跟蹤(在此限制下):

首先,考慮下面的示例,其中 torch.compilef 開始跟蹤,並一直跟蹤直到遇到 inner1 中的圖中斷。

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def inner2(x):
    x = x + 4
    x = inner1(x)
    x = x + 8

@torch.compile
def f(x):
    # start tracing from here
    x = x + 16
    x = inner2(x)
    x = x + 32

f(torch.randn(3))

由於我們只能從頂級函式恢復,因此我們在 f 中對 inner2 的呼叫進行圖中斷。

# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

inner2 然後會自動編譯為頂級函式。我們一直跟蹤直到再次遇到 inner1 中的圖中斷。

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

# this torch.compile is automatically applied
@torch.compile
def inner2(x):
    # start tracing from here
    x = x + 4
    x = inner1(x)
    x = x + 8

def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

然後,我們在 inner2 中對 inner1 的呼叫進行圖中斷。

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

inner1 然後會自動編譯為頂級函式。圖中斷來自 inner1,因此我們正常處理該圖中斷。

# this torch.compile is automatically applied
@torch.compile
def inner1(x):
    # start tracing from here
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

compiled_f_semantics(torch.randn(3))

inner1 被正常處理。

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

因此,初始程式碼在語義上等同於:

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = compiled_inner1_semantics(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

compiled_f_semantics(torch.randn(3))

請特別注意,我們跟蹤了 3 個頂級函式,並且跟蹤了相同的圖中斷 3 次。這就是為什麼在使用 torch.compile 時可能會遇到重複圖中斷的原因。

總而言之,巢狀圖中斷的處理方式如下:

  • 從頂級函式一直跟蹤到巢狀的圖中斷。

  • 在頂級函式中,在呼叫二級函式時進行圖中斷。

  • 編譯到目前為止跟蹤到的 PyTorch 操作並執行編譯後的圖。

  • 呼叫二級函式,該函式會被自動編譯為頂級函式。

  • 在呼叫二級函式後恢復跟蹤。

請注意,處理此圖中斷的執行時為 O(NK)\mathcal O(NK),其中 NN 是巢狀深度,KK 是從頂級函式到圖中斷的指令數。我們最終會跟蹤 O(N2)\mathcal O(N^2) 幀,並且我們跟蹤相同的圖中斷 O(N)\mathcal O(N) 次。