使用 Fully Sharded Data Parallel (FSDP) 進行高階模型訓練#
創建於:2024年10月31日 | 最後更新:2024年10月31日 | 最後驗證:2024年11月5日
作者: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao
PyTorch 的 Fully Sharded Data Parallel 模組:一個用於跨資料並行工作節點分片模組引數的包裝器。
資料並行工作節點。
PyTorch 1.12 或更高版本
閱讀 FSDP API。
本教程作為 PyTorch 1.12 版本的一部分,介紹了 Fully Sharded Data Parallel (FSDP) 的更高階功能。要熟悉 FSDP,請參閱 FSDP 入門教程。
在本教程中,我們將以 HuggingFace (HF) T5 模型進行文字摘要微調作為工作示例。
該示例使用了 Wikihow 資料集,為簡單起見,我們將在單個節點、配備 8 個 A100 GPU 的 P4dn 例項上展示訓練。我們現在有幾篇部落格文章( (連結1), (連結2))以及一篇關於多節點叢集上大規模 FSDP 訓練的 論文。
FSDP 是一個生產就緒的軟體包,專注於易用性、效能和長期支援。FSDP 的主要優點之一是減少每個 GPU 的記憶體佔用。這使得訓練更大的模型所需的總記憶體低於 DDP,並利用計算和通訊的重疊來高效地訓練模型。這種減少的記憶體壓力可以用來訓練更大的模型或增加批次大小,從而可能提高整體訓練吞吐量。您可以在此處閱讀更多關於 PyTorch FSDP 的資訊:這裡。
本教程中的 FSDP 功能#
Transformer 自動包裝策略
混合精度
在裝置上初始化 FSDP 模型
分片策略
向後預取
透過流式傳輸到 CPU 來儲存模型檢查點
FSDP 工作原理回顧#
高層次上,FSDP 的工作方式如下:
在建構函式中
分片模型引數,每個 rank 只保留自己的分片。
在前向傳播中
執行 `all_gather` 來收集所有 rank 的所有分片,以恢復此 FSDP 單元的完整引數,並執行前向計算。
丟棄已收集的非擁有引數分片以釋放記憶體。
在反向傳播中
執行 `all_gather` 來收集所有 rank 的所有分片,以恢復此 FSDP 單元的完整引數,並執行反向計算。
丟棄非擁有引數以釋放記憶體。
執行 `reduce_scatter` 來同步梯度。
微調 HF T5#
HF T5 預訓練模型有四種不同的尺寸,從 6000 萬引數的小模型到 110 億引數的 XXL 模型。在本教程中,我們將展示使用 FSDP 和 WikiHow 資料集對 T5 3B 模型進行文字摘要微調。本教程的主要重點是展示 FSDP 中有助於訓練 3B 以上大模型的各種可用功能。此外,我們還將涵蓋 Transformer 模型的特定功能。本教程的程式碼可在 Pytorch examples 中找到。
設定
1.1 安裝最新 PyTorch
pip3 install torch torchvision torchaudio
1.2 資料集設定
請建立一個 `data` 資料夾,從 wikihowAll.csv 和 wikihowSep.cs 下載 WikiHow 資料集,並將它們放在 `data` 資料夾中。我們將使用來自 summarization_dataset 的 wikihow 資料集。
接下來,我們將以下程式碼片段新增到 Python 指令碼 "T5_training.py" 中。
注意
本教程的完整原始碼可在 PyTorch examples 中找到。
1.3 匯入必要的包
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing_wrapper)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
1.4 分散式訓練設定。我們使用兩個輔助函式來初始化分散式訓練的程序,並在訓練完成後進行清理。在本教程中,我們將使用 torch elastic,透過 torchrun,它將自動設定 `RANK` 和 `WORLD_SIZE` 工作節點。
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
2.1 設定 HuggingFace T5 模型
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
return model, tokenizer
我們還添加了幾個輔助函式來處理日期和格式化記憶體指標。
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
2.2 定義訓練函式
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
2.3 定義驗證函式
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
2.4 定義包裝了 FSDP 的分散式訓練函式
def fsdp_main(args):
model, tokenizer = setup_model("t5-base")
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup()
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
2.5 解析引數並設定主函式
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
fsdp_main(args)
使用 torchrun 執行訓練
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
Transformer 包裝策略#
如 前一教程 所述,`auto_wrap_policy` 是 FSDP 的一項功能,可以輕鬆地自動分片給定模型,並將模型、最佳化器和梯度分片放入不同的 FSDP 單元中。
對於 Transformer 編碼器-解碼器等一些架構,模型的某些部分(如嵌入表)由編碼器和解碼器共享。在這種情況下,我們需要將嵌入表放在外部 FSDP 單元中,以便編碼器和解碼器都能訪問它。此外,透過註冊 Transformer 的層類,分片計劃可以大大提高通訊效率。在 PyTorch 1.12 中,FSDP 增加了此支援,現在我們有了一個用於 Transformer 的包裝策略。
可以如下建立,其中 T5Block 代表 T5 Transformer 層類(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy)
要檢視包裝後的模型,您可以輕鬆地列印模型並直觀地檢查分片和 FSDP 單元。
混合精度#
FSDP 支援靈活的混合精度訓練,允許使用任意的降低精度型別(如 fp16 或 bfloat16)。目前 BFloat16 僅在 Ampere GPU 上可用,因此在使用它之前您需要確認本機支援。例如,在 V100 上,仍然可以執行 BFloat16,但由於是非本機執行,可能會導致顯著的減速。
要檢查 BFloat16 是否得到本機支援,您可以使用以下方法:
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
FSDP 中混合精度的優點之一是允許對引數、梯度和緩衝區提供不同精度級別的精細控制,如下所示:
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
請注意,如果未指定某種型別(引數、reduce、buffer),則它們將完全不進行轉換。
這種靈活性允許使用者進行精細控制,例如僅設定以降低精度進行梯度通訊,而所有引數/緩衝區計算都以全精度進行。這在節點內通訊是主要瓶頸且引數/緩衝區必須保持全精度以避免準確性問題的情況下可能很有用。這可以透過以下策略實現:
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
在 2.4 中,我們只需將相關的混合精度策略新增到 FSDP 包裝器中:
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen)
在我們的實驗中,我們觀察到使用 BFloat16 進行訓練可以提高高達 4 倍的速度,並且在某些實驗中記憶體減少了約 30%,可用於增加批次大小。
在裝置上初始化 FSDP 模型#
在 1.12 版本中,FSDP 支援一個 `device_id` 引數,用於在由 `device_id` 指定的裝置上初始化輸入的 CPU 模組。當整個模型不適合單個 GPU 但適合主機 CPU 記憶體時,這非常有用。當指定 `device_id` 時,FSDP 將在每個 FSDP 單元的基礎上將模型移動到指定裝置,避免 GPU OOM 問題,同時初始化速度比基於 CPU 的初始化快得多。
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device())
向後預取#
向後預取設定控制著何時請求下一個 FSDP 單元的引數。透過將其設定為 `BACKWARD_PRE`,可以在當前單元的計算開始之前,更早地開始請求並接收下一個 FSDP 單元的引數。這使得 `all_gather` 通訊和梯度計算重疊,從而可能提高訓練速度,但會略微增加記憶體消耗。可以在 2.4 中的 FSDP 包裝器中按如下方式使用:
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
`backward_prefetch` 有兩種模式:`BACKWARD_PRE` 和 `BACKWARD_POST`。`BACKWARD_POST` 意味著下一個 FSDP 單元的引數在當前 FSDP 單元處理完成之前不會被請求,從而最大限度地減少記憶體開銷。在某些情況下,使用 `BACKWARD_PRE` 可以將模型訓練速度提高 1-20%,對於更大的模型,甚至可以提高更多。
透過流式傳輸到 Rank0 CPU 來儲存模型檢查點#
要使用 FULL_STATE_DICT 儲存模型檢查點(以與本地模型相同的方式儲存模型),PyTorch 1.12 提供了一些實用程式來支援儲存更大的模型。
首先,可以指定一個 `FullStateDictConfig`,允許 state_dict 僅在 rank 0 上填充並解除安裝到 CPU。
使用此配置時,FSDP 將 `allgather` 模型引數,一次將它們解除安裝到 CPU,僅在 rank 0 上進行。當 state_dict 最終儲存時,它將僅在 rank 0 上填充,幷包含 CPU 張量。這避免了對於大於單個 GPU 記憶體的模型可能出現的 OOM 問題,並允許使用者儲存大小約為使用者機器可用 CPU RAM 的模型。
此功能可以如下執行:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
if rank == 0:
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
摘要#
在本教程中,我們介紹了 PyTorch 1.12 中 FSDP 的許多新功能,並以 HF T5 作為執行示例。使用合適的包裝策略,特別是對於 Transformer 模型,結合混合精度和向後預取,應該可以加快您的訓練執行。此外,諸如在裝置上初始化模型以及透過流式傳輸到 CPU 來儲存檢查點等功能,在處理大型模型時應該有助於避免 OOM 錯誤。
我們正在積極為下一個版本新增 FSDP 的新功能。如果您有反饋、功能請求、問題或在使用 FSDP 時遇到問題,請隨時透過在 PyTorch Github 倉庫 中開啟 issue 來聯絡我們。