評價此頁

torch.export API 參考#

創建於: 2025年7月17日 | 最後更新於: 2025年7月17日

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=False, preserve_module_call_signature=(), prefer_deferred_runtime_asserts_over_guards=False)[source]#

export() 接收任意 nn.Module 和示例輸入,以提前(AOT)方式生成一個僅表示函式張量計算的跟蹤圖,之後可以使用不同的輸入執行或序列化該圖。跟蹤圖 (1) 在函式式 ATen 運算子集中生成規範化的運算子(以及使用者指定的任何自定義運算子),(2) 消除了所有 Python 控制流和資料結構(某些例外情況除外),並且 (3) 記錄了顯示這種規範化和控制流消除對於未來輸入是健全的形狀約束。

健全性保證

在跟蹤期間,export() 會記錄使用者程式和底層 PyTorch 運算子核心所做的與形狀相關的假設。只有當這些假設成立時,生成的 ExportedProgram 才被認為是有效的。

跟蹤會做出關於輸入張量形狀(而非值)的假設。為了使 export() 成功,這些假設必須在圖捕獲時進行驗證。具體來說:

  • 對輸入張量靜態形狀的假設無需額外工作即可自動驗證。

  • 對輸入張量動態形狀的假設需要透過使用 Dim() API 來構造動態維度,並透過 dynamic_shapes 引數將其與示例輸入關聯來顯式指定。

如果任何假設無法驗證,將引發致命錯誤。發生這種情況時,錯誤訊息將包含對驗證假設所需的規範的建議修復。例如,export() 可能會為動態維度 dim0_x 的定義提出以下修復,該維度出現在輸入 x 的形狀中,該維度以前定義為 Dim("dim0_x")

dim = Dim("dim0_x", max=5)

此示例意味著生成的程式碼要求輸入 x 的維度 0 小於或等於 5 才能有效。您可以檢查動態維度定義的建議修復,然後將其逐字複製到您的程式碼中,而無需更改傳遞給 export() 呼叫的 dynamic_shapes 引數。

引數
  • mod (Module) – 我們將跟蹤此模組的 forward 方法。

  • args (tuple[Any, ...]) – 示例位置輸入。

  • kwargs (Optional[Mapping[str, Any]]) – 可選的示例關鍵字輸入。

  • dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]]) –

    一個可選引數,其型別應為:1)一個字典,從 f 的引數名稱對映到其動態形狀規範,2)一個元組,指定按原始順序排列的每個輸入的動態形狀規範。如果您正在為關鍵字引數指定動態性,則需要按照函式原始簽名中定義的順序傳遞它們。

    張量引數的動態形狀可以指定為:1)一個從動態維度索引到 Dim() 型別的字典,其中不需要在此字典中包含靜態維度索引,但當它們存在時,應將其對映到 None;或 2)一個 Dim() 型別或 None 的元組/列表,其中 Dim() 型別對應動態維度,靜態維度由 None 表示。由字典或張量元組/列表組成的引數透過使用包含的規範的對映或序列來遞迴指定。

  • strict (bool) – 當停用(預設)時,export 函式將透過 Python 執行時跟蹤程式,這本身不會驗證圖中的一些隱式假設。它仍然會驗證大多數關鍵假設,例如形狀安全性。當啟用(透過設定 strict=True)時,export 函式將透過 TorchDynamo 跟蹤程式,這將確保生成圖的健全性。TorchDynamo 對 Python 特性的覆蓋有限,因此您可能會遇到更多錯誤。請注意,切換此引數不會影響生成的 IR 規範,模型將以相同的方式序列化,無論此處傳遞什麼值。

  • preserve_module_call_signature (tuple[str, ...]) – 一個子模組路徑列表,將保留其原始呼叫約定作為元資料。呼叫 torch.export.unflatten 時將使用元資料來保留模組的原始呼叫約定。

返回

一個包含跟蹤的可呼叫物件的 ExportedProgram

返回型別

ExportedProgram

可接受的輸入/輸出型別

