評價此頁

帶描述符的聯合#

創建於:2025 年 8 月 11 日 | 最後更新於:2025 年 8 月 11 日

帶描述符的聯合 (Joint with descriptors) 是一個實驗性 API,用於匯出支援 `torch.compile` 所有功能的、最通用的追蹤聯合圖 (traced joint graph),並且在處理後可以轉換回可微分的可呼叫物件 (differentiable callable),以正常方式執行。例如,它用於實現 autoparallel,這是一個接受模型並重新劃分輸入和引數以使其成為分散式 SPMD 程式的系統。

torch._functorch.aot_autograd.aot_export_joint_with_descriptors(stack, mod, args, kwargs=None, *, decompositions=None, keep_inference_input_mutations=False, ignore_shape_env=False, fw_compiler=<function boxed_nop_preserve_node_meta>, bw_compiler=<function boxed_nop_preserve_node_meta>)[source]#

此 API 捕獲 `nn.Module` 的聯合圖。然而,與 `aot_export_joint_simple` 或 `aot_export_module(trace_joint=True)` 不同,生成的聯合圖的呼叫約定不遵循固定的位置模式;例如,您不能依賴於追蹤聯合圖的第二個引數對應於您追蹤的模組的第二個引數。但是,追蹤圖的輸入和輸出是用 **描述符 (descriptors)** 進行模式化的,這些描述符標註在佔位符和返回的 FX 節點上的 `meta['desc']` 中,您可以使用它們來確定引數的含義。

與 `aot_export_joint_simple` 相比,使用此匯出的主要好處是,我們擁有 `torch.compile` 支援的所有情況(透過 `aot_module_simplified`)的功能對等性,包括處理更復雜的情況,例如多個可微分輸出、必須在圖外部處理的輸入變異、張量子類等。

您可以使用帶有描述符的聯合圖做什麼?其主要用例(autoparallel)涉及獲取聯合圖,對其進行最佳化,然後將其轉換回可呼叫物件,以便稍後可以進行 `torch.compile`。由於兩個原因,這不能作為傳統的 `torch.compile` 聯合圖傳遞完成:

  1. 引數的切分 (sharding) 必須在引數初始化/檢查點載入之前決定,這遠早於 `torch.compile` 通常執行的時間。

  2. 我們需要改變引數的含義(例如,我們可能會用分片版本替換複製引數,從而改變其輸入大小)。`torch.compile` 通常是語義保留的,不允許更改輸入的含義。

一些描述符可能相當奇異,因此我們建議仔細考慮是否存在一個安全的後備方案可以應用於您不理解的描述符。例如,您應該有一些方法來處理在最終 FX 圖輸入中找不到完全相同的特定輸入的情況。

注意:使用此 API 時,您必須建立並進入 `ExitStack` 上下文管理器,並將其傳遞給此函式。如果您呼叫 `compile` 函式來完成編譯,則此上下文管理器必須保持活動狀態。(TODO:我們可能會放寬此要求,讓 AOTAutograd 能夠跟蹤如何稍後重建所有上下文管理器。)

注意:您不必在第二階段執行 /完整的/ 編譯;相反,您可以不指定前向/後向編譯器,在這種情況下,分割槽後的 FX 圖將直接執行。整體 `autograd.Function` 可以保留在圖中,以便您可以在(可能更大)已編譯區域的上下文中稍後重新處理它。

注意:這些 API **不** 命中快取,因為我們只快取最終的編譯結果,而不快取中間匯出結果。

注意:如果傳入的 `nn.Module` 具有引數和緩衝區,我們將生成額外的隱式引數/緩衝區引數,併為其分配 `ParamAOTInput` 和 `BufferAOTInput` 描述符。但是,如果您從 Dynamo 等機制生成輸入 `nn.Module`,則不會得到這些描述符(因為 Dynamo 已經處理了將引數/緩衝區提升為引數!)。在這種情況下,有必要分析輸入的 `Sources` 以確定輸入是否是引數及其 FQN。

返回型別

JointWithDescriptors

torch._functorch.aot_autograd.aot_compile_joint_with_descriptors(jd)[source]#

與 `aot_export_joint_with_descriptors` 配套的函式,它將聯合圖編譯成一個遵循標準呼叫約定的可呼叫函式。`params_flat` 都是引數。

