FSDP2 入門#
創建於:2022 年 3 月 17 日 | 最後更新:2025 年 9 月 2 日 | 最後驗證:2024 年 11 月 5 日
作者:Wei Feng, Will Constable, Yifan Mao
注意
從 pytorch/examples 檢視本教程中的程式碼。FSDP1 已棄用。FSDP1 教程已存檔在 [1] 和 [2]
FSDP2 工作原理#
在 DistributedDataParallel (DDP) 訓練中,每個 rank 擁有一個模型副本並處理一個數據批次,最後使用 all-reduce 在 ranks 之間同步梯度。
與 DDP 相比,FSDP 透過分片模型引數、梯度和最佳化器狀態來減少 GPU 記憶體佔用。這使得訓練無法放入單個 GPU 的模型成為可能。如下圖所示,
在前向和後向計算之外,引數是完全分片的
在前向和後向計算之前,分片引數被 all-gather 成未分片引數
在後向計算內部,本地未分片梯度被 reduce-scatter 成分片梯度
最佳化器使用分片梯度更新分片引數,從而產生分片最佳化器狀態
FSDP 可以被認為是 DDP 的 all-reduce 操作分解為 reduce-scatter 和 all-gather 操作
與 FSDP1 相比,FSDP2 具有以下優點
如何使用 FSDP2#
模型初始化#
在子模組上應用 fully_shard:與 DDP 不同,我們應該在子模組以及根模型上應用 fully_shard。在下面的 Transformer 示例中,我們首先在每個層上應用了 fully_shard,然後是根模型
在
layers[i]的前向計算過程中,其餘層被分片以減少記憶體佔用在
fully_shard(model)內部,FSDP2 會排除model.layers中的引數,並將剩餘引數分類到引數組中,以便高效地進行 all-gather 和 reduce-scatter。fully_shard將分片模型移動到實際的訓練裝置(例如cuda)
命令:torchrun --nproc_per_node 2 train.py
from torch.distributed.fsdp import fully_shard, FSDPModule
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
assert isinstance(model, Transformer)
assert isinstance(model, FSDPModule)
print(model)
# FSDPTransformer(
# (tok_embeddings): Embedding(...)
# ...
# (layers): 3 x FSDPTransformerBlock(...)
# (output): Linear(...)
# )
我們可以使用 print(model) 來檢查巢狀的包裝。 FSDPTransformer 是 Transformer 和 FSDPModule 的聯合類。對於 FSDPTransformerBlock 也是如此。所有 FSDP2 公共 API 都透過 FSDPModule 公開。例如,使用者可以呼叫 model.unshard() 來手動控制 all-gather 排程。有關詳細資訊,請參閱下面的“顯式預取”。
model.parameters() 作為 DTensor:fully_shard 在 ranks 之間分片引數,並將 model.parameters() 從普通的 torch.Tensor 轉換為 DTensor 來表示分片引數。FSDP2 預設在 dim-0 上分片,因此 DTensor 的 placement 是 Shard(dim=0)。假設我們有 N 個 ranks 和一個在分片前有 N 行的引數。分片後,每個 rank 將擁有該引數的 1 行。我們可以使用 param.to_local() 來檢查分片引數。
from torch.distributed.tensor import DTensor
for param in model.parameters():
assert isinstance(param, DTensor)
assert param.placements == (Shard(0),)
# inspect sharded parameters with param.to_local()
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
請注意,最佳化器是在應用 fully_shard 後構造的。模型和最佳化器狀態字典都以 DTensor 的形式表示。
DTensor 促進了最佳化器、梯度裁剪和檢查點
torch.optim.Adam和torch.nn.utils.clip_grad_norm_對 DTensor 引數開箱即用。這使得程式碼在單裝置和分散式訓練之間保持一致。我們可以使用 DTensor 和 DCP API 來操作引數以獲取完整的 state dict,有關詳細資訊,請參閱“state dict”部分。對於分散式 state dict,我們可以儲存/載入檢查點(文件),而無需額外的通訊。
帶預取的 Forward/Backward#
命令:torchrun --nproc_per_node 2 train.py
for _ in range(epochs):
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
fully_shard 會向所有層註冊前向/後向鉤子,以便在計算前 all-gather 引數,並在計算後重新分片引數。為了重疊 all-gathers 和計算,FSDP2 提供了**隱式預取**,它與上述訓練迴圈開箱即用,以及供高階使用者手動控制 all-gather 排程的**顯式預取**。
隱式預取:CPU 執行緒在層 i 之前發出 all-gather i。All-gathers 被排入其自己的 cuda 流,而層 i 的計算發生在預設流中。對於非 CPU 密集型工作負載(例如具有大批次大小的 Transformer),all-gather i+1 可以與層 i 的計算重疊。隱式預取在後向操作中類似,只是 all-gathers 的發出順序與前向後順序相反。
我們建議使用者從隱式預取開始,以瞭解開箱即用的效能。
顯式預取:使用者可以透過 set_modules_to_forward_prefetch 指定前向順序,並透過 set_modules_to_backward_prefetch 指定後向順序。如下面的程式碼所示,CPU 執行緒在層 i 處發出 all-gather i + 1 和 i + 2。
顯式預取在以下情況下效果很好
CPU 密集型工作負載:如果使用隱式預取,當層 i 的核心執行時,CPU 執行緒將太慢而無法發出層 i+1 的 all-gather。我們必須在執行層 i 的前向計算之前顯式發出 all-gather i+1。
2 層及以上預取:隱式預取一次只 all-gather 下一層,以最大限度地減少記憶體佔用。透過顯式預取,可以一次 all-gather 多層,以可能獲得更好的效能,同時增加記憶體。請參閱程式碼中的 layers_to_prefetch。
提前發出第一個 all-gather:隱式預取在呼叫 model(x) 時發生。第一個 all-gather 會被暴露。我們可以顯式呼叫 model.unshard() 來提前發出第一個 all-gather。
命令:torchrun --nproc_per_node 2 train.py --explicit-prefetching
num_to_forward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
for _ in range(epochs):
# trigger 1st all-gather earlier
# this overlaps all-gather with any computation before model(x)
model.unshard()
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
啟用混合精度#
FSDP2 提供了一個靈活的 混合精度策略 來加速訓練。一個典型的用例是:
將 float32 引數轉換為 bfloat16 進行前向/後向計算,請參閱
param_dtype=torch.bfloat16將梯度提升為 float32 進行 reduce-scatter 以保持精度,請參閱
reduce_dtype=torch.float32
與 torch.amp 相比,FSDP2 混合精度具有以下優點:
高效靈活的引數轉換:
FSDPModule內的所有引數都在模組邊界(前向/後向之前和之後)一起轉換。我們可以為每個層設定不同的混合精度策略。例如,前幾層可以是 float32,其餘層可以是 bfloat16。float32 梯度歸約(reduce-scatter):梯度可能因 rank 而異。以 float32 歸約梯度對於數值計算可能至關重要。
命令:torchrun --nproc_per_node 2 train.py --mixed-precision
model = Transformer(model_args)
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
}
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
# sharded parameters are float32
for param in model.parameters():
assert param.dtype == torch.float32
# unsharded parameters are bfloat16
model.unshard()
for param in model.parameters(recurse=False):
assert param.dtype == torch.bfloat16
model.reshard()
# optimizer states are in float32
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
# training loop
# ...
梯度裁剪和帶 DTensor 的最佳化器#
命令:torchrun --nproc_per_node 2 train.py
# optim is constructed base on DTensor model parameters
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(epochs):
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optim.step()
optim.zero_grad()
最佳化器在對模型應用 fully_shard 後進行初始化,並持有對 DTensor model.parameters() 的引用。對於梯度裁剪,torch.nn.utils.clip_grad_norm_ 可用於 DTensor 引數。Tensor 操作將在 DTensor 內部正確分派,以跨 ranks 通訊部分張量,從而保持單裝置語義。
帶 DTensor API 的 State Dicts#
我們展示瞭如何將完整的 state dict 轉換為 DTensor state dict 進行載入,以及如何將其轉換回完整的 state dict 進行儲存。
命令:torchrun --nproc_per_node 2 train.py
第一次執行時,它會為模型和最佳化器建立檢查點
第二次執行時,它從先前的檢查點載入以恢復訓練
載入 state dicts:我們在 meta device 下初始化模型,然後呼叫 fully_shard 將 model.parameters() 從普通的 torch.Tensor 轉換為 DTensor。從 torch.load 讀取完整 state dict 後,我們可以呼叫 distribute_tensor 將普通的 torch.Tensor 轉換為 DTensor,使用與 model.state_dict() 相同的 placement 和 device mesh。最後,我們可以呼叫 model.load_state_dict 將 DTensor state dicts 載入到模型中。
from torch.distributed.tensor import distribute_tensor
# mmap=True reduces CPU memory usage
full_sd = torch.load(
"checkpoints/model_state_dict.pt",
mmap=True,
weights_only=True,
map_location='cpu',
)
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, assign=True)
儲存 state dicts:model.state_dict() 返回一個 DTensor state dict。我們可以透過呼叫 full_tensor() 將 DTensor 轉換為普通的 torch.Tensor。內部它會發出一個跨 ranks 的 all-gather 來獲取未分片的普通 torch.Tensor 引數。對於 rank 0,full_param.cpu() 會逐個將張量解除安裝到 CPU,以避免因未分片引數而導致 GPU 記憶體峰值。
sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if torch.distributed.get_rank() == 0:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
torch.save(cpu_state_dict, "checkpoints/model_state_dict.pt")
最佳化器狀態字典也類似工作(程式碼)。使用者可以自定義上述 DTensor 指令碼以與第三方檢查點一起使用。
如果不需要自定義,我們可以直接使用 DCP API 來支援單節點和多節點訓練。
帶 DCP API 的 State Dict#
命令:torchrun --nproc_per_node 2 train.py --dcp-api
第一次執行時,它會為模型和最佳化器建立檢查點
第二次執行時,它從先前的檢查點載入以恢復訓練
載入 state dicts:我們可以使用 set_model_state_dict 將一個完整的 state dict 載入到一個 FSDP2 模型中。使用 broadcast_from_rank0=True,我們可以在 rank 0 上僅載入完整的 state dict,以避免 CPU 記憶體峰值。DCP 會分片張量並將其廣播到其他 ranks。
from torch.distributed.checkpoint.state_dict import set_model_state_dict
set_model_state_dict(
model=model,
model_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
儲存 state dicts:get_model_state_dict 配合 full_state_dict=True 和 cpu_offload=True 會 all-gather 張量並將其解除安裝到 CPU。它的工作方式與 DTensor API 類似。
from torch.distributed.checkpoint.state_dict import get_model_state_dict
model_state_dict = get_model_state_dict(
model=model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
)
)
torch.save(model_state_dict, "model_state_dict.pt")
有關使用 set_optimizer_state_dict 和 get_optimizer_state_dict 載入和儲存最佳化器狀態字典,請參考 pytorch/examples。
FSDP1 到 FSDP2 遷移指南#
讓我們看一個 FSDP 用法和等效的 fully_shard 用法的示例。我們將重點介紹關鍵差異並提出遷移步驟。
原始 FSDP() 用法
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
model = FSDP(model, auto_wrap_policy=policy)
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)
新的 fully_shard() 用法
with torch.device("meta"):
model = Transformer()
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Initialize the model after sharding
model.to_empty(device="cuda")
model.reset_parameters()
遷移步驟
替換匯入
直接實現你的“策略”(將
fully_shard應用於所需的子層)使用
fully_shard而不是FSDP來包裝你的根模型刪除
param_init_fn並手動呼叫model.reset_parameters()替換其他 FSDP1 kwargs(見下文)
sharding_strategy
FULL_SHARD:
reshard_after_forward=TrueSHARD_GRAD_OP:
reshard_after_forward=FalseHYBRID_SHARD:
reshard_after_forward=True配合 2D device mesh_HYBRID_SHARD_ZERO2:
reshard_after_forward=False配合 2D device mesh
cpu_offload
CPUOffload.offload_params=False:
offload_policy=NoneCPUOffload.offload_params = True:
offload_policy=CPUOffloadPolicy()
backward_prefetch
BACKWARD_PRE: 始終使用
BACKWARD_POST: 不支援
mixed_precision
buffer_dtype被省略,因為 fully_shard 不分片 buffersfully_shard 的
cast_forward_inputs對映到 FSDP1 中的cast_forward_inputs和cast_root_forward_inputsoutput_dtype是 fully_shard 的新配置
device_id: 從 device_mesh 的 device 推斷
sync_module_states=True/False: 已移至 DCP。使用者可以使用 set_model_state_dict 配合 broadcast_from_rank0=True 從 rank0 廣播 state dicts。
forward_prefetch: 可以透過以下方式手動控制預取:
使用這些 API 控制自動預取:set_modules_to_forward_prefetch 和 set_modules_to_backward_prefetch
limit_all_gathers: 不再需要,因為 fully_shard 移除了 CPU 同步
use_orig_params: 始終使用原始引數(不再是 flat parameter)
no_sync(): set_requires_gradient_sync
ignored_params 和 ignored_states: ignored_params