可接受的輸入(用於 argskwargs)和輸出型別包括:

  • 基本型別,即 torch.Tensorintfloatboolstr

  • 資料類,但它們必須首先透過呼叫 register_dataclass() 進行註冊。

  • (巢狀的) 資料結構,由 dictlisttuplenamedtupleOrderedDict 組成,其中包含上述所有型別。

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]#

來自 export() 的程式包。它包含一個 torch.fx.Graph,該圖表示張量計算,一個包含所有提升的引數和緩衝區張量值的 state_dict,以及各種元資料。

您可以使用與 export() 跟蹤的原始可呼叫物件相同的呼叫約定來呼叫 ExportedProgram。

要對圖執行轉換,請使用 `.module` 屬性訪問 torch.fx.GraphModule。然後,您可以使用 FX 轉換 來重寫圖。之後,您只需再次使用 export() 即可構建一個正確的 ExportedProgram。

buffers()[source]#

返回原始模組緩衝區的迭代器。

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

返回型別

Iterator[Tensor]

property call_spec#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property constants#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property dialect: str#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property example_inputs#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property graph#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property graph_module#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property graph_signature#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

module(check_guards=True)[source]#

返回一個自包含的 GraphModule,其中所有引數/緩衝區都已內聯。

  • check_guards=True (預設) 時,將生成一個 _guards_fn 子模組,並在圖中的佔位符之後插入一個對 _guards_fn 子模組的呼叫。此模組檢查輸入的 guard。

  • check_guards=False 時,一部分檢查將由圖模組的 forward pre-hook 執行。不會生成 _guards_fn 子模組。

返回型別

GraphModule

property module_call_graph#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

named_buffers()[source]#

返回一個迭代器,其中包含原始模組緩衝區,同時生成緩衝區的名稱以及緩衝區本身。

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

返回型別

Iterator[tuple[str, torch.Tensor]]

named_parameters()[source]#

返回一個迭代器,其中包含原始模組引數,同時生成引數的名稱以及引數本身。

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

返回型別

Iterator[tuple[str, torch.nn.parameter.Parameter]]

parameters()[source]#

返回原始模組引數的迭代器。

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

返回型別

Iterator[Parameter]

property range_constraints#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source]#

對匯出的程式執行一組分解,並返回一個新的匯出的程式。預設情況下,我們將執行核心 ATen 分解,以在 Core ATen Operator Set 中獲得運算子。

目前,我們不分解聯合圖。

引數

decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 一個可選引數,指定 Aten ops 的分解行為 (1) 如果為 None,我們分解為核心 aten 分解 (2) 如果為空,我們不分解任何運算子

返回型別

ExportedProgram

一些例子

如果您不想分解任何內容

ep = torch.export.export(model, ...)
ep = ep.run_decompositions(decomp_table={})

如果您想獲取核心 aten 運算子集,但排除某些運算子,您可以這樣做:

ep = torch.export.export(model, ...)
decomp_table = torch.export.default_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table)
property state_dict#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property tensor_constants#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

validate()[source]#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property verifier: Any#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

property verifiers#

警告

此 API 仍處於實驗階段,並且 *不* 向後相容。

class torch.export.dynamic_shapes.AdditionalInputs[source]#

根據附加輸入推斷 dynamic_shapes。

這對於部署工程師特別有用,他們一方面可能擁有充足的測試或分析資料,可以提供對模型代表性輸入的良好認識,但另一方面,他們可能對模型瞭解不夠,無法猜測哪些輸入形狀應該是動態的。

與原始輸入不同的輸入形狀被視為動態;反之,與原始輸入相同的形狀被視為靜態。此外,我們驗證附加輸入對於匯出的程式是有效的。這保證了用它們代替原始輸入進行跟蹤會生成相同的圖。

示例

args0, kwargs0 = ...  # example inputs for export

# other representative inputs that the exported program will run on
dynamic_shapes = torch.export.AdditionalInputs()
dynamic_shapes.add(args1, kwargs1)
...
dynamic_shapes.add(argsN, kwargsN)

torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
add(args, kwargs=None)[source]#

