評價此頁

torch.export#

創建於: 2025年6月12日 | 最後更新於: 2025年8月11日

概述#

torch.export.export() 接收一個 torch.nn.Module 並以 Ahead-of-Time (AOT) 的方式生成一個僅表示函式 Tensor 計算的跟蹤圖,該圖隨後可以執行不同的輸出或序列化。

import torch
from torch.export import export, ExportedProgram

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
             # File: /tmp/ipykernel_647/2550508656.py:6 in forward, code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x);  x = None
            
             # File: /tmp/ipykernel_647/2550508656.py:7 in forward, code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y);  y = None
            
             # File: /tmp/ipykernel_647/2550508656.py:8 in forward, code: return a + b
            add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)
            
Graph signature: 
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    
Range constraints: {}

torch.export 生成一個具有以下不變式的清晰的中間表示 (IR)。有關 IR 的更多規範可以在 這裡 找到。

  • 可靠性: 保證是對原始程式的可靠表示,並保持原始程式的呼叫約定。

  • 標準化: 圖中沒有 Python 語義。原始程式中的子模組被內聯,形成一個完全展平的計算圖。

  • 圖屬性: 該圖是純函式式的,意味著它不包含副作用操作,如突變或別名。它不會修改任何中間值、引數或緩衝區。

  • 元資料: 圖包含在跟蹤過程中捕獲的元資料,例如來自使用者程式碼的堆疊跟蹤。

底層,torch.export 利用了以下最新技術

  • TorchDynamo (torch._dynamo) 是一個內部 API,它使用名為 Frame Evaluation API 的 CPython 功能來安全地跟蹤 PyTorch 圖。這提供了大大改進的圖捕獲體驗,只需要進行更少的重寫即可完全跟蹤 PyTorch 程式碼。

  • AOT Autograd 提供了一個函式化的 PyTorch 圖,並確保圖被分解/降低到 ATen 運算子集。

  • Torch FX (torch.fx) 是圖的基礎表示,允許靈活的基於 Python 的轉換。

現有框架#

torch.compile() 也利用了與 torch.export 相同的 PT2 堆疊,但略有不同

  • JIT vs. AOT: torch.compile() 是一個 JIT 編譯器,而 torch.export 是一個 AOT 編譯器,後者不打算用於在部署之外生成編譯後的工件。

  • 部分 vs. 完整圖捕獲: 當 torch.compile() 遇到模型中無法跟蹤的部分時,它會“圖中斷”並回退到在急切 Python 執行時中執行程式。相比之下,torch.export 旨在獲得 PyTorch 模型的完整圖表示,因此在遇到無法跟蹤的內容時會報錯。由於 torch.export 生成了與任何 Python 功能或執行時分離的完整圖,因此該圖可以被儲存、載入並在不同的環境和語言中執行。

  • 可用性權衡: 由於 torch.compile() 能夠隨時回退到 Python 執行時以處理任何無法跟蹤的內容,因此它更加靈活。而 torch.export 則要求使用者提供更多資訊或重寫程式碼以使其可跟蹤。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 進行跟蹤,TorchDynamo 在 Python 位元組碼級別執行,這使其能夠跟蹤任意 Python 結構,而不受 Python 運算子過載支援的限制。此外,torch.export 會精細地跟蹤 Tensor 元資料,以便像 Tensor 形狀上的條件不會導致跟蹤失敗。總的來說,torch.export 應該能處理更多使用者程式,並生成更低級別的圖(在 torch.ops.aten 運算子級別)。請注意,使用者仍然可以使用 torch.fx.symbolic_trace() 作為 torch.export 之前的預處理步驟。

torch.jit.script() 相比,torch.export 不會捕獲 Python 控制流或資料結構,除非使用顯式的 控制流運算子,但由於其對 Python 位元組碼的全面覆蓋,它支援更多的 Python 語言特性。生成的圖更簡單,只有直線控制流,除了顯式的控制流運算子。

