torch.export 程式設計模型#
創建於:2024年12月18日 | 最後更新於:2025年6月11日
本文件旨在解釋 torch.export.export() 的行為和能力。旨在幫助您直觀理解 torch.export.export() 如何處理程式碼。
追蹤基礎知識#
torch.export.export() 透過追蹤模型在“示例”輸入上的執行,並記錄觀察到的 PyTorch 操作和條件來捕獲代表模型的圖。然後,該圖可以針對滿足相同條件的其他輸入進行執行。
torch.export.export() 的基本輸出是單個 PyTorch 操作圖,以及相關的元資料。輸出的具體格式在 torch.export IR 規範 中有詳細介紹。
嚴格追蹤與非嚴格追蹤#
torch.export.export() 提供兩種追蹤模式。
在非嚴格模式下,我們透過正常的 Python 直譯器來跟蹤程式。您的程式碼的執行方式與貪婪模式完全相同;唯一的區別是所有 Tensor 都被替換為 Fake Tensors,它們具有形狀和其他元資料,但沒有資料,並被包裝在 Proxy 物件 中,這些物件會將所有對它們的運算記錄到一個圖中。我們還捕獲 Tensor 形狀的條件,這些條件用於保證生成程式碼的正確性。
在嚴格模式下,我們首先透過 TorchDynamo(一個 Python 位元組碼分析引擎)來跟蹤程式。TorchDynamo 實際上不執行您的 Python 程式碼。相反,它對其進行符號化分析,並根據結果構建一個圖。一方面,這種分析允許 torch.export.export() 提供額外的 Python 級別安全保證(除了捕獲 Tensor 形狀的條件,如非嚴格模式)。另一方面,並非所有 Python 功能都支援此分析。
儘管目前跟蹤的預設模式是嚴格模式,但我們強烈建議使用非嚴格模式,它將很快成為預設模式。對於大多數模型而言,Tensor 形狀的條件足以保證正確性,而額外的 Python 級別安全保證沒有影響;同時,在 TorchDynamo 中遇到不支援的 Python 功能的可能性會帶來不必要的風險。
在本文件的其餘部分,我們假定我們正在以非嚴格模式進行跟蹤;特別是,我們假定所有 Python 功能都得到支援。
值:靜態 vs. 動態#
理解 torch.export.export() 行為的關鍵概念是靜態和動態值之間的區別。
靜態值#
靜態值是在匯出時固定的,並且在匯出程式的不同執行之間不會改變。在跟蹤期間遇到該值時,我們將其視為常量並將其硬編碼到圖中。
當執行一個運算(例如 x + y)並且所有輸入都是靜態的,該運算的輸出將直接硬編碼到圖中,該運算本身則不會顯示(即它被“常量摺疊”了)。
當一個值被硬編碼到圖中時,我們說該圖已經被特化為該值。例如
import torch
class MyMod(torch.nn.Module):
def forward(self, x, y):
z = y + 7
return x + z
m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)
"""
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None
return (add,)
"""
在這裡,我們將 3 作為 y 的跟蹤值;它被視為靜態值並加到 7 上,在圖中燒入靜態值 10。
動態值#
動態值是每次執行時都可能改變的值。它的行為就像一個“普通”的函式引數:您可以傳遞不同的輸入並期望函式能夠正確執行。
哪些值是靜態的,哪些是動態的?#
一個值是靜態還是動態取決於它的型別
對於 Tensor
Tensor 的資料被視為動態。
Tensor 的形狀可以被系統視為靜態或動態。
預設情況下,所有輸入 Tensor 的形狀都被視為靜態。使用者可以透過為任何輸入 Tensor 指定一個動態形狀來覆蓋此行為。
作為模組狀態一部分的 Tensor(即引數和緩衝區)始終具有靜態形狀。
其他形式的 Tensor元資料(例如
device、dtype)是靜態的。
Python基本型別(
int、float、bool、str、None)是靜態的。有一些基本型別的動態變體(
SymInt、SymFloat、SymBool)。通常使用者不必處理它們。
對於 Python標準容器(
list、tuple、dict、namedtuple)結構(即
list和tuple的長度,以及dict和namedtuple的鍵序列)是靜態的。包含的元素將遞迴地應用這些規則(基本上是 PyTree 方案),葉子是 Tensor 或基本型別。
其他類(包括資料類)可以註冊為 PyTree(見下文),並遵循與標準容器相同的規則。
輸入型別#
輸入將根據其型別(如上所述)被視為靜態或動態。
靜態輸入將被硬編碼到圖中,在執行時傳遞不同的值將導致錯誤。請記住,這些大多是基本型別的值。
動態輸入表現得像“普通”函式輸入。請記住,這些大多是 Tensor 型別的值。
預設情況下,您可以用於程式的輸入型別是
張量
Python 基本型別(
int、float、bool、str、None)Python 標準容器(
list、tuple、dict、namedtuple)
自定義輸入型別#
此外,您還可以定義自己的(自定義)類並將其用作輸入型別,但您需要將此類註冊為 PyTree。
這是一個使用實用工具註冊用作輸入型別的 dataclass 的示例。
@dataclass
class Input:
f: torch.Tensor
p: torch.Tensor
torch.export.register_dataclass(Input)
class M(torch.nn.Module):
def forward(self, x: Input):
return x.f + 1
torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))
可選輸入型別#
對於未傳入程式的可選輸入,torch.export.export() 將特化為它們的預設值。因此,匯出的程式將要求使用者顯式傳入所有引數,並會丟失預設行為。例如
class M(torch.nn.Module):
def forward(self, x, y=None):
if y is not None:
return y * x
return x + x
# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
# File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None
return (mul,)
"""
# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]", y):
# File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
"""
控制流:靜態 vs. 動態#
torch.export.export() 支援控制流。控制流的行為取決於您分支的值是靜態還是動態。
靜態控制流#
關於靜態值的 Python 控制流被透明地支援。(回想一下,靜態值包括靜態形狀,因此關於靜態形狀的控制流也包含在此情況中。)
如上所述,我們“燒入”靜態值,因此匯出的圖將永遠不會看到關於靜態值的任何控制流。
在 if 語句的情況下,我們將繼續跟蹤匯出時所採取的分支。在 for 或 while 語句的情況下,我們將透過展開迴圈來繼續跟蹤。
動態控制流:依賴形狀 vs. 依賴資料#
當控制流中涉及的值是動態的時,它可能依賴於動態形狀或動態資料。鑑於編譯器使用形狀資訊而不是資料進行跟蹤,在這些情況下對程式設計模型的影響是不同的。
動態依賴形狀的控制流#
當控制流中涉及的值是動態形狀時,在大多數情況下我們也會在跟蹤時知道動態形狀的具體值:有關編譯器如何跟蹤此資訊的更多詳細資訊,請參閱下一節。
在這些情況下,我們稱控制流是依賴形狀的。我們使用動態形狀的具體值來評估條件,以確定是True還是False,然後繼續跟蹤(如上所述),並額外發出一個對應於剛剛評估的條件的 guard。
否則,控制流被認為是依賴資料的。我們無法將條件評估為True或False,因此無法繼續跟蹤,並且必須在匯出時引發錯誤。請參閱下一節。
動態依賴資料的控制流#
支援動態值的依賴資料的控制流,但您必須使用 PyTorch 的顯式運算子之一來繼續跟蹤。不允許在動態值上使用 Python 控制流語句,因為編譯器無法評估繼續跟蹤所需的條件,因此必須在匯出時引發錯誤。
我們提供用於表示動態值的通用條件和迴圈的運算子,例如 torch.cond、torch.map。請注意,只有當您確實想要依賴資料的控制流時,才需要使用它們。
這是一個關於資料依賴條件 x.sum() > 0 的 if 語句的示例,其中 x 是一個輸入 Tensor,使用 torch.cond 重寫。與其必須決定跟蹤哪個分支,不如現在兩個分支都會被跟蹤。
class M_old(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
return torch.cond(
pred=x.sum() > 0,
true_fn=lambda x: x.sin(),
false_fn=lambda x: x.cos(),
operands=(x,),
)
資料依賴控制流的一個特殊情況是當它涉及依賴資料的動態形狀時:通常,某些中間 Tensor 的形狀依賴於輸入資料而不是輸入形狀(因此不是依賴形狀的)。在這種情況下,您可以使用斷言來決定條件是True還是False,而不是使用控制流運算子。給定這樣的斷言,我們可以繼續跟蹤,併發出一個 guard。
我們提供用於表示動態形狀斷言的運算子,例如 torch._check。請注意,只有當存在依賴資料的動態形狀的控制流時,您才需要使用它。
這是一個關於涉及依賴資料的動態形狀的條件 nz.shape[0] > 0 的 if 語句的示例,其中 nz 是呼叫 torch.nonzero() 的結果,這是一個輸出形狀依賴於輸入資料的運算子。與其重寫它,不如使用 torch._check 新增斷言來有效地決定跟蹤哪個分支。
class M_old(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
class M_new(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
torch._check(nz.shape[0] > 0)
if nz.shape[0] > 0:
return x.sin()
else:
return x.cos()
符號形狀基礎#
在跟蹤期間,動態 Tensor 形狀和關於它們的條件被編碼為“符號表達式”。(相比之下,靜態 Tensor 形狀和關於它們的條件僅僅是 int 和 bool 值。)
符號就像一個變數;它描述了一個動態 Tensor 形狀。
隨著跟蹤的進行,中間 Tensor 的形狀可能由更通用的表示式描述,通常涉及整數算術運算子。這是因為對於大多數 PyTorch 運算子,輸出 Tensor 的形狀可以描述為輸入 Tensor 形狀的函式。例如,torch.cat() 的輸出形狀是其輸入形狀的總和。
此外,當我們遇到程式中的控制流時,我們會建立布林表示式,通常涉及關係運算符,以描述沿跟蹤路徑的條件。這些表示式用於決定透過程式的哪個路徑進行跟蹤,並記錄在形狀環境中,以保證跟蹤路徑的正確性並評估隨後建立的表示式。
我們將在下面簡要介紹這些子系統。
PyTorch 運算子的 Fake 實現#
回想一下,在跟蹤期間,我們使用Fake Tensors(沒有資料)來執行程式。一般情況下,我們不能用 Fake Tensors 呼叫實際的 PyTorch 運算子實現。因此,每個運算子都需要有一個額外的 Fake(也稱為“meta”)實現,它接收和輸出 Fake Tensors,並且在形狀和 Fake Tensors 所攜帶的其他元資料方面與實際實現的行為匹配。
例如,注意 torch.index_select() 的 Fake 實現如何使用輸入形狀計算輸出形狀(同時忽略輸入資料並返回空輸出資料)。
def meta_index_select(self, dim, index):
result_size = list(self.size())
if self.dim() > 0:
result_size[dim] = index.numel()
return self.new_empty(result_size)
形狀傳播:已支援 vs. 未支援的動態形狀#
形狀透過 PyTorch 運算子的 Fake 實現進行傳播。
要理解動態形狀的傳播,特別是其傳播方式,一個關鍵概念是已支援(backed)和未支援(unbacked)動態形狀之間的區別:我們知道前者(已支援)的具體值,但不知道後者(未支援)的具體值。
形狀的傳播,包括跟蹤已支援和未支援的動態形狀,按如下方式進行
表示輸入的 Tensor 的形狀可以是靜態的或動態的。當是動態的時,它們由符號描述;此外,這些符號是已支援的,因為我們還知道它們在匯出時由使用者提供的“真實”示例輸入的具體值。
運算子的輸出形狀由其 Fake 實現計算,可以是靜態的或動態的。當是動態的時,通常由符號表達式描述。此外
如果輸出形狀僅依賴於輸入形狀,則當所有輸入形狀都是靜態的或已支援的動態的時,它要麼是靜態的,要麼是已支援的動態。
另一方面,如果輸出形狀依賴於輸入資料,那麼它必然是動態的,而且,因為我們無法知道其具體值,所以它是未支援的。
控制流:Guards 和 Assertions#
遇到形狀條件時,它要麼只涉及靜態形狀,在這種情況下它是一個 bool,要麼涉及動態形狀,在這種情況下它是一個符號布林表示式。對於後者
當條件僅涉及已支援的動態形狀時,我們可以使用這些動態形狀的具體值將條件評估為
True或False。然後,我們可以將一個 guard 新增到形狀環境中,宣告相應的符號布林表示式為True或False,然後繼續跟蹤。否則,條件涉及未支援的動態形狀。通常,我們無法在沒有額外資訊的情況下評估此類條件;因此,我們無法繼續跟蹤,並且必須在匯出時引發錯誤。使用者應該使用顯式的 PyTorch 運算子進行跟蹤以繼續。這些資訊作為 guard 新增到形狀環境中,並且還可能幫助評估隨後遇到的其他條件為
True或False。
模型匯出後,任何關於已支援動態形狀的 guard 都可以被理解為關於輸入動態形狀的條件。這些條件將針對匯出時必須提供的動態形狀規範進行驗證,該規範描述了不僅示例輸入,而且所有未來輸入都應滿足的動態形狀條件,以保證生成程式碼的正確性。更準確地說,動態形狀規範必須邏輯上暗示生成的 guard,否則將在匯出時引發錯誤(並提供對動態形狀規範的修復建議)。另一方面,當沒有關於已支援動態形狀的 guard 時(特別是當所有形狀都是靜態的時),則無需嚮導出提供動態形狀規範。通常,動態形狀規範被轉換為生成程式碼的輸入的執行時斷言。
最後,任何關於未支援動態形狀的 guard 都被轉換為“內聯”執行時斷言。這些斷言將新增到生成程式碼中,在那些未支援動態形狀被建立的位置:通常是資料依賴運算子呼叫之後。
允許的 PyTorch 運算子#
所有 PyTorch 運算子都是允許的。
自定義運算子#
此外,您還可以定義和使用自定義運算子。定義自定義運算子包括為其定義 Fake 實現,就像任何其他 PyTorch 運算子一樣(參見上一節)。
這是一個包裝 NumPy 的自定義 sin 運算子及其註冊的(平凡)Fake 實現的示例。
@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
x_np = x.numpy()
y_np = np.sin(x_np)
return torch.from_numpy(y_np)
@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
return torch.empty_like(x)
有時您的自定義運算子的 Fake 實現會涉及依賴資料的形狀。以下是自定義 nonzero 的 Fake 實現可能的樣子。
...
@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
nnz = torch.library.get_ctx().new_dynamic_size()
shape = [nnz, x.dim()]
return x.new_empty(shape, dtype=torch.int64)
模組狀態:讀取 vs. 更新#
模組狀態包括引數、緩衝區和常規屬性。
常規屬性可以是任何型別。
另一方面,引數和緩衝區始終是 Tensor。
模組狀態可以是動態的或靜態的,具體取決於它們的型別(如上所述)。例如,self.training 是一個 bool,這意味著它是靜態的;另一方面,任何引數或緩衝區都是動態的。
模組狀態中包含的任何 Tensor 的形狀都不能是動態的,也就是說,這些形狀在匯出時是固定的,並且在匯出程式的執行之間不能改變。
訪問規則#
所有模組狀態都必須初始化。訪問尚未初始化的模組狀態將在匯出時引發錯誤。
讀取模組狀態始終是允許的.
更新模組狀態是可能的,但必須遵循以下規則
一個靜態的常規屬性(例如,基本型別)可以被更新。讀取和更新可以自由地交錯,並且如預期,任何讀取都將始終看到最新更新的值。由於這些屬性是靜態的,我們也將燒入值,因此生成的程式碼將不會有任何實際“獲取”或“設定”此類屬性的指令。
一個動態的常規屬性(例如,Tensor 型別)不能被更新。要做到這一點,它必須在模組初始化期間註冊為緩衝區。
緩衝區可以被更新,其中更新可以是就地(例如,
self.buffer[:] = ...)或非就地(例如,self.buffer = ...)。引數不能被更新。通常,引數僅在訓練期間更新,而不是在推理期間更新。我們建議使用
torch.no_grad()進行匯出,以避免在匯出時更新引數。