注意:我們 **不** 例項化模組;這為您提供了子類化並自定義其行為的靈活性,而無需擔心 FQN 重新繫結。

TODO:考慮我們是否應預設允許在圖中 (allow_in_graph) 返回結果。

返回型別

callable

描述符#

class torch._functorch._aot_autograd.descriptors.AOTInput[source]#

描述來自 AOTAutograd 生成的 FX 圖的輸入的來源

is_buffer()[source]#

如果此輸入是緩衝區或派生自緩衝區(例如,子類屬性),則為 True

返回型別

布林值

is_param()[source]#

如果此輸入是引數或派生自引數(例如,子類屬性),則為 True

返回型別

布林值

is_tangent()[source]#

如果此輸入是切線 (tangent) 或派生自切線(例如,子類屬性),則為 True

返回型別

布林值

class torch._functorch._aot_autograd.descriptors.AOTOutput[source]#

描述 AOTAutograd 生成的 FX 圖的輸出最終將如何打包到最終輸出中

is_grad()[source]#

如果此輸出是梯度或派生自梯度(例如,子類屬性),則為 True

返回型別

布林值

class torch._functorch._aot_autograd.descriptors.BackwardTokenAOTInput(idx)[source]#

用於反向傳播的副作用操作的世界令牌 (world token)

class torch._functorch._aot_autograd.descriptors.BackwardTokenAOTOutput(idx)[source]#

副作用呼叫的世界令牌輸出,返回以便我們不會對其進行 DCE(死程式碼消除),僅用於反向傳播

class torch._functorch._aot_autograd.descriptors.BufferAOTInput(target)[source]#

輸入是緩衝區,其 FQN 為 target

class torch._functorch._aot_autograd.descriptors.DummyAOTInput(idx)[source]#

在某些情況下,我們希望呼叫一個期望 `AOTInput` 的函式,但我們實際上並不關心該邏輯(最典型的是,因為某些程式碼同時用於編譯時和執行時;在此情況下不需要 `AOTInput` 處理)。在這種情況下傳入一個 dummy;但最好是有一個根本沒有這個的函式版本。

class torch._functorch._aot_autograd.descriptors.DummyAOTOutput(idx)[source]#

在您實際上不關心描述符傳播的情況下,請勿在正常情況下使用。

class torch._functorch._aot_autograd.descriptors.GradAOTOutput(grad_of)[source]#

一個輸出,表示在聯合圖中為可微分輸入計算出的梯度

class torch._functorch._aot_autograd.descriptors.InputMutationAOTOutput(mutated_input)[source]#

輸入的變異值,返回以便我們能夠適當地傳播自動微分。

class torch._functorch._aot_autograd.descriptors.IntermediateBaseAOTOutput(base_of)[source]#

多個別名(aliasing)輸出的中間基。我們只報告一個貢獻給該基的輸出

class torch._functorch._aot_autograd.descriptors.ParamAOTInput(target)[source]#

輸入是引數,其 FQN 為 target

class torch._functorch._aot_autograd.descriptors.PhiloxBackwardBaseOffsetAOTInput[source]#

功能化的 Philox RNG 呼叫的偏移量,專用於後向圖。

class torch._functorch._aot_autograd.descriptors.PhiloxBackwardSeedAOTInput[source]#

功能化的 Philox RNG 呼叫的種子,專用於後向圖。

class torch._functorch._aot_autograd.descriptors.PhiloxForwardBaseOffsetAOTInput[source]#

功能化的 Philox RNG 呼叫的偏移量,專用於前向圖。

class torch._functorch._aot_autograd.descriptors.PhiloxForwardSeedAOTInput[source]#

功能化的 Philox RNG 呼叫的種子,專用於前向圖。

class torch._functorch._aot_autograd.descriptors.PhiloxUpdatedBackwardOffsetAOTOutput[source]#

功能化 RNG 呼叫的最終偏移量,僅用於後向傳播

class torch._functorch._aot_autograd.descriptors.PhiloxUpdatedForwardOffsetAOTOutput[source]#

功能化 RNG 呼叫的最終偏移量,僅用於前向傳播

class torch._functorch._aot_autograd.descriptors.PlainAOTInput(idx)[source]#

輸入是普通輸入,對應於特定的位置索引。

