評價此頁

Tensor 並行 - torch.distributed.tensor.parallel#

創建於: 2025年6月13日 | 最後更新於: 2025年6月13日

Tensor 並行 (TP) 構建在 PyTorch DistributedTensor (DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] 之上,並提供不同的並行樣式:逐列 (Colwise)、逐行 (Rowwise) 和序列並行 (Sequence Parallelism)。

警告

Tensor 並行 API 處於實驗階段,可能會發生更改。

使用 Tensor 並行來並行化您的 nn.Module 的入口點是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)[source]#

透過根據使用者指定的計劃並行化模組或子模組,在 PyTorch 中應用 Tensor 並行。

我們根據 `parallelize_plan` 來並行化模組或子模組。`parallelize_plan` 包含 `ParallelStyle`,它指示使用者希望模組或子模組如何被並行化。

使用者還可以為每個模組的完全限定名 (FQN) 指定不同的並行樣式。

請注意,parallelize_module 只接受一維 DeviceMesh。如果您有一個二維或 N 維 DeviceMesh,請先將其切片到一個一維子 DeviceMesh,然後再傳遞給此 API (例如 device_mesh["tp"])。

引數
  • module (nn.Module) – 要並行化的模組。

  • device_mesh (DeviceMesh, optional) – 描述 DTensor 裝置網格拓撲的物件。如果未指定,呼叫必須在 `DeviceMesh` 上下文中進行。

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]], optional) – 用於並行化模組的計劃。它可以是包含我們如何準備 Tensor 並行輸入的/輸出的 `ParallelStyle` 物件,也可以是模組 FQN 及其對應的 `ParallelStyle` 物件的字典。如果未指定,呼叫將不做任何操作。

關鍵字引數

src_data_rank (int, optional) – 邏輯/全域性張量的源資料的 rank,它由 distribute_tensor() 用於將分片/副本散佈/廣播到其他 rank。預設情況下,我們在每個 `DeviceMesh` 維度上使用 `group_rank=0` 作為源資料,以保留單裝置語義。如果顯式傳遞 None,則 parallelize_module() 僅使用其本地資料,而不是嘗試透過散佈/廣播來保留單裝置語義。預設值:0

返回

一個已並行化的 nn.Module 物件。

返回型別

模組

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>

注意

對於複雜的模組架構,如 Attention、MLP 層,我們建議組合不同的 ParallelStyle (例如 ColwiseParallelRowwiseParallel) 並將它們作為 `parallelize_plan` 傳遞,以實現所需的 Sharding 計算。

Tensor 並行支援以下並行樣式:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#

以逐列 (column-wise) 的方式劃分相容的 `nn.Module`。目前支援 `nn.Linear` 和 `nn.Embedding`。使用者可以將其與 `RowwiseParallel` 組合以實現更復雜的模組 (例如 MLP、Attention) 的劃分。

關鍵字引數
  • input_layouts (Placement, optional) – `nn.Module` 的輸入張量的 DTensor 佈局,用於註解輸入張量以使其成為 DTensor。如果未指定,我們假設輸入張量是複製的。

  • output_layouts (Placement, optional) – `nn.Module` 輸出的 DTensor 佈局,用於確保 `nn.Module` 的輸出具有使用者期望的佈局。如果未指定,輸出張量將在最後一個維度上進行分片。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 `DTensor` 作為模組輸出,預設為 True。

返回

一個表示 `nn.Module` 逐列分片的 `ParallelStyle` 物件。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意

預設情況下,如果未指定 `output_layouts`,`ColwiseParallel` 的輸出將在最後一個維度上進行分片。如果存在需要特定張量形狀的運算子 (例如,在配對的 `RowwiseParallel` 之前),請注意,如果輸出被分片,則該運算子可能需要調整以適應分片後的大小。

class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#

