torch.export 程式設計模型#
建立時間:2024 年 12 月 18 日 | 最後更新時間:2025 年 7 月 16 日
本文件旨在解釋 torch.export.export() 的行為和能力。它旨在幫助您直觀地理解 torch.export.export() 如何處理程式碼。
追蹤基礎知識#
torch.export.export() 透過在“示例”輸入上追蹤其執行並記錄追蹤路徑上觀察到的 PyTorch 操作和條件來捕獲表示您模型的圖。然後,只要輸入滿足相同的條件,就可以使用不同的輸入執行此圖。
torch.export.export() 的基本輸出是單個 PyTorch 操作圖,以及相關的元資料。該輸出的確切格式在 export IR 規範 中介紹。
嚴格追蹤與非嚴格追蹤#
torch.export.export() 提供兩種追蹤模式。
在非嚴格模式下,我們使用標準的 Python 直譯器進行追蹤。您的程式碼會像在 eager 模式下一樣執行;唯一的區別是所有 Tensor 都被替換為 Fake Tensor,這些 Fake Tensor 具有形狀和其他元資料,但沒有資料,並被封裝在 Proxy 物件 中,這些 Proxy 物件會將所有對它們的運算記錄到一個圖中。我們還捕獲 Tensor 形狀的條件,這些條件會保護生成程式碼的正確性。
在嚴格模式下,我們首先使用 TorchDynamo(一個 Python 位元組碼分析引擎)進行追蹤。TorchDynamo 實際上不執行您的 Python 程式碼。相反,它會對其進行符號分析並根據結果構建圖。一方面,這種分析允許 torch.export.export() 提供對 Python 級別安全性的額外保證(除了在非嚴格模式下捕獲 Tensor 形狀的條件之外)。另一方面,並非所有 Python 特性都支援這種分析。
儘管目前預設的追蹤模式是嚴格模式,但我們強烈建議使用非嚴格模式,該模式很快將成為預設模式。對於大多數模型而言,Tensor 形狀的條件足以保證正確性,並且對 Python 級別安全性的額外保證沒有影響;同時,在 TorchDynamo 中遇到不支援的 Python 特性的可能性會帶來不必要的風險。
在本檔的其餘部分,我們假設我們正在以 非嚴格模式 進行追蹤;特別是,我們假設所有 Python 特性都得到支援。
值:靜態與動態#
理解 torch.export.export() 行為的一個關鍵概念是靜態值與動態值之間的區別。
靜態值#
靜態值是在匯出時固定的,在匯出程式執行之間不會改變的值。當在追蹤過程中遇到該值時,我們將其視為常量並將其硬編碼到圖中。
當執行一個操作(例如 x + y)且所有輸入都是靜態的時,該操作的輸出將被直接硬編碼到圖中,並且該操作不會顯示(即它會被“常量摺疊”)。
當一個值被硬編碼到圖中時,我們稱該圖已被專門化到該值。例如
import torch
class MyMod(torch.nn.Module):
def forward(self, x, y):
z = y + 7
return x + z
m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)
"""
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None
return (add,)
"""
在這裡,我們將 3 作為 y 的追蹤值;它被視為一個靜態值,並被加到 7 中,從而在圖中固化了靜態值 10。
動態值#
動態值是可以從一次執行到另一次執行而改變的值。它的行為就像一個“正常”的函式引數:您可以傳遞不同的輸入並期望您的函式執行正確。
哪些值是靜態的,哪些是動態的?#
一個值是靜態還是動態取決於它的型別
對於 Tensor
Tensor 的資料被視為動態。
Tensor 的形狀可以被系統視為靜態或動態。
預設情況下,所有輸入 Tensor 的形狀都被視為靜態。使用者可以透過為任何輸入 Tensor 指定 動態形狀 來覆蓋此行為。
作為模組狀態一部分的 Tensor,即引數和緩衝區,始終具有靜態形狀。
Tensor 的其他形式的元資料(例如
device、dtype)是靜態的。
Python原始型別(
int、float、bool、str、None)是靜態的。對於某些原始型別有動態變體(
SymInt、SymFloat、SymBool)。通常使用者不必處理它們。使用者可以透過為其指定 動態形狀 來將整數輸入指定為動態。
對於 Python標準容器(
list、tuple、dict、namedtuple)結構(即
list和tuple的長度,dict和namedtuple的鍵序列)是靜態的。包含的元素遞迴地應用這些規則(基本上是 PyTree 方案),其中葉子是 Tensor 或原始型別。
其他類(包括資料類)可以註冊到 PyTree(見下文),並遵循與標準容器相同的規則。
輸入型別#
輸入將根據其型別(如上所述)被視為靜態或動態。
靜態輸入將被硬編碼到圖中,在執行時傳遞不同的值將導致錯誤。請記住,這些主要是原始型別的值。
動態輸入就像“正常”函式輸入一樣。請記住,這些主要是 Tensor 型別的值。
預設情況下,您可以使用以下型別的輸入來執行程式:
張量
Python 原始型別(
int、float、bool、str、None)Python 標準容器(
list、tuple、dict、namedtuple)
自定義輸入型別(PyTree)#
此外,您還可以定義自己的(自定義)類並將其用作輸入型別,但您需要將此類註冊為 PyTree。
以下是一個使用實用程式註冊用作輸入型別的資料類的示例。
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
import torch.utils._pytree as pytree
pytree.register_dataclass(Input)
class M(torch.nn.Module):
def forward(self, x: Input):
return x.f + 1
torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))
可選輸入型別#
對於程式的可選輸入但未傳入的引數,torch.export.export() 將專門化其預設值。因此,匯出的程式將要求使用者顯式傳入所有引數,並且會丟失預設行為。例如
class M(torch.nn.Module):
def forward(self, x, y=None):
if y is not None:
return y * x
return x + x
# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
# File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None
return (mul,)
"""
# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y):
# File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
"""
控制流:靜態與動態#
torch.export.export() 支援控制流。控制流的行為取決於您分支的值是靜態還是動態。
靜態控制流#
基於靜態值的 Python 控制流得到透明支援。(請記住,靜態值包括靜態形狀,因此基於靜態形狀的控制流也包含在此情況中。)
如上所述,我們“固化”靜態值,因此匯出的圖將永遠不會看到任何基於靜態值的控制流。
在 if 語句的情況下,我們將繼續追蹤匯出時選擇的分支。在 for 或 while 語句的情況下,我們將透過展開迴圈繼續追蹤。
動態控制流:依賴形狀 vs. 依賴資料#
當控制流中涉及的值是動態的時,它可能依賴於動態形狀或動態資料。考慮到編譯器使用形狀資訊而不是資料進行追蹤,這些情況對程式設計模型的影響是不同的。
依賴動態形狀的控制流#
當控制流中涉及的值是 動態形狀 時,在大多數情況下我們也能在追蹤期間得知動態形狀的具體值:關於編譯器如何追蹤此資訊,請參閱以下部分。
在這些情況下,我們稱控制流為依賴形狀的。我們使用動態形狀的具體值來評估條件,以使其為 True 或 False,然後繼續追蹤(如上所述),並額外發出一個對應於剛剛評估的條件的 guard。
否則,控制流被視為依賴資料的。我們無法將條件評估為 True 或 False,因此無法繼續追蹤,必須在匯出時引發錯誤。請參閱下一節。
依賴資料的控制流#
支援對動態值進行依賴資料的控制流,但您必須使用 PyTorch 的顯式運算元之一才能繼續追蹤。不允許在動態值上使用 Python 控制流語句,因為編譯器無法評估繼續追蹤所需的條件,因此必須在匯出時引發錯誤。
我們提供用於表示動態值上的通用條件和迴圈的運算元,例如 torch.cond、torch.map。請注意,只有當您確實想要依賴資料的控制流時,才需要使用它們。
以下是一個依賴資料條件(x.sum() > 0,其中 x 是輸入 Tensor)的 if 語句,使用 torch.cond 重寫。它不必決定追蹤哪個分支,現在兩個分支都被追蹤了。
class M_old(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
return torch.cond(
pred=x.sum() > 0,
true_fn=lambda x: x.sin(),
false_fn=lambda x: x.cos(),
operands=(x,),
)
資料相關控制流的一個特例是涉及 依賴資料的動態形狀:通常,某些中間 Tensor 的形狀依賴於輸入資料而不是輸入形狀(因此不依賴形狀)。在這種情況下,您可以使用斷言來決定條件是 True 還是 False,而不是使用控制流運算元。給定這樣的斷言,我們可以繼續追蹤,並像上面一樣發出 guard。
我們提供用於表示動態形狀斷言的運算元,例如 torch._check。請注意,只有當存在依賴資料的動態形狀上的控制流時,才需要使用它。
以下是一個涉及依賴資料的動態形狀(nz.shape[0] > 0,其中 nz 是呼叫 torch.nonzero() 的結果,一個輸出形狀依賴於輸入資料的運算元)的 if 語句。您可以透過使用 torch._check 新增斷言來有效地決定追蹤哪個分支,而無需重寫。
class M_old(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
torch._check(nz.shape[0] > 0)
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
符號形狀基礎知識#
在追蹤過程中,動態 Tensor 形狀及其上的條件被編碼為“符號表達式”。(相比之下,靜態 Tensor 形狀及其上的條件就是 int 和 bool 值。)
符號類似於變數;它描述了一個動態 Tensor 形狀。
隨著追蹤的進行,中間 Tensor 的形狀可能由更通用的表示式描述,通常涉及整數算術運算子。這是因為對於大多數 PyTorch 運算元,輸出 Tensor 的形狀可以描述為輸入 Tensor 形狀的函式。例如,torch.cat() 的輸出形狀是其輸入形狀的總和。
此外,當我們遇到程式中的控制流時,我們會建立布林表示式,通常涉及關係運算符,以描述追蹤路徑上的條件。這些表示式被評估以決定要追蹤的程式路徑,並記錄在 形狀環境中,以保護追蹤路徑的正確性並評估後續建立的表示式。
我們將在下面簡要介紹這些子系統。
PyTorch 運算元的 Fake 實現#
請注意,在追蹤過程中,我們使用 Fake Tensor 執行程式,這些 Fake Tensor 沒有資料。通常我們無法使用 Fake Tensor 呼叫實際的 PyTorch 運算元實現。因此,每個運算元都需要有一個額外的 Fake(也稱為“meta”)實現,該實現接受和輸出 Fake Tensor,並在形狀和其他由 Fake Tensor 攜帶的元資料方面與實際實現的行為匹配。
例如,注意 torch.index_select() 的 Fake 實現如何使用輸入形狀來計算輸出形狀(同時忽略輸入資料並返回空輸出資料)。
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
形狀傳播:Backed 與 Unbacked 動態形狀#
形狀使用 PyTorch 運算元的 Fake 實現進行傳播。
理解動態形狀傳播的一個關鍵概念是backed(已支援)和unbacked(未支援)動態形狀之間的區別:前者我們知道具體值,後者則不知道。
形狀的傳播,包括追蹤 backed 和 unbacked 動態形狀,過程如下:
代表輸入的 Tensor 的形狀可以是靜態的,也可以是動態的。當是動態的時,它們由符號描述;此外,這些符號是 backed 的,因為在匯出時我們還知道使用者提供的“真實”示例輸入的具體值。
運算元的輸出形狀由其 Fake 實現計算,可以是靜態的,也可以是動態的。當是動態的時,通常由一個符號表達式描述。此外:
如果輸出形狀僅取決於輸入形狀,則當所有輸入形狀都是靜態的或 backed 動態的時,它就是靜態的或 backed 動態的。
另一方面,如果輸出形狀依賴於輸入資料,它必然是動態的,而且,因為我們不知道其具體值,所以它是 unbacked 的。
控制流:Guards 和 Assertions#
當遇到形狀上的條件時,它要麼只涉及靜態形狀,在這種情況下它是 bool,要麼涉及動態形狀,在這種情況下它是符號布林表示式。對於後者:
當條件僅涉及 backed 動態形狀時,我們可以使用這些動態形狀的具體值將條件評估為
True或False。然後,我們可以向形狀環境新增一個 guard,宣告相應的符號布林表示式為True或False,並繼續追蹤。否則,條件涉及 unbacked 動態形狀。通常,我們無法在沒有額外資訊的情況下評估這種條件;因此,我們無法繼續追蹤,並且必須在匯出時引發錯誤。使用者應使用顯式的 PyTorch 運算元進行追蹤以繼續。此資訊將作為 guard 新增到形狀環境中,並且還可能有助於將其他後續遇到的條件評估為
True或False。
模型匯出後,任何關於 backed 動態形狀的 guard 都可以被理解為對輸入動態形狀的條件。這些條件會與必須提供給 export 的動態形狀規範進行驗證,該規範描述了不僅示例輸入,而且所有未來輸入都應滿足的動態形狀條件,以確保生成程式碼的正確性。更準確地說,動態形狀規範必須邏輯上包含生成的 guard,否則將在匯出時引發錯誤(並提供對動態形狀規範的建議修復)。另一方面,當沒有關於 backed 動態形狀的 guard 生成時(特別是在所有形狀都是靜態的時),則不需要向 export 提供動態形狀規範。通常,動態形狀規範會被轉換為生成程式碼的執行時斷言。
最後,關於 unbacked 動態形狀的任何 guard 都將被轉換為“內聯”執行時斷言。這些斷言被新增到生成程式碼中,在建立那些 unbacked 動態形狀的位置:通常是在資料相關運算元呼叫之後。
允許的 PyTorch 運算元#
允許使用所有 PyTorch 運算元。
自定義運算元#
此外,您還可以定義和使用 自定義運算元。定義自定義運算元包括為其定義 Fake 實現,就像任何其他 PyTorch 運算元一樣(參見上一節)。
這是一個包裝 NumPy 的自定義 sin 運算元的示例,以及它註冊的(平凡)Fake 實現。
@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
return torch.empty_like(x)
有時您的自定義運算元的 Fake 實現會涉及依賴資料的形狀。以下是一個自定義 nonzero 的 Fake 實現可能的樣子。
...
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
nnz = torch.library.get_ctx().new_dynamic_size()
shape = [nnz, x.dim()]
return x.new_empty(shape, dtype=torch.int64)
模組狀態:讀取 vs. 更新#
模組狀態包括引數、緩衝區和常規屬性。
常規屬性可以是任何型別。
另一方面,引數和緩衝區始終是 Tensor。
模組狀態可以是動態的或靜態的,具體取決於它們的型別(如上所述)。例如,self.training 是一個 bool,這意味著它是靜態的;另一方面,任何引數或緩衝區都是動態的。
模組狀態中包含的任何 Tensor 的形狀都不能是動態的,即這些形狀在匯出時是固定的,並且不能在匯出程式的執行之間改變。
訪問規則#
所有模組狀態都必須初始化。訪問尚未初始化的模組狀態將在匯出時引發錯誤。
讀取模組狀態總是被允許的。.
更新模組狀態是可能的,但必須遵循以下規則:
靜態常規屬性(例如,原始型別)可以被更新。讀取和更新可以自由交錯,並且正如預期的那樣,任何讀取都將始終看到最新更新的值。因為這些屬性是靜態的,所以我們也將固化這些值,因此生成程式碼將不會有任何實際“獲取”或“設定”這些屬性的指令。
動態常規屬性(例如,Tensor 型別)不能被更新。要做到這一點,它必須在模組初始化期間註冊為緩衝區。
緩衝區可以被更新,更新可以是原地更新(例如,
self.buffer[:] = ...)或非原地更新(例如,self.buffer = ...)。引數不能被更新。通常,引數僅在訓練期間更新,而在推理期間不更新。我們建議使用
torch.no_grad()進行匯出,以避免在匯出時更新引數。