附加輸入 args()kwargs()

dynamic_shapes(m, args, kwargs=None)[source]#

透過合併原始輸入 args()kwargs() 以及每個附加輸入 args 和 kwargs 的形狀,推斷出 dynamic_shapes() 的 Pytree 結構。

verify(ep)[source]#

驗證匯出的程式對於每個附加輸入是否有效。

class torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]#

Dim 類允許使用者在匯出的程式中指定動態性。透過用 Dim 標記一個維度,編譯器會將該維度與包含動態範圍的符號整數關聯起來。

該 API 可以以兩種方式使用:Dim 提示(即自動動態形狀:Dim.AUTODim.DYNAMICDim.STATIC)或命名 Dim(即 Dim("name", min=1, max=2))。

Dim 提示提供了匯出能力的最低門檻,使用者只需指定維度是動態的、靜態的,還是由編譯器決定(Dim.AUTO)。匯出過程將自動推斷剩餘的關於最小/最大範圍以及維度之間關係的約束。

示例

class Foo(nn.Module):
    def forward(self, x, y):
        assert x.shape[0] == 4
        assert y.shape[0] >= 16
        return x @ y


x = torch.randn(4, 8)
y = torch.randn(8, 16)
dynamic_shapes = {
    "x": {0: Dim.AUTO, 1: Dim.AUTO},
    "y": {0: Dim.AUTO, 1: Dim.AUTO},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

在這裡,如果我們用 Dim.DYNAMIC 替換所有 Dim.AUTO 的用法,匯出將引發異常,因為模型已將 x.shape[0] 約束為靜態。

維度之間更復雜的關係也可能被編譯器編碼為執行時斷言節點,例如 (x.shape[0] + y.shape[1]) % 4 == 0,如果執行時輸入不滿足這些約束,將引發該斷言。

您還可以為 Dim 提示指定最小-最大邊界,例如 Dim.AUTO(min=16, max=32)Dim.DYNAMIC(max=64),編譯器將在這些範圍內的剩餘約束進行推斷。如果有效範圍完全超出使用者指定的範圍,將引發異常。

命名 Dim 提供了一種更嚴格的方式來指定動態性,如果編譯器推斷出的約束與使用者規範不匹配,則會引發異常。例如,匯出之前的模型,使用者將需要以下 dynamic_shapes 引數。

s0 = Dim("s0")
s1 = Dim("s1", min=16)
dynamic_shapes = {
    "x": {0: 4, 1: s0},
    "y": {0: s0, 1: s1},
}
ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)

命名 Dim 還允許指定維度之間的關係,最多為單變數線性關係。例如,以下表示一個維度是另一個維度的倍數加 4。

s0 = Dim("s0")
s1 = 3 * s0 + 4
class torch.export.dynamic_shapes.ShapesCollection[source]#

dynamic_shapes 的構建器。用於為輸入中出現的張量分配動態形狀規範。

這特別有用,當 args() 是巢狀輸入結構時,索引輸入張量比在 dynamic_shapes() 規範中複製 args() 的結構更容易。

示例

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

要為整數指定動態性,我們需要先用 `_IntWrapper` 包裝整數,這樣我們就可以為每個整數擁有一個“唯一識別符號”。

示例

args = {"x": tensor_x, "others": [int_x, int_y]}
# Wrap all ints with _IntWrapper
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)

dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC

# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)
dynamic_shapes(m, args, kwargs=None)[source]#

根據 args()kwargs() 生成 dynamic_shapes() Pytree 結構。

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]#

當使用 dynamic_shapes() 進行匯出時,如果規範與從跟蹤模型推斷出的約束不匹配,匯出可能會因 ConstraintViolation 錯誤而失敗。錯誤訊息可能會提供建議的修復 - 可以對 dynamic_shapes() 進行的更改,以成功匯出。

示例 ConstraintViolation 錯誤訊息

Suggested fixes:

    dim = Dim('dim', min=3, max=6)  # this just refines the dim's range
    dim = 4  # this specializes to a constant
    dy = dx + 1  # dy was specified as an independent dim, but is actually tied to dx with this relation

