評價此頁

torch.export 教程#

建立日期:2023 年 10 月 02 日 | 最後更新日期:2025 年 01 月 27 日 | 最後驗證日期:2024 年 11 月 05 日

作者: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

警告

torch.export 及其相關功能處於原型狀態,並可能發生不相容的更改。本教程提供了 PyTorch 2.5 版本中 torch.export 用法的快照。

torch.export() 是 PyTorch 2.X 版本中將 PyTorch 模型匯出為標準化模型表示的方式,旨在在不同的(即無 Python 的)環境中執行。官方文件可在此處找到。

在本教程中,您將學習如何使用 torch.export() 從 PyTorch 程式中提取 ExportedProgram(即單圖表示)。我們還將詳細介紹為了使模型與 torch.export 相容可能需要進行的一些考慮/修改。

內容

基本用法#

torch.export 透過跟蹤目標函式(給定示例輸入)來從 PyTorch 程式中提取單圖表示。 torch.export.export()torch.export 的主要入口點。

在本教程中,torch.exporttorch.export.export() 在實踐中是同義詞,儘管 torch.export 通常指 PyTorch 2.X 匯出過程,而 torch.export.export() 通常指實際的函式呼叫。

torch.export.export() 的簽名是

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 透過呼叫 mod(*args, **kwargs) 來跟蹤張量計算圖,並將其包裝在 ExportedProgram 中,該程式可以稍後進行序列化或使用不同的輸入執行。要執行 ExportedProgram,我們可以呼叫其 .module() 方法,返回一個與原始程式一樣可呼叫的 torch.nn.Module。我們將在教程後面詳細介紹 dynamic_shapes 引數。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.1344, 0.0000, 0.5624, 0.3435, 0.6017,
         2.7909],
        [0.0000, 0.0000, 0.0000, 1.0477, 0.0000, 0.0000, 2.0461, 0.0000, 0.7889,
         0.0000],
        [0.0000, 0.0000, 0.9877, 0.0000, 0.5176, 0.5654, 0.0000, 1.5727, 0.6437,
         0.0000],
        [0.5942, 0.7034, 0.0564, 0.0000, 0.4003, 0.0000, 0.0000, 0.2568, 0.0000,
         0.0000],
        [0.0000, 0.2937, 0.3767, 0.5728, 0.0000, 1.2748, 1.0475, 1.4762, 1.1964,
         0.0000],
        [0.0032, 0.0920, 0.3563, 0.0403, 0.2645, 0.0000, 0.7189, 0.7865, 0.0000,
         0.0000],
        [0.9222, 0.0000, 0.0495, 0.7542, 1.1252, 0.8088, 0.0000, 0.0000, 0.7890,
         0.0707],
        [0.0000, 0.5030, 0.0000, 0.2587, 0.8563, 0.6805, 1.0002, 0.9677, 0.0000,
         1.3133]], grad_fn=<ReluBackward0>)

讓我們回顧一下 ExportedProgram 中一些值得關注的屬性。

graph 屬性是一個從我們匯出的函式跟蹤而來的FX 圖,即所有 PyTorch 操作的計算圖。FX 圖處於“ATen IR”中,這意味著它僅包含“ATen 級別”的操作。

graph_signature 屬性提供了關於匯出圖中輸入和輸出節點的更詳細描述,說明哪些是引數、緩衝區、使用者輸入或使用者輸出。

range_constraints 屬性將在後面介紹。

print(exported_mod)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias);  add = p_lin_weight = p_lin_bias = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear);  linear = None
            return (relu_,)

Graph signature:
    # inputs
    p_lin_weight: PARAMETER target='lin.weight'
    p_lin_bias: PARAMETER target='lin.bias'
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    relu_: USER_OUTPUT

Range constraints: {}

有關更多詳細資訊,請參閱 torch.export文件

圖中斷#

儘管 torch.exporttorch.compile 共享元件,但 torch.export 的關鍵限制(尤其是與 torch.compile 相比)是它不支援圖中斷。這是因為處理圖中斷涉及使用預設的 Python 評估來解釋不支援的操作,這與匯出用例不相容。因此,為了使模型程式碼與 torch.export 相容,您需要修改程式碼以移除圖中斷。

在以下情況下需要圖中斷:

  • 依賴資料的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
def forward(self, arg0_1: "f32[3, 3]"):
     # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
    sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None




def forward(self, arg0_1: "f32[3, 3]"):
     # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
    sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
    export(Bad1(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1409, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1479, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1066, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
    r = self.evaluate()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7233, in evaluate_sym_node
    return self.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7333, in evaluate_expr
    return self._inner_evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7356, in _inner_evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_export/non_strict_utils.py:1066 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
  • 使用 .data 訪問張量資料

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 呼叫不支援的函式(例如許多內建函式)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()

非嚴格匯出#

預設情況下,torch.export 使用 TorchDynamo(一個位元組碼分析引擎)來符號化分析 Python 程式碼並基於結果構建圖。此分析允許 torch.export 提供更強的安全保證,但並非所有 Python 程式碼都受支援,這會導致圖中斷。

為了解決這個問題,在 PyTorch 2.3 中,我們引入了一種新的匯出模式,稱為非嚴格模式,其中我們使用 Python 直譯器跟蹤程式,精確地按照其在 eager 模式下的執行方式執行,從而允許我們跳過不支援的 Python 功能。這是透過新增 strict=False 標誌實現的。

看看前面導致圖中斷的一些示例

  • 呼叫不支援的函式(例如許多內建函式)會進行跟蹤

但在此情況下,id(x) 被特化為一個常量整數。這是因為 id(x) 不是張量操作,因此該操作未記錄在圖中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 140226861055920);  add = None
            return (add_1,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    add_1: USER_OUTPUT

Range constraints: {}

tensor([[1.4023e+14, 1.4023e+14, 1.4023e+14],
        [1.4023e+14, 1.4023e+14, 1.4023e+14],
        [1.4023e+14, 1.4023e+14, 1.4023e+14]])

但是,仍有一些功能需要重寫原始模組。

控制流操作#

torch.export 實際上支援依賴資料的控制流。但這些需要使用控制流操作來表達。例如,我們可以使用 cond 操作來修復上面的控制流示例,如下所示

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
            sum_1: "f32[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: <eval_with_key>.35:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x,));  gt = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 3]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: <eval_with_key>.32:6 in forward, code: sin = torch.sin(l_args_3_0__1);  l_args_3_0__1 = None
                sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: <eval_with_key>.33:6 in forward, code: cos = torch.cos(l_args_3_0__1);  l_args_3_0__1 = None
                cos: "f32[3, 3]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {}

tensor([[0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403]])

應該注意 cond 的一些限制:

  • 謂詞(即 x.sum() > 0)必須結果為一個布林值或一個單元素張量。

  • 運算元(即 [x])必須是張量。

  • 分支函式(即 true_fnfalse_fn)的簽名必須與運算元匹配,並且它們都必須返回具有相同元資料(例如 dtypeshape 等)的單個張量。

  • 分支函式不能修改輸入或全域性變數。

  • 分支函式不能訪問閉包變數,除非函式定義在方法的範圍內,此時可以訪問 self

有關 cond 的更多詳細資訊,請查閱cond 文件

我們還可以使用 map,它將一個函式應用於第一個張量引數的第一個維度。

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
             # File: <eval_with_key>.64:9 in forward, code: map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_flat_xs_0_], [l_flat_args_0_, l_flat_args_1_]);  map_body_0 = l_flat_xs_0_ = l_flat_args_0_ = l_flat_args_1_ = None
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]);  body_graph_0 = xs = y = z = None
            getitem: "f32[6, 4]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
                 # File: <eval_with_key>.62:5 in forward, code: add = child + l_flat_args_0_;  child = l_flat_args_0_ = None
                add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None

                 # File: <eval_with_key>.62:6 in forward, code: add_1 = add + l_flat_args_1_;  add = l_flat_args_1_ = None
                add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z);  add = z = None
                return (add_1,)

