入門全分片資料並行(FSDP)#
作者: Hamid Shojanazeri, Yanli Zhao, Shen Li
注意
FSDP1 已棄用。請檢視 FSDP2 教程。
大規模訓練 AI 模型是一項艱鉅的任務,需要大量的計算能力和資源。處理這些非常大的模型的訓練也帶來了相當大的工程複雜性。PyTorch 1.11 中釋出的 PyTorch FSDP 使這項工作變得更加容易。
在本教程中,我們將演示如何使用 FSDP API 來處理簡單的 MNIST 模型,這些模型可以擴充套件到其他更大的模型,例如 HuggingFace BERT 模型、高達 1T 引數的 GPT 3 模型。DDP MNIST 示例程式碼由 Patrick Hu 提供。
FSDP 如何工作#
在 DistributedDataParallel (DDP) 訓練中,每個程序/工作節點都擁有一個模型副本,並處理一批資料,最後使用 all-reduce 操作在不同工作節點之間彙總梯度。在 DDP 中,模型權重和最佳化器狀態會複製到所有工作節點。FSDP 是一種資料並行方法,它將模型引數、最佳化器狀態和梯度分片到 DDP 程序中。
使用 FSDP 進行訓練時,與所有工作節點上的 DDP 訓練相比,GPU 記憶體佔用更小。這使得一些非常大的模型的訓練成為可能,因為它允許更大的模型或批次大小能夠容納到裝置中。這樣做是以增加通訊量為代價的。透過諸如通訊與計算重疊之類的內部最佳化來減少通訊開銷。
FSDP 工作流程#
FSDP 的工作流程大致如下:
在建構函式中
對模型引數進行分片,每個程序只保留自己的分片
在前向傳播中
執行 all_gather 以收集來自所有程序的分片,以恢復此 FSDP 單元的完整引數
執行前向計算
丟棄剛剛收集的引數分片
在反向傳播中
執行 all_gather 以收集來自所有程序的分片,以恢復此 FSDP 單元的完整引數
執行反向計算
執行 reduce_scatter 以同步梯度
丟棄引數。
理解 FSDP 分片的一種方法是將 DDP 的梯度 all-reduce 分解為 reduce-scatter 和 all-gather。具體來說,在反向傳播過程中,FSDP 會對梯度進行 reduce-scatter 操作,確保每個程序擁有梯度的一個分片。然後,在最佳化器步驟中更新相應的引數分片。最後,在隨後的前向傳播中,它執行 all-gather 操作來收集和組合更新後的引數分片。
FSDP Allreduce#
如何使用 FSDP#
在這裡,我們使用一個玩具模型來演示在 MNIST 資料集上進行訓練。這些 API 和邏輯也可以應用於訓練更大的模型。
設定
1.1 安裝 PyTorch 和 Torchvision
有關安裝資訊,請參閱 入門指南。
我們將以下程式碼片段新增到 Python 指令碼“FSDP_mnist.py”中。
1.2 匯入必要的包
注意
本教程適用於 PyTorch 1.12 及更高版本。如果您使用的是早期版本,請將所有 size_based_auto_wrap_policy 例項替換為 default_auto_wrap_policy,並將 fsdp_auto_wrap_policy 替換為 auto_wrap_policy。
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
1.3 分散式訓練設定。正如我們之前提到的,FSDP 是一種資料並行方法,需要分散式訓練環境,因此我們使用兩個輔助函式來初始化分散式訓練程序並進行清理。
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
2.1 定義我們的玩具模型用於手寫數字分類。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
2.2 定義一個訓練函式
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction='sum')
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
2.3 定義一個驗證函式
def test(model, rank, world_size, test_loader):
model.eval()
correct = 0
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
100. * ddp_loss[1] / ddp_loss[2]))
2.4 定義一個包裝模型為 FSDP 的分散式訓練函式
注意:要儲存 FSDP 模型,我們需要在每個程序上呼叫 state_dict,然後在 Rank 0 上儲存整體狀態。
def fsdp_main(rank, world_size, args):
setup(rank, world_size)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
torch.cuda.set_device(rank)
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
model = Net().to(rank)
model = FSDP(model)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record()
for epoch in range(1, args.epochs + 1):
train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
test(model, rank, world_size, test_loader)
scheduler.step()
init_end_event.record()
if rank == 0:
init_end_event.synchronize()
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
print(f"{model}")
if args.save_model:
# use a barrier to make sure training is done on all ranks
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
cleanup()
2.5 最後,解析引數並設定主函式
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
WORLD_SIZE = torch.cuda.device_count()
mp.spawn(fsdp_main,
args=(WORLD_SIZE, args),
nprocs=WORLD_SIZE,
join=True)
我們記錄了 CUDA 事件以測量 FSDP 模型特定部分的耗時。CUDA 事件耗時為 110.85 秒。
python FSDP_mnist.py
CUDA event elapsed time on training loop 40.67462890625sec
透過使用 FSDP 包裝模型,模型將如下所示,我們可以看到模型已被包裝在一個 FSDP 單元中。或者,我們將在接下來的內容中介紹新增 auto_wrap_policy,並討論其差異。
FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Net(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
)
)
以下是 FSDP MNIST 訓練在 g4dn.12.xlarge AWS EC2 例項上(使用 4 個 GPU)的峰值記憶體使用情況,透過 PyTorch Profiler 捕獲。
FSDP 峰值記憶體使用情況#
應用 FSDP 的auto_wrap_policy,否則 FSDP 會將整個模型放在一個 FSDP 單元中,這將降低計算效率和記憶體效率。其工作原理是,假設您的模型包含 100 個 Linear 層。如果您執行 FSDP(model),則只有一個 FSDP 單元包裝整個模型。在這種情況下,allgather 將收集所有 100 個線性層的完整引數,因此不會節省用於引數分片的 CUDA 記憶體。此外,對於所有 100 個線性層只有一個阻塞的 allgather 呼叫,層之間不會發生通訊與計算重疊。
為避免這種情況,您可以傳遞一個 auto_wrap_policy,它會在滿足指定條件時(例如,大小限制)自動封存當前的 FSDP 單元並啟動一個新的。這樣您將擁有多個 FSDP 單元,並且一次只有一個 FSDP 單元需要收集完整引數。例如,假設您有 5 個 FSDP 單元,每個單元包裝 20 個線性層。那麼,在前向傳播中,第一個 FSDP 單元將 allgather 前 20 個線性層的引數,進行計算,丟棄引數,然後處理接下來的 20 個線性層。因此,在任何給定時間點,每個程序只例項化 20 個線性層的引數/梯度,而不是 100 個。
要在 2.4 中實現這一點,我們定義 auto_wrap_policy 並將其傳遞給 FSDP 包裝器。在以下示例中,my_auto_wrap_policy 定義了一個層,如果該層的引數數量大於 100,則該層可以被 FSDP 包裝或分片。如果該層的引數數量小於 100,它將與其他小型層一起被 FSDP 包裝。尋找最優的 auto wrap policy 具有挑戰性,PyTorch 將在未來新增對該配置的自動調整。沒有自動調整工具,最好透過實驗性地使用不同的 auto wrap policies 來剖析您的工作流程並找到最優策略。
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=20000
)
torch.cuda.set_device(rank)
model = Net().to(rank)
model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy)
應用 auto_wrap_policy 後,模型將如下所示
FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Net(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): FullyShardedDataParallel(
(_fsdp_wrapped_module): FlattenParamsWrapper(
(_fpw_module): Linear(in_features=9216, out_features=128, bias=True)
)
)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
)
python FSDP_mnist.py
CUDA event elapsed time on training loop 41.89130859375sec
以下是使用 auto_wrap policy 的 FSDP MNIST 訓練在 g4dn.12.xlarge AWS EC2 例項上(使用 4 個 GPU)的峰值記憶體使用情況,透過 PyTorch Profiler 捕獲。可以觀察到,與未使用 auto wrap policy 的 FSDP 相比,每個裝置的峰值記憶體使用量有所減小,從約 75 MB 降至 66 MB。
使用 Auto_wrap policy 的 FSDP 峰值記憶體使用情況#
CPU 解除安裝:如果模型非常大,即使使用 FSDP 也無法容納到 GPU 中,那麼 CPU 解除安裝會很有幫助。
目前只支援引數和梯度 CPU 解除安裝。可以透過傳遞 cpu_offload=CPUOffload(offload_params=True) 來啟用。
請注意,這目前隱式啟用了到 CPU 的梯度解除安裝,以便引數和梯度位於同一裝置上以便與最佳化器一起使用。此 API 可能會更改。預設值為 None,在這種情況下將不會解除安裝。
使用此功能可能會導致訓練速度顯著變慢,因為張量需要在主機和裝置之間頻繁複制,但它可以幫助提高記憶體效率並訓練更大規模的模型。
在 2.4 中,我們將其新增到 FSDP 包裝器中
model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True))
與 DDP 相比,如果在 2.4 中我們只是正常地將模型包裝在 DPP 中,並將更改儲存在“DDP_mnist.py”中。
model = Net().to(rank)
model = DDP(model)
python DDP_mnist.py
CUDA event elapsed time on training loop 39.77766015625sec
以下是 DDP MNIST 訓練在 g4dn.12.xlarge AWS EC2 例項上(使用 4 個 GPU)的峰值記憶體使用情況,透過 PyTorch profiler 捕獲。
使用 Auto_wrap policy 的 DDP 峰值記憶體使用情況#
考慮到我們在這裡定義的玩具示例和微小的 MNIST 模型,我們可以觀察到 DDP 和 FSDP 之間峰值記憶體使用量的差異。在 DDP 中,每個程序都持有模型的副本,因此記憶體佔用比 FSDP 高,FSDP 將模型引數、最佳化器狀態和梯度分片到 DDP 程序中。使用帶有 auto_wrap policy 的 FSDP 的峰值記憶體使用量最低,其次是 FSDP 和 DDP。
此外,在檢視時間方面,考慮到小模型和在單臺機器上執行訓練,帶或不帶 auto_wrap policy 的 FSDP 的效能幾乎與 DDP 一樣快。此示例不能代表大多數實際應用,有關 DDP 和 FSDP 之間詳細的分析和比較,請參閱此 部落格文章 。