torch.compile 具有不同的 autograd 語義#
建立日期:2025 年 6 月 26 日 | 最後更新日期:2025 年 6 月 26 日
當您將 torch.compile 應用到模型前向傳播的某個函式時,它會自動為編譯後的函式生成一個反向傳播。在編譯期間,它會為反向傳播跟蹤出一個圖,該圖將在每次呼叫 autograd 時使用。我們將 torch.compile 內部負責此任務的元件稱為 AOTDispatcher(有時也稱為 AOTAutograd)。
因此,torch.compile 會在前向傳播的函式編譯過程中,將計算的細節“烘焙”到跟蹤出的反向傳播圖中。然而,在 eager 模式的 PyTorch 中,反向傳播是動態的:在前向傳播之外,您可以將 tensor.backward() 或 torch.autograd.grad(...) 的呼叫包裝在可能改變其行為的上下文管理器中。
此頁面記錄了 torch.compile 的 autograd 語義與 eager 模式 PyTorch 的不同之處,以及如何規避這些差異。
Autocast 行為#
torch.compile 會預先假設反向傳播是否會在環境 autocast 上下文管理器下執行。預設情況下,使用 torch._functorch.config.backward_pass_autocast 來控制該假設;不正確的假設可能導致靜默的錯誤。
選項包括:
"same_as_forward"(預設)。我們假設torch.compile編譯區域的反向傳播將在該區域執行的相同 autocast 上下文管理器下執行(如果存在)。如果您的程式碼如下所示,請使用此選項:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # backward pass run under the same autocast context as the compiled region z.backward()
"off"。我們假設torch.compile編譯區域的反向傳播不會在任何 autocast 上下文管理器下執行。如果您的程式碼如下所示,請使用此選項:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # Backward pass runs under no autocast. z.backward()
還有第三個選項。如果將
torch._functorch.config.backward_pass_autocast設定為 kwargs 列表,我們將假定反向傳播在由這些 kwargs 構建的 autocast 上下文下執行。例如,如果您的程式碼如下所示:
y = torch.compile(region)(x) ... # Backward pass runs under special context manager with torch.amp.autocast(**kwargs): z.backward()
則將
torch._functorch.config.backward_pass_autocast = kwargs設定為。
使用 patch 將選項應用於特定的 torch.compile 呼叫。
with torch.amp.autocast(...):
with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
y = torch.compile(region)(x)
...
# backward pass run under the same autocast context as the compiled region
z.backward()