以逐行 (row-wise) 的方式劃分相容的 `nn.Module`。目前支援 `nn.Linear` 和 `nn.Embedding`。使用者可以將其與 `ColwiseParallel` 組合以實現更復雜的模組 (例如 MLP、Attention) 的劃分。

關鍵字引數
  • input_layouts (Placement, optional) – `nn.Module` 的輸入張量的 DTensor 佈局,用於註解輸入張量以使其成為 DTensor。如果未指定,我們假設輸入張量在最後一個維度上進行分片。

  • output_layouts (Placement, optional) – `nn.Module` 輸出的 DTensor 佈局,用於確保 `nn.Module` 的輸出具有使用者期望的佈局。如果未指定,輸出張量將被複制。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 `DTensor` 作為模組輸出,預設為 True。

返回

一個表示 `nn.Module` 逐行分片的 `ParallelStyle` 物件。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
>>> ...
class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]#

SequenceParallel 複製相容的 nn.Module 引數,並使用在序列維度上分片的輸入進行分片計算。這目前支援 nn.LayerNormnn.Dropout 以及 RMSNorm Python 實現

此樣式實現了論文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。

如果傳遞給此 nn.Module 的輸入是 torch.Tensor,它會假設輸入已在序列維度上分片,並將輸入轉換為在序列維度上分片的 DTensor。如果傳遞給此 nn.Module 的輸入已經是 DTensor 但未在序列維度上分片,它將重新分發輸入以在序列維度上分片。

`nn.Module` 的輸出將在序列維度上進行分片。

關鍵字引數
  • sequence_dim (int, optional) – `nn.Module` 的輸入張量的序列維度,用於註解輸入張量以使其成為在序列維度上分片的 DTensor,預設值為 1。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸出,預設值為 False。

返回

一個表示 `nn.Module` 序列並行的 `ParallelStyle` 物件。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...

注意

SequenceParallel 樣式假設 `nn.Module` 中的權重 (例如 nn.LayerNormRMSNorm) 具有全1初始化 (預設情況下)。如果您對這些模組的權重有自定義初始化,您需要在並行化之前/之後廣播權重,以確保它們被複制。

要僅使用 `parallelize_module` 呼叫中的 `parallelize_plan` 來配置 `nn.Module` 的輸入和輸出的 DTensor 佈局並執行必要的佈局重分佈,而不分發模組引數到 DTensor,可以使用以下 `ParallelStyle`:

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]#

根據 `input_layouts` 配置 `nn.Module` 的輸入,在執行時將 `nn.Module` 的輸入張量轉換為 DTensor,並根據 `desired_input_layouts` 進行佈局重分佈。

關鍵字引數
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的輸入張量的 DTensor 佈局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。預設為 None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 輸入張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。此引數的長度必須與 `input_layouts` 相同。預設為 None。

  • input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的輸入 kwargs 的 DTensor 佈局,用於將輸入 kwargs 張量轉換為 DTensor。預設為 None。

  • desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的輸入 kwargs 的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。預設為 None。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸入,預設為 False。

返回

一個 `ParallelStyle` 物件,用於準備 `nn.Module` 輸入的分片佈局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...)
>>>         ),
>>>     }
>>> )
class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]#

根據 `output_layouts` 配置 `nn.Module` 的輸出,在執行時將 `nn.Module` 的輸出張量轉換為 DTensor,並根據 `desired_output_layouts` 進行佈局重分佈。

關鍵字引數
  • output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的輸出張量的 DTensor 佈局,用於將輸出張量轉換為 DTensor (如果它們是 torch.Tensor)。如果某些輸出不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 輸出張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸出具有期望的 DTensor 佈局。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸出,預設為 True。

返回

一個 `ParallelStyle` 物件,用於準備 `nn.Module` 輸出的分片佈局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(),
>>>         desired_output_layouts=Shard(0)
>>>     )
>>> )
class torch.distributed.tensor.parallel.PrepareModuleInputOutput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_input=False, output_layouts, desired_output_layouts, use_local_output=True)[source]#