這是一個輔助函式,它接收 ConstraintViolation 錯誤訊息和原始 dynamic_shapes() 規範,並返回一個包含建議修復的新 dynamic_shapes() 規範。

使用示例

try:
    ep = export(mod, args, dynamic_shapes=dynamic_shapes)
except torch._dynamo.exc.UserError as exc:
    new_shapes = refine_dynamic_shapes_from_suggested_fixes(
        exc.msg, dynamic_shapes
    )
    ep = export(mod, args, dynamic_shapes=new_shapes)
返回型別

Union[dict[str, Any], tuple[Any], list[Any]]

torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

警告

正在積極開發中,儲存的檔案可能無法在新版本的 PyTorch 中使用。

ExportedProgram 儲存到檔案類物件。然後可以使用 Python API torch.export.load 載入它。

引數
  • ep (ExportedProgram) – 要儲存的匯出的程式。

  • f (str | os.PathLike[str] | IO[bytes]) – 實現 write 和 flush 的檔案物件,或包含檔名的字串。

  • extra_files (Optional[Dict[str, Any]]) – 從檔名到內容的對映,將作為 f 的一部分儲存。

  • opset_version (Optional[Dict[str, int]]) – 一個 opset 名稱到該 opset 版本的對映

  • pickle_protocol (int) – 可以指定以覆蓋預設協議

示例

import torch
import io


class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10


ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, "exported_program.pt2")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {"foo.txt": b"bar".decode("utf-8")}
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]#

警告

正在積極開發中,儲存的檔案可能無法在新版本的 PyTorch 中使用。

載入之前使用 torch.export.save 儲存的 ExportedProgram

引數
  • f (str | os.PathLike[str] | IO[bytes]) – 檔案類物件(必須實現 write 和 flush)或包含檔名的字串。

  • extra_files (Optional[Dict[str, Any]]) – 此對映中提供的附加檔名將被載入,其內容將儲存在提供的對映中。

  • expected_opset_version (Optional[Dict[str, int]]) – 一個 opset 名稱到預期 opset 版本的對映

返回

一個 ExportedProgram 物件

返回型別

ExportedProgram

示例

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load("exported_program.pt2")

# Load ExportedProgram from io.BytesIO object
with open("exported_program.pt2", "rb") as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {"foo.txt": ""}  # values will be replaced with data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])
print(ep(torch.randn(5)))
torch.export.pt2_archive._package.package_pt2(f, *, exported_programs=None, aoti_files=None, extra_files=None, opset_version=None, pickle_protocol=2)[source]#

將工件儲存為 PT2Archive 格式。該工件隨後可以使用 `load_pt2` 載入。

