評價此頁

torch.export IR 規範#

創建於:2023年10月05日 | 最後更新於:2025年07月16日

Export IR(匯出中間表示)是用於編譯器的中間表示(IR),它與MLIR和TorchScript有相似之處。它專門用於表達PyTorch程式的語義。Export IR主要以精簡的操作列表形式表示計算,並對動態性(如控制流)提供有限的支援。

要建立Export IR圖,可以使用前端,透過一個健全的、專用於追蹤的機制來捕獲PyTorch程式。生成的Export IR隨後可以由後端進行最佳化和執行。今天,可以透過torch.export.export()來完成此操作。

本文件將涵蓋的關鍵概念包括:

  • ExportedProgram:包含Export IR程式的資料結構

  • Graph:由節點列表組成。

  • Nodes:表示在此節點上儲存的操作、控制流和元資料。

  • 值由節點產生和消費。

  • 型別與值和節點相關聯。

  • 還定義了值的尺寸和記憶體佈局。

假設#

本文件假設讀者對PyTorch有充分的瞭解,特別是對torch.fx及其相關工具鏈有深入的瞭解。因此,它將不再描述torch.fx文件和論文中已有的內容。

什麼是Export IR#

Export IR是PyTorch程式的基於圖的中間表示IR。Export IR建立在torch.fx.Graph之上實現的。換句話說,**所有Export IR圖也是有效的FX圖**,如果使用標準的FX語義來解釋,Export IR可以被健全地解釋。一個隱含的含義是,透過標準的FX程式碼生成,匯出的圖可以轉換為有效的Python程式。

本文件將主要關注Export IR在嚴格性方面與FX不同的地方,而跳過與FX相似的部分。

ExportedProgram#

頂級Export IR結構是torch.export.ExportedProgram類。它將PyTorch模型的計算圖(通常是torch.nn.Module)與其消耗的引數或權重捆綁在一起。

torch.export.ExportedProgram類的一些值得注意的屬性包括:

  • graph_moduletorch.fx.GraphModule):包含PyTorch模型扁平化計算圖的資料結構。可以透過ExportedProgram.graph直接訪問該圖。

  • graph_signaturetorch.export.ExportGraphSignature):圖的簽名,它指定了圖中使用的和修改的引數和緩衝區名稱。引數和緩衝區沒有儲存為圖的屬性,而是被提升為圖的輸入。graph_signature用於跟蹤這些引數和緩衝區的額外資訊。

  • state_dictDict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含引數和緩衝區的資料結構。

  • range_constraintsDict[sympy.Symbol, RangeConstraint]):對於具有資料依賴行為匯出的程式,每個節點的元資料將包含符號形狀(看起來像s0, i0)。此屬性將符號形狀對映到它們的下限/上限範圍。

#

Export IR圖是以DAG(有向無環圖)形式表示的PyTorch程式。圖中的每個節點代表一個特定的計算或操作,圖的邊由節點之間的引用組成。

我們可以將圖看作具有以下模式:

class Graph:
  nodes: List[Node]

實際上,Export IR的圖是透過torch.fx.Graph Python類實現的。

Export IR圖包含以下節點(節點將在下一節中更詳細地描述):

  • 0個或多個placeholder型別的節點

  • 0個或多個call_function型別的節點

  • 正好1個output型別的節點

推論: 最小的有效圖將是單個節點。即節點列表從不為空。

定義: 圖的placeholder節點集合代表GraphModule的**輸入**。圖的output節點代表GraphModule的**輸出**。

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是圖的文字表示,每一行代表一個節點。

節點#

節點代表一個特定的計算或操作,在Python中使用torch.fx.Node類表示。節點之間的邊透過Node類的args屬性直接引用其他節點來表示。使用相同的FX機制,我們可以表示計算圖通常需要的以下操作,例如操作呼叫、佔位符(又名輸入)、條件和迴圈。

Node具有以下模式:

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文字格式

如上例所示,請注意,每一行都遵循此格式:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式以緊湊的方式捕獲了Node類中的所有內容,除了meta欄位。

具體來說:

  • 是節點在node.name中出現的名稱。

  • <op_name>node.op欄位,必須是以下之一:<call_function><placeholder><get_attr><output>

  • 是節點的目標,即node.target。此欄位的含義取決於op_name

  • args1, … args 4…node.args元組中列出的內容。如果列表中的值是torch.fx.Node,則會透過前導的%進行特別指示。

例如,對add操作的呼叫將顯示為:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中%x%y是另外兩個名為x和y的節點。值得注意的是,字串torch.op.aten.add.Tensor代表儲存在target欄位中的可呼叫物件,而不僅僅是其字串名稱。

此文字格式的最後一行是:

return [add]

這是一個op_name為output的節點,表示我們正在返回這個單一的元素。

call_function#

一個call_function節點表示對一個操作的呼叫。

定義

  • 函式式: 我們說一個可呼叫物件是“函式式”的,如果它滿足以下所有要求:

    • 非修改性:該操作不會修改其輸入的(對於張量,這包括元資料和資料)。

    • 無副作用:該操作不會修改從外部可見的狀態,例如更改模組引數的值。

  • 操作: 是一個具有預定義模式的函式式可呼叫物件。這類操作的例子包括函式式ATen操作。

在FX中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

與普通FX的call_function的區別

  1. 在FX圖中,call_function可以引用任何可呼叫物件,而在Export IR中,我們將其限制為僅限於選定的ATen操作、自定義操作和控制流操作。

  2. 在Export IR中,常量引數將被嵌入到圖中。

  3. 在FX圖中,get_attr節點可以表示讀取圖模組中儲存的任何屬性。然而,在Export IR中,這僅限於讀取子模組,因為所有引數/緩衝區都將作為輸入傳遞給圖模組。