配置 `nn.Module` 的輸入 (和輸出),在執行時根據 `input_layouts` (和 `output_layouts`,分別) 將 `nn.Module` 的輸入張量 (和輸出張量,分別) 轉換為 DTensor,並根據 `desired_input_layouts` (和 `desired_output_layouts`,分別) 進行佈局重分佈。這是 `PrepareModuleInput` 和 `PrepareModuleOutput` 的組合。

關鍵字引數
  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的輸入張量的 DTensor 佈局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。預設為 None。

  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 輸入張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。此引數的長度必須與 `input_layouts` 相同。預設為 None。

  • input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的輸入 kwargs 的 DTensor 佈局,用於將輸入 kwargs 張量轉換為 DTensor。預設為 None。

  • desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的輸入 kwargs 的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。預設為 None。

  • use_local_input (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸入,預設為 False。

  • output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的輸出張量的 DTensor 佈局,用於將輸出張量轉換為 DTensor (如果它們是 torch.Tensor)。如果某些輸出不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。

  • desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 輸出張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸出具有期望的 DTensor 佈局。

  • use_local_output (bool, optional) – 是否使用本地 torch.Tensor 而不是 DTensor 作為模組輸出,預設為 True。

返回

一個 `ParallelStyle` 物件,用於準備 `nn.Module` 輸入和輸出的分片佈局。

示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor
>>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated
>>> # as Replicated DTensor and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh,
>>>     parallelize_plan={
>>>         "attn": PrepareModuleInputOutput(
>>>             input_layouts=(Shard(0), None, None, ...),
>>>             desired_input_layouts=(Replicate(), None, None, ...),
>>>             output_layouts=Replicate(),
>>>             desired_output_layouts=Shard(0),
>>>         ),
>>>     }
>>> )

注意

當使用 `Shard(dim)` 作為上述 `ParallelStyle` 的輸入/輸出佈局時,我們假設輸入/輸出啟用張量在 TP 操作的 `DeviceMesh` 上均勻分片於張量維度 `dim`。例如,由於 `RowwiseParallel` 接受在最後一個維度上分片的輸入,它假設輸入張量已經均勻分片於最後一個維度。對於不均勻分片的啟用張量,可以先將 DTensor 直接傳遞給分片後的模組,並使用 `use_local_output=False` 在每個 `ParallelStyle` 後返回 DTensor,DTensor 可以跟蹤不均勻分片資訊。

對於 Transformer 等模型,我們建議使用者在 `parallelize_plan` 中一起使用 `ColwiseParallel` 和 `RowwiseParallel`,以實現整個模型 (例如 Attention 和 MLP) 的期望分片。

可以透過以下上下文管理器支援並行化的交叉熵損失計算 (損失並行):

torch.distributed.tensor.parallel.loss_parallel()[source]#

一個啟用損失並行的上下文管理器,當輸入在類別維度上分片時,可以執行高效的並行化損失計算。目前僅支援交叉熵損失。

在此上下文管理器中,您可以像平常一樣使用 `cross_entropy()` 或 `CrossEntropyLoss`,並滿足以下輸入引數假設。相應的 backward() 呼叫 (如果有) 也需要在此上下文管理器下進行。

引數
  • input (DTensor) – 輸入 logits。假定在類別維度上分片。

  • target (Union[torch.Tensor, DTensor]) – 必須是真實類索引 (目前不支援類別機率)。假定在 `DeviceMesh` 上覆制。

  • weight (Union[torch.Tensor, DTensor], optional) – 如果提供,假定在 `DeviceMesh` 上覆制。

  • label_smoothing – 目前不支援。

返回

一個複製的 `DTensor`。

示例

此處手動建立了一個分片 DTensor 以展示用法。實際上,它通常是 TP 模組的輸出。

>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告

The loss_parallel API is experimental and subject to change.