使用張量並行(TP)進行大規模 Transformer 模型訓練#
建立日期:2024 年 4 月 19 日 | 最後更新:2025 年 7 月 18 日 | 最後驗證:2024 年 11 月 5 日
注意
在 github 上檢視和編輯此教程。
本教程演示瞭如何使用張量並行(TP)和完全分片資料並行(FSDP)在數百到數千個 GPU 上訓練大型 Transformer 類模型。
先決條件
安裝了 CUDA/Linux 的 PyTorch 2.3.0 或更高版本
張量並行如何工作?#
張量並行(TP)最初在 Megatron-LM 論文中提出,是一種用於訓練大規模 Transformer 模型的高效模型並行技術。序列並行(SP)是我們在此教程中提到的張量並行的變體,它在序列維度上對 `nn.LayerNorm` 或 `nn.RMSNorm` 進行分片,以進一步節省訓練期間的啟用記憶體。隨著模型變得越來越大,啟用記憶體成為瓶頸,因此在張量並行訓練中,通常對 `LayerNorm` 或 `RMSNorm` 層應用序列並行。
圖 1. 在 Transformer 模型的 MLP 和自注意力層上表示張量並行風格的分片,其中注意力/MLP 中的矩陣乘法透過分片計算完成(圖片來源)#
總的來說,PyTorch 張量並行的工作方式如下:
分片初始化
確定將哪個 `ParallelStyle` 應用於每個層,並透過呼叫 `parallelize_module` 來分片初始化後的模組。
並行化後的模組的模型引數將被切換為 DTensors,DTensor 將負責使用分片計算來執行並行化後的模組。
執行時前向/後向傳播
根據使用者為每個 `ParallelStyle` 指定的輸入/輸出 DTensor 佈局,它將執行適當的通訊操作來轉換輸入/輸出的 DTensor 佈局(例如 `allreduce`、`allgather` 和 `reduce_scatter`)。
執行並行化層的分片計算以節省計算/記憶體(例如 `nn.Linear`、`nn.Embedding`)。
何時以及為何應該應用張量並行#
PyTorch 完全分片資料並行(FSDP)已經具備了將模型訓練擴充套件到特定數量 GPU 的能力。然而,當在模型大小和 GPU 數量方面進一步擴充套件模型訓練時,許多額外的挑戰會出現,這可能需要將張量並行與 FSDP 相結合。
當世界規模(GPU 數量)變得過大(超過 128/256 個 GPU)時,FSDP 的集合通訊(如 `allgather`)會被環延遲所主導。透過在 FSDP 之上實現 TP/SP,可以將 FSDP 世界規模減少 8 倍,透過僅對主機內部應用 FSDP,從而將延遲成本降低相同的數量。
在資料並行達到極限,由於收斂性和 GPU 記憶體限制而無法將全域性批次大小提高到超過 GPU 數量時,張量/序列並行是唯一已知的方法可以“大致”確定全域性批次大小並繼續使用更多 GPU 進行擴充套件。這意味著模型大小和 GPU 數量都可以繼續擴充套件。
對於某些型別的模型,當局部批次大小變小時,TP/SP 可以產生更最佳化的浮點運算(FLOPS)矩陣乘法形狀。
那麼,在預訓練時,達到這些限制有多容易?目前,預訓練一個具有數十億或數萬億個 token 的大型語言模型(LLM)可能需要數月時間,即使使用數千個 GPU。
在 LLM 大規模訓練時,總會遇到限制 1。例如,Llama 2 70B 在 2k 個 GPU 上訓練了 35 天,在 2k 的規模下需要多維並行。
當 Transformer 模型變得更大(例如 Llama2 70B)時,也會很快達到限制 2。即使區域性 `batch_size=1`,由於記憶體和收斂性限制,也無法僅使用 FSDP。例如,Llama 2 的全域性批次大小為 1K,因此在 2K 個 GPU 上無法僅使用資料並行。
如何應用張量並行#
PyTorch 張量並行 API 提供了一組模組級原語(`ParallelStyle`),用於配置模型每個單獨層的分片,包括:
`ColwiseParallel` 和 `RowwiseParallel`:按列或按行分片 `nn.Linear` 和 `nn.Embedding`。
`SequenceParallel`:對 `nn.LayerNorm`、`nn.Dropout`、`RMSNormPython` 等執行分片計算。
`PrepareModuleInput` 和 `PrepareModuleOutput`:使用適當的通訊操作配置模組輸入/輸出的分片佈局。
為了演示如何使用 PyTorch 原生的張量並行 API,讓我們看一下一個常見的 Transformer 模型。在本教程中,我們使用最新的 Llama2 模型作為參考 Transformer 模型實現,因為它也在社群中被廣泛使用。
由於張量並行會在一組裝置上分片單個張量,因此我們需要先設定分散式環境(例如 NCCL communicators)。張量並行是一種單程式多資料(SPMD)分片演算法,類似於 PyTorch DDP/FSDP,它在底層利用 PyTorch DTensor 來執行分片。它還利用 DeviceMesh 抽象(底層管理 ProcessGroups)來進行裝置管理和分片。有關如何使用 DeviceMesh 設定多維並行,請參閱本教程。張量並行通常在每個主機內部工作,所以我們先初始化一個連線主機內 8 個 GPU 的 DeviceMesh。
from torch.distributed.device_mesh import init_device_mesh
tp_mesh = init_device_mesh("cuda", (8,))
現在我們已經初始化了 DeviceMesh,讓我們詳細看看 Llama 2 模型架構,並瞭解如何執行張量並行分片。這裡我們重點關注核心的 `TransformerBlock`,其中 Transformer 模型堆疊相同的 `TransformerBlock` 以擴充套件模型。
核心的 `TransformerBlock` 由一個 `Attention` 層和一個 `FeedForward` 層組成。讓我們先看看更簡單的 `FeedForward` 層。對於 `FeedForward` 層,它包含三個 Linear 層,執行 SwiGLU 風格的 MLP,檢視其 forward 函式:
# forward in the FeedForward layer
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
它並行執行 `w1` 和 `w3` 的矩陣乘法,然後執行 `w2` 的矩陣乘法,結果是 `w1`/`w3` 線性投影結果的組合。這意味著我們可以借鑑張量並行論文中的思想,將 `w1`/`w3` Linear 層按列分片,並將 `w2` Linear層按行分片,這樣在所有三個層結束時只有一個 `allreduce` 通訊。使用 PyTorch 原生的張量並行,我們可以為 `FeedForward` 層簡單地建立一個 `parallelize_plan`,如下所示:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
這就是我們使用 PyTorch 張量並行 API 為 `FeedForward` 層配置分片的方式。請注意,使用者只需指定如何分片各個層,通訊(例如 `allreduce`)將在後臺自動進行。
接下來是 `Attention` 層。它由 `wq`、`wk`、`wv` Linear 層組成,用於將輸入投影到 `q` / `k` / `v`,然後它執行注意力計算並與 `wo` Linear 層進行輸出投影。張量並行在這裡旨在對 q/k/v 投影執行按列分片,並對 `wo` Linear 投影執行按行分片。因此,我們可以將 Attention 的計劃新增到我們剛剛起草的 `tp_plan` 中:
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
這幾乎是我們用於對 `TransformerBlock` 應用張量並行所需的 `layer_tp_plan`。但是,我們需要注意的一點是,當按列分片線性層時,線性層的輸出將在最後一個張量維度上分片,而按行分片線性層直接接受一個在最後一個維度上分片的輸入。如果在按列分片的線性層和按行分片的線性層之間有任何其他張量操作(例如 view 操作),我們需要調整相關的形狀相關操作以適應分片形狀。
對於 Llama 模型,在注意力層中,有幾個與形狀相關的 view 操作。具體來說,對於 `wq` / `wk` / `wv` 線性層的按列並行,啟用張量在 `num_heads` 維度上進行分片。為了管理全域性 `num_heads` 和區域性 `num_heads` 之間的差異,我們應該設定 `use_local_output=False` 以確保輸出是 DTensor。與常規張量不同,DTensor 瞭解並行計劃,並將自動處理 `num_heads` 維度的變化。
最後,我們需要呼叫 `parallelize_module` API 來使每個 `TransformerBlock` 的計劃生效。在底層,它將 `Attention` 和 `FeedForward` 層中的模型引數分佈到 DTensors,並(如果需要)為模組的輸入和輸出(分別在每個模組之前和之後)註冊通訊鉤子。
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
現在我們已經詳細說明了每個 `TransformerBlock` 的分片計劃,通常在第一層有一個 `nn.Embedding` 層,在最後一層有一個 `nn.Linear` 投影層。使用者可以選擇按行或按列分片第一個 `nn.Embedding` 層,並按列分片最後一個 `nn.Linear` 投影層,並指定適當的輸入和輸出佈局。下面是一個示例:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)
注意
如果需要分割槽的模型太大而無法放入 CPU 記憶體,可以採用 `meta` 裝置初始化(例如,先在 meta 裝置上初始化模型,分片層,然後具體化模型),或者在 Transformer 模型初始化過程中逐層並行化 `TransformerBlock`。
將序列並行應用於 `LayerNorm/RMSNorm` 層#
序列並行建立在上面介紹的張量並行之上。與僅在 `Attention` 模組和 `FeedForward` 模組內對張量進行分片並保持其模組輸入和輸出(即前向傳播中的啟用和後向傳播中的梯度)副本的常規張量並行相比,序列並行在序列維度上保持它們的分片狀態。
在一個典型的 `TransformerBlock` 中,forward 函式結合了 norm 層(`LayerNorm` 或 `RMSNorm`)、一個注意力層、一個前饋層以及殘差連線。例如:
# forward in a TransformerBlock
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
在大多數用例中,`Attention` 和 `FeedForward` 模組之外的啟用(和梯度)的形狀為 `[batch size, sequence length, hidden dimension]`。用 DTensor 的術語來說,序列並行在模組的前向/後向計算中使用 `Shard(1)` 佈局。遵循之前的程式碼示例,下面的程式碼演示瞭如何將序列並行應用於 `TransformerBlock` 中的 norm 層:
首先,讓我們匯入序列並行所需的依賴項:
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
SequenceParallel,
)
接下來,讓我們調整 `layer_tp_plan` 以在 `RMSNorm` 層上啟用序列並行:
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
),
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
可以看到,我們現在使用 `PrepareModuleInput` 來修改 Attention 和 FeedForward 層的模組輸入佈局,從 `Shard(1)` 更改為 `Replicate()`,並將其輸出佈局標記為 `Shard(1)`。與張量並行發生的情況類似,使用者只需指定輸入和輸出的張量分片佈局,層之間的通訊將自動發生。
請注意,使用序列並行時,我們假設 `TransformerBlock` 的輸入和輸出始終在序列維度上分片,以便可以無縫地連線多個 `TransformerBlock`。這可以透過顯式地將第一個 `nn.Embedding` 層的輸出和最後一個 `nn.Linear` 投影層的輸入指定為 `Shard(1)` 來實現。
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)
應用損失並行#
損失並行是一項相關的技術,用於在計算損失函式時節省記憶體和通訊,因為模型輸出通常非常大。在損失並行中,當模型輸出在(通常非常大的)詞彙表維度上分片時,交叉熵損失可以高效地計算,而無需將所有模型輸出收集到每個 GPU 上。這不僅顯著減少了記憶體消耗,而且透過減少通訊開銷並並行執行分片計算來提高了訓練速度。下圖簡要說明了損失並行如何透過執行分片計算來避免將所有模型輸出收集到每個 GPU 上。
圖 2. 在一個 GPU 上使用損失平行計算的交叉熵損失前向傳播。藍色表示分片張量;綠色表示複製張量;黃色表示具有部分值的張量(待 all-reduce)。黑色箭頭表示本地計算;紅色箭頭表示 GPU 之間的函式式集合通訊。#
在 PyTorch 張量並行 API 中,可以透過上下文管理器 `loss_parallel` 來啟用損失並行,使用它可以直接使用 `torch.nn.functional.cross_entropy` 或 `torch.nn.CrossEntropyLoss`,而無需修改程式碼的其他部分。
為了應用損失並行,模型預測(通常形狀為 `[batch size, sequence length, vocabulary size]`)應該在詞彙表維度上進行分片。這可以透過標記最後一個線性投影層輸出的輸出佈局輕鬆實現:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
},
)
在上面的程式碼中,我們還對輸出之前的 norm 層應用了序列並行。我們應用 `use_local_output=False` 以使輸出保持為 DTensor,以便與 `loss_parallel` 上下文管理器配合使用。之後,可以像下面那樣簡單地呼叫 cross_entropy 損失函式。請注意,後向傳播也需要在此上下文內進行。
import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel
pred = model(input_ids)
with loss_parallel():
# assuming pred and labels are of the shape [batch, seq, vocab]
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
loss.backward()
將張量並行與完全分片資料並行結合#
既然我們已經展示瞭如何將張量/序列並行應用於模型,我們還可以看看張量並行和完全分片資料並行如何協同工作。由於張量並行會產生阻塞計算的通訊,因此我們希望確保它在快速通訊通道(如 NVLink)內執行。實際上,我們通常在每個主機內部應用張量並行,並在主機之間應用完全分片資料並行。
圖 3. FSDP 和 TP 在獨立的裝置維度上工作,FSDP 通訊發生在主機之間,TP 通訊發生在主機內部。#
這種二維並行模式可以透過二維 DeviceMesh 輕鬆表達,我們只需要將每個“子” DeviceMesh 傳遞給每個單獨的並行 API:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import fully_shard
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices
model = Model(...)
tp_plan = {...}
# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)
這將使我們能夠在每個主機內部(主機內)輕鬆應用張量並行,並在主機之間(主機間)應用 FSDP,並且 Llama 模型 **無需更改任何程式碼**。張量(模型)並行和資料並行技術相結合,能夠透過大量 GPU 繼續增加模型大小並高效地進行訓練。
結論#
本教程演示瞭如何將張量並行與完全分片資料並行結合,在數百到數千個 GPU 上訓練大型 Transformer 類模型。它解釋瞭如何將張量並行應用於模型的不同部分,而 **無需修改模型本身的程式碼**。張量並行是一種用於大規模訓練的高效模型並行技術。
要檢視本教程中解釋的完整的端到端程式碼示例,請參閱 pytorch/examples 儲存庫中的 張量並行示例。