評價此頁

ONNX簡介 || 將PyTorch模型匯出到ONNX || 擴充套件ONNX匯出器運算子支援 || `使用控制流將模型匯出到ONNX

使用控制流將模型匯出到ONNX#

作者: Xavier Dupré

概述#

本教程演示了在將PyTorch模型匯出到ONNX時如何處理控制流邏輯。它強調了直接匯出條件語句的挑戰,並提供了繞過它們的解決方案。

除非使用 torch.cond() 進行重構,否則條件邏輯無法匯出到ONNX。讓我們從實現一個測試的簡單模型開始。

您將學到什麼

  • 如何重構模型以使用 torch.cond() 進行匯出。

  • 如何將帶有控制流邏輯的模型匯出到ONNX。

  • 如何使用ONNX最佳化器最佳化匯出的模型。

先決條件#

  • torch >= 2.6

import torch

定義模型#

定義了兩個模型

ForwardWithControlFlowTest: 一個forward方法包含if-else條件的模型。

ModelWithControlFlowTest: 一個模型,它將 ForwardWithControlFlowTest 作為簡單MLP的一部分。模型使用隨機輸入張量進行測試,以確認它們按預期執行。

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

匯出模型:第一次嘗試#

使用torch.export.export匯出此模型會失敗,因為forward pass中的控制流邏輯會建立一個圖中斷,匯出器無法處理。這是預期的行為,因為未使用 torch.cond() 編寫的條件邏輯不受支援。

使用try-except塊捕獲匯出過程中預期的失敗。如果匯出意外成功,則會引發 AssertionError

x = torch.randn(3)
model(x)

try:
    torch.export.export(model, (x,), strict=False)
    raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
    print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None




def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
    if x.sum():


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

建議的補丁:使用 torch.cond() 進行重構#

為了使控制流可匯出,本教程演示了使用 torch.cond`() 重構的版本替換 ForwardWithControlFlowTest 中的forward方法。

重構細節

兩個輔助函式(identity2和neg)代表條件邏輯的分支:* torch.cond`() 用於指定條件和兩個分支以及輸入引數。* 更新後的forward方法隨後在模型內動態地分配給 ForwardWithControlFlowTest 例項。列印子模組列表以確認替換。

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

讓我們看看FX圖的樣子。

print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias);  x = p_mlp_0_weight = p_mlp_0_bias = None
            linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias);  linear = p_mlp_1_weight = p_mlp_1_bias = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:250 in forward, code: input = module(input)
            sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: <eval_with_key>.3:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (linear_1,));  gt = true_graph_0 = false_graph_0 = linear_1 = None
            getitem: "f32[1]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.0:6 in forward, code: mul = l_args_3_0__1 * 2;  l_args_3_0__1 = None
                mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2);  linear_1 = None
                return (mul,)

        class false_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.1:6 in forward, code: neg = -l_args_3_0__1;  l_args_3_0__1 = None
                neg: "f32[1]" = torch.ops.aten.neg.default(linear_1);  linear_1 = None
                return (neg,)

Graph signature:
    # inputs
    p_mlp_0_weight: PARAMETER target='mlp.0.weight'
    p_mlp_0_bias: PARAMETER target='mlp.0.bias'
    p_mlp_1_weight: PARAMETER target='mlp.1.weight'
    p_mlp_1_bias: PARAMETER target='mlp.1.bias'
    x: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {}

讓我們再次匯出。

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
    ir_version=10,
    opset_imports={'': 20},
    producer_name='pytorch',
    producer_version='2.9.0+cu128',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([0.1949, 0.2263], requires_grad=True), name='mlp.0.bias')},
        %"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.4595], requires_grad=True), name='mlp.1.bias')},
        %"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.03271979,  0.5771984 ], [ 0.48856184, -0.1434506 ], [ 0.53982925, -0.26834208]], dtype=float32), name='val_0')},
        %"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.6689668 ], [ 0.59571433]], dtype=float32), name='val_2')},
        %"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
        %"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
    ),
) {
    0 |  # node_MatMul_1
         %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.03271978721022606, 0.5771983861923218], [0.48856183886528015, -0.14345060288906097], [0.5398292541503906, -0.2683420777320862]]})
    1 |  # node_linear
         %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.19491833448410034, 0.22633221745491028]})
    2 |  # node_MatMul_3
         %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6689668297767639], [0.5957143306732178]]})
    3 |  # node_linear_1
         %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.45953357219696045]})
    4 |  # node_sum_1
         %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=0}
    5 |  # node_gt
         %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
    6 |  # node_cond__0
         %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
             graph(
                 name=true_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"mul_true_graph_0"<FLOAT,[1]>
                 ),
             ) {
                 0 |  # node_mul
                      %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
                 return %"mul_true_graph_0"<FLOAT,[1]>
             }, else_branch=
             graph(
                 name=false_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"neg_false_graph_0"<FLOAT,[1]>
                 ),
             ) {
                 0 |  # node_neg
                      %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                 return %"neg_false_graph_0"<FLOAT,[1]>
             }}
    return %"getitem"<FLOAT,[1]>
}

我們可以最佳化模型並刪除為捕獲控制流分支而建立的模型本地函式。

<
    ir_version=10,
    opset_imports={'': 20},
    producer_name='pytorch',
    producer_version='2.9.0+cu128',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([0.1949, 0.2263], requires_grad=True), name='mlp.0.bias')},
        %"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.4595], requires_grad=True), name='mlp.1.bias')},
        %"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.03271979,  0.5771984 ], [ 0.48856184, -0.1434506 ], [ 0.53982925, -0.26834208]], dtype=float32), name='val_0')},
        %"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.6689668 ], [ 0.59571433]], dtype=float32), name='val_2')},
        %"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
        %"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
    ),
) {
    0 |  # node_MatMul_1
         %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.03271978721022606, 0.5771983861923218], [0.48856183886528015, -0.14345060288906097], [0.5398292541503906, -0.2683420777320862]]})
    1 |  # node_linear
         %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.19491833448410034, 0.22633221745491028]})
    2 |  # node_MatMul_3
         %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6689668297767639], [0.5957143306732178]]})
    3 |  # node_linear_1
         %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.45953357219696045]})
    4 |  # node_sum_1
         %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=0}
    5 |  # node_gt
         %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
    6 |  # node_cond__0
         %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
             graph(
                 name=true_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"mul_true_graph_0"<FLOAT,[1]>
                 ),
             ) {
                 0 |  # node_mul
                      %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
                 return %"mul_true_graph_0"<FLOAT,[1]>
             }, else_branch=
             graph(
                 name=false_graph_0,
                 inputs=(

                 ),
                 outputs=(
                     %"neg_false_graph_0"<FLOAT,[1]>
                 ),
             ) {
                 0 |  # node_neg
                      %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                 return %"neg_false_graph_0"<FLOAT,[1]>
             }}
    return %"getitem"<FLOAT,[1]>
}

結論#

本教程演示了將帶有條件邏輯的模型匯出到ONNX的挑戰,並使用 torch.cond() 提供了一個實際的解決方案。雖然預設匯出器可能會失敗或生成不完美的圖,但重構模型的邏輯可確保相容性並生成忠實的ONNX表示。

透過理解這些技術,我們可以克服在PyTorch模型中使用控制流時的常見陷阱,並確保與ONNX工作流程的順暢整合。

延伸閱讀#

下面的列表引用了從基本示例到高階場景的教程,不一定按列出的順序。您可以隨時跳轉到您感興趣的特定主題,或者坐下來,享受學習 ONNX 匯出器所有知識的樂趣。

指令碼總執行時間: (0分鐘 6.284秒)