注意,`AOTInput` 始終相對於具有 **扁平** 呼叫約定的函式(例如 `aot_module_simplified` 接受的)。有一些 AOTAutograd API 會扁平化 pytrees,我們不記錄扁平化中的 PyTree 鍵路徑(但我們應該能夠!)。

class torch._functorch._aot_autograd.descriptors.PlainAOTOutput(idx)[source]#

輸出元組位置 `idx` 處的普通張量輸出

class torch._functorch._aot_autograd.descriptors.SavedForBackwardsAOTOutput(idx: int)[source]#
class torch._functorch._aot_autograd.descriptors.SubclassGetAttrAOTInput(base, attr)[source]#

子類輸入在進入 FX 圖之前會解包成其組成部分。這告訴您此輸入對應於子類(原始子類引數)的哪個特定屬性。

class torch._functorch._aot_autograd.descriptors.SubclassGetAttrAOTOutput(base, attr)[source]#

此輸出將被打包到此位置的子類中

class torch._functorch._aot_autograd.descriptors.SubclassSizeAOTInput(base, idx)[source]#

這個特定的外部大小 SymInt 輸入(在維度 idx 處)來自哪個子類。

class torch._functorch._aot_autograd.descriptors.SubclassSizeAOTOutput(base, idx)[source]#

此輸出大小將被打包到此位置的子類中

class torch._functorch._aot_autograd.descriptors.SubclassStrideAOTInput(base, idx)[source]#

這個特定的外部步幅 SymInt 輸入(在維度 idx 處)來自哪個子類。

class torch._functorch._aot_autograd.descriptors.SubclassStrideAOTOutput(base, idx)[source]#

此輸出步幅將被打包到此位置的子類中

class torch._functorch._aot_autograd.descriptors.SyntheticBaseAOTInput(base_of)[source]#

這與 `ViewBaseAOTInput` 類似,但當沒有檢視是可微分的時,我們會發生這種情況,因此我們無法獲取真正的原始檢視,而是為了自動微分而構造了一個合成檢視。

class torch._functorch._aot_autograd.descriptors.ViewBaseAOTInput(base_of)[source]#

當多個可微分輸入是同一輸入的檢視時,AOTAutograd 會將這些檢視替換為單個表示基的輸入。如果您不希望這樣,可以在將檢視示例輸入傳遞給 AOTAutograd 之前克隆它們。

TODO:原則上,我們可以報告所有貢獻給此基的輸入。

FX 工具#

