torch.export#
創建於: 2025年6月12日 | 最後更新於: 2025年8月11日
概述#
torch.export.export() 接收一個 torch.nn.Module 並以 Ahead-of-Time (AOT) 的方式生成一個僅表示函式 Tensor 計算的跟蹤圖,該圖隨後可以執行不同的輸出或序列化。
import torch
from torch.export import export, ExportedProgram
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# File: /tmp/ipykernel_647/2550508656.py:6 in forward, code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x); x = None
# File: /tmp/ipykernel_647/2550508656.py:7 in forward, code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y); y = None
# File: /tmp/ipykernel_647/2550508656.py:8 in forward, code: return a + b
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
torch.export 生成一個具有以下不變式的清晰的中間表示 (IR)。有關 IR 的更多規範可以在 這裡 找到。
可靠性: 保證是對原始程式的可靠表示,並保持原始程式的呼叫約定。
標準化: 圖中沒有 Python 語義。原始程式中的子模組被內聯,形成一個完全展平的計算圖。
圖屬性: 該圖是純函式式的,意味著它不包含副作用操作,如突變或別名。它不會修改任何中間值、引數或緩衝區。
元資料: 圖包含在跟蹤過程中捕獲的元資料,例如來自使用者程式碼的堆疊跟蹤。
底層,torch.export 利用了以下最新技術
TorchDynamo (torch._dynamo) 是一個內部 API,它使用名為 Frame Evaluation API 的 CPython 功能來安全地跟蹤 PyTorch 圖。這提供了大大改進的圖捕獲體驗,只需要進行更少的重寫即可完全跟蹤 PyTorch 程式碼。
AOT Autograd 提供了一個函式化的 PyTorch 圖,並確保圖被分解/降低到 ATen 運算子集。
Torch FX (torch.fx) 是圖的基礎表示,允許靈活的基於 Python 的轉換。
現有框架#
torch.compile() 也利用了與 torch.export 相同的 PT2 堆疊,但略有不同
JIT vs. AOT:
torch.compile()是一個 JIT 編譯器,而torch.export是一個 AOT 編譯器,後者不打算用於在部署之外生成編譯後的工件。部分 vs. 完整圖捕獲: 當
torch.compile()遇到模型中無法跟蹤的部分時,它會“圖中斷”並回退到在急切 Python 執行時中執行程式。相比之下,torch.export旨在獲得 PyTorch 模型的完整圖表示,因此在遇到無法跟蹤的內容時會報錯。由於torch.export生成了與任何 Python 功能或執行時分離的完整圖,因此該圖可以被儲存、載入並在不同的環境和語言中執行。可用性權衡: 由於
torch.compile()能夠隨時回退到 Python 執行時以處理任何無法跟蹤的內容,因此它更加靈活。而torch.export則要求使用者提供更多資訊或重寫程式碼以使其可跟蹤。
與 torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 進行跟蹤,TorchDynamo 在 Python 位元組碼級別執行,這使其能夠跟蹤任意 Python 結構,而不受 Python 運算子過載支援的限制。此外,torch.export 會精細地跟蹤 Tensor 元資料,以便像 Tensor 形狀上的條件不會導致跟蹤失敗。總的來說,torch.export 應該能處理更多使用者程式,並生成更低級別的圖(在 torch.ops.aten 運算子級別)。請注意,使用者仍然可以使用 torch.fx.symbolic_trace() 作為 torch.export 之前的預處理步驟。
與 torch.jit.script() 相比,torch.export 不會捕獲 Python 控制流或資料結構,除非使用顯式的 控制流運算子,但由於其對 Python 位元組碼的全面覆蓋,它支援更多的 Python 語言特性。生成的圖更簡單,只有直線控制流,除了顯式的控制流運算子。
與 torch.jit.trace() 相比,torch.export 是可靠的:它可以跟蹤執行 Tensor 形狀上整數計算的程式碼,並記錄所有必要 的側條件,以確保特定跟蹤對其他輸入有效。
匯出 PyTorch 模型#
主入口點是透過 torch.export.export(),它接收一個 torch.nn.Module 和示例輸入,並將計算圖捕獲到一個 torch.export.ExportedProgram 中。一個示例
import torch
from torch.export import export, ExportedProgram
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); x = p_conv_weight = p_conv_bias = None
# File: /tmp/ipykernel_647/2848084713.py:16 in forward, code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant); conv2d = constant = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_); add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]); relu = None
return (max_pool2d,)
Graph signature:
# inputs
p_conv_weight: PARAMETER target='conv.weight'
p_conv_bias: PARAMETER target='conv.bias'
x: USER_INPUT
constant: USER_INPUT
# outputs
max_pool2d: USER_OUTPUT
Range constraints: {}
tensor([[[[1.3647, 2.2724, 1.9313, ..., 2.2872, 2.0532, 1.7862],
[2.3376, 1.7866, 2.0821, ..., 1.8447, 2.5090, 1.4548],
[2.1719, 2.0520, 1.6153, ..., 2.1426, 1.4364, 1.7749],
...,
[2.1165, 2.0080, 1.6747, ..., 1.8013, 2.3714, 2.1069],
[1.9574, 1.9642, 1.6017, ..., 2.2444, 1.7903, 2.1224],
[1.6735, 1.4687, 1.4179, ..., 2.1677, 1.4616, 1.9001]],
[[1.7637, 2.2845, 1.6327, ..., 1.8567, 1.6745, 1.3376],
[1.8699, 1.1547, 2.5045, ..., 1.1387, 2.0210, 1.3825],
[1.7050, 1.3393, 1.9955, ..., 1.4296, 1.8792, 1.7073],
...,
[1.5292, 1.9183, 1.8844, ..., 2.0815, 1.8089, 2.5264],
[2.0338, 2.3296, 2.1650, ..., 2.0727, 2.0166, 1.7309],
[1.8381, 1.6556, 1.9402, ..., 1.6529, 1.8134, 1.6075]],
[[1.6673, 2.1074, 1.5976, ..., 1.7421, 1.7998, 1.6087],
[2.0269, 1.3379, 1.6679, ..., 1.6671, 1.6000, 1.9894],
[2.0480, 1.5340, 1.4017, ..., 1.7944, 0.9860, 2.3785],
...,
[1.5266, 1.3949, 1.2980, ..., 1.3569, 1.9492, 1.8062],
[1.8315, 1.2293, 1.1087, ..., 1.5446, 1.6492, 1.6620],
[1.4799, 1.3720, 1.5748, ..., 1.8854, 1.2940, 1.7422]],
...,
[[1.5734, 2.2576, 1.6242, ..., 2.2690, 1.5416, 1.6914],
[1.6577, 1.1605, 1.3565, ..., 0.8677, 1.1838, 1.7662],
[2.0769, 1.6546, 1.6169, ..., 2.1301, 1.3892, 1.7564],
...,
[1.4403, 2.0147, 2.0693, ..., 1.8359, 1.3394, 1.6654],
[1.6513, 2.2535, 2.0069, ..., 1.0825, 1.6039, 1.3635],
[1.3937, 1.0050, 2.8575, ..., 1.9308, 1.4201, 1.4665]],
[[1.8365, 1.3995, 1.8011, ..., 1.7474, 2.0621, 1.7876],
[1.7984, 2.0523, 1.4683, ..., 1.7031, 2.4383, 1.2690],
[1.6586, 1.4678, 1.7569, ..., 1.4851, 1.5530, 2.1754],
...,
[1.3299, 1.2796, 1.9345, ..., 1.8214, 2.7972, 2.0472],
[1.3732, 1.3926, 1.5657, ..., 1.4157, 2.2771, 1.9893],
[1.4669, 1.7627, 1.6258, ..., 1.5350, 2.0609, 2.1192]],
[[1.8094, 2.4813, 1.9160, ..., 2.0701, 1.2966, 1.0297],
[2.2580, 1.9149, 1.8033, ..., 1.4026, 1.9687, 1.5859],
[1.3034, 1.1883, 1.5391, ..., 1.5249, 1.5463, 2.0138],
...,
[2.2308, 1.7266, 1.8078, ..., 1.7636, 2.2842, 1.5676],
[1.7690, 1.3066, 1.8903, ..., 1.8704, 1.5074, 1.5758],
[2.3337, 1.5704, 1.8364, ..., 1.4995, 1.8263, 1.8836]]]],
grad_fn=<MaxPool2DWithIndicesBackward0>)
檢查 ExportedProgram,我們可以注意到以下幾點
該
torch.fx.Graph包含原始程式的計算圖,以及用於方便除錯的原始程式碼記錄。圖僅包含
torch.ops.aten運算子,這些運算子可以在 此處 找到,以及自定義運算子。引數(conv 的 weight 和 bias)被提升為圖的輸入,因此圖中沒有 `get_attr` 節點,而這些節點在
torch.fx.symbolic_trace()的結果中存在。torch.export.ExportGraphSignature模型化了輸入和輸出簽名,同時指定了哪些輸入是引數。圖中每個節點生成的 Tensor 的形狀和資料型別被記錄下來。例如,`conv2d` 節點將生成一個 dtype 為 `torch.float32`,形狀為 (1, 16, 256, 256) 的 Tensor。
表達動態性#
預設情況下,torch.export 會在假設所有輸入形狀都是**靜態**的情況下跟蹤程式,並將匯出的程式專門化到這些維度。這帶來的一個後果是,在執行時,該程式將無法處理形狀不同的輸入,即使它們在急切模式下是有效的。
一個例子
import torch
import traceback as tb
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
ep = torch.export.export(M(), example_args)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
ep.module()(*example_args2) # fails
except Exception:
tb.print_exc()
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_647/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {}
Traceback (most recent call last):
File "/tmp/ipykernel_647/1522925308.py", line 28, in <module>
ep.module()(*example_args2) # fails
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 413, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
File "<eval_with_key>.25", line 11, in forward
_guards_fn = self._guards_fn(x1, x2); _guards_fn = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 209, in inner
return func(*args, **kwargs)
File "<string>", line 3, in _
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/__init__.py", line 2185, in _assert
assert condition, message
AssertionError: Guard failed: x1.size()[0] == 32
然而,某些維度,例如批次維度,可以是動態的,並且每次執行都可能不同。此類維度必須透過使用 torch.export.Dim() API 來建立,並透過 dynamic_shapes 引數傳遞給 torch.export.export()。
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
ep = torch.export.export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2) # success
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_647/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.8873, 1.2537, ..., 1.8441, 1.5857, 1.2611],
[1.0000, 1.0000, 1.0598, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 2.3600, ..., 1.8214, 1.0710, 1.0000],
...,
[1.0000, 1.0516, 1.1907, ..., 1.4364, 1.0000, 1.0000],
[1.0000, 1.1808, 1.0000, ..., 1.0000, 1.0000, 1.3670],
[1.5627, 2.3273, 1.0000, ..., 1.0000, 1.0000, 1.6566]],
grad_fn=<AddBackward0>),
tensor([[0.0000, 0.0000, 0.0000, ..., 0.1734, 0.0999, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.4386, 0.7500],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
...,
[1.1618, 0.0000, 0.0000, ..., 0.0000, 0.1554, 0.0000],
[0.0000, 0.4619, 0.0904, ..., 0.0000, 0.0792, 0.5539],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
grad_fn=<ReluBackward0>))
一些額外的注意事項
透過
torch.export.Dim()API 和dynamic_shapes引數,我們指定了每個輸入的第一個維度是動態的。檢視輸入 `x1` 和 `x2`,它們具有符號形狀(s0, 64)和(s0, 128),而不是我們作為示例輸入傳入的形狀為(32, 64)和(32, 128)的 Tensor。`s0` 是一個符號,代表該維度可以是一系列值。exported_program.range_constraints描述了圖中出現的每個符號的範圍。在這種情況下,我們看到 `s0` 的範圍是 [0, int_oo]。出於難以在此處解釋的技術原因,它們假定不為 0 或 1。這不是 bug,也不一定意味著匯出的程式不適用於維度 0 或 1。請參閱 0/1 特化問題 以深入討論此主題。
在此示例中,我們使用了 Dim("batch") 來建立一個動態維度。這是指定動態性的最明確的方法。我們也可以使用 Dim.DYNAMIC 和 Dim.AUTO 來指定動態性。我們將在下一節中介紹這兩種方法。
命名維度#
對於用 Dim("name") 指定的每個維度,我們將分配一個符號形狀。使用相同名稱指定 Dim 將導致生成相同的符號。這允許使用者指定為每個輸入維度分配了哪些符號。
batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}
對於每個 Dim,我們可以指定最小值和最大值。我們也允許在單變數線性表示式中指定 Dim 之間的關係:A * dim + B。這允許使用者為動態維度指定更復雜的約束,如整數可除性。這些功能使使用者能夠對生成的 ExportedProgram 的動態行為施加明確的限制。
dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
"x": (dx, None),
"y": (2 * dx, dh),
}
但是,如果在跟蹤過程中發出與給定關係或靜態/動態規範衝突的保護,將引發 ConstraintViolationErrors。例如,在上述規範中,斷言如下
x.shape[0]的範圍是[4, 256],並且與y.shape[0]的關係是y.shape[0] == 2 * x.shape[0]。x.shape[1]是靜態的。y.shape[1]的範圍是[0, 512],並且與任何其他維度無關。
如果在跟蹤過程中發現任何這些斷言不正確(例如,`x.shape[0]` 是靜態的,或者 `y.shape[1]` 的範圍更小,或者 `y.shape[0] != 2 * x.shape[0]`),則將引發 ConstraintViolationError,使用者需要更改其 dynamic_shapes 規範。
維度提示#
而不是使用 Dim("name") 顯式指定動態性,我們可以讓 torch.export 使用 Dim.DYNAMIC 來推斷動態值的範圍和關係。當您不確定動態值具體動態到什麼程度時,這也是一種更方便的指定動態性的方法。
dynamic_shapes = {
"x": (Dim.DYNAMIC, None),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
}
我們還可以為 Dim.DYNAMIC 指定 min/max 值,這些值將作為匯出的提示。但如果在跟蹤過程中匯出發現範圍不同,它將自動更新範圍而不會引發錯誤。我們也無法指定動態值之間的關係。相反,這將由匯出推斷,並透過檢查圖中的斷言暴露給使用者。在這種指定動態性的方法中,只有當推斷出的值為**靜態**時,才會引發 ConstraintViolationErrors。
指定動態性的一個更方便的方法是使用 Dim.AUTO,它的行為類似於 Dim.DYNAMIC,但如果推斷出維度是靜態的,則不會引發錯誤。當您對動態值的範圍一無所知,並希望以“盡力而為”的動態方式匯出程式時,這很有用。
ShapesCollection#
在透過 dynamic_shapes 指定哪些輸入是動態的時,我們必須指定每個輸入的動態性。例如,給定以下輸入
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
我們需要指定 `tensor_x`、`tensor_y` 和 `tensor_z` 的動態性以及動態形狀
# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
然而,這特別複雜,因為我們需要以與輸入引數相同的巢狀輸入結構來指定 `dynamic_shapes` 規範。相反,一種更簡單的指定動態形狀的方法是使用輔助工具 torch.export.ShapesCollection,其中我們不必指定每個輸入的動態性,而是可以直接分配哪些輸入維度是動態的。
import torch
class M(torch.nn.Module):
def forward(self, inp):
x = inp["x"] * 1
y = inp["others"][0] * 2
z = inp["others"][1] * 3
return x, y, z
tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}
print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
# File: /tmp/ipykernel_647/1070110726.py:5 in forward, code: x = inp["x"] * 1
mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1); inp_x = None
# File: /tmp/ipykernel_647/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2); inp_others_0 = None
# File: /tmp/ipykernel_647/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3); inp_others_1 = None
return (mul, mul_1, mul_2)
Graph signature:
# inputs
inp_x: USER_INPUT
inp_others_0: USER_INPUT
inp_others_1: USER_INPUT
# outputs
mul: USER_OUTPUT
mul_1: USER_OUTPUT
mul_2: USER_OUTPUT
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}
AdditionalInputs#
如果您不知道輸入的動態性有多大,但有一組豐富的測試或效能分析資料,可以提供對模型代表性輸入的合理瞭解,您可以使用 torch.export.AdditionalInputs 代替 dynamic_shapes。您可以指定用於跟蹤程式的所有可能的輸入,並且 AdditionalInputs 將根據輸入形狀的變化來推斷哪些輸入是動態的。
示例
import dataclasses
import torch
import torch.utils._pytree as pytree
@dataclasses.dataclass
class D:
b: bool
i: int
f: float
t: torch.Tensor
pytree.register_dataclass(D)
class M(torch.nn.Module):
def forward(self, d: D):
return d.i + d.f + d.t
input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)
print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, d_b, d_i: "Sym(s37)", d_f, d_t: "f32[s99]"):
# File: /tmp/ipykernel_647/829931439.py:16 in forward, code: return d.i + d.f + d.t
sym_float: "Sym(ToFloat(s37))" = torch.sym_float(d_i); d_i = None
add: "Sym(ToFloat(s37) + 3.0)" = sym_float + 3.0; sym_float = None
add_1: "f32[s99]" = torch.ops.aten.add.Tensor(d_t, add); d_t = add = None
return (add_1,)
Graph signature:
# inputs
d_b: USER_INPUT
d_i: USER_INPUT
d_f: USER_INPUT
d_t: USER_INPUT
# outputs
add_1: USER_OUTPUT
Range constraints: {s37: VR[0, int_oo], s99: VR[2, int_oo]}
序列化#
要儲存 ExportedProgram,使用者可以使用 torch.export.save() 和 torch.export.load() API。生成的檔案是一個具有特定結構的 zip 檔案。結構的詳細資訊在 PT2 Archive Spec 中定義。
一個例子
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), (torch.randn(5),))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
匯出 IR,分解#
torch.export 生成的圖返回一個僅包含 ATen 運算子 的圖,ATen 運算子是 PyTorch 中的基本計算單元。由於有超過 3000 個 ATen 運算子,匯出提供了一種根據某些特徵縮小圖中使用 的運算子集的方法,從而建立不同的 IR。
預設情況下,匯出生成最通用的 IR,其中包含所有 ATen 運算子,包括功能性和非功能性運算子。功能性運算子是沒有突變或別名的輸入運算子。您可以在 此處 找到所有 ATen 運算子的列表,並且可以透過檢查 `op._schema.is_mutable` 來檢查運算子是否為功能性的。
此通用 IR 可用於在急切 PyTorch Autograd 中進行訓練。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
return (batch_norm,)
然而,如果您想將 IR 用於推理,或減少使用的運算子數量,您可以透過 `ExportedProgram.run_decompositions()` API 將圖進行降低。此方法將 ATen 運算子分解為分解表中指定的運算子,並將圖進行函式化。
透過指定一個空集,我們只執行函式化,而不進行任何額外的分解。這會生成一個包含約 2000 個運算子(而不是上述 3000 個運算子)的 IR,這對於推理場景非常理想。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
我們可以看到,以前的就地運算子 `torch.ops.aten.add_.default` 現在已被替換為 `torch.ops.aten.add.default`,這是一個功能性運算子。
我們還可以將此匯出的程式進一步降低到僅包含 核心 ATen 運算子集 的運算子集,這是一個約 180 個運算子的集合。此 IR 最適合不想重新實現所有 ATen 運算子的後端。
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
我們現在看到 `torch.ops.aten.conv2d.default` 已被分解為 `torch.ops.aten.convolution.default`。這是因為 `convolution` 是一個更“核心”的運算子,因為 `conv1d` 和 `conv2d` 等運算可以使用相同的 op 來實現。
我們也可以指定自己的分解行為
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
請注意,`torch.ops.aten.conv2d.default` 沒有被分解為 `torch.ops.aten.convolution.default`,而是被分解為 `torch.ops.aten.convolution.default` 和 `torch.ops.aten.mul.Tensor`,這符合我們的自定義分解規則。
torch.export 的侷限性#
由於 torch.export 是一個從 PyTorch 程式捕獲計算圖的一次性過程,它最終可能會遇到程式中無法跟蹤的部分,因為幾乎不可能支援跟蹤所有 PyTorch 和 Python 功能。在 torch.compile 的情況下,不支援的操作會導致“圖中斷”,並且不支援的操作將使用預設的 Python 評估來執行。相比之下,torch.export 會要求使用者提供額外資訊或重寫部分程式碼以使其可跟蹤。
Draft-export 是一個很好的資源,用於列出跟蹤程式時會遇到的圖中斷,以及解決這些錯誤的額外除錯資訊。
ExportDB 也是一個很好的資源,用於瞭解支援和不支援的程式型別,以及重寫程式以使其可跟蹤的方法。
TorchDynamo 不支援#
當使用 `strict=True` 的 torch.export 時,它將使用 TorchDynamo 在 Python 位元組碼級別評估程式以將程式跟蹤到圖。與以前的跟蹤框架相比,需要更少的重寫即可使程式可跟蹤,但仍會存在一些不支援的 Python 功能。為了繞過處理圖中斷的方法是使用 非嚴格匯出,透過將 `strict` 標誌更改為 `strict=False`。
資料/形狀依賴的控制流#
當形狀未專門化時,資料依賴的控制流(`if x.shape[0] > 2`)也可能遇到圖中斷,因為跟蹤編譯器無法處理,除非生成組合爆炸式路徑的程式碼。在這種情況下,使用者需要使用特殊的控制流運算子重寫其程式碼。目前,我們支援 torch.cond 來表示 if-else 類似的控制流(更多內容即將推出!)。
您還可以參考此 教程 以瞭解更多解決資料依賴錯誤的方法。