Graph signature:
    # inputs
    xs: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {}

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.]])

其他控制流操作包括 while_loopassociative_scanscan。有關每個運算子的更多文件,請參閱此頁面

約束/動態形狀#

本節介紹匯出程式的動態行為和表示。動態行為取決於正在匯出的特定模型,因此在本教程的大部分內容中,我們將專注於這個特定的玩具模型(並附帶了輸出張量的形狀註解)。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

預設情況下,torch.export 生成靜態程式。其後果之一是,在執行時,程式無法在形狀不同的輸入上工作,即使它們在 eager 模式下是有效的。

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.95", line 8, in forward
    _guards_fn = self._guards_fn(w, x, y, z);  _guards_fn = None
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 209, in inner
    return func(*args, **kwargs)
  File "<string>", line 6, in _
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2185, in _assert
    assert condition, message
AssertionError: Guard failed: y.size()[0] == 8

基本概念:符號和守衛#

為了實現動態性,export() 提供了一個 dynamic_shapes 引數。處理動態形狀的最簡單方法是使用 Dim.AUTO 並檢視返回的程式。動態行為是在輸入維度級別指定的;對於每個輸入,我們可以指定一個值的元組。

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

在我們檢視生成的程式之前,讓我們先了解一下指定 dynamic_shapes 的含義,以及它如何與匯出互動。對於每個指定了 Dim 物件的輸入維度,都會分配一個符號,該符號的範圍為 [2, inf](為什麼不是 [0, inf][1, inf]?我們將在 0/1 特化部分進行解釋)。

匯出然後執行模型跟蹤,檢視模型執行的每個操作。每個單獨的操作都可以發出所謂的“守衛”;基本上是布林條件,程式必須滿足這些條件才能有效。當守衛涉及為輸入維度分配的符號時,程式就包含了對有效輸入形狀的限制;即程式的動態行為。符號形狀子系統負責接收所有發出的守衛,並生成一個最終程式表示,該表示符合所有這些守衛。在我們看到 ExportedProgram 中的“最終表示”之前,讓我們看看我們跟蹤的玩具模型發出的守衛。

這裡,每個前向輸入張量都用在跟蹤開始時分配的符號進行了註解。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

讓我們來理解每個操作和發出的守衛。

  • x0 = x + y:這是一個逐元素加法,帶有廣播,因為 x 是一個一維張量,而 y 是一個二維張量。 x 沿 y 的最後一個維度進行廣播,發出的守衛是 s2 == s4

  • x1 = self.l(w):呼叫 nn.Linear() 會執行與模型引數的矩陣乘法。在匯出中,引數、緩衝區和常量被視為程式狀態,它們被認為是靜態的,因此這是動態輸入(w: [s0, s1])和靜態形狀張量之間的矩陣乘法。這會發出守衛 s1 == 5

  • x2 = x0.flatten():此呼叫實際上不發出任何守衛!(至少沒有與輸入形狀相關的)

  • x3 = x2 + z:展平後,x2 的形狀為 [s3*s4],此逐元素加法發出 s3 * s4 == s5

將所有這些守衛寫下來並進行總結幾乎就像一個數學證明,這就是符號形狀子系統試圖做的!總之,我們可以得出結論,程式必須具有以下輸入形狀才能有效:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

當我們最終打印出匯出的程式以檢視結果時,這些形狀就是我們在相應輸入上看到的註解。

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s15, 5]", x: "f32[s77]", y: "f32[s17, s77]", z: "f32[s17*s77]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y  # [8, 4]
            add: "f32[s17, s77]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s15, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias);  w = p_l_weight = p_l_bias = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten()  # [32]
            flatten: "f32[s17*s77]" = torch.ops.aten.flatten.using_ints(add);  add = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z  # [32]
            add_1: "f32[s17*s77]" = torch.ops.aten.add.Tensor(flatten, z);  flatten = z = None
            return (linear, add_1)

Graph signature:
    # inputs
    p_l_weight: PARAMETER target='l.weight'
    p_l_bias: PARAMETER target='l.bias'
    w: USER_INPUT
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    linear: USER_OUTPUT
    add_1: USER_OUTPUT

Range constraints: {s15: VR[2, int_oo], s77: VR[2, int_oo], s17: VR[2, int_oo], s17*s77: VR[4, int_oo]}

另一個要注意的功能是上面的 range_constraints 欄位,它包含每個符號的有效範圍。目前這並不太有趣,因為此匯出呼叫不發出與符號界限相關的守衛,並且每個基本符號都有一個通用界限,但稍後會討論這一點。