元資料#

Node.meta是一個附加到每個FX節點上的字典。然而,FX規範並未指定可以或將有哪些元資料。Export IR提供了更強的契約,特別是所有call_function節點將保證具有並且僅具有以下元資料欄位:

  • node.meta["stack_trace"]是一個包含Python堆疊跟蹤的字串,它引用了原始的Python原始碼。一個堆疊跟蹤的例子如下:

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"]描述了執行操作的輸出。它可以是<symint><FakeTensor>List[Union[FakeTensor, SymInt]]None型別的。

  • node.meta["nn_module_stack"]描述了該節點來自的torch.nn.Module的“堆疊跟蹤”,如果它來自torch.nn.Module呼叫。例如,如果一個包含addmm操作的節點是從torch.nn.Linear模組中呼叫,而該模組又在torch.nn.Sequential模組內,那麼nn_module_stack將看起來像這樣:

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"]包含該節點在分解之前被呼叫的torch函式或葉torch.nn.Module類。例如,一個包含addmm操作且來自torch.nn.Linear模組呼叫的節點,將在其source_fn中包含torch.nn.Linear,而一個包含addmm操作且來自torch.nn.functional.Linear模組呼叫的節點,將在其source_fn中包含torch.nn.functional.Linear

placeholder#

Placeholder代表圖的輸入。其語義與FX中的完全相同。Placeholder節點必須是圖中節點列表的最初N個節點。N可以為零。

在FX中的表示

%name = placeholder[target = name](args = ())

target欄位是一個字串,表示輸入的名稱。

args,如果不為空,則應為大小為1,表示此輸入的預設值。

元資料

Placeholder節點也具有meta[‘val’],類似於call_function節點。在這種情況下,val欄位表示圖期望接收的此輸入引數的形狀/資料型別。

output#

輸出呼叫代表函式中的return語句;因此,它會終止當前圖。只有一個輸出節點,並且它將始終是圖的最後一個節點。

在FX中的表示

output[](args = (%something, …))

這與torch.fx中的語義完全相同。args代表要返回的節點。

元資料

輸出節點具有與call_function節點相同的元資料。

get_attr#

get_attr節點表示從包含的torch.fx.GraphModule中讀取子模組。與torch.fx.symbolic_trace()生成的普通FX圖不同,在普通FX圖中get_attr節點用於從頂層torch.fx.GraphModule讀取引數和緩衝區等屬性,而在Export IR中,引數和緩衝區作為輸入傳遞給圖模組,並存儲在頂層的torch.export.ExportedProgram中。

在FX中的表示

%name = get_attr[target = name](args = ())

示例

考慮以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

Graph

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0],讀取包含sin操作的子模組true_graph_0

參考文獻#

SymInt#

SymInt 是一個物件,它可以是字面整數,也可以是代表整數的符號(在Python中由sympy.Symbol類表示)。當SymInt是符號時,它描述了一個在編譯時圖中未知型別的整數變數,也就是說,它的值只在執行時才知道。

FakeTensor#

FakeTensor 是一個包含張量元資料的物件。它可以被看作具有以下元資料:

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor的size欄位是整數或SymInts的列表。如果存在SymInts,則表示該張量具有動態形狀。如果存在整數,則假定該張量具有精確的靜態形狀。TensorMeta的秩永遠不會是動態的。dtype欄位表示該節點輸出的資料型別。Edge IR中沒有隱式型別提升。FakeTensor中沒有步幅(strides)。

換句話說:

  • 如果node.target中的操作返回一個Tensor,那麼node.meta['val']是一個描述該張量的FakeTensor。

  • 如果node.target中的操作返回一個Tensor的n元組,那麼node.meta['val']是一個描述每個張量的FakeTensor的n元組。

  • 如果node.target中的操作返回一個在編譯時已知的int/float/scalar,那麼node.meta['val']為None。

  • 如果node.target中的操作返回一個在編譯時未知的int/float/scalar,那麼node.meta['val']是SymInt型別。

例如

  • aten::add返回一個Tensor;因此,其規範將是一個FakeTensor,包含該操作返回的張量的資料型別和尺寸。

  • aten::sym_size返回一個整數;因此,其val將是一個SymInt,因為它的值只在執行時可用。

  • max_pool2d_with_indexes返回一個(Tensor,Tensor)元組;因此,其規範也將是一個FakeTensor物件的2元組,第一個TensorMeta描述返回值的第一項,依此類推。

Python程式碼

def add_one(x):
  return torch.ops.aten(x, 1)

Graph

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 型別#

我們將一個型別定義為“Pytree-able”,如果它是葉子型別或包含其他Pytree-able型別的容器型別。

注意

Pytree的概念與JAX文件此處中記錄的概念相同。

以下型別被定義為**葉子型別**:

型別

定義

張量

torch.Tensor

Scalar

Python中的任何數值型別,包括整數型別、浮點數型別和零維張量。

int

Python int(在C++中繫結為int64_t)

浮點數

Python float(在C++中繫結為double)

布林值

Python bool

str

Python string

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

裝置

torch.device

以下型別被定義為**容器型別**:

型別

定義

Tuple

Python tuple

List

Python list

Dict

具有標量鍵的Python dict

NamedTuple

Python namedtuple

Dataclass

必須透過register_dataclass註冊

自定義類

使用_register_pytree_node定義的任何自定義類