torch.jit.trace() 相比,torch.export 是可靠的:它可以跟蹤執行 Tensor 形狀上整數計算的程式碼,並記錄所有必要 的側條件,以確保特定跟蹤對其他輸入有效。

匯出 PyTorch 模型#

主入口點是透過 torch.export.export(),它接收一個 torch.nn.Module 和示例輸入,並將計算圖捕獲到一個 torch.export.ExportedProgram 中。一個示例

import torch
from torch.export import export, ExportedProgram

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)

# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
            conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  x = p_conv_weight = p_conv_bias = None
            
             # File: /tmp/ipykernel_647/2848084713.py:16 in forward, code: a.add_(constant)
            add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant);  conv2d = constant = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_);  add_ = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
            max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]);  relu = None
            return (max_pool2d,)
            
Graph signature: 
    # inputs
    p_conv_weight: PARAMETER target='conv.weight'
    p_conv_bias: PARAMETER target='conv.bias'
    x: USER_INPUT
    constant: USER_INPUT
    
    # outputs
    max_pool2d: USER_OUTPUT
    
Range constraints: {}

tensor([[[[1.3647, 2.2724, 1.9313,  ..., 2.2872, 2.0532, 1.7862],
          [2.3376, 1.7866, 2.0821,  ..., 1.8447, 2.5090, 1.4548],
          [2.1719, 2.0520, 1.6153,  ..., 2.1426, 1.4364, 1.7749],
          ...,
          [2.1165, 2.0080, 1.6747,  ..., 1.8013, 2.3714, 2.1069],
          [1.9574, 1.9642, 1.6017,  ..., 2.2444, 1.7903, 2.1224],
          [1.6735, 1.4687, 1.4179,  ..., 2.1677, 1.4616, 1.9001]],

         [[1.7637, 2.2845, 1.6327,  ..., 1.8567, 1.6745, 1.3376],
          [1.8699, 1.1547, 2.5045,  ..., 1.1387, 2.0210, 1.3825],
          [1.7050, 1.3393, 1.9955,  ..., 1.4296, 1.8792, 1.7073],
          ...,
          [1.5292, 1.9183, 1.8844,  ..., 2.0815, 1.8089, 2.5264],
          [2.0338, 2.3296, 2.1650,  ..., 2.0727, 2.0166, 1.7309],
          [1.8381, 1.6556, 1.9402,  ..., 1.6529, 1.8134, 1.6075]],

         [[1.6673, 2.1074, 1.5976,  ..., 1.7421, 1.7998, 1.6087],
          [2.0269, 1.3379, 1.6679,  ..., 1.6671, 1.6000, 1.9894],
          [2.0480, 1.5340, 1.4017,  ..., 1.7944, 0.9860, 2.3785],
          ...,
          [1.5266, 1.3949, 1.2980,  ..., 1.3569, 1.9492, 1.8062],
          [1.8315, 1.2293, 1.1087,  ..., 1.5446, 1.6492, 1.6620],
          [1.4799, 1.3720, 1.5748,  ..., 1.8854, 1.2940, 1.7422]],

         ...,

         [[1.5734, 2.2576, 1.6242,  ..., 2.2690, 1.5416, 1.6914],
          [1.6577, 1.1605, 1.3565,  ..., 0.8677, 1.1838, 1.7662],
          [2.0769, 1.6546, 1.6169,  ..., 2.1301, 1.3892, 1.7564],
          ...,
          [1.4403, 2.0147, 2.0693,  ..., 1.8359, 1.3394, 1.6654],
          [1.6513, 2.2535, 2.0069,  ..., 1.0825, 1.6039, 1.3635],
          [1.3937, 1.0050, 2.8575,  ..., 1.9308, 1.4201, 1.4665]],

         [[1.8365, 1.3995, 1.8011,  ..., 1.7474, 2.0621, 1.7876],
          [1.7984, 2.0523, 1.4683,  ..., 1.7031, 2.4383, 1.2690],
          [1.6586, 1.4678, 1.7569,  ..., 1.4851, 1.5530, 2.1754],
          ...,
          [1.3299, 1.2796, 1.9345,  ..., 1.8214, 2.7972, 2.0472],
          [1.3732, 1.3926, 1.5657,  ..., 1.4157, 2.2771, 1.9893],
          [1.4669, 1.7627, 1.6258,  ..., 1.5350, 2.0609, 2.1192]],

         [[1.8094, 2.4813, 1.9160,  ..., 2.0701, 1.2966, 1.0297],
          [2.2580, 1.9149, 1.8033,  ..., 1.4026, 1.9687, 1.5859],
          [1.3034, 1.1883, 1.5391,  ..., 1.5249, 1.5463, 2.0138],
          ...,
          [2.2308, 1.7266, 1.8078,  ..., 1.7636, 2.2842, 1.5676],
          [1.7690, 1.3066, 1.8903,  ..., 1.8704, 1.5074, 1.5758],
          [2.3337, 1.5704, 1.8364,  ..., 1.4995, 1.8263, 1.8836]]]],
       grad_fn=<MaxPool2DWithIndicesBackward0>)