到目前為止,因為我們一直在匯出這個玩具模型,所以這次體驗並不能代表除錯動態形狀守衛和問題的通常難度。在大多數情況下,不清楚發出了哪些守衛,以及哪些操作和使用者程式碼負責。對於這個玩具模型,我們 pinpoint 了確切的行,守衛也相當直觀。

在更復雜的情況下,一個有用的第一步總是啟用詳細日誌記錄。這可以透過環境變數 TORCH_LOGS="+dynamic" 來完成,或者透過互動式呼叫 torch._logging.set_logs(dynamic=10) 來完成。

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
I1015 19:14:15.352000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.354000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.354000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.356000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.358000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.358000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.360000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.366000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.367000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.367000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.369000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.369000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.371000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.371000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.372000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.373000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.374000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.376000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I1015 19:14:15.377000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[2, int_oo]
V1015 19:14:15.379000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.385000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.386000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.386000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.398000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.401000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.403000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.405000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[4, int_oo] (update)
I1015 19:14:15.405000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I1015 19:14:15.410000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.411000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.411000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.412000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.413000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.414000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.415000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
V1015 19:14:15.428000 18276 torch/fx/experimental/symbolic_shapes.py:7471] eval 5 [trivial]

即使對於這個簡單的玩具模型,這也會產生很多資訊。這裡的日誌行已從開頭和結尾截斷以忽略不必要的資訊,但檢視日誌,我們可以看到與我們上面描述相關的內容;例如,符號的分配。

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"

create_symbol 行顯示何時分配了一個新符號,日誌還標識了張量變數名及其分配的維度。在其他行中,我們還可以看到發出的守衛。

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'

[guard added] 訊息旁邊,我們還看到了負責的使用者程式碼行 - 幸運的是,這裡的模型足夠簡單。在許多實際情況下,情況並非如此直接:高階 torch 操作可能具有複雜的偽核心實現或運算元分解,這會使守衛的發出位置和內容變得複雜。在這種情況下,深入研究和調查的最佳方法是遵循日誌的建議,並使用環境變數 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..." 重新執行,以進一步歸因感興趣的守衛。

Dim.AUTO 只是與 dynamic_shapes 互動的可用選項之一;截至撰寫本文時,還有另外兩個選項:Dim.DYNAMICDim.STATICDim.STATIC 僅將維度標記為靜態,而 Dim.DYNAMIC 在特化為常量時與 Dim.AUTO 類似,除了一個不同之處:它在特化為常量時會引發錯誤;這是為了保持動態性。例如,看看當在動態標記的維度上發出靜態守衛時會發生什麼。

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I1015 19:14:15.432000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.434000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.434000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.436000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.438000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.438000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.440000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.446000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.447000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.448000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.449000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.449000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.451000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.451000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.452000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.453000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.454000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.456000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I1015 19:14:15.457000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[2, int_oo]
V1015 19:14:15.459000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.465000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.466000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.466000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.478000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.481000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.483000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.485000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[4, int_oo] (update)
I1015 19:14:15.485000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I1015 19:14:15.490000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V1015 19:14:15.491000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.492000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.493000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.494000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.495000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.495000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

靜態守衛也不總是模型固有的;它們也可以來自使用者規範。事實上,一個導致形狀特化的常見陷阱是使用者為等效維度指定了衝突的標記;一個動態,另一個靜態。當 x.shape[0]y.shape[1] 出現這種情況時,會引發相同的錯誤型別。

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I1015 19:14:15.507000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.508000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.508000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.511000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.512000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.514000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.520000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.520000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.521000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.523000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.523000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.524000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.525000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.526000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.531000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s94, 4) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s94, 4)"
V1015 19:14:15.531000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 4] (update)
I1015 19:14:15.532000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = 4 (range_refined_to_singleton) VR[4, 4]
I1015 19:14:15.539000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.540000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.541000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
I1015 19:14:15.560000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(4*s17, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s17, s68)"
V1015 19:14:15.564000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[8, int_oo] (update)
I1015 19:14:15.565000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = 4*s17 (solve) VR[8, int_oo]
I1015 19:14:15.570000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.570000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s15 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.571000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 4 None
V1015 19:14:15.572000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.573000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 4 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.574000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] 4*s17 None
V1015 19:14:15.575000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.575000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

在這裡,您可能會問為什麼匯出會“特化”,即為什麼我們透過選擇靜態路徑來解決這種靜態/動態衝突。答案是由於上面描述的符號形狀系統,以及符號和守衛。當 x.shape[0] 被標記為靜態時,我們不分配符號,而是將此形狀編譯為具體整數 4。為 y.shape[1] 分配了一個符號,因此我們最終發出了守衛 s3 == 4,從而導致特化。

匯出的一個特點是,在跟蹤期間,assert、torch._check()if/else 條件等語句也會發出守衛。看看當我們用這些語句增強現有模型時會發生什麼。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I1015 19:14:15.585000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.587000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.587000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.589000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.591000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.591000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.593000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.599000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.600000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.601000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.602000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.603000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.604000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.604000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.605000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.606000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.607000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.613000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval s15 <= 512 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s15 <= 512"
V1015 19:14:15.613000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s15 = VR[2, 512] (update)
I1015 19:14:15.617000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval s77 >= 4 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s77 >= 4"
V1015 19:14:15.617000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s77 = VR[4, int_oo] (update)
I1015 19:14:15.622000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s15, s77 + 2) [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s15, s77 + 2)"
V1015 19:14:15.624000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s77 = VR[4, 510] (update)
V1015 19:14:15.625000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s15 = VR[6, 512] (update)
I1015 19:14:15.626000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s15 = s77 + 2 (solve) VR[6, 512]
I1015 19:14:15.630000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
V1015 19:14:15.631000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 510] (update)
I1015 19:14:15.631000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s77 (solve) VR[4, 510]
V1015 19:14:15.635000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.642000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2248 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V1015 19:14:15.643000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s21 = VR[5, 5] (update)
I1015 19:14:15.644000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V1015 19:14:15.658000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.661000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.669000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V1015 19:14:15.670000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s68 = VR[8, int_oo] (update)
I1015 19:14:15.671000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s68 = s17*s77 (solve) VR[8, int_oo]
I1015 19:14:15.676000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.677000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[0] s77 + 2 None
V1015 19:14:15.677000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].size()[1] 5 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[0] 5 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].stride()[1] 1 None
V1015 19:14:15.678000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['w'].storage_offset() 0 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s77 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 1 None
V1015 19:14:15.679000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] s77 None
V1015 19:14:15.680000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] s77 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.681000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].size()[0] s17*s77 None
V1015 19:14:15.682000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].stride()[0] 1 None
V1015 19:14:15.682000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['z'].storage_offset() 0 None
V1015 19:14:15.698000 18276 torch/fx/experimental/symbolic_shapes.py:7471] eval 5 [trivial]

