torch.export IR 規範#
創建於: 2023年10月05日 | 最後更新於: 2025年06月13日
Export IR 是一個用於編譯器的中間表示 (IR),它與 MLIR 和 TorchScript 類似。它專門用於表達 PyTorch 程式的語義。Export IR 主要以簡化的操作列表來表示計算,對控制流等動態性支援有限。
要建立 Export IR 圖,可以使用前端,透過一個跟蹤特殊化機制來可靠地捕獲 PyTorch 程式。生成的 Export IR 隨後可以由後端進行最佳化和執行。目前可以透過 torch.export.export() 來實現這一點。
本文件將涵蓋的關鍵概念包括:
ExportedProgram:包含 Export IR 程式的的資料結構
Graph:由節點列表組成。
Nodes:代表操作、控制流以及儲存在該節點上的元資料。
值由節點生成和消耗。
型別與值和節點相關聯。
還定義了值的尺寸和記憶體佈局。
什麼是 Export IR#
Export IR 是 PyTorch 程式的基於圖的中間表示 IR。Export IR 實現於 torch.fx.Graph 之上。換句話說,**所有 Export IR 圖也是有效的 FX 圖**,如果使用標準的 FX 語義進行解釋,Export IR 可以被可靠地解釋。一個隱含的結論是,透過標準的 FX 程式碼生成,匯出的圖可以被轉換為有效的 Python 程式。
本文件將主要關注 Export IR 與 FX 在嚴格性方面的差異,並跳過它們共享的相似部分。
ExportedProgram#
頂級的 Export IR 構造是 torch.export.ExportedProgram 類。它將 PyTorch 模型(通常是 torch.nn.Module)的計算圖與該模型消耗的引數或權重捆綁在一起。
torch.export.ExportedProgram 類的一些值得注意的屬性包括:
graph_module(torch.fx.GraphModule):包含 PyTorch 模型展平計算圖的資料結構。可以透過ExportedProgram.graph直接訪問該圖。graph_signature(torch.export.ExportGraphSignature):圖簽名,它指定了圖中使用的引數和緩衝區名稱以及被修改的引數和緩衝區。它不是將引數和緩衝區儲存為圖的屬性,而是將它們提升為圖的輸入。graph_signature 用於跟蹤這些引數和緩衝區上的附加資訊。state_dict(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含引數和緩衝區的的資料結構。range_constraints(Dict[sympy.Symbol, RangeConstraint]):對於具有資料依賴行為匯出的程式,每個節點上的元資料將包含符號形狀(看起來像s0、i0)。此屬性將符號形狀對映到它們的下限/上限範圍。
Graph#
Export IR Graph 是以 DAG(有向無環圖)形式表示的 PyTorch 程式。圖中的每個節點代表一個特定的計算或操作,圖的邊由節點之間的引用組成。
我們可以將 Graph 看作具有以下模式:
class Graph:
nodes: List[Node]
在實踐中,Export IR 的圖是透過 torch.fx.Graph Python 類實現的。
Export IR 圖包含以下節點(節點將在下一節更詳細地描述):
0 個或多個
placeholder型別的節點0 個或多個
call_function型別的節點恰好 1 個
output型別的節點
推論: 最小的有效 Graph 將是單個節點。即節點永遠不會為空。
定義: Graph 的 placeholder 節點集代表 GraphModule 的**輸入**。Graph 的 output 節點代表 GraphModule 的**輸出**。
示例
import torch
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
return (add,)
以上是 Graph 的文字表示,每一行代表一個節點。
Node#
Node 代表一個特定的計算或操作,並使用 torch.fx.Node 類在 Python 中表示。節點之間的邊透過 Node 類的 args 屬性直接表示為對其他節點的引用。使用相同的 FX 機制,我們可以表示計算圖通常需要的以下操作,例如操作呼叫、佔位符(也稱為輸入)、條件和迴圈。
Node 具有以下模式:
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文字格式
如上例所示,請注意,每行都遵循以下格式:
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以緊湊的方式捕獲了 Node 類中的所有內容,meta 除外。
具體來說:
<name> 是節點在
node.name中出現的名稱。<op_name> 是
node.op欄位,必須是以下之一:<call_function>、<placeholder>、<get_attr>或<output>。<target> 是節點作為
node.target的目標。此欄位的含義取決於op_name。args1, … args 4… 是
node.args元組中列出的內容。如果列表中的值為torch.fx.Node,則會以前導的 % 特別指示。
例如,對 add 運算子的呼叫將顯示為
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x、%y 是另外兩個名為 x 和 y 的節點。值得注意的是,字串 torch.op.aten.add.Tensor 代表實際儲存在 target 欄位中的可呼叫物件,而不僅僅是其字串名稱。
這種文字格式的最後一行是
return [add]
這是一個 op_name = output 的節點,表示我們正在返回該元素。
call_function#
一個 call_function 節點表示對運算子的呼叫。
定義
函式式: 我們說一個可呼叫物件是“函式式”的,如果它滿足以下所有要求:
非變異:運算子不會改變其輸入的 D值(對於張量,這包括元資料和資料)。
無副作用:運算子不會改變從外部可見的狀態,例如更改模組引數的值。
運算子: 是具有預定義模式的函式式可呼叫物件。此類運算子的示例包括函式式 ATen 運算子。
在 FX 中的表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
與普通 FX call_function 的區別
在 FX 圖中,call_function 可以引用任何可呼叫物件,而在 Export IR 中,我們將其限制為僅選擇一部分 ATen 運算子、自定義運算子和控制流運算子。
在 Export IR 中,常量引數將嵌入到圖中。
在 FX 圖中,get_attr 節點可以表示讀取圖模組中儲存的任何屬性。然而,在 Export IR 中,這被限制為僅讀取子模組,因為所有引數/緩衝區都將作為輸入傳遞給圖模組。
元資料#
Node.meta 是附加到每個 FX 節點的字典。但是,FX 規範並未指定哪些元資料可能存在或將會存在。Export IR 提供了更強的約定,特別是所有 call_function 節點都將保證具有且僅具有以下元資料欄位:
node.meta["stack_trace"]是一個字串,包含引用原始 Python 原始碼的 Python 堆疊跟蹤。堆疊跟蹤示例看起來像File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]描述了執行操作的輸出。它可以是<symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]或None型別。node.meta["nn_module_stack"]描述了節點來自的torch.nn.Module的“堆疊跟蹤”,如果它來自torch.nn.Module呼叫。例如,如果一個包含addmm運算子的節點是從torch.nn.Linear模組內部的torch.nn.Sequential模組呼叫的,則nn_module_stack將如下所示:{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}node.meta["source_fn_stack"]包含在分解之前呼叫該節點的 torch 函式或葉子torch.nn.Module類。例如,一個包含來自torch.nn.Linear模組呼叫的addmm運算子的節點將在其source_fn中包含torch.nn.Linear,而一個包含來自torch.nn.functional.Linear模組呼叫的addmm運算子的節點將在其source_fn中包含torch.nn.functional.Linear。
placeholder#
Placeholder 代表圖的輸入。其語義與 FX 中的完全相同。Placeholder 節點必須是圖中節點列表的前 N 個節點。N 可以為零。
在 FX 中的表示
%name = placeholder[target = name](args = ())
target 欄位是輸入名稱的字串。
args(如果非空)應大小為 1,表示此輸入的預設值。
元資料
Placeholder 節點也具有 meta[‘val’],就像 call_function 節點一樣。在這種情況下,val 欄位表示圖在編譯時預期接收的該輸入的形狀/dtype。
output#
輸出呼叫代表函式中的 return 語句;因此,它終止了當前圖。只有一個輸出節點,並且它將始終是圖的最後一個節點。
在 FX 中的表示
output[](args = (%something, …))
這與 torch.fx 中的語義完全相同。args 表示要返回的節點。
元資料
輸出節點的元資料與 call_function 節點相同。
get_attr#
get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與 torch.fx.symbolic_trace() 的普通 FX 圖不同,在普通 FX 圖中 get_attr 節點用於從頂層 torch.fx.GraphModule 讀取引數和緩衝區等屬性,在 Export IR 中,引數和緩衝區作為輸入傳遞給圖模組,並存儲在頂層 torch.export.ExportedProgram 中。
在 FX 中的表示
%name = get_attr[target = name](args = ())
示例
考慮以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
圖
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取包含 sin 運算子的子模組 true_graph_0。
參考文獻#
SymInt#
SymInt 是一個物件,它可以是字面整數,也可以是表示整數的符號(在 Python 中用 sympy.Symbol 類表示)。當 SymInt 是符號時,它描述了一個在編譯時對圖未知的整數型別變數,也就是說,它的值僅在執行時才知道。
FakeTensor#
FakeTensor 是一個包含張量元資料quoi的物件。它可以被視為具有以下元資料。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的 size 欄位是整數或 SymInts 的列表。如果存在 SymInts,則表示此張量具有動態形狀。如果存在整數,則假定張量將具有該確切的靜態形狀。TensorMeta 的秩永遠不是動態的。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式型別提升。FakeTensor 中沒有 strides。
換句話說
如果 node.target 中的運算子返回一個 Tensor,則
node.meta['val']是一個描述該張量的 FakeTensor。如果 node.target 中的運算子返回一個 n 元組的 Tensor,則
node.meta['val']是一個描述每個張量的 n 元組的 FakeTensors。如果 node.target 中的運算子返回一個在編譯時已知的 int/float/scalar,則
node.meta['val']為 None。如果 node.target 中的運算子返回一個在編譯時未知的 int/float/scalar,則
node.meta['val']的型別為 SymInt。
例如
aten::add返回一個 Tensor;因此,其規範將是描述該運算子返回的張量的 dtype 和大小的 FakeTensor。aten::sym_size返回一個整數;因此,其 val 將是 SymInt,因為其值僅在執行時可用。max_pool2d_with_indexes返回一個(Tensor,Tensor)元組;因此,規範也將是一個 FakeTensor 物件的 2 元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。
Python 程式碼
def add_one(x):
return torch.ops.aten(x, 1)
圖
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree-able 型別#
我們將一種型別定義為“Pytree-able”,如果它是一個葉子型別或一個包含其他 Pytree-able 型別的容器型別。
注意
Pytree 的概念與 JAX 的文件 此處 記錄的概念相同。
以下型別定義為葉子型別:
型別 |
定義 |
|---|---|
張量 |
|
Scalar |
Python 中的任何數值型別,包括整數型別、浮點型別和零維張量。 |
int |
Python int(在 C++ 中繫結為 int64_t) |
浮點數 |
Python float(在 C++ 中繫結為 double) |
布林值 |
Python bool |
str |
Python string |
ScalarType |
|
Layout |
|
MemoryFormat |
|
裝置 |
以下型別定義為容器型別:
型別 |
定義 |
|---|---|
Tuple |
Python tuple |
List |
Python list |
Dict |
鍵為 Scalar 的 Python dict |
NamedTuple |
Python namedtuple |
Dataclass |
必須透過 register_dataclass 註冊 |
Custom class |
透過 _register_pytree_node 定義的任何自定義類 |