檢查 ExportedProgram,我們可以注意到以下幾點

  • torch.fx.Graph 包含原始程式的計算圖,以及用於方便除錯的原始程式碼記錄。

  • 圖僅包含 torch.ops.aten 運算子,這些運算子可以在 此處 找到,以及自定義運算子。

  • 引數(conv 的 weight 和 bias)被提升為圖的輸入,因此圖中沒有 `get_attr` 節點,而這些節點在 torch.fx.symbolic_trace() 的結果中存在。

  • torch.export.ExportGraphSignature 模型化了輸入和輸出簽名,同時指定了哪些輸入是引數。

  • 圖中每個節點生成的 Tensor 的形狀和資料型別被記錄下來。例如,`conv2d` 節點將生成一個 dtype 為 `torch.float32`,形狀為 (1, 16, 256, 256) 的 Tensor。

表達動態性#

預設情況下,torch.export 會在假設所有輸入形狀都是**靜態**的情況下跟蹤程式,並將匯出的程式專門化到這些維度。這帶來的一個後果是,在執行時,該程式將無法處理形狀不同的輸入,即使它們在急切模式下是有效的。

一個例子

import torch
import traceback as tb

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

ep = torch.export.export(M(), example_args)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
    ep.module()(*example_args2)  # fails
except Exception:
    tb.print_exc()
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_647/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {}
Traceback (most recent call last):
  File "/tmp/ipykernel_647/1522925308.py", line 28, in <module>
    ep.module()(*example_args2)  # fails
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.25", line 11, in forward
    _guards_fn = self._guards_fn(x1, x2);  _guards_fn = None
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
    return func(*args, **kwargs)
  File "<string>", line 3, in _
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/__init__.py", line 2185, in _assert
    assert condition, message
AssertionError: Guard failed: x1.size()[0] == 32

然而,某些維度,例如批次維度,可以是動態的,並且每次執行都可能不同。此類維度必須透過使用 torch.export.Dim() API 來建立,並透過 dynamic_shapes 引數傳遞給 torch.export.export()

import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

ep = torch.export.export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2)  # success
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_647/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.8873, 1.2537,  ..., 1.8441, 1.5857, 1.2611],
         [1.0000, 1.0000, 1.0598,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 2.3600,  ..., 1.8214, 1.0710, 1.0000],
         ...,
         [1.0000, 1.0516, 1.1907,  ..., 1.4364, 1.0000, 1.0000],
         [1.0000, 1.1808, 1.0000,  ..., 1.0000, 1.0000, 1.3670],
         [1.5627, 2.3273, 1.0000,  ..., 1.0000, 1.0000, 1.6566]],
        grad_fn=<AddBackward0>),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.1734, 0.0999, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4386, 0.7500],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.1618, 0.0000, 0.0000,  ..., 0.0000, 0.1554, 0.0000],
         [0.0000, 0.4619, 0.0904,  ..., 0.0000, 0.0792, 0.5539],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
        grad_fn=<ReluBackward0>))