這些語句中的每一個都會發出一個附加守衛,並且匯出的程式會顯示更改;s0s2 + 2 取代,而 s2 現在包含上下界,這反映在 range_constraints 中。

對於 if/else 條件,您可能會問為什麼選擇了 True 分支,而不是從跟蹤中發出的 w.shape[0] != x.shape[0] + 2 守衛。答案是匯出由跟蹤提供的樣本輸入指導,並特化到所選的分支。如果提供了不同的樣本輸入形狀,它們會不滿足 if 條件,匯出將跟蹤併發出對應於 else 分支的守衛。此外,您可能會問為什麼我們只跟蹤了 if 分支,以及是否可以維護程式中的控制流並同時保留兩個分支。有關此問題,請參閱上面“控制流操作”部分重寫模型程式碼。

0/1 特化#

既然我們在談論守衛和特化,那麼現在是時候談談我們之前提到的 0/1 特化問題了。關鍵在於,匯出將特化具有值為 0 或 1 的樣本輸入維度,因為這些形狀具有在跟蹤時不會泛化到其他形狀的屬性。例如,大小為 1 的張量可以廣播,而其他大小的張量會失敗;大小為 0 的張量...。這僅僅意味著當您希望程式硬編碼 0/1 樣本輸入時,您應該指定它們,而當需要動態行為時,應指定非 0/1 樣本輸入。看看我們在執行時匯出此線性層時會發生什麼。

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()
I1015 19:14:15.703000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.715000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].size()[0] 1 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].size()[1] 4 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].stride()[0] 4 None
V1015 19:14:15.716000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].stride()[1] 1 None
V1015 19:14:15.717000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['input'].storage_offset() 0 None
W1015 19:14:15.719000 18276 torch/_export/non_strict_utils.py:564] dimension inputs['input'].shape[0] 0/1 specialized; Dim.AUTO was specified along with a sample input with hint = 1.
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
    ep.module()(torch.randn(2, 4))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1829, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.125", line 9, in forward
    _guards_fn = self._guards_fn(input_1);  _guards_fn = None
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 209, in inner
    return func(*args, **kwargs)
  File "<string>", line 3, in _
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2185, in _assert
    assert condition, message
AssertionError: Guard failed: input.size()[0] == 1

命名維度#

到目前為止,我們只討論了指定動態形狀的 3 種方法:Dim.AUTODim.DYNAMICDim.STATIC。這些方法的吸引力在於低摩擦的使用者體驗;所有在模型跟蹤期間發出的守衛都會得到遵守,並且動態行為(如最小/最大範圍、關係和靜態/動態維度)會在匯出下自動確定。動態形狀子系統基本上充當一個“發現”過程,總結這些守衛並呈現匯出認為的程式的整體動態行為。當用戶對這些模型的動態行為有更強的期望或信念時,這種設計的缺點就會顯現出來 - 也許強烈希望保持動態性,並希望避免在特定維度上進行特化,或者我們只想透過更改原始模型程式碼、底層分解或元核心來捕獲動態行為的變化。這些更改將不會被檢測到,並且 export() 呼叫很可能成功,除非有檢查生成的 ExportedProgram 表示的測試。

對於這種情況,我們的立場是推薦指定動態形狀的“傳統”方法,這可能是匯出長期使用者所熟悉的:命名 Dims

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

這種風格的動態形狀允許使用者指定為輸入維度分配的符號、這些符號的最小/最大界限,並對生成的 ExportedProgram 的動態行為施加限制;如果模型跟蹤發出的守衛與給定的關係或靜態/動態規範衝突,將引發 ConstraintViolation 錯誤。例如,在上面的規範中,斷言了以下內容:

  • x.shape[0] 的範圍是 [4, 256],並且與 y.shape[0] 的關係是 y.shape[0] == 2 * x.shape[0]

  • x.shape[1] 是靜態的。

  • y.shape[1] 的範圍是 [2, 512],並且與其他任何維度無關。

在此設計中,我們允許使用一元線性表示式指定維度之間的關係:A * dim + B 可以為任何維度指定。這允許使用者為動態維度指定更復雜的約束,例如整數可除性。

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

約束衝突,建議的修復#

此規範樣式(在引入 Dim.AUTO 之前)的一個常見問題是規範經常與模型跟蹤生成的內容不匹配。這會導致 ConstraintViolation 錯誤和匯出建議的修復 - 例如,看看這個模型和規範,其中模型本身要求 xy 的維度 0 之間相等,並要求維度 1 是靜態的。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()
I1015 19:14:15.728000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.730000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s77 = 6 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.732000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s27 = 4 for L['x'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s27" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.734000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s17 = 6 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I1015 19:14:15.735000 18276 torch/fx/experimental/symbolic_shapes.py:5114] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:232 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V1015 19:14:15.741000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.743000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.744000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.745000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.746000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
V1015 19:14:15.747000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.751000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s27, s94) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, s94)"
I1015 19:14:15.752000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = s27 (solve) VR[2, int_oo]
I1015 19:14:15.754000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s77, s17) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s17)"
I1015 19:14:15.755000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s77 = s17 (solve) VR[2, int_oo]
V1015 19:14:15.757000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == False [statically known]
I1015 19:14:15.767000 18276 torch/fx/experimental/symbolic_shapes.py:7207] eval Eq(s27, 4) [guard added] (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, 4)"
V1015 19:14:15.768000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s27 = VR[4, 4] (update)
I1015 19:14:15.768000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s27 = 4 (range_refined_to_singleton) VR[4, 4]
I1015 19:14:15.774000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.774000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range s94 = VR[4, 4] (update)
I1015 19:14:15.775000 18276 torch/fx/experimental/symbolic_shapes.py:6786] set_replacement s94 = 4 (find) VR[4, 4]
V1015 19:14:15.775000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 4 None
V1015 19:14:15.776000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.777000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 4 None
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[1] 1 None
V1015 19:14:15.778000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1821, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1968, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 533, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 500, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5204, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5928, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
  d1 = 4
  dy = dx

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
    ep = export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1823, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
  d1 = 4
  dy = dx

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