此模組包含用於處理 AOTAutograd 生成的帶描述符的聯合 FX 圖的實用函式。它們**不**適用於通用 FX 圖。另請參閱 torch._functorch.aot_autograd.aot_export_joint_with_descriptors()。我們還建議閱讀 :mod:torch._functorch._aot_autograd.descriptors`。

torch._functorch._aot_autograd.fx_utils.get_all_input_and_grad_nodes(g)[source]#

給定一個帶描述符的聯合圖(佔位符和輸出上的 `meta['desc']`),返回每個輸入及其對應的梯度輸出節點(如果存在)。這些元組儲存在一個字典中,該字典由描述輸入的 `AOTInput` 描述符索引。

注意:返回 **所有** 前向張量輸入,包括不可微分輸入(這些輸入只有一個 `None` 梯度),因此安全地使用此函式來對所有輸入執行操作。(非張量輸入,如符號整數、令牌或 RNG 狀態,**不** 被此函式遍歷。)

引數

g (Graph) – 帶描述符的 FX 聯合圖

返回

一個字典,將每個 `DifferentiableAOTInput` 描述符對映到一個元組,該元組包含: - 輸入節點本身 - 梯度(輸出)節點(如果存在),否則為 `None`

引發
  • RuntimeError – 如果聯合圖包含子類張量輸入/輸出;此

  • API 不支援,因為當涉及子類時,輸入和梯度之間不一定存在一對一的對應關係

  • 當涉及子類時。

返回型別

dict[torch._functorch._aot_autograd.descriptors.DifferentiableAOTInput, tuple[torch.fx.node.Node, Optional[torch.fx.node.Node]]]

torch._functorch._aot_autograd.fx_utils.get_all_output_and_tangent_nodes(g)[source]#

從聯合圖中獲取所有輸出節點及其對應的切線節點。

與 `get_all_input_and_grad_nodes` 類似,但返回輸出節點與其切線節點配對(如果存在)。此函式遍歷圖以查詢所有可微分輸出,並將它們與其在正向模式自動微分中使用的相應切線輸入進行匹配。

注意:返回 **所有** 前向張量輸出,包括不可微分輸出,因此您可以使用此函式來對所有輸出執行操作。

引數

g (Graph) – 帶描述符的 FX 聯合圖

返回

一個字典,將每個 `DifferentiableAOTOutput` 描述符對映到一個元組,該元組包含: - 輸出節點本身 - 切線(輸入)節點(如果存在),否則為 `None`

引發
  • RuntimeError – 如果聯合圖包含子類張量輸入/輸出;此

  • API 不支援,因為當涉及子類時,輸入和梯度之間不一定存在一對一的對應關係

  • 當涉及子類時,輸出和切線之間不存在一一對應關係。

返回型別

dict[torch._functorch._aot_autograd.descriptors.DifferentiableAOTOutput, tuple[torch.fx.node.Node, Optional[torch.fx.node.Node]]]

torch._functorch._aot_autograd.fx_utils.get_buffer_nodes(graph)[source]#

將圖中的所有緩衝區節點獲取為一個列表。

您可以依賴此函式提供您需要饋入聯合圖(在引數之後)的正確緩衝區順序。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

表示圖中所有緩衝區的 FX 節點列表。

引發
  • RuntimeError – 如果遇到子類張量(尚不支援),因為

  • 不清楚您是否想要子類的每個單獨的組成部分

  • 還是希望將它們以某種方式分組。

返回型別

list[torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_named_buffer_nodes(graph)[source]#

按完全限定名稱對映緩衝區節點。

此函式遍歷圖以查詢所有緩衝區輸入節點,並返回一個字典,其中鍵是緩衝區名稱 (FQN),值是相應的 FX 節點。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

將緩衝區名稱 (str) 對映到其相應 FX 節點的字典。

引發
  • RuntimeError – 如果遇到子類張量(尚不支援),因為

  • 對於子類,FQN 不一定對映到一個普通張量。

返回型別

dict[str, torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_named_param_nodes(graph)[source]#

按完全限定名稱對映引數節點。

此函式遍歷圖以查詢所有引數輸入節點,並返回一個字典,其中鍵是引數名稱 (FQN),值是相應的 FX 節點。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

將引數名稱 (str) 對映到其相應 FX 節點的字典。

引發
  • RuntimeError – 如果遇到子類張量(尚不支援),因為

  • 對於子類,FQN 不一定對映到一個普通張量。

返回型別

dict[str, torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_param_and_grad_nodes(graph)[source]#

從聯合圖中獲取引數節點及其對應的梯度節點。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

  • 引數輸入節點

  • 梯度(輸出)節點(如果存在),否則為 `None`

返回型別

一個字典,將每個 `ParamAOTInput` 描述符對映到一個元組,該元組包含

torch._functorch._aot_autograd.fx_utils.get_param_nodes(graph)[source]#

將圖中的所有引數節點獲取為一個列表。

您可以依賴此函式提供您需要饋入聯合圖(在引數列表的開頭,在緩衝區之前)的正確引數順序。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

表示圖中所有引數的 FX 節點列表。

引發
  • RuntimeError – 如果遇到子類張量(尚不支援),因為

  • 不清楚您是否想要子類的每個單獨的組成部分

  • 還是希望將它們以某種方式分組。

返回型別

list[torch.fx.node.Node]

torch._functorch._aot_autograd.fx_utils.get_plain_input_and_grad_nodes(graph)[source]#

從聯合圖中獲取普通輸入節點及其對應的梯度節點。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

  • 普通輸入節點

  • 梯度(輸出)節點(如果存在),否則為 `None`

返回型別

一個字典,將每個 `PlainAOTInput` 描述符對映到一個元組,該元組包含

torch._functorch._aot_autograd.fx_utils.get_plain_output_and_tangent_nodes(graph)[source]#

從聯合圖中獲取普通輸出節點及其對應的切線節點。

引數

graph (Graph) – 帶描述符的 FX 聯合圖

返回

  • 普通輸出節點

  • 切線(輸入)節點(如果存在),否則為 `None`

返回型別

一個字典,將每個 `PlainAOTOutput` 描述符對映到一個元組,該元組包含