評價此頁
torch.compile 具有不同的 autograd 語義"

torch.compile 具有不同的 autograd 語義#

建立日期:2025 年 6 月 26 日 | 最後更新日期:2025 年 6 月 26 日

當您將 torch.compile 應用到模型前向傳播的某個函式時,它會自動為編譯後的函式生成一個反向傳播。在編譯期間,它會為反向傳播跟蹤出一個圖,該圖將在每次呼叫 autograd 時使用。我們將 torch.compile 內部負責此任務的元件稱為 AOTDispatcher(有時也稱為 AOTAutograd)。

因此,torch.compile 會在前向傳播的函式編譯過程中,將計算的細節“烘焙”到跟蹤出的反向傳播圖中。然而,在 eager 模式的 PyTorch 中,反向傳播是動態的:在前向傳播之外,您可以將 tensor.backward()torch.autograd.grad(...) 的呼叫包裝在可能改變其行為的上下文管理器中。

此頁面記錄了 torch.compile 的 autograd 語義與 eager 模式 PyTorch 的不同之處,以及如何規避這些差異。

Autocast 行為#

torch.compile 會預先假設反向傳播是否會在環境 autocast 上下文管理器下執行。預設情況下,使用 torch._functorch.config.backward_pass_autocast 來控制該假設;不正確的假設可能導致靜默的錯誤。

選項包括:

  • "same_as_forward"(預設)。我們假設 torch.compile 編譯區域的反向傳播將在該區域執行的相同 autocast 上下文管理器下執行(如果存在)。如果您的程式碼如下所示,請使用此選項:

    with torch.amp.autocast(...):
        y = torch.compile(region)(x)
        ...
        # backward pass run under the same autocast context as the compiled region
        z.backward()
    
  • "off"。我們假設 torch.compile 編譯區域的反向傳播不會在任何 autocast 上下文管理器下執行。如果您的程式碼如下所示,請使用此選項:

    with torch.amp.autocast(...):
        y = torch.compile(region)(x)
        ...
    # Backward pass runs under no autocast.
    z.backward()
    
  • 還有第三個選項。如果將 torch._functorch.config.backward_pass_autocast 設定為 kwargs 列表,我們將假定反向傳播在由這些 kwargs 構建的 autocast 上下文下執行。

    例如,如果您的程式碼如下所示:

    y = torch.compile(region)(x)
    ...
    # Backward pass runs under special context manager
    with torch.amp.autocast(**kwargs):
        z.backward()
    

    則將 torch._functorch.config.backward_pass_autocast = kwargs 設定為。

使用 patch 將選項應用於特定的 torch.compile 呼叫。

with torch.amp.autocast(...):
    with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
    y = torch.compile(region)(x)
    ...
    # backward pass run under the same autocast context as the compiled region
    z.backward()