評價此頁

torch.export IR 規範#

創建於: 2023年10月05日 | 最後更新於: 2025年06月13日

Export IR 是一個用於編譯器的中間表示 (IR),它與 MLIR 和 TorchScript 類似。它專門用於表達 PyTorch 程式的語義。Export IR 主要以簡化的操作列表來表示計算,對控制流等動態性支援有限。

要建立 Export IR 圖,可以使用前端,透過一個跟蹤特殊化機制來可靠地捕獲 PyTorch 程式。生成的 Export IR 隨後可以由後端進行最佳化和執行。目前可以透過 torch.export.export() 來實現這一點。

本文件將涵蓋的關鍵概念包括:

  • ExportedProgram:包含 Export IR 程式的的資料結構

  • Graph:由節點列表組成。

  • Nodes:代表操作、控制流以及儲存在該節點上的元資料。

  • 值由節點生成和消耗。

  • 型別與值和節點相關聯。

  • 還定義了值的尺寸和記憶體佈局。

假設#

本文件假設讀者已充分熟悉 PyTorch,特別是 torch.fx 及其相關工具。因此,它將不再描述 torch.fx 文件和論文中已包含的內容。

什麼是 Export IR#

Export IR 是 PyTorch 程式的基於圖的中間表示 IR。Export IR 實現於 torch.fx.Graph 之上。換句話說,**所有 Export IR 圖也是有效的 FX 圖**,如果使用標準的 FX 語義進行解釋,Export IR 可以被可靠地解釋。一個隱含的結論是,透過標準的 FX 程式碼生成,匯出的圖可以被轉換為有效的 Python 程式。

本文件將主要關注 Export IR 與 FX 在嚴格性方面的差異,並跳過它們共享的相似部分。

ExportedProgram#

頂級的 Export IR 構造是 torch.export.ExportedProgram 類。它將 PyTorch 模型(通常是 torch.nn.Module)的計算圖與該模型消耗的引數或權重捆綁在一起。

torch.export.ExportedProgram 類的一些值得注意的屬性包括:

  • graph_moduletorch.fx.GraphModule):包含 PyTorch 模型展平計算圖的資料結構。可以透過 ExportedProgram.graph 直接訪問該圖。

  • graph_signaturetorch.export.ExportGraphSignature):圖簽名,它指定了圖中使用的引數和緩衝區名稱以及被修改的引數和緩衝區。它不是將引數和緩衝區儲存為圖的屬性,而是將它們提升為圖的輸入。graph_signature 用於跟蹤這些引數和緩衝區上的附加資訊。

  • state_dictDict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含引數和緩衝區的的資料結構。

  • range_constraintsDict[sympy.Symbol, RangeConstraint]):對於具有資料依賴行為匯出的程式,每個節點上的元資料將包含符號形狀(看起來像 s0i0)。此屬性將符號形狀對映到它們的下限/上限範圍。

Graph#

Export IR Graph 是以 DAG(有向無環圖)形式表示的 PyTorch 程式。圖中的每個節點代表一個特定的計算或操作,圖的邊由節點之間的引用組成。

我們可以將 Graph 看作具有以下模式:

class Graph:
  nodes: List[Node]

在實踐中,Export IR 的圖是透過 torch.fx.Graph Python 類實現的。

Export IR 圖包含以下節點(節點將在下一節更詳細地描述):

  • 0 個或多個 placeholder 型別的節點

  • 0 個或多個 call_function 型別的節點

  • 恰好 1 個 output 型別的節點

推論: 最小的有效 Graph 將是單個節點。即節點永遠不會為空。

定義: Graph 的 placeholder 節點集代表 GraphModule 的**輸入**。Graph 的 output 節點代表 GraphModule 的**輸出**。

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是 Graph 的文字表示,每一行代表一個節點。

Node#

Node 代表一個特定的計算或操作,並使用 torch.fx.Node 類在 Python 中表示。節點之間的邊透過 Node 類的 args 屬性直接表示為對其他節點的引用。使用相同的 FX 機制,我們可以表示計算圖通常需要的以下操作,例如操作呼叫、佔位符(也稱為輸入)、條件和迴圈。

Node 具有以下模式:

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文字格式

如上例所示,請注意,每行都遵循以下格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以緊湊的方式捕獲了 Node 類中的所有內容,meta 除外。

具體來說:

  • <name> 是節點在 node.name 中出現的名稱。

  • <op_name>node.op 欄位,必須是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 是節點作為 node.target 的目標。此欄位的含義取決於 op_name

  • args1, … args 4…node.args 元組中列出的內容。如果列表中的值為 torch.fx.Node,則會以前導的 % 特別指示。

例如,對 add 運算子的呼叫將顯示為

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x%y 是另外兩個名為 x 和 y 的節點。值得注意的是,字串 torch.op.aten.add.Tensor 代表實際儲存在 target 欄位中的可呼叫物件,而不僅僅是其字串名稱。

這種文字格式的最後一行是

return [add]

這是一個 op_name = output 的節點,表示我們正在返回該元素。

call_function#

一個 call_function 節點表示對運算子的呼叫。

定義

  • 函式式: 我們說一個可呼叫物件是“函式式”的,如果它滿足以下所有要求:

    • 非變異:運算子不會改變其輸入的 D值(對於張量,這包括元資料和資料)。

    • 無副作用:運算子不會改變從外部可見的狀態,例如更改模組引數的值。

  • 運算子: 是具有預定義模式的函式式可呼叫物件。此類運算子的示例包括函式式 ATen 運算子。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