建議修復的預期是使用者可以互動式地將更改複製貼上到其動態形狀規範中,然後成功匯出。

最後,關於規範選項,有幾點值得了解。

  • None 是靜態行為的良好選項: - dynamic_shapes=None(預設)以整個模型都是靜態的方式匯出。 - 在輸入級別指定 None 以所有張量維度都是靜態的方式匯出,對於非張量輸入也是必需的。 - 在維度級別指定 None 會特化該維度,但這已棄用,推薦使用 Dim.STATIC

  • 指定每維整數值也會產生靜態行為,並且還會檢查提供的樣本輸入是否與規範匹配。

這些選項在下面的輸入和動態形狀規範中組合。

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

依賴資料的錯誤#

在嘗試匯出模型時,您可能遇到了類似“無法在依賴資料的表示式上設定守衛”或“無法從依賴資料的表示式中提取特化整數”的錯誤。這些錯誤存在是因為 torch.export() 使用 FakeTensors 來編譯程式,FakeTensors 符號化地表示其真實的張量對應物。雖然它們具有等效的符號屬性(例如,大小、步幅、資料型別),但它們不同之處在於 FakeTensors 不包含任何資料值。雖然這避免了不必要的記憶體使用和昂貴的計算,但這意味著匯出可能無法開箱即用地編譯使用者程式碼中依賴資料值進行編譯的部分。簡而言之,如果編譯器需要一個具體的、依賴資料的數值才能繼續,它將報錯,並抱怨該數值不可用。

資料依賴值出現在許多地方,常見來源包括 item()tolist()torch.unbind() 等呼叫,這些呼叫會從張量中提取標量值。這些值在匯出的程式中如何表示?在約束/動態形狀部分,我們討論了為動態輸入維度分配符號。這裡也是如此:我們為程式中出現的每個依賴資料的變數分配符號。重要的區別在於這些是“未備份”符號,與為輸入維度分配的“備份”符號相對。“備份/未備份”命名法是指是否存在符號的“提示”:一個支援該符號的具體值,可以告知編譯器如何繼續。

在輸入形狀符號情況(備份符號)下,這些提示僅僅是提供的樣本輸入形狀,這解釋了為什麼控制流分支由樣本輸入屬性決定。對於依賴資料的變數,符號取自跟蹤期間的 FakeTensor “資料”,因此編譯器不知道這些符號將採用的實際值(提示)。

讓我們看看這些如何在匯出的程式中顯示。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.787000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.794000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.795000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.798000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.799000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u1]
I1015 19:14:15.800000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u2 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.800000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u2]
I1015 19:14:15.802000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 2 None
V1015 19:14:15.803000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.804000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "i64[2]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
            unbind = torch.ops.aten.unbind.int(y);  y = None
            getitem: "i64[]" = unbind[0]
            getitem_1: "i64[]" = unbind[1];  unbind = None
            item_1: "Sym(u1)" = torch.ops.aten.item.default(getitem);  getitem = None
            item_2: "Sym(u2)" = torch.ops.aten.item.default(getitem_1);  getitem_1 = None
            return (item_1, item_2, item)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    item_1: USER_OUTPUT
    item_2: USER_OUTPUT
    item: USER_OUTPUT

Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo]}

結果是分配並返回了 3 個未備份的符號(注意它們以“u”為字首,而不是通常的輸入形狀/備份符號的“s”):1 個用於 item() 呼叫,1 個用於 tolist() 呼叫中的 y 的每個元素。注意從 range constraints 欄位可以看出,這些符號的範圍是 [-int_oo, int_oo],而不是輸入形狀符號分配的預設 [0, int_oo] 範圍,因為我們沒有任何關於這些值的資訊 - 它們不代表大小,所以不一定具有正值。

守衛,torch._check()#

但是上面的情況很容易匯出,因為這些符號的具體值不會用於任何編譯器決策;唯一相關的是返回值為未備份符號。本節中突出顯示的依賴資料錯誤是以下情況,其中依賴資料的守衛被遇到。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

這裡我們實際上需要“提示”,或者 a 的具體值,編譯器才能決定是跟蹤 return y + 2 還是 return y * 5 作為輸出。因為我們使用 FakeTensors 進行跟蹤,所以我們不知道 a // 2 >= 5 實際求值為多少,匯出會報錯“無法在依賴資料的表示式 u0 // 2 >= 5 (unhinted) 上設定守衛”。

那麼我們如何匯出這個玩具模型呢?與 torch.compile() 不同,匯出需要完全的圖編譯,而我們不能僅僅在那裡進行圖中斷。以下是一些基本選項:

  1. 手動特化:我們可以透過移除控制流程式碼只包含特化分支,或者使用 torch.compiler.is_compiling() 來過濾在編譯時跟蹤的內容來干預,從而選擇要跟蹤的分支。

  2. torch.cond():我們可以重寫控制流程式碼以使用 torch.cond(),這樣我們就不會特化到某個分支。

雖然這些選項是有效的,但它們也有其缺點。選項 1 有時需要對模型程式碼進行激進的、侵入性的重寫才能進行特化,而 torch.cond() 並不是處理依賴資料錯誤的全面系統。正如我們將看到的,存在不涉及控制流的依賴資料錯誤。

通常推薦的方法是先使用 torch._check() 呼叫。雖然這些看起來只是 assert 語句,但實際上它們是一個告知編譯器符號屬性的系統。雖然 torch._check() 呼叫在執行時充當斷言,但在編譯時跟蹤時,被檢查的表示式會被髮送到符號形狀子系統進行推理,並且任何由表示式為真而得出的符號屬性都會被儲存為符號屬性(前提是它足夠智慧以推斷這些屬性)。因此,即使未備份的符號沒有提示,如果我們能夠透過 torch._check() 呼叫傳達對於這些符號普遍為真的屬性,我們也可能繞過依賴資料的守衛,而無需重寫有問題的模型程式碼。