引數
  • f (str | os.PathLike[str] | IO[bytes]) – 檔案類物件(必須實現 write 和 flush)或包含檔名的字串。

  • exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]) – 要儲存的匯出的程式,或者是一個將模型名稱對映到匯出的程式的字典。匯出的程式將儲存在 models/*.json 下。如果只指定了一個 ExportedProgram,它將自動命名為“model”。

  • aoti_files (Union[list[str], dict[str, list[str]]]) – 由 AOTInductor 透過 torch._inductor.aot_compile(..., {"aot_inductor.package": True}) 生成的檔案列表,或者是一個將模型名稱對映到其 AOTInductor 生成檔案的字典。如果只指定了一組檔案,它將自動命名為“model”。

  • extra_files (Optional[Dict[str, Any]]) – 從檔名到內容的對映,將作為 pt2 的一部分儲存。

  • opset_version (Optional[Dict[str, int]]) – 一個 opset 名稱到該 opset 版本的對映

  • pickle_protocol (int) – 可以指定以覆蓋預設協議

返回型別

Union[str, PathLike[str], IO[bytes]]

torch.export.pt2_archive._package.load_pt2(f, *, expected_opset_version=None, run_single_threaded=False, num_runners=1, device_index=-1, load_weights_from_disk=False)[source]#

載入使用 `package_pt2` 儲存的所有工件。

引數
  • f (str | os.PathLike[str] | IO[bytes]) – 檔案類物件(必須實現 write 和 flush)或包含檔名的字串。

  • expected_opset_version (Optional[Dict[str, int]]) – 一個 opset 名稱到預期 opset 版本的對映

  • num_runners (int) – 載入 AOTInductor 工件的執行器數量

  • run_single_threaded (bool) – 模型是否應在沒有執行緒同步邏輯的情況下執行。這有助於避免與 CUDAGraphs 衝突。

  • device_index (int) – 將 PT2 包載入到的裝置索引。預設情況下,使用 device_index=-1,當使用 CUDA 時,它對應於 cuda 裝置。例如,傳遞 device_index=1 會將包載入到 cuda:1

返回

一個包含 PT2 中所有物件的 PT2ArchiveContents 物件。

返回型別

PT2ArchiveContents

torch.export.draft_export(mod, args, kwargs=None, *, dynamic_shapes=None, preserve_module_call_signature=(), strict=False, prefer_deferred_runtime_asserts_over_guards=False)[source]#

一個 `torch.export.export` 的版本,旨在始終如一地生成 ExportedProgram,即使存在潛在的健全性問題,並生成一份報告列出發現的問題。

返回型別

ExportedProgram

class torch.export.unflatten.FlatArgsAdapter[source]#

使用 `input_spec` 調整輸入引數,以匹配 `target_spec`。

abstract adapt(target_spec, input_spec, input_args, metadata=None, obj=None)[source]#

注意:此介面卡可能會修改給定的 `input_args_with_path`。

返回型別

list[Any]

get_flat_arg_paths()[source]#

返回用於訪問扁平引數的路徑列表。

返回型別

list[str]

class torch.export.unflatten.InterpreterModule(graph, ty=None)[source]#

一個使用 torch.fx.Interpreter 執行的模組,而不是 GraphModule 使用的常規程式碼生成。這提供了更好的堆疊跟蹤資訊,並使執行除錯更容易。

class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source]#

一個模組,它攜帶與該模組呼叫序列相對應的 InterpreterModules 序列。每次呼叫模組時,它會分派到下一個 InterpreterModule,並在最後一個之後迴繞。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]#

取消扁平化 ExportedProgram,生成一個具有與原始 eager 模組相同模組層次結構的模組。如果您嘗試將 torch.export 與期望模組層次結構而不是 torch.export 通常生成的扁平圖的其他系統一起使用,這會很有用。

注意

未扁平化模組的 args/kwargs 不一定與 eager 模組匹配,因此進行模組交換(例如,self.submod = new_mod)不一定有效。如果您需要交換模組,則需要設定 torch.export.export()preserve_module_call_signature 引數。

引數
  • module (ExportedProgram) – 要取消扁平化的 ExportedProgram。

  • flat_args_adapter (Optional[FlatArgsAdapter]) – 如果輸入 TreeSpec 與匯出的模組不匹配,則調整扁平引數。

返回

一個 UnflattenedModule 例項,它具有與匯出現象之前的原始 eager 模組相同的模組層次結構。

返回型別

UnflattenedModule

torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]#

將資料類註冊為 torch.export.export() 的有效輸入/輸出型別。

引數
  • cls (type[Any]) – 要註冊的資料類型別

  • serialized_type_name (Optional[str]) – 資料類的序列化名稱。這是

  • this (required if you want to serialize the pytree TreeSpec containing) – (如果要序列化包含資料類的 Pytree TreeSpec,則需要此項)

  • dataclass.

示例

import torch
from dataclasses import dataclass


@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int


@dataclass
class OutputDataClass:
    res: torch.Tensor


torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)


class Mod(torch.nn.Module):
    def forward(self, x: InputDataClass) -> OutputDataClass:
        res = x.feature + x.bias
        return OutputDataClass(res=res)


ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
print(ep)
class torch.export.decomp_utils.CustomDecompTable[source]#

這是一個自定義字典,專門用於處理 export 中的 decomp_table。我們需要這個是因為在新世界中,您只能 *刪除* 一個 op 來從 decomp table 中保留它。這對於自定義 op 來說是個問題,因為我們不知道自定義 op 何時才能實際載入到排程器中。因此,我們需要記錄自定義 op 操作,直到我們真正需要實現它(這發生在執行分解傳遞時)。

我們保持的不變數是:
  1. 所有 aten 分解都在初始化時載入

  2. 當用戶從表中讀取時,我們會實現 *所有* op,以便排程器更有可能拾取自定義 op。

  3. 如果是寫操作,我們不一定實現

  4. 我們將在呼叫 run_decompositions() 之前,在 export 的最後一次載入。

copy()[source]#
返回型別

CustomDecompTable

items()[source]#
keys()[source]#
materialize()[source]#
返回型別

dict[torch._ops.OperatorBase, Callable]

pop(*args)[source]#
update(other_dict)[source]#
torch.export.passes.move_to_device_pass(ep, location)[source]#

將匯出的程式移動到指定裝置。

引數
  • ep (ExportedProgram) – 要移動的匯出的程式。

  • location (Union[torch.device, str, Dict[str, str]]) – 要將匯出的程式移動到的裝置。如果為字串,則將其解釋為裝置名稱。如果是字典,則將其解釋為從現有裝置到目標裝置的對映。

返回

移動後的匯出的程式。

返回型別

ExportedProgram

class torch.export.pt2_archive.PT2ArchiveReader(archive_path_or_buffer)#

用於讀取 PT2 存檔的上下文管理器。

archive_version()[source]#

獲取存檔版本。

返回型別

int

get_file_names()[source]#

獲取存檔中的檔名。

返回型別

list[str]

read_bytes(name)[source]#

從存檔中讀取位元組物件。name:存檔內的原始檔名。

返回型別

位元組

read_string(name)[source]#

從存檔中讀取字串物件。name:存檔內的原始檔名。

返回型別

str

class torch.export.pt2_archive.PT2ArchiveWriter(archive_path_or_buffer)#

用於寫入 PT2 存檔的上下文管理器。

close()[source]#

關閉存檔。

count_prefix(prefix)[source]#

計算以給定字首開頭記錄的數量。

返回型別

int

has_record(name)[source]#

檢查存檔中是否存在記錄。

返回型別

布林值

write_bytes(name, data)[source]#

將位元組物件寫入存檔。name:存檔內的目標檔名。data:要寫入的位元組物件。

write_file(name, file_path)[source]#

將檔案複製到存檔中。name:存檔內的目標檔名。file_path:磁碟上的原始檔。

write_folder(archive_dir, folder_dir)[source]#

將資料夾複製到歸檔中。archive_dir:歸檔內的目標資料夾。folder_dir:磁碟上的原始檔夾。

write_string(name, data)[源]#

將字串物件寫入歸檔。name:歸檔內的目標檔名。data:要寫入的字串物件。

torch.export.pt2_archive.is_pt2_package(serialized_model)[源]#

檢查序列化模型是否為 PT2 歸檔包。

返回型別

布林值

class torch.export.exported_program.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[源]#
class torch.export.exported_program.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[源]#
torch.export.exported_program.default_decompositions()[源]#

這是預設的分解表,其中包含所有 ATEN 運算元到核心 aten opset 的分解。請將此 API 與 run_decompositions() 一起使用。

返回型別

CustomDecompTable

class torch.export.custom_obj.ScriptObjectMeta(constant_name, class_fqn)[源]#

儲存在代表 ScriptObjects 的節點上的元資料。

class torch.export.graph_signature.ConstantArgument(name: str, value: Union[int, float, bool, str, NoneType])[源]#
name: str#
value: Optional[Union[int, float, bool, str]]#
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[源]#
class_fqn: str#
fake_val: Optional[FakeScriptObject] = None#
name: str#
class torch.export.graph_signature.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[源]#
gradients_to_parameters: dict[str, str]#
gradients_to_user_inputs: dict[str, str]#
loss_output: str#
class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源]#

ExportGraphSignature 模擬 Export Graph 的輸入/輸出簽名,這是一個具有更強不變性保證的 fx.Graph。

Export Graph 是函式式的,不會透過 getattr 節點訪問圖內的“狀態”,例如引數或緩衝區。相反,export() 保證引數、緩衝區和常量張量被提升到圖外部作為輸入。類似地,對緩衝區的任何修改也不會包含在圖中,而是將修改後的緩衝區值建模為 Export Graph 的附加輸出。

所有輸入和輸出的順序是

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果匯出以下模組

class CustomModule(nn.Module):
    def __init__(self) -> None:
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer("my_buffer1", torch.tensor(3.0))
        self.register_buffer("my_buffer2", torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (
            x1 + self.my_parameter
        ) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0)  # In-place addition

        return output


mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

產生的圖是非函式式的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_1,)

非函式式圖產生的 ExportGraphSignature 將是

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_1: USER_OUTPUT

要獲得函式式圖,您可以使用 run_decompositions()

mod = CustomModule()
ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
ep = ep.run_decompositions()

產生的圖是函式式的

graph():
    %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
    %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
    %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
    return (add_2, add_1)

函式式圖產生的 ExportGraphSignature 將是

# inputs
p_my_parameter: PARAMETER target='my_parameter'
b_my_buffer1: BUFFER target='my_buffer1' persistent=True
b_my_buffer2: BUFFER target='my_buffer2' persistent=True
x1: USER_INPUT
x2: USER_INPUT

# outputs
add_2: BUFFER_MUTATION target='my_buffer2'
add_1: USER_OUTPUT
property assertion_dep_token: Optional[Mapping[int, str]]#
property backward_signature: Optional[ExportBackwardSignature]#
property buffers: Collection[str]#
property buffers_to_mutate: Mapping[str, str]#
get_replace_hook(replace_inputs=False)[源]#
input_specs: list[torch.export.graph_signature.InputSpec]#
property input_tokens: Collection[str]#
property inputs_to_buffers: Mapping[str, str]#
property inputs_to_lifted_custom_objs: Mapping[str, str]#
property inputs_to_lifted_tensor_constants: Mapping[str, str]#
property inputs_to_parameters: Mapping[str, str]#
property lifted_custom_objs: Collection[str]#
property lifted_tensor_constants: Collection[str]#
property non_persistent_buffers: Collection[str]#
output_specs: list[torch.export.graph_signature.OutputSpec]#
property output_tokens: Collection[str]#
property parameters: Collection[str]#
property parameters_to_mutate: Mapping[str, str]#
replace_all_uses(old, new)[源]#

在簽名中用新名稱替換所有舊名稱的使用。

property user_inputs: Collection[Union[int, float, bool, None, str]]#
property user_inputs_to_mutate: Mapping[str, str]#
property user_outputs: Collection[Union[int, float, bool, None, str]]#
class torch.export.graph_signature.InputKind(value)[源]#

一個列舉。

BUFFER = 3#
CONSTANT_TENSOR = 4#
CUSTOM_OBJ = 5#
PARAMETER = 2#
TOKEN = 6#
USER_INPUT = 1#
class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[源]#
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]#
kind: InputKind#
persistent: Optional[bool] = None#
target: Optional[str]#
class torch.export.graph_signature.OutputKind(value)[源]#

一個列舉。

BUFFER_MUTATION = 3#
GRADIENT_TO_PARAMETER = 5#
GRADIENT_TO_USER_INPUT = 6#
LOSS_OUTPUT = 2#
PARAMETER_MUTATION = 4#
TOKEN = 8#
USER_INPUT_MUTATION = 7#
USER_OUTPUT = 1#
class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[源]#
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument, ConstantArgument, CustomObjArgument, TokenArgument]#
kind: OutputKind#
target: Optional[str]#
class torch.export.graph_signature.SymBoolArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.SymFloatArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.SymIntArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.TensorArgument(name: str)[源]#
name: str#
class torch.export.graph_signature.TokenArgument(name: str)[源]#
name: str#