與普通 FX call_function 的區別

  1. 在 FX 圖中,call_function 可以引用任何可呼叫物件,而在 Export IR 中,我們將其限制為僅選擇一部分 ATen 運算子、自定義運算子和控制流運算子。

  2. 在 Export IR 中,常量引數將嵌入到圖中。

  3. 在 FX 圖中,get_attr 節點可以表示讀取圖模組中儲存的任何屬性。然而,在 Export IR 中,這被限制為僅讀取子模組,因為所有引數/緩衝區都將作為輸入傳遞給圖模組。

元資料#

Node.meta 是附加到每個 FX 節點的字典。但是,FX 規範並未指定哪些元資料可能存在或將會存在。Export IR 提供了更強的約定,特別是所有 call_function 節點都將保證具有且僅具有以下元資料欄位:

  • node.meta["stack_trace"] 是一個字串,包含引用原始 Python 原始碼的 Python 堆疊跟蹤。堆疊跟蹤示例看起來像

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了執行操作的輸出。它可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None 型別。

  • node.meta["nn_module_stack"] 描述了節點來自的 torch.nn.Module 的“堆疊跟蹤”,如果它來自 torch.nn.Module 呼叫。例如,如果一個包含 addmm 運算子的節點是從 torch.nn.Linear 模組內部的 torch.nn.Sequential 模組呼叫的,則 nn_module_stack 將如下所示:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含在分解之前呼叫該節點的 torch 函式或葉子 torch.nn.Module 類。例如,一個包含來自 torch.nn.Linear 模組呼叫的 addmm 運算子的節點將在其 source_fn 中包含 torch.nn.Linear,而一個包含來自 torch.nn.functional.Linear 模組呼叫的 addmm 運算子的節點將在其 source_fn 中包含 torch.nn.functional.Linear

placeholder#

Placeholder 代表圖的輸入。其語義與 FX 中的完全相同。Placeholder 節點必須是圖中節點列表的前 N 個節點。N 可以為零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

target 欄位是輸入名稱的字串。

args(如果非空)應大小為 1,表示此輸入的預設值。

元資料

Placeholder 節點也具有 meta[‘val’],就像 call_function 節點一樣。在這種情況下,val 欄位表示圖在編譯時預期接收的該輸入的形狀/dtype。

output#

輸出呼叫代表函式中的 return 語句;因此,它終止了當前圖。只有一個輸出節點,並且它將始終是圖的最後一個節點。

在 FX 中的表示

output[](args = (%something, …))

這與 torch.fx 中的語義完全相同。args 表示要返回的節點。

元資料

輸出節點的元資料與 call_function 節點相同。

get_attr#

get_attr 節點表示從封裝的 torch.fx.GraphModule 讀取子模組。與 torch.fx.symbolic_trace() 的普通 FX 圖不同,在普通 FX 圖中 get_attr 節點用於從頂層 torch.fx.GraphModule 讀取引數和緩衝區等屬性,在 Export IR 中,引數和緩衝區作為輸入傳遞給圖模組,並存儲在頂層 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考慮以下模型

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 讀取包含 sin 運算子的子模組 true_graph_0

參考文獻#

SymInt#

SymInt 是一個物件,它可以是字面整數,也可以是表示整數的符號(在 Python 中用 sympy.Symbol 類表示)。當 SymInt 是符號時,它描述了一個在編譯時對圖未知的整數型別變數,也就是說,它的值僅在執行時才知道。

FakeTensor#

FakeTensor 是一個包含張量元資料quoi的物件。它可以被視為具有以下元資料。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 欄位是整數或 SymInts 的列表。如果存在 SymInts,則表示此張量具有動態形狀。如果存在整數,則假定張量將具有該確切的靜態形狀。TensorMeta 的秩永遠不是動態的。dtype 欄位表示該節點輸出的 dtype。Edge IR 中沒有隱式型別提升。FakeTensor 中沒有 strides。

換句話說

  • 如果 node.target 中的運算子返回一個 Tensor,則 node.meta['val'] 是一個描述該張量的 FakeTensor。

  • 如果 node.target 中的運算子返回一個 n 元組的 Tensor,則 node.meta['val'] 是一個描述每個張量的 n 元組的 FakeTensors。

  • 如果 node.target 中的運算子返回一個在編譯時已知的 int/float/scalar,則 node.meta['val'] 為 None。

  • 如果 node.target 中的運算子返回一個在編譯時未知的 int/float/scalar,則 node.meta['val'] 的型別為 SymInt。

例如

  • aten::add 返回一個 Tensor;因此,其規範將是描述該運算子返回的張量的 dtype 和大小的 FakeTensor。

  • aten::sym_size 返回一個整數;因此,其 val 將是 SymInt,因為其值僅在執行時可用。

  • max_pool2d_with_indexes 返回一個(Tensor,Tensor)元組;因此,規範也將是一個 FakeTensor 物件的 2 元組,第一個 TensorMeta 描述返回值的第一個元素,依此類推。

Python 程式碼

def add_one(x):
  return torch.ops.aten(x, 1)

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 型別#

我們將一種型別定義為“Pytree-able”,如果它是一個葉子型別或一個包含其他 Pytree-able 型別的容器型別。

注意

Pytree 的概念與 JAX 的文件 此處 記錄的概念相同。

以下型別定義為葉子型別

型別

定義

張量

torch.Tensor

Scalar

Python 中的任何數值型別,包括整數型別、浮點型別和零維張量。

int

Python int(在 C++ 中繫結為 int64_t)

浮點數

Python float(在 C++ 中繫結為 double)

布林值

Python bool

str

Python string

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

裝置

torch.device

以下型別定義為容器型別

型別

定義

Tuple

Python tuple

List

Python list

Dict

鍵為 Scalar 的 Python dict

NamedTuple

Python namedtuple

Dataclass

必須透過 register_dataclass 註冊

Custom class

透過 _register_pytree_node 定義的任何自定義類