例如,在上面的模型中,插入 torch._check(a >= 10) 會告訴編譯器 y + 2 總是可以返回,而 torch._check(a == 4) 告訴它返回 y * 5。看看當我們重新匯出這個模型時會發生什麼。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.811000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.816000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.817000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.819000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 10 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V1015 19:14:15.819000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[10, int_oo] (update)
I1015 19:14:15.824000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 <= 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V1015 19:14:15.825000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[10, 60] (update)
V1015 19:14:15.830000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
I1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.834000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 4 None
V1015 19:14:15.835000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.835000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.837000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 10 == True [statically known]
V1015 19:14:15.838000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 <= 60 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[4]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_2: "Sym(u0 >= 10)" = item >= 10
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 10 on node 'ge_2'");  ge_2 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 60)" = item <= 60;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
            add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2);  y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {u0: VR[10, 60]}

匯出成功,並從 range constraints 欄位注意到 u0 的範圍是 [10, 60]

那麼 torch._check() 呼叫實際上傳達了什麼資訊?這會隨著符號形狀子系統的改進而變化,但從根本上說,通常是以下幾點:

  1. 與非依賴資料的表示式相等:torch._check() 呼叫傳達了諸如 u0 == s0 + 4u0 == 5 之類的相等性。

  2. 範圍細化:提供了符號下界或上界的呼叫,如上所示。

  3. 對更復雜表示式的一些基本推理:插入 torch._check(a < 4) 通常會告訴編譯器 a >= 4 為 false。對複雜表示式的檢查,例如 torch._check(a ** 2 - 3 * a <= 10),通常可以讓你繞過相同的守衛。

如前所述,torch._check() 呼叫在依賴資料的控制流之外也有應用。例如,這是一個模型,其中插入 torch._check() 是必需的,因為手動特化和 torch.cond() 都無法奏效。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()
I1015 19:14:15.844000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.850000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.850000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.852000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u0 >= 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:388 in meta_select)
I1015 19:14:15.852000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.853000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u0 < 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:390 in meta_select)
I1015 19:14:15.853000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.854000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:402 in meta_select)
I1015 19:14:15.855000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate u1 >= 0 due to data dependency, it was assumed to be True with no runtime assertions (utils/_stats.py:28 in wrapper)
I1015 19:14:15.855000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.856000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u1]
I1015 19:14:15.858000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.859000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None

