Dynamo 深度解析#
建立時間:2024 年 4 月 2 日 | 最後更新時間:2025 年 8 月 12 日
torch.compile 中的追蹤器是 TorchDynamo(或簡稱 Dynamo),它通常是那些令人費解的堆疊回溯的“罪魁禍首”。然而,我們不能一概而論地將錯誤歸咎於 Dynamo。為了給使用者提供所需的靈活性,Dynamo 肩負著理解任何 Python 程式的艱鉅任務。特別是,Dynamo 需要在內部實現相當一部分 Python 語言的功能!
在本文中,我們將從頭開始介紹 Dynamo 的內部設計。我們將討論它提供的功能以及它的實現方式。閱讀本文後,您將更深入地瞭解當您 torch.compiled 一個 PyTorch 程式並遇到編譯錯誤時,或者編譯成功但速度提升不如預期時,到底發生了什麼。
Dynamo 入門指南#
在深入探討所有實現細節之前,讓我們先討論一下 Dynamo 的作用。
Dynamo 是一個追蹤器。這意味著,給定一個函式及其輸入,它會執行該函式並將指令的線性序列(不包含控制流)記錄到一個圖中。例如,考慮以下程式
import torch
@torch.compile
def mse(x, y):
z = (x - y) ** 2
return z.sum()
x = torch.randn(200)
y = torch.randn(200)
mse(x, y)
如果我們將此程式儲存到檔案 example.py 並執行
TORCH_LOGS=graph_code python example.py
我們將看到 Dynamo 追蹤的輸出
def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
# File: example.py:5, code: z = (x - y) ** 2
sub = l_x_ - l_y_
z = sub ** 2
# File: example.py:6, code: return z.sum()
sum_1 = z.sum()
return (sum_1,)
我們稱之為給定輸入的函式的圖(或追蹤)。這透過 FX graph 表示。我們將 FX graph 簡單地視為一個儲存函式呼叫列表的容器。
我們首先注意到的是,該圖是 PyTorch 操作的線性序列。1 Dynamo 記錄所有 PyTorch 操作並將它們按順序儲存。例如,它將 z = (x - y) ** 2 分解為其兩個組成操作 sub = l_x_ - l_y_ 和 z = sub ** 2。
當我們說追蹤是線性的時,我們指的是沒有分支或任何控制流。要理解這一點,請考慮
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
當使用 TORCH_LOGS=graph_code 執行時,它會返回
def forward(l_x_: torch.Tensor):
# File: example.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: example.py:7, code: return (n + 1) * y
mul = 3 * y
return (mul,)
我們看到 Dynamo 完全從追蹤中移除了 if 語句,只記錄了使用輸入執行的操作。
因此,應該清楚的是,函式的追蹤取決於輸入。特別地,這意味著當我們在 @torch.compile 中編寫程式碼時,追蹤並不會生成,而是在使用實際引數執行函式 fn(x, 2) 時生成。
另一個值得注意的有趣之處是 Dynamo 移除了函式的第二個引數。相反,它將其視為常量,並在圖中記錄操作 n + 1 的結果。這是 Dynamo 的另一個特性:Dynamo 會將任何非張量值視為常量…除了整數。現在讓我們看看整數是如何特殊的。
Dynamo 的最後一個定義屬性是它知道如何處理動態形狀。符號形狀是指 Dynamo 追蹤形狀的能力,更廣泛地說,是追蹤整數而不是將它們保留為常量。這使得可以避免重新編譯,並在生產環境中部署通用的、適用於任何大小的模型。動態形狀出現的主要例子是批次大小(batch size),我們可以用固定的批次大小訓練模型,然後在推理時使用任意批次大小,或者是在處理文字或音訊時遇到的可變序列長度。
我們可以透過多次執行上面的示例來看到這一點
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)
在這種情況下,TORCH_LOGS=graph_code 生成了另外兩個圖
# Graph for n==2 omitted
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:7, code: return (n + 1) * y
add = l_n_ + 1
mul = add * y
return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:9, code: return y / n
truediv = y / l_n_
return (truediv,)
Dynamo 檢測到一個整數在其第一次呼叫後改變了值,並開始追蹤它。我們看到這些圖是通用的,並透過 SymInt 型別的一個物件來符號化地追蹤變數 n。
如果在這些呼叫之後呼叫 fn(x, 4),Dynamo 不會重新編譯,而是會重用已追蹤的圖。
總結:1. Dynamo 是一個 Python 追蹤器 2. 給定一些輸入,它返回一個包含執行的 PyTorch 函式的 FX 圖 3. 如果它檢測到整數在呼叫之間發生變化,它也可以追蹤整數 4. 它會專門化任何非張量或標量值
當然,Dynamo 還做了很多其他事情,比如確定何時需要重新追蹤、重寫函式的位元組碼、實現圖中斷…為了保持介紹簡短,我們將在後續文章中逐步討論所有這些內容。
PEP 523:為 CPython 新增幀評估 API#
現在,想象一下我們的任務是實現 Dynamo。我們從哪裡開始呢?恰好,PEP 523 在 Python 3.6 中釋出了。該 PEP 旨在 讓第三方能夠建立 Python 的 JIT 編譯器。讓我們看看它是如何實現的。
關於 CPython 的說明:CPython 在內部實現為一個棧式虛擬機器。Python 程式被編譯成位元組碼,然後由直譯器執行。要了解更多關於這些位元組碼的資訊,請參閱標準庫中的 dis 模組。另請參閱 開發者文件 以瞭解 CPython 直譯器的介紹。我們假設讀者熟悉棧式虛擬機器這個概念。
PEP 523 公開了一個 API,使用者可以透過該 API 新增自定義的每個函式的直譯器。然後,CPython 將使用該直譯器而不是其自身的直譯器來執行函式。為了能夠執行函式,在進入時,CPython 會向自定義直譯器提供諸如以下內容:- 函式的位元組碼 - 函式引數的值(即區域性變數)及其名稱 - 全域性變數的值及其名稱 - 內建函式,如 abs 或 print
總之,CPython 向用戶的直譯器提供了執行函式所需的所有資訊。3
有了這個 API,我們可以透過實現一個直譯器來建立一個追蹤器,該直譯器執行程式碼並將所有在執行過程中發生的 PyTorch 操作記錄到一個圖中。這正是 Dynamo 的做法。
Dynamo 使用這個 CPython API 來解析所有這些物件,並將它們打包成一個 Python 結構。完成之後…它從 C 回到 Python。除了這段與 CPython 通訊的程式碼之外,Dynamo 完全是用 Python 實現的。
應該清楚的是,裝飾器 @torch.compile 的工作是安裝必要的腳手架,以便在呼叫函式時將位元組碼、引數、全域性變數等傳遞給 Dynamo。同樣,@torch.compile 實際上並沒有編譯任何東西。
用 Python 實現 CPython#
所以,我們回到了 Python 世界。我們擁有函式的位元組碼以及執行它所需的所有上下文。特別是,我們來到了 _convert_frame_assert。這是裝飾器 torch.compile 返回的函式!我們從 _dynamo.optimize 到達此函式。裝飾器 torch.compile 只是 _dynamo.optimize 的一個方便的 API。
在開始實現 Python 直譯器之前,我們想定義一個中間表示 (IR)。特別是,我們想將所有區域性變數和全域性變數包裝在我們自己的內部類中。這使我們能夠更好地跟蹤這些物件,並將可以以相同方式處理的物件組合起來,以便 Dynamo 識別。
內部類結構中的父類是 VariableTracker,它代表 Dynamo 理解的不同物件。例如,ListVariable 代表一個 list 物件,並在內部維護一個 VariableTrackers 列表。另一個 VariableTracker 的例子是 ConstantVariable。ConstantVariable 包裝了所有 Dynamo 認為常量的物件。我們還有專門的子類來處理需要特殊關注的物件,例如 TensorVariable。所有這些內部類都定義在 torch/_dynamo/variables 資料夾中。
Python 物件被包裝到其對應的 VariableTracker 類中,在 VariableBuilder._wrap 中。這個函式只是一個非常長的 elif 鏈,它試圖遞迴地將 Python 輸入模式匹配到適當的 VariableTracker 型別。
除錯技巧。當 Dynamo 產生非預期結果時,有時是由於構建器引起的。如果構建器的邏輯錯誤,有時 Dynamo 可能會將變數包裝到錯誤的 VariableTracker 型別中,這可能會在後續導致問題。在遇到 Dynamo 錯誤時,檢視錯誤中出現的 VariableTracker 型別以及丟擲異常的 VariableTracker 方法非常有幫助。特別是,有時我們會發現一個物件被追蹤為 UserDefinedObjectVariable(這是 Dynamo 的通用類),而它應該被追蹤為更具體的型別。在這些情況下,VariableBuilder 的邏輯通常是罪魁禍首。
除錯技巧。當使用 TORCH_LOGS=dynamo 執行程式時,打印出的一個偽像是這樣的行:
TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]
這是原始程式的位元組碼以及當時堆的狀態。這對於查詢物件未被正確追蹤到 VariableTracker 中的位置非常有用。
好了,我們有了一個追蹤器的 IR,現在我們只需要重新實現 CPython 的棧式虛擬機器。這由 InstructorTranslatorBase 在 symbolic_convert.py 中實現。
InstructionTranslatorBase 擁有大約 200 個方法,實現了幾乎所有的 Python 位元組碼。例如,我們可以看到 BUILD_LIST 的實現
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
這是由 l = [2, 3, 4] 等構造生成的位元組碼。在這種情況下,由於有三個元素,生成的位元組碼是 BUILD_LIST 3。這意味著我們將堆疊頂部的 3 個元素彈出,並將由這三個元素組成的新列表物件推到堆疊頂部。
生成輸出圖#
透過一種符號化執行 Python 程式碼的方法,我們可以提取在給定輸入下對程式進行符號化執行期間發生的 PyTorch 操作。這在 Dynamo 中透過 OutputGraph 物件實現。OutputGraph 物件繫結到 InstructionTranslator 物件,並跟蹤建立 Dynamo 返回的 FX 圖所需的所有資料。
FX 圖的所有輸入和中間元素都是 fx.Node。在 Dynamo 中,fx.Node 被包裝在 fx.Proxy 中。fx.Proxy 用於構建 FX 圖。特別是,它們會將對它們執行的每個 PyTorch 操作記錄到圖中。您可以建立一個新的操作並透過呼叫 create_proxy 來新增到圖中。然後,我們可以透過函式 wrap_fx_proxy 將其新增到圖中。
一個圖儲存對張量的操作…以及對符號整數的操作。我們稍後將討論符號整數,但首先我們將討論 Dynamo 如何解決一個相當重要的正確性問題。
使 Dynamo 正確:Guard#
此時,我們有了一種方法可以完全忽略控制流來追蹤程式。為此,我們重新實現了所有 CPython…如果這聽起來有點殺雞焉用牛刀,那是因為它確實是。torch.jit.trace 已經實現了這一點,而無需所有這些機制,那麼 Dynamo 有何優勢?
正如其文件中所警告的,torch.jit.trace 的問題在於,只有當追蹤的程式不是資料依賴的時,它才能正常工作。換句話說,如果程式本身是線性的,它就能工作。這意味著編寫程式時不能使用 if-else、for-while 迴圈、異常。更重要的是,我們使用的任何庫都不能使用任何控制流!總而言之,在一個像 Python 這樣動態的語言中不使用控制流,實際上是一個巨大的限制。
JAX 透過始終重新追蹤並在重新追蹤後快取圖來解決這個問題。另一方面,Dynamo 使用 Guard 來避免每次都重新追蹤整個程式。
Guard 是一個假設(關於輸入的布林表示式),為了使一個幀專門化為一組示例輸入而做的。重用圖只有在這些假設在新輸入上成立時才有效。
例如,函式中的任何常量輸入(如字串)都會安裝一個 Guard,宣告該輸入應為 str 型別,並且等於我們傳入的字串。執行
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
並使用 TORCH_LOGS=guards 列印(除其他 Guard 外)
___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'
這可以解讀為“區域性變數 b 應該具有特定的型別(在這種情況下為 str,由常量 9433... 表示),並且其值應為 'Hello'”。如果我們然後再次執行該函式並傳入一個不同的引數
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")
透過執行 TORCH_LOGS=recompiles,我們可以看到失敗的 Guard
Recompiling function fn in script.py:3
triggered by the following guard failure(s):
- L['b'] == 'Hello'
Guard 在函式輸入被包裝在構建器中以及程式執行期間累積。我們將在下一節中展示更多 Guard 的示例,但首先讓我們討論 Source。
Source 跟蹤如何從進入當前幀時存在的原始區域性變數或全域性變數中重建變數。特別是,它跟蹤原始區域性變數和全域性變數以及它們包含的任何物件。在
def foo(x: Tensor, y: List[Tensor]):
a = x * y[0]
return a * x
x 和 y 的 Source 是 LocalSource,而 y[0] 的 Source 是 GetItemSource,它在內部儲存了一個 LocalSource。另一方面,a 將沒有 Source,因為它是僅存在於 fx 圖中的中間變數。
所有這些都在 torch/_dynamo/source.py 中定義。我們可以在以下示例中看到由 GetItemSource 生成的 Guard
import torch
@torch.compile
def fn(x, l):
return x * len(l[0])
fn(torch.randn(8), ["Hi", "Hello"])
生成以下 Guard
___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'
在這裡,我們看到由 GetItemSource([0] 和 [1])生成的程式碼,它包裝了一個 LocalSource(L['l'])。
至此,我們有了 Source 和 Guard,我們就能夠實現一個快取系統,以避免不必要的重新編譯,而無需每次都重新追蹤。我們將在後續文章中更詳細地討論這個快取系統。
細心的讀者會注意到,這還沒有解釋為什麼我們需要對 Python 直譯器進行如此精細地控制,以至於需要重新實現它。我們展示的 Guard 示例依賴於輸入物件,因此我們仍然可以在執行函式之前計算它們。換句話說,我們可以將這個 Guard 系統實現在 torch.jit.trace 的頂層,並以更少的精力獲得相同的功能…進入符號形狀。
符號形狀#
我們在介紹中討論的另一個觀點是,Dynamo 知道如何追蹤整數。為了實現這一點,我們使用一個符號類 torch.SymInt,它像一個 int 一樣工作,但它會將所有對其執行的操作記錄到輸出的 FX 圖中。4 在介紹中介紹符號整數追蹤時,我們已經見過這個類。
現在讓我們討論定義 Dynamo 中符號形狀追蹤的三種屬性以及如何實現它們。
預設靜態#
Dynamo 假定每個整數,無論是輸入還是張量的形狀,預設都是靜態的。換句話說,在函式第一次執行時不會追蹤任何整數。然後,只有當它檢測到在執行過程中整數或形狀的值發生了變化時,它才會追蹤它並生成一個關於該變數的通用圖。
我們在介紹中已經使用整數看到了這種行為。現在讓我們來看一個使用張量形狀的示例。
import torch
@torch.compile
def fn(a, b):
return a.shape[0] * a * b
fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))
使用 TORCH_LOGS=graph_code 執行此程式,我們看到這兩個呼叫被追蹤為
def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
mul = 4 * l_a_
mul_1 = mul * l_b_
return (mul_1,)
def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
size = l_a_.size()
getitem = size[0]
mul = getitem * l_a_
mul_1 = mul * l_b_
return (mul_1,)
在第一個圖中,形狀被追蹤為常量,但一旦它發生變化,它就會使用 SymInt 符號化地追蹤它。通常,檢視中間值的形狀的更簡單方法是使用 TORCH_LOGS=graph_sizes 執行程式
TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)
在那裡,我們可以看到兩個張量引數的第一個維度是動態的,因為它由 s0 變量表示。
我們可以透過執行 TORCH_LOGS=guards 來找到 Dynamo 實現這一點的方法
# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]
我們看到在第一次呼叫時,Guard 檢查張量是否具有固定的尺寸和步幅。這些 Guard 在第二次執行時會失敗,所以它會重新追蹤。由於這是一個 int Guard 失敗了,所以在第二次迭代中,它會將這個 int 符號化地追蹤,併為這個更通用的核心安裝更通用的 Guard。
編譯效能技巧。如果您知道某個維度的大小會變化,可以在呼叫 torch.compile 之前透過呼叫 torch._dynamo.mark_dynamic 將其標記為動態。這將避免第一次使用靜態形狀進行編譯。還有其他有用的實用函式,如 maybe_mark_dynamic 或 mark_static。您也可以透過呼叫 torch.compile(dynamic=True) 來追蹤所有整數和形狀。這主要用於除錯目的。
0 和 1 總是被專門化#
無論我們是否將某個維度標記為動態,如果我們傳入一個該維度為 0 或 1 的輸入,Dynamo 將會將其追蹤為非動態,併為其生成一個特定的圖。這就是為什麼在上面的示例中我們會發現形式為 2 <= L['a'].size()[0] 的 Guard。
這個選擇有幾個原因。其中有兩個尤其重要:- 張量為空當且僅當其任何維度為零- 張量僅在步幅之一為一時才能是連續的
這個策略決定不適用於普通 Python 整數;如果我們認為一個 Python 整數應該被動態編譯,我們預設不會專門化它們;相反,它是否被專門化取決於它的用法。
鴨子型別形狀(Duck Shaping)#
Dynamo 執行所謂的“鴨子型別形狀”。如果兩個動態整數在追蹤時具有相同的值,我們將假設它們相等併為此建立 Guard。有效地,這意味著我們不再像上面的示例那樣有兩個符號 s0、s1,而是將它們統一為 s0,並擁有 Guard L['b'].size()[0] == L['a'].size()[0]。這使得在編譯器內部進行融合成為可能,同時能夠生成足夠通用的核心。
對符號整數的 Guard#
我們現在在高層面上理解了符號形狀是如何實現的以及它們的屬性。那麼,為什麼符號形狀迫使我們透過控制 CPython 直譯器這種棘手的路線呢?考慮以下示例
import torch
@torch.compile(dynamic=True)
def fn(a):
if a.shape[0] * 2 < 16:
return a
else:
return a + 1
fn(torch.randn(8))
這段程式碼有一個形式為 2*L['a'].size()[0] >= 16 的 Guard。這是一個關於函式輸入的非平凡 Guard,但它是在程式執行中間註冊的。更重要的是,我們直到看到依賴於 SymNodeVariable 引數的 if 語句時,才能知道這個 Guard 是必需的。這些條件對 torch.jit.trace 是不可見的,並且需要對 Python 程式碼進行深度分析。
除錯技巧。使用 TORCH_LOGS=dynamo 執行這段程式碼會告訴我們這個 Guard 是在哪裡新增的
eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)
在該處設定斷點並查看回溯對於理解 Guard 的來源非常有用。
使 Dynamo 完整:圖中斷(Graph Breaks)#
有了我們討論過的所有工具,我們就擁有了一個追蹤器,它可以追蹤張量和整數上的 PyTorch 操作,並擁有一個知道何時可以重用先前追蹤的圖以及何時需要重新追蹤的快取系統。所有這一切都是在執行任意 Python 程式碼!
這裡只有一個小問題。陳述“執行任意 Python 程式碼”可能有點太籠統了。Dynamo 實現了一大部分 Python,但它是否實現了協程或非同步等更復雜的部分?它是否實現了整個 Python 標準庫?NumPy 也有一個 Python API。 torch.compile 是否也理解 NumPy?以及 Django?5
Python 的生態系統非常龐大,其中很大一部分是用 C++ 或 Rust 等更高效的語言編寫的,它們只是暴露了 Python 繫結。Dynamo 無法追蹤用 C++ 實現的 Python 物件。當追蹤器遇到它不理解的操作時,它能做什麼?
機器學習追蹤器處理這個問題的常規方法是告知使用者它們卡住的操作,然後完全放棄追蹤。在 PyTorch 的情況下,這會帶來真正的可用性問題,因為它的使用者習慣了它提供的靈活性。例如,doctr_det_predictor 模型使用 NumPy 和 cv2 庫來後處理模型的輸出。
這裡是 CPython 的另一個有趣之處。Dynamo 不是丟擲錯誤,而是可以讓 CPython 執行那個有問題的程式碼!為了做到這一點,Dynamo 在追蹤時生成一個圖,包含有問題程式碼之前的所有操作,以及一個包含有問題程式碼之後的所有操作的圖。6 然後,在執行時,它會將執行第一個圖、有問題的程式碼以及第二個圖的任務委託給 CPython。這種停止追蹤並生成多個圖的過程稱為圖中斷。
一個小小的坦白:我在整個介紹和前幾節都說了謊。Dynamo 不只生成一個圖,而是生成多個圖!對於所有實際目的,在第二個圖之後開始重新追蹤可以被認為是在開始追蹤一個新函式。圖中斷後的新圖將有自己的 Guard,自己的一組區域性變數,等等。
為了討論如何實現圖中斷,我們首先需要回顧一下 Dynamo 如何與 CPython 互動。使用 PEP 523,CPython 允許使用者使用自己的幀評估機制。我們還沒有討論的是,CPython 還公開了自己的幀評估供他人使用。Dynamo 利用這一點讓快速的 CPython 直譯器執行編譯後的程式碼。對於沒有圖中斷的函式,程式第一次和第二次呼叫函式(引數相同)的整個追蹤/執行過程如下:
在第一次呼叫函式時
Dynamo 將函式追蹤成一個 FX 圖
FX 圖由編譯器(Inductor)編譯成高效的低階程式碼…但這又是另一個故事了
它重寫函式的位元組碼,使其只需呼叫編譯後的函式
它將這個新的位元組碼提供給 CPython,並要求它執行它在此處
在第二次呼叫函式時
這個過程本身看起來過於複雜。為什麼還要生成新的位元組碼並要求 CPython 執行它,而不是直接建立一個 C++ 繫結到編譯後的函式並執行它?嗯,這個模式允許我們實現圖中斷!由圖中斷生成的位元組碼具有以下結構:
執行第一個圖的位元組碼
離開堆疊的位元組碼,就像 CPython 執行第一個圖時的狀態一樣。它還會重放當時可見的對區域性或全域性變數的任何修改
導致 Dynamo 圖中斷的位元組碼
執行第二個圖的位元組碼
讓我們看一個簡單的例子
import torch
@torch.compile
def fn(a):
b = a + 2
print("Hi")
return b + a
fn(torch.randn(4))
使用 TORCH_LOGS=bytecode 執行此程式碼會向我們顯示初始位元組碼和修改後的位元組碼
MODIFIED BYTECODE fn script.py line 3
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 CALL_FUNCTION 1
6 STORE_FAST 3 (graph_out_0)
8 LOAD_GLOBAL 0 (print)
10 LOAD_CONST 2 ('Hi')
12 LOAD_FAST 3 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 STORE_FAST 1 (b)
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST 0 (a)
28 LOAD_FAST 1 (b)
30 CALL_FUNCTION 3
32 RETURN_VALUE
MODIFIED BYTECODE resume_in_fn script.py line 6
0 LOAD_GLOBAL 1 (__compiled_fn_2)
2 LOAD_FAST 2 (b)
4 LOAD_FAST 1 (a)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
我們可以看到修改後的位元組碼被分成兩個函式:fn(原始函式)和一個名為 resume_in_fn 的函式。第二個函式是由 Dynamo 建立的,用於實現從圖中斷開始的程式執行。這通常被稱為續延函式 (continuation function)。此續延函式只需使用正確的引數呼叫第二個編譯後的函式。初始函式的程式碼被重寫,實現了我們之前描述的策略:
L0-4. 呼叫編譯後的函式(
a + 2)。L6. 將其結果儲存在名為
graph_out_0的區域性變數中。graph_out_0是一個元組L8-18. 離開堆疊,使其處於圖中斷時的狀態
L20. 執行導致圖中斷的程式碼
L22-32. 呼叫編譯後的續延函式(
a + b)
Dynamo 中堆疊的程式碼生成委託給了 VariableTracker 子類。Dynamo 中的每個 VariableTracker 物件都有一個 reconstruct 方法,該方法生成必要的位元組碼以在堆疊上建立它所代表的 Python 物件。
除錯技巧。圖中斷會影響效能,因此最好避免它們。使用 TORCH_LOGS=graph_breaks 執行程式是查詢程式觸發了多少圖中斷的一個好方法。它返回的資訊是關於 VariableTracker 物件,所以上面的除錯技巧有時也有助於弄清楚是什麼原因導致了該圖中斷。
結論#
Dynamo 是一個複雜的軟體。一旦你決定實現一個 CPython 直譯器,你就知道這是一段艱難的旅程。話雖如此,我們希望這篇文章能幫助您對其進行一些解釋。
Dynamo (主要)是用 Python 實現的。我們留下了許多指向我們討論過的程式碼片段的連結。我們希望閱讀這些程式碼片段,然後搜尋呼叫它們的地方,或者在它們上面設定斷點並檢視呼叫堆疊,有助於理解程式碼庫的其餘部分。
當然,學習軟體工作原理的最佳方法是擴充套件它。在這種情況下,最好的方法是檢視 GitHub 上的開放 Dynamo 問題。其中許多隻需要對程式碼進行非常小的更改,一旦您找到需要進行更改的地方。
腳註#
以下是本文件中提到的概念的附加詳細資訊和參考。
- 1
在文獻中,這被稱為有向無環圖 (DAG)。
- 2
所有這些繫結程式碼都位於
torch/csrc/dynamo/eval_frame.c中。- 3
在 CPython 術語中,所有這些物件的集合被稱為幀 (frame)。
- 4
還有
SymBool和SymFloat類。後者在撰寫本文時使用得不多。- 5
有趣的是,它確實理解 NumPy 程式碼!可以看看這篇博文和文件。現在,這之所以成為可能,僅僅是因為我們用 PyTorch 重寫了 NumPy。不過,要想用 PyTorch 實現 Django,祝你好運…
- 6
假設只有一段有問題的程式碼。如果有多段,Dynamo 可以將程式碼分割成任意數量的圖。