Tensor 並行 - torch.distributed.tensor.parallel#
創建於: 2025年6月13日 | 最後更新於: 2025年6月13日
Tensor 並行 (TP) 構建在 PyTorch DistributedTensor (DTensor)[https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/README.md] 之上,並提供不同的並行樣式:逐列 (Colwise)、逐行 (Rowwise) 和序列並行 (Sequence Parallelism)。
警告
Tensor 並行 API 處於實驗階段,可能會發生更改。
使用 Tensor 並行來並行化您的 nn.Module 的入口點是:
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)[source]#
透過根據使用者指定的計劃並行化模組或子模組,在 PyTorch 中應用 Tensor 並行。
我們根據 `parallelize_plan` 來並行化模組或子模組。`parallelize_plan` 包含 `
ParallelStyle`,它指示使用者希望模組或子模組如何被並行化。使用者還可以為每個模組的完全限定名 (FQN) 指定不同的並行樣式。
請注意,
parallelize_module只接受一維DeviceMesh。如果您有一個二維或 N 維DeviceMesh,請先將其切片到一個一維子DeviceMesh,然後再傳遞給此 API (例如device_mesh["tp"])。- 引數
module (
nn.Module) – 要並行化的模組。device_mesh (
DeviceMesh, optional) – 描述 DTensor 裝置網格拓撲的物件。如果未指定,呼叫必須在 `DeviceMesh` 上下文中進行。parallelize_plan (Union[
ParallelStyle, Dict[str,ParallelStyle]], optional) – 用於並行化模組的計劃。它可以是包含我們如何準備 Tensor 並行輸入的/輸出的 `ParallelStyle` 物件,也可以是模組 FQN 及其對應的 `ParallelStyle` 物件的字典。如果未指定,呼叫將不做任何操作。
- 關鍵字引數
src_data_rank (int, optional) – 邏輯/全域性張量的源資料的 rank,它由
distribute_tensor()用於將分片/副本散佈/廣播到其他 rank。預設情況下,我們在每個 `DeviceMesh` 維度上使用 `group_rank=0` 作為源資料,以保留單裝置語義。如果顯式傳遞None,則parallelize_module()僅使用其本地資料,而不是嘗試透過散佈/廣播來保留單裝置語義。預設值:0- 返回
一個已並行化的
nn.Module物件。- 返回型別
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>>
注意
對於複雜的模組架構,如 Attention、MLP 層,我們建議組合不同的
ParallelStyle(例如ColwiseParallel和RowwiseParallel) 並將它們作為 `parallelize_plan` 傳遞,以實現所需的 Sharding 計算。
Tensor 並行支援以下並行樣式:
- class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#
以逐列 (column-wise) 的方式劃分相容的 `nn.Module`。目前支援 `nn.Linear` 和 `nn.Embedding`。使用者可以將其與 `RowwiseParallel` 組合以實現更復雜的模組 (例如 MLP、Attention) 的劃分。
- 關鍵字引數
input_layouts (Placement, optional) – `nn.Module` 的輸入張量的 DTensor 佈局,用於註解輸入張量以使其成為 DTensor。如果未指定,我們假設輸入張量是複製的。
output_layouts (Placement, optional) – `nn.Module` 輸出的 DTensor 佈局,用於確保 `nn.Module` 的輸出具有使用者期望的佈局。如果未指定,輸出張量將在最後一個維度上進行分片。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是 `DTensor` 作為模組輸出,預設為 True。
- 返回
一個表示 `nn.Module` 逐列分片的 `
ParallelStyle` 物件。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ...
注意
預設情況下,如果未指定 `output_layouts`,`ColwiseParallel` 的輸出將在最後一個維度上進行分片。如果存在需要特定張量形狀的運算子 (例如,在配對的 `RowwiseParallel` 之前),請注意,如果輸出被分片,則該運算子可能需要調整以適應分片後的大小。
- class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]#
以逐行 (row-wise) 的方式劃分相容的 `nn.Module`。目前支援 `nn.Linear` 和 `nn.Embedding`。使用者可以將其與 `ColwiseParallel` 組合以實現更復雜的模組 (例如 MLP、Attention) 的劃分。
- 關鍵字引數
input_layouts (Placement, optional) – `nn.Module` 的輸入張量的 DTensor 佈局,用於註解輸入張量以使其成為 DTensor。如果未指定,我們假設輸入張量在最後一個維度上進行分片。
output_layouts (Placement, optional) – `nn.Module` 輸出的 DTensor 佈局,用於確保 `nn.Module` 的輸出具有使用者期望的佈局。如果未指定,輸出張量將被複制。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是 `DTensor` 作為模組輸出,預設為 True。
- 返回
一個表示 `nn.Module` 逐行分片的 `
ParallelStyle` 物件。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ...
- class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]#
SequenceParallel 複製相容的
nn.Module引數,並使用在序列維度上分片的輸入進行分片計算。這目前支援nn.LayerNorm、nn.Dropout以及 RMSNorm Python 實現。此樣式實現了論文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。
如果傳遞給此
nn.Module的輸入是torch.Tensor,它會假設輸入已在序列維度上分片,並將輸入轉換為在序列維度上分片的DTensor。如果傳遞給此nn.Module的輸入已經是DTensor但未在序列維度上分片,它將重新分發輸入以在序列維度上分片。`nn.Module` 的輸出將在序列維度上進行分片。
- 關鍵字引數
sequence_dim (int, optional) – `nn.Module` 的輸入張量的序列維度,用於註解輸入張量以使其成為在序列維度上分片的 DTensor,預設值為 1。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是DTensor作為模組輸出,預設值為 False。
- 返回
一個表示 `nn.Module` 序列並行的 `
ParallelStyle` 物件。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ...
注意
SequenceParallel 樣式假設 `nn.Module` 中的權重 (例如
nn.LayerNorm或RMSNorm) 具有全1初始化 (預設情況下)。如果您對這些模組的權重有自定義初始化,您需要在並行化之前/之後廣播權重,以確保它們被複制。
要僅使用 `parallelize_module` 呼叫中的 `parallelize_plan` 來配置 `nn.Module` 的輸入和輸出的 DTensor 佈局並執行必要的佈局重分佈,而不分發模組引數到 DTensor,可以使用以下 `ParallelStyle`:
- class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]#
根據 `input_layouts` 配置 `nn.Module` 的輸入,在執行時將 `nn.Module` 的輸入張量轉換為 DTensor,並根據 `desired_input_layouts` 進行佈局重分佈。
- 關鍵字引數
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的輸入張量的 DTensor 佈局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。預設為 None。
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 輸入張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。此引數的長度必須與 `input_layouts` 相同。預設為 None。
input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的輸入 kwargs 的 DTensor 佈局,用於將輸入 kwargs 張量轉換為 DTensor。預設為 None。
desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的輸入 kwargs 的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。預設為 None。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是DTensor作為模組輸入,預設為 False。
- 返回
一個 `
ParallelStyle` 物件,用於準備 `nn.Module` 輸入的分片佈局。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> )
- class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]#
根據 `output_layouts` 配置 `nn.Module` 的輸出,在執行時將 `nn.Module` 的輸出張量轉換為 DTensor,並根據 `desired_output_layouts` 進行佈局重分佈。
- 關鍵字引數
output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的輸出張量的 DTensor 佈局,用於將輸出張量轉換為 DTensor (如果它們是
torch.Tensor)。如果某些輸出不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 輸出張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸出具有期望的 DTensor 佈局。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是DTensor作為模組輸出,預設為 True。
- 返回
一個 `ParallelStyle` 物件,用於準備 `nn.Module` 輸出的分片佈局。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> )
- class torch.distributed.tensor.parallel.PrepareModuleInputOutput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_input=False, output_layouts, desired_output_layouts, use_local_output=True)[source]#
配置 `nn.Module` 的輸入 (和輸出),在執行時根據 `input_layouts` (和 `output_layouts`,分別) 將 `nn.Module` 的輸入張量 (和輸出張量,分別) 轉換為 DTensor,並根據 `desired_input_layouts` (和 `desired_output_layouts`,分別) 進行佈局重分佈。這是 `
PrepareModuleInput` 和 `PrepareModuleOutput` 的組合。- 關鍵字引數
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 的輸入張量的 DTensor 佈局,用於將輸入張量轉換為 DTensor。如果某些輸入不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。預設為 None。
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – `nn.Module` 輸入張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。此引數的長度必須與 `input_layouts` 相同。預設為 None。
input_kwarg_layouts (Dict[str, Placement]) – `nn.Module` 的輸入 kwargs 的 DTensor 佈局,用於將輸入 kwargs 張量轉換為 DTensor。預設為 None。
desired_input_kwarg_layouts – (Dict[str, Placement]): `nn.Module` 的輸入 kwargs 的期望 DTensor 佈局,用於確保 `nn.Module` 的輸入具有期望的 DTensor 佈局。預設為 None。
use_local_input (bool, optional) – 是否使用本地
torch.Tensor而不是DTensor作為模組輸入,預設為 False。output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 的輸出張量的 DTensor 佈局,用於將輸出張量轉換為 DTensor (如果它們是
torch.Tensor)。如果某些輸出不是 `torch.Tensor` 或不需要轉換為 DTensor,則需要指定 `None` 作為佔位符。desired_output_layouts (Union[Placement, Tuple[Placement]]) – `nn.Module` 輸出張量的期望 DTensor 佈局,用於確保 `nn.Module` 的輸出具有期望的 DTensor 佈局。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor而不是DTensor作為模組輸出,預設為 True。
- 返回
一個 `
ParallelStyle` 物件,用於準備 `nn.Module` 輸入和輸出的分片佈局。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated >>> # as Replicated DTensor and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInputOutput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...), >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0), >>> ), >>> } >>> )
注意
當使用 `Shard(dim)` 作為上述 `ParallelStyle` 的輸入/輸出佈局時,我們假設輸入/輸出啟用張量在 TP 操作的 `DeviceMesh` 上均勻分片於張量維度 `dim`。例如,由於 `RowwiseParallel` 接受在最後一個維度上分片的輸入,它假設輸入張量已經均勻分片於最後一個維度。對於不均勻分片的啟用張量,可以先將 DTensor 直接傳遞給分片後的模組,並使用 `use_local_output=False` 在每個 `ParallelStyle` 後返回 DTensor,DTensor 可以跟蹤不均勻分片資訊。
對於 Transformer 等模型,我們建議使用者在 `parallelize_plan` 中一起使用 `ColwiseParallel` 和 `RowwiseParallel`,以實現整個模型 (例如 Attention 和 MLP) 的期望分片。
可以透過以下上下文管理器支援並行化的交叉熵損失計算 (損失並行):
- torch.distributed.tensor.parallel.loss_parallel()[source]#
一個啟用損失並行的上下文管理器,當輸入在類別維度上分片時,可以執行高效的並行化損失計算。目前僅支援交叉熵損失。
在此上下文管理器中,您可以像平常一樣使用 `
cross_entropy()` 或 `CrossEntropyLoss`,並滿足以下輸入引數假設。相應的backward()呼叫 (如果有) 也需要在此上下文管理器下進行。- 引數
input (
DTensor) – 輸入 logits。假定在類別維度上分片。target (Union[
torch.Tensor,DTensor]) – 必須是真實類索引 (目前不支援類別機率)。假定在 `DeviceMesh` 上覆制。weight (Union[
torch.Tensor,DTensor], optional) – 如果提供,假定在 `DeviceMesh` 上覆制。label_smoothing – 目前不支援。
- 返回
一個複製的 `
DTensor`。
示例
此處手動建立了一個分片 DTensor 以展示用法。實際上,它通常是 TP 模組的輸出。
>>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ...
警告
The loss_parallel API is experimental and subject to change.