一些額外的注意事項

  • 透過 torch.export.Dim() API 和 dynamic_shapes 引數,我們指定了每個輸入的第一個維度是動態的。檢視輸入 `x1` 和 `x2`,它們具有符號形狀 (s0, 64)(s0, 128),而不是我們作為示例輸入傳入的形狀為 (32, 64)(32, 128) 的 Tensor。`s0` 是一個符號,代表該維度可以是一系列值。

  • exported_program.range_constraints 描述了圖中出現的每個符號的範圍。在這種情況下,我們看到 `s0` 的範圍是 [0, int_oo]。出於難以在此處解釋的技術原因,它們假定不為 0 或 1。這不是 bug,也不一定意味著匯出的程式不適用於維度 0 或 1。請參閱 0/1 特化問題 以深入討論此主題。

在此示例中,我們使用了 Dim("batch") 來建立一個動態維度。這是指定動態性的最明確的方法。我們也可以使用 Dim.DYNAMICDim.AUTO 來指定動態性。我們將在下一節中介紹這兩種方法。

命名維度#

對於用 Dim("name") 指定的每個維度,我們將分配一個符號形狀。使用相同名稱指定 Dim 將導致生成相同的符號。這允許使用者指定為每個輸入維度分配了哪些符號。

batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}

對於每個 Dim,我們可以指定最小值和最大值。我們也允許在單變數線性表示式中指定 Dim 之間的關係:A * dim + B。這允許使用者為動態維度指定更復雜的約束,如整數可除性。這些功能使使用者能夠對生成的 ExportedProgram 的動態行為施加明確的限制。

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

但是,如果在跟蹤過程中發出與給定關係或靜態/動態規範衝突的保護,將引發 ConstraintViolationErrors。例如,在上述規範中,斷言如下

  • x.shape[0] 的範圍是 [4, 256],並且與 y.shape[0] 的關係是 y.shape[0] == 2 * x.shape[0]

  • x.shape[1] 是靜態的。

  • y.shape[1] 的範圍是 [0, 512],並且與任何其他維度無關。

如果在跟蹤過程中發現任何這些斷言不正確(例如,`x.shape[0]` 是靜態的,或者 `y.shape[1]` 的範圍更小,或者 `y.shape[0] != 2 * x.shape[0]`),則將引發 ConstraintViolationError,使用者需要更改其 dynamic_shapes 規範。

維度提示#

而不是使用 Dim("name") 顯式指定動態性,我們可以讓 torch.export 使用 Dim.DYNAMIC 來推斷動態值的範圍和關係。當您不確定動態值具體動態到什麼程度時,這也是一種更方便的指定動態性的方法。

dynamic_shapes = {
    "x": (Dim.DYNAMIC, None),
    "y": (Dim.DYNAMIC, Dim.DYNAMIC),
}

我們還可以為 Dim.DYNAMIC 指定 min/max 值,這些值將作為匯出的提示。但如果在跟蹤過程中匯出發現範圍不同,它將自動更新範圍而不會引發錯誤。我們也無法指定動態值之間的關係。相反,這將由匯出推斷,並透過檢查圖中的斷言暴露給使用者。在這種指定動態性的方法中,只有當推斷出的值為**靜態**時,才會引發 ConstraintViolationErrors

指定動態性的一個更方便的方法是使用 Dim.AUTO,它的行為類似於 Dim.DYNAMIC,但如果推斷出維度是靜態的,則不會引發錯誤。當您對動態值的範圍一無所知,並希望以“盡力而為”的動態方式匯出程式時,這很有用。

