評價此頁
torch.compile">

已編譯的自動梯度:捕獲更大的反向傳播圖以用於 torch.compile#

創建於:2024 年 10 月 09 日 | 最後更新:2024 年 10 月 23 日 | 最後驗證:2024 年 10 月 09 日

作者: Simon Fan

您將學到什麼
  • 已編譯的自動梯度如何與 torch.compile 互動

  • 如何使用已編譯的自動梯度 API

  • 如何使用 TORCH_LOGS 檢查日誌

先決條件

概述#

已編譯的自動梯度是 PyTorch 2.4 中引入的一個 torch.compile 擴充套件,它允許捕獲更大的反向傳播圖。

雖然 torch.compile 會捕獲反向傳播圖,但它是部分捕獲的。AOTAutograd 元件會提前捕獲反向傳播圖,但存在一些限制。

  • 前向傳播中的圖中斷會導致反向傳播中的圖中斷

  • 反向傳播鉤子未被捕獲

已編譯的自動梯度透過直接與自動梯度引擎整合來解決這些限制,允許它在執行時捕獲完整的反向傳播圖。具有這兩種特性的模型應嘗試使用已編譯的自動梯度,並可能觀察到更好的效能。

然而,已編譯的自動梯度也引入了自己的限制

  • 在反向傳播開始時增加了快取查詢的執行時開銷

  • 由於捕獲範圍更大,更容易發生重新編譯和圖中斷

注意

已編譯的自動梯度正在積極開發中,尚未與所有現有的 PyTorch 功能相容。有關特定功能的最新狀態,請參閱 已編譯自動梯度登陸頁面

設定#

在本教程中,我們將基於這個簡單的神經網路模型來舉例。它接收一個 10 維的輸入向量,透過一個線性層進行處理,並輸出另一個 10 維向量。

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

基本用法#

在呼叫 torch.compile API 之前,請確保將 torch._dynamo.config.compiled_autograd 設定為 True

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

在上面的程式碼中,我們建立了一個 Model 類的例項,並使用 torch.randn(10) 生成了一個隨機的 10 維張量 x。我們定義了訓練迴圈函式 train,並使用 @torch.compile 裝飾它以最佳化其執行。當呼叫 train(model, x) 時:

  • Python 直譯器呼叫 Dynamo,因為此呼叫被裝飾了 @torch.compile

  • Dynamo 攔截 Python 位元組碼,模擬其執行並將操作記錄到圖中。

  • AOTDispatcher 停用鉤子,並呼叫自動梯度引擎來計算 model.linear.weightmodel.linear.bias 的梯度,並將操作記錄到圖中。使用 torch.autograd.Function,AOTDispatcher 重寫了 train 的前向和反向傳播實現。

  • Inductor 生成一個對應於 AOTDispatcher 前向和反向傳播最佳化實現的函式。

  • Dynamo 設定最佳化後的函式,以便 Python 直譯器接下來進行評估。

  • Python 直譯器執行最佳化後的函式,該函式執行 loss = model(x).sum()

  • Python 直譯器執行 loss.backward(),呼叫自動梯度引擎,該引擎會路由到已編譯的自動梯度引擎,因為我們將 torch._dynamo.config.compiled_autograd = True 設定為 True。

  • 已編譯的自動梯度計算 model.linear.weightmodel.linear.bias 的梯度,並將操作記錄到圖中,包括它遇到的任何鉤子。在此過程中,它將記錄 AOTDispatcher 之前重寫的反向傳播。然後,已編譯的自動梯度生成一個新函式,該函式對應於 loss.backward() 的完全跟蹤實現,並以推理模式使用 torch.compile 執行它。

  • 相同的步驟將遞迴應用於已編譯的自動梯度圖,但這次 AOTDispatcher 將不需要劃分圖。

檢查已編譯的自動梯度日誌#

使用 TORCH_LOGS 環境變數執行指令碼。

  • 要僅列印已編譯的自動梯度圖,請使用 TORCH_LOGS="compiled_autograd" python example.py

  • 要以損失效能為代價,列印帶有更多張量元資料和重新編譯原因的圖,請使用 TORCH_LOGS="compiled_autograd_verbose" python example.py

重新執行上面的片段,已編譯的自動梯度圖現在應該被記錄到 stderr。某些圖節點將帶有以 aot0_ 為字首的名稱,這些名稱對應於之前在 AOTAutograd 反向傳播圖 0 中預先編譯的節點,例如,aot0_view_2 對應於 ID 為 0 的 AOT 反向傳播圖的 view_2

在下面的影像中,紅色框包含了在沒有已編譯自動梯度的情況下被 torch.compile 捕獲的 AOT 反向傳播圖。

../_images/entire_verbose_log.png

注意

這是我們將呼叫 torch.compile 的圖,而不是最佳化後的圖。已編譯的自動梯度本質上會生成一些未最佳化的 Python 程式碼來表示整個 C++ 自動梯度執行。

使用不同的標誌編譯前向和反向傳播#

您可以使用不同的編譯器配置進行兩次編譯,例如,即使前向傳播中有圖中斷,反向傳播也可以是 fullgraph。

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

或者,您可以使用上下文管理器,它將應用於其作用域內的所有自動梯度呼叫。

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

已編譯的自動梯度解決了 AOTAutograd 的某些限制#

  1. 前向傳播中的圖中斷不再必然導致反向傳播中的圖中斷。

@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

在第一個 torch.compile 案例中,我們看到由於編譯函式 fn 中的 2 次圖中斷,產生了 3 個反向傳播圖。而在第二個使用已編譯自動梯度的 torch.compile 案例中,我們看到即使存在圖中斷,也跟蹤了一個完整的反向傳播圖。

注意

Dynamo 在跟蹤已編譯自動梯度捕獲的反向傳播鉤子時,仍有可能發生圖中斷。

  1. 現在可以捕獲反向傳播鉤子了。

@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

圖中應該有一個 call_hook 節點,Dynamo 稍後會將其內聯到以下內容:

../_images/call_hook_node.png

已編譯自動梯度的常見重新編譯原因#

  1. 由於損失值自動梯度結構的變化

torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,我們在每次迭代時呼叫不同的運算子,導致 loss 跟蹤不同的自動梯度歷史。您應該會看到一些重新編譯訊息:由於新的自動梯度節點導致快取未命中

../_images/recompile_due_to_node.png
  1. 由於張量形狀發生變化

torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,x 的形狀發生變化,在第一次變化後,已編譯的自動梯度會將 x 標記為動態形狀張量。您應該會看到重新編譯訊息:由於形狀變化導致快取未命中

../_images/recompile_due_to_dynamic.png

結論#

在本教程中,我們概述了 torch.compile 與已編譯自動梯度的生態系統、已編譯自動梯地的基礎知識以及一些常見的重新編譯原因。請繼續關注 dev-discuss 上的深度探討。