這是一個需要插入 torch._check() 僅僅是為了防止操作失敗的場景。匯出呼叫將失敗並顯示“無法在依賴資料的表示式 -u0 > 60 上設定守衛”,這意味著編譯器不知道這是否是有效的索引操作 - x 的值是否超出了 y 的界限。在這裡,手動特化過於繁瑣,而 torch.cond() 也沒有用武之地。相反,告知編譯器 u0 的範圍就足夠了。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I1015 19:14:15.863000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.869000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.869000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.871000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 0 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V1015 19:14:15.871000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, int_oo] (update)
I1015 19:14:15.874000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 < 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V1015 19:14:15.875000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, 59] (update)
V1015 19:14:15.877000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
V1015 19:14:15.878000 18276 torch/fx/experimental/symbolic_shapes.py:7485] eval False == True [statically known]
I1015 19:14:15.881000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.881000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.882000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.884000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
V1015 19:14:15.886000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 <= 59 == True [statically known]
V1015 19:14:15.887000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 < 60 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le: "Sym(u0 <= 59)" = item <= 59
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 59 on node 'le'");  le = _assert_scalar_default_1 = None

             #
            lt_1: "Sym(u0 < 60)" = item < 60
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'");  lt_1 = _assert_scalar_default_2 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
            select: "f32[]" = torch.ops.aten.select.int(y, 0, item);  y = item = None
            return (select,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    select: USER_OUTPUT

Range constraints: {u0: VR[0, 59]}

特化值#

另一類依賴資料的錯誤發生在程式在跟蹤時嘗試提取具體的依賴資料的整數/浮點值。這看起來像“無法從依賴資料的表示式中提取特化整數”,並且與前一類錯誤類似 - 如果在嘗試評估具體整數/浮點值時發生這些錯誤,則在評估具體布林值時會出現依賴資料的守衛錯誤。

此錯誤通常發生在對依賴資料的表示式進行顯式或隱式 int() 轉換時。例如,這個列表推導式中的 range() 呼叫隱式地對列表的大小進行了 int() 轉換。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()
I1015 19:14:15.893000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.899000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.899000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532] Data dependent variable 'u0' allocated at:
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/bin/sphinx-build", line 7, in <module>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     sys.exit(main())
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 339, in main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return make_main(argv)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 213, in make_main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return make_mode.run_make_mode(argv[1:])
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 181, in run_make_mode
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return make.run_generic_build(args[0])
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 169, in run_generic_build
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return build_main(args + opts)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 293, in build_main
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 272, in __init__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self._init_builder()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 343, in _init_builder
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self.events.emit('builder-inited')
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 97, in emit
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     results.append(listener.handler(self.app, *args))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 757, in generate_gallery_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ) = generate_dir_rst(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 606, in generate_dir_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     results = parallel(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 607, in <genexpr>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/var/lib/workspace/conf.py", line 85, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     p.start()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self._popen = self._Popen(self)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return _default_context.get_context().Process._Popen(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return Popen(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self._launch(process_obj)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     code = process_obj._bootstrap(parent_sentinel=child_r)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self.run()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     self._target(*self._args, **self._kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/var/lib/workspace/conf.py", line 73, in call_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     result = func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1374, in generate_file_rst
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     output_blocks, time_elapsed = execute_script(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1192, in execute_script
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     execute_code_block(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1048, in execute_code_block
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     is_last_expr, mem_max = _exec_and_get_memory(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 876, in _exec_and_get_memory
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     mem_max, _ = call_memory(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1725, in _sg_call_memory_noop
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return 0.0, func()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 794, in __call__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     exec(self.code, self.fake_main.__dict__)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     export(Foo(), inps, strict=False)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return _export(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ep = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ep = _export_for_training(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ep = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     export_artifact = export_func(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     aten_export_artifact = _to_aten_func(  # type: ignore[operator]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     gm, graph_signature = transform(_make_fx_helper)(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     gm = make_fx(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return make_fx_tracer.trace(f, *args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._trace_inner(f, *args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     t = dispatch_trace(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return disable_fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     res = super().trace(root, concrete_args)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     (self.create_arg(fn(*args)),),
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     out = f(*tensors)  # type:ignore[call-arg]
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "<string>", line 1, in <lambda>
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return tuple(flat_fn(*args))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     tree_out = fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     out = mod(*args[params_len:], **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self.call_module(mod, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return Tracer.call_module(self, m, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ret_val = forward(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return _orig_module_call(mod, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._call_impl(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return forward_call(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     tree_out = mod(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self.call_module(mod, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return Tracer.call_module(self, m, forward, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     ret_val = forward(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return _orig_module_call(mod, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._call_impl(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return forward_call(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     a = x.item()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1409, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1479, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1066, in __torch_function__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 962, in handler
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return torch._library.utils.handle_dispatch_mode(
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 286, in handle_dispatch_mode
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 28, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1534, in __torch_dispatch__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 994, in proxy_call
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     out = func(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 841, in __call__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._op(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 28, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return fn(*args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1376, in __torch_dispatch__
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self.dispatch(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2096, in dispatch
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._cached_dispatch_impl(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1498, in _cached_dispatch_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return self._dispatch_impl(func, types, args, kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2725, in _dispatch_impl
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     op_impl_out = op_impl(self, func, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 169, in dispatch_to_op_implementations_dict
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 651, in local_scalar_dense
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     r = fake_mode.shape_env.create_unbacked_symint()
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]     return retlog(fn(*args, **kwargs))
V1015 19:14:15.900000 18276 torch/fx/experimental/symbolic_shapes.py:6532]



def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
     # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
    item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None




def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
     # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
    item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
    export(Foo(), inps, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 277, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2255, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1163, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1129, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2071, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2002, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1793, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1922, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1706, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2429, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2356, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2318, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1303, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1908, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 868, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1361, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1593, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1906, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1997, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 836, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
    b = torch.cat([y for y in range(a)], dim=0)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 449, in __index__
    return self.node.int_()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 468, in int_
    return self.guard_int("", 0)  # NB: uses Python backtrace
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 518, in guard_int
    r = self.evaluate()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7233, in evaluate_sym_node
    return self.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7333, in evaluate_expr
    return self._inner_evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7356, in _inner_evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7574, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

對於這些錯誤,您的一些基本選項是:

  1. 避免不必要的 int() 轉換呼叫,例如此處的返回語句中的 int(a)

  2. 使用 torch._check() 呼叫;不幸的是,在這種情況下,您可能只能特化(使用 torch._check(a == 60))。

  3. 在更高級別重寫有問題的程式碼。例如,列表推導式在語義上是 repeat() 操作,不涉及 int() 轉換。以下重寫避免了依賴資料的錯誤。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I1015 19:14:15.918000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:15.924000 18276 torch/fx/experimental/symbolic_shapes.py:4780] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:651 in local_scalar_dense)
I1015 19:14:15.924000 18276 torch/fx/experimental/symbolic_shapes.py:1289] compute_unbacked_bindings [u0]
I1015 19:14:15.927000 18276 torch/fx/experimental/symbolic_shapes.py:7207] runtime_assert u0 >= 0 [guard added] (_meta_registrations.py:4109 in meta_repeat), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V1015 19:14:15.928000 18276 torch/fx/experimental/symbolic_shapes.py:6619] _update_var_to_range u0 = VR[0, int_oo] (update)
V1015 19:14:15.929000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
I1015 19:14:15.932000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate Eq(u0, 0) due to data dependency, it was assumed to be False with no runtime assertions (utils/_stats.py:28 in wrapper)
I1015 19:14:15.932000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.938000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate 60*u0 < 2 due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:310 in is_contiguous)
I1015 19:14:15.938000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.939000 18276 torch/fx/experimental/symbolic_shapes.py:7379] could not evaluate Eq(u0, 1) due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:276 in check_contiguous_sizes_strides)
I1015 19:14:15.939000 18276 torch/fx/experimental/symbolic_shapes.py:7379] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I1015 19:14:15.946000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].size()[0] 60 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].stride()[0] 1 None
V1015 19:14:15.947000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['y'].storage_offset() 0 None
V1015 19:14:15.949000 18276 torch/fx/experimental/symbolic_shapes.py:7700] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
            unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0);  y = None
            repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]);  unsqueeze = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
            add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item);  repeat = item = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {u0: VR[0, int_oo]}

依賴資料的錯誤可能更加複雜,您工具箱中有許多選項可以處理它們:torch._check_is_size()guard_size_oblivious() 或 real-tensor 跟蹤,僅舉幾例。有關更深入的指南,請參閱匯出程式設計模型,或處理 GuardOnDataDependentSymNode 錯誤

自定義操作#

torch.export 可以匯出帶有自定義運算子的 PyTorch 程式。請參閱此頁面瞭解如何在 C++ 或 Python 中編寫自定義運算子。

以下是在 Python 中註冊自定義運算子以供 torch.export 使用的示例。需要注意的重要一點是,自定義運算子必須具有FakeTensor 核心

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

以下是匯出帶有自定義運算子的程式的示例。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I1015 19:14:16.027000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 3 None
V1015 19:14:16.038000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 3 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 3 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:16.039000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
            custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin);  sin = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op);  custom_op = None
            return (cos,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    cos: USER_OUTPUT

Range constraints: {}

custom_op called!
tensor([[1.0000, 0.5618, 0.9935],
        [0.9409, 0.9454, 0.8364],
        [0.5766, 1.0000, 1.0000]])

請注意,在 ExportedProgram 中,自定義運算子已包含在圖中。

IR/分解#

torch.export 生成的圖僅包含ATen 運算子,它們是 PyTorch 中的基本計算單元。由於 ATen 運算子有 3000 多個,匯出提供了一種根據某些特徵縮小圖中使用的運算子集的方法,從而建立不同的 IR。

預設情況下,匯出生成最通用的 IR,其中包含所有 ATen 運算子,包括函式式和非函式式運算子。函式式運算子是不包含輸入任何變異或別名的運算子。您可以在此處找到所有 ATen 運算子的列表,並且可以透過檢查 op._schema.is_mutable 來判斷一個運算子是否是函式式的,例如。

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True

此通用 IR 可用於在 eager PyTorch Autograd 中進行訓練。此 IR 可以透過 torch.export.export_for_training API 更明確地訪問,該 API 在 PyTorch 2.5 中引入,但呼叫 torch.export.export 應該生成與 PyTorch 2.6 版本相同的圖。

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
/var/lib/workspace/intermediate_source/torch_export_tutorial.py:862: FutureWarning:

`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent.

I1015 19:14:16.050000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:16.083000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 1 None
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 1 None
V1015 19:14:16.084000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[2] 3 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[3] 3 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 9 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 9 None
V1015 19:14:16.085000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[2] 3 None
V1015 19:14:16.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[3] 1 None
V1015 19:14:16.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
    return (batch_norm,)

然後,我們可以透過 run_decompositions API 將此匯出的程式降低到一個僅包含函式式 ATen 運算子的運算子集,該 API 將 ATen 運算子分解為分解表中指定的運算子,並將圖函式化。透過指定一個空集,我們僅執行函式化,而不執行任何額外的分解。這會生成一個包含約 2000 個運算子(而不是上面的 3000 個運算子)的 IR,非常適合推理場景。

graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

如我們所見,以前的可變運算子 torch.ops.aten.add_.default 已被替換為 torch.ops.aten.add.default,這是一個函式式運算子。

我們還可以進一步將此匯出的程式降低到一個僅包含核心 ATen 運算子集的運算子集,這是一個僅包含約 180 個運算子的集合。此 IR 對於不想重新實現所有 ATen 運算子的後端來說是最優的。

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

我們現在看到 torch.ops.aten.conv2d.default 已被分解為 torch.ops.aten.convolution.default。這是因為 convolution 是一個更“核心”的運算子,因為 conv1dconv2d 等操作可以使用相同的運算子實現。

我們還可以指定自己的分解行為。

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

請注意,torch.ops.aten.conv2d.default 沒有被分解為 torch.ops.aten.convolution.default,而是被分解為 torch.ops.aten.convolution.defaulttorch.ops.aten.mul.Tensor,這與我們的自定義分解規則相匹配。

ExportDB#

torch.export 僅從 PyTorch 程式中匯出一個計算圖。由於此要求,將存在與 torch.export 不相容的 Python 或 PyTorch 功能,這將要求使用者重寫模型程式碼的某些部分。我們在本教程前面已經看到過例子 - 例如,使用 cond 重寫 if 語句。

ExportDB 是記錄 torch.export 支援和不支援的 Python/PyTorch 功能的標準參考。它本質上是一個程式樣本列表,每個樣本代表一個特定 Python/PyTorch 功能的使用及其與 torch.export 的互動。樣本還按類別標記,以便更容易搜尋。

例如,讓我們使用 ExportDB 來更好地理解 cond 運算子中的謂詞是如何工作的。我們可以檢視名為 cond_predicate 的示例,它有一個 torch.cond 標籤。示例程式碼如下所示。

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更一般地說,當發生以下情況之一時,可以使用 ExportDB 作為參考:

  1. 在嘗試 torch.export 之前,您已提前知道您的模型使用了某些複雜的 Python/PyTorch 功能,並且您想知道 torch.export 是否涵蓋了該功能。

  2. 在嘗試 torch.export 時出現故障,並且不清楚如何解決。

ExportDB 並非詳盡無遺,但旨在涵蓋典型 PyTorch 程式碼中的所有用例。如果您發現有重要的 Python/PyTorch 功能應該新增到 ExportDB 或由 torch.export 支援,請隨時聯絡我們。

執行匯出的程式#

由於 torch.export 僅是一個圖捕獲機制,因此呼叫 torch.export 生成的產物進行 eager 執行將等同於執行 eager 模組。為了最佳化匯出的程式的執行,我們可以將此匯出的產物傳遞給後端,例如透過 torch.compileAOTInductorTensorRT 的 Inductor。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I1015 19:14:17.072000 18276 torch/fx/experimental/symbolic_shapes.py:3769] create_env
I1015 19:14:17.085000 18276 torch/fx/experimental/symbolic_shapes.py:5242] produce_guards
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[0] 2 None
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].size()[1] 3 None
V1015 19:14:17.086000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[0] 3 None
V1015 19:14:17.087000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].stride()[1] 1 None
V1015 19:14:17.087000 18276 torch/fx/experimental/symbolic_shapes.py:5462] track_symint L['x'].storage_offset() 0 None
tensor([[ 1.1126,  0.1263,  0.8522],
        [ 0.2973, -1.2118,  1.0370]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
I1015 19:14:18.236000 18276 torch/fx/experimental/symbolic_shapes.py:3769] [2/0] create_env
/usr/local/lib/python3.10/dist-packages/torch/backends/cuda/__init__.py:131: UserWarning:

Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.com.tw/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)

/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:312: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

I1015 19:14:19.222000 18276 torch/fx/experimental/symbolic_shapes.py:5242] [2/0] produce_guards
I1015 19:14:19.233000 18276 torch/fx/experimental/symbolic_shapes.py:5242] [2/0] produce_guards
V1015 19:14:19.233000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].size()[0] 2 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].size()[1] 3 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].stride()[0] 3 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].stride()[1] 1 None
V1015 19:14:19.234000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['x'].storage_offset() 0 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V1015 19:14:19.235000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V1015 19:14:19.236000 18276 torch/fx/experimental/symbolic_shapes.py:5462] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].size()[0] == 2
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].size()[1] == 3
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].stride()[0] == 3
V1015 19:14:19.237000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].stride()[1] == 1
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['x'].storage_offset() == 0
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V1015 19:14:19.238000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V1015 19:14:19.239000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V1015 19:14:19.240000 18276 torch/fx/experimental/symbolic_shapes.py:5675] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 1.1126,  0.1263,  0.8522],
        [ 0.2973, -1.2118,  1.0370]], device='cuda:0',
       grad_fn=<CompiledFunctionBackward>)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.com.tw/docs/stable/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

結論#

我們介紹了 torch.export,這是 PyTorch 2.X 版本中從 PyTorch 程式匯出單個計算圖的新方法。特別是,我們演示了為了匯出圖需要進行的一些程式碼修改和考慮(控制流操作、約束等)。

指令碼總執行時間: (0 分鐘 8.501 秒)