ShapesCollection#

在透過 dynamic_shapes 指定哪些輸入是動態的時,我們必須指定每個輸入的動態性。例如,給定以下輸入

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

我們需要指定 `tensor_x`、`tensor_y` 和 `tensor_z` 的動態性以及動態形狀

# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

然而,這特別複雜,因為我們需要以與輸入引數相同的巢狀輸入結構來指定 `dynamic_shapes` 規範。相反,一種更簡單的指定動態形狀的方法是使用輔助工具 torch.export.ShapesCollection,其中我們不必指定每個輸入的動態性,而是可以直接分配哪些輸入維度是動態的。

import torch

class M(torch.nn.Module):
    def forward(self, inp):
        x = inp["x"] * 1
        y = inp["others"][0] * 2
        z = inp["others"][1] * 3
        return x, y, z

tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}

print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
             # File: /tmp/ipykernel_647/1070110726.py:5 in forward, code: x = inp["x"] * 1
            mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1);  inp_x = None
            
             # File: /tmp/ipykernel_647/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
            mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2);  inp_others_0 = None
            
             # File: /tmp/ipykernel_647/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
            mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3);  inp_others_1 = None
            return (mul, mul_1, mul_2)
            
Graph signature: 
    # inputs
    inp_x: USER_INPUT
    inp_others_0: USER_INPUT
    inp_others_1: USER_INPUT
    
    # outputs
    mul: USER_OUTPUT
    mul_1: USER_OUTPUT
    mul_2: USER_OUTPUT
    
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}

AdditionalInputs#

如果您不知道輸入的動態性有多大,但有一組豐富的測試或效能分析資料,可以提供對模型代表性輸入的合理瞭解,您可以使用 torch.export.AdditionalInputs 代替 dynamic_shapes。您可以指定用於跟蹤程式的所有可能的輸入,並且 AdditionalInputs 將根據輸入形狀的變化來推斷哪些輸入是動態的。

示例

import dataclasses
import torch
import torch.utils._pytree as pytree

@dataclasses.dataclass
class D:
    b: bool
    i: int
    f: float
    t: torch.Tensor

pytree.register_dataclass(D)

class M(torch.nn.Module):
    def forward(self, d: D):
        return d.i + d.f + d.t

input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)

print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, d_b, d_i: "Sym(s37)", d_f, d_t: "f32[s99]"):
             # File: /tmp/ipykernel_647/829931439.py:16 in forward, code: return d.i + d.f + d.t
            sym_float: "Sym(ToFloat(s37))" = torch.sym_float(d_i);  d_i = None
            add: "Sym(ToFloat(s37) + 3.0)" = sym_float + 3.0;  sym_float = None
            add_1: "f32[s99]" = torch.ops.aten.add.Tensor(d_t, add);  d_t = add = None
            return (add_1,)
            
Graph signature: 
    # inputs
    d_b: USER_INPUT
    d_i: USER_INPUT
    d_f: USER_INPUT
    d_t: USER_INPUT
    
    # outputs
    add_1: USER_OUTPUT
    
Range constraints: {s37: VR[0, int_oo], s99: VR[2, int_oo]}

序列化#

要儲存 ExportedProgram,使用者可以使用 torch.export.save()torch.export.load() API。生成的檔案是一個具有特定結構的 zip 檔案。結構的詳細資訊在 PT2 Archive Spec 中定義。

一個例子

import torch

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), (torch.randn(5),))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

匯出 IR,分解#

torch.export 生成的圖返回一個僅包含 ATen 運算子 的圖,ATen 運算子是 PyTorch 中的基本計算單元。由於有超過 3000 個 ATen 運算子,匯出提供了一種根據某些特徵縮小圖中使用 的運算子集的方法,從而建立不同的 IR。

預設情況下,匯出生成最通用的 IR,其中包含所有 ATen 運算子,包括功能性和非功能性運算子。功能性運算子是沒有突變或別名的輸入運算子。您可以在 此處 找到所有 ATen 運算子的列表,並且可以透過檢查 `op._schema.is_mutable` 來檢查運算子是否為功能性的。

此通用 IR 可用於在急切 PyTorch Autograd 中進行訓練。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = add_ = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        return (batch_norm,)
        

然而,如果您想將 IR 用於推理,或減少使用的運算子數量,您可以透過 `ExportedProgram.run_decompositions()` API 將圖進行降低。此方法將 ATen 運算子分解為分解表中指定的運算子,並將圖進行函式化。

透過指定一個空集,我們只執行函式化,而不進行任何額外的分解。這會生成一個包含約 2000 個運算子(而不是上述 3000 個運算子)的 IR,這對於推理場景非常理想。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

我們可以看到,以前的就地運算子 `torch.ops.aten.add_.default` 現在已被替換為 `torch.ops.aten.add.default`,這是一個功能性運算子。

我們還可以將此匯出的程式進一步降低到僅包含 核心 ATen 運算子集 的運算子集,這是一個約 180 個運算子的集合。此 IR 最適合不想重新實現所有 ATen 運算子的後端。

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

我們現在看到 `torch.ops.aten.conv2d.default` 已被分解為 `torch.ops.aten.convolution.default`。這是因為 `convolution` 是一個更“核心”的運算子,因為 `conv1d` 和 `conv2d` 等運算可以使用相同的 op 來實現。

我們也可以指定自己的分解行為

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2);  convolution = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

請注意,`torch.ops.aten.conv2d.default` 沒有被分解為 `torch.ops.aten.convolution.default`,而是被分解為 `torch.ops.aten.convolution.default` 和 `torch.ops.aten.mul.Tensor`,這符合我們的自定義分解規則。

torch.export 的侷限性#

由於 torch.export 是一個從 PyTorch 程式捕獲計算圖的一次性過程,它最終可能會遇到程式中無法跟蹤的部分,因為幾乎不可能支援跟蹤所有 PyTorch 和 Python 功能。在 torch.compile 的情況下,不支援的操作會導致“圖中斷”,並且不支援的操作將使用預設的 Python 評估來執行。相比之下,torch.export 會要求使用者提供額外資訊或重寫部分程式碼以使其可跟蹤。

Draft-export 是一個很好的資源,用於列出跟蹤程式時會遇到的圖中斷,以及解決這些錯誤的額外除錯資訊。

ExportDB 也是一個很好的資源,用於瞭解支援和不支援的程式型別,以及重寫程式以使其可跟蹤的方法。

TorchDynamo 不支援#

當使用 `strict=True` 的 torch.export 時,它將使用 TorchDynamo 在 Python 位元組碼級別評估程式以將程式跟蹤到圖。與以前的跟蹤框架相比,需要更少的重寫即可使程式可跟蹤,但仍會存在一些不支援的 Python 功能。為了繞過處理圖中斷的方法是使用 非嚴格匯出,透過將 `strict` 標誌更改為 `strict=False`。

資料/形狀依賴的控制流#

當形狀未專門化時,資料依賴的控制流(`if x.shape[0] > 2`)也可能遇到圖中斷,因為跟蹤編譯器無法處理,除非生成組合爆炸式路徑的程式碼。在這種情況下,使用者需要使用特殊的控制流運算子重寫其程式碼。目前,我們支援 torch.cond 來表示 if-else 類似的控制流(更多內容即將推出!)。

您還可以參考此 教程 以瞭解更多解決資料依賴錯誤的方法。

運算子缺少 Fake/Meta 核心#

在跟蹤時,需要為所有運算子提供 FakeTensor 核心(也稱為 meta 核心)。這用於推理運算子的輸入/輸出形狀。

有關更多詳細資訊,請參閱此 教程

如果您的模型不幸使用了沒有 FakeTensor 核心實現的 ATen 運算子,請提交一個 issue。