評價此頁

使用 ZeroRedundancyOptimizer 分片最佳化器狀態#

創建於:2021 年 2 月 26 日 | 最後更新:2021 年 10 月 20 日 | 最後驗證:未驗證

在本實踐中,您將學習

要求#

什麼是 ZeroRedundancyOptimizer#

ZeroRedundancyOptimizer 的思想源自 DeepSpeed/ZeRO 專案Marian,它們將最佳化器狀態分片到分散式資料並行程序中,以減少每個程序的記憶體佔用。在 開始使用分散式資料並行 教程中,我們展示瞭如何使用 DistributedDataParallel (DDP) 來訓練模型。在該教程中,每個程序都保留一個獨立的最佳化器副本。由於 DDP 在反向傳播時已經同步了梯度,因此所有最佳化器副本在每次迭代中都會操作相同的引數和梯度值,這就是 DDP 保持模型副本處於相同狀態的方式。通常,最佳化器也會維護本地狀態。例如,Adam 最佳化器使用每個引數的 exp_avgexp_avg_sq 狀態。因此,Adam 最佳化器的記憶體消耗至少是模型大小的兩倍。基於此觀察,我們可以透過將最佳化器狀態分片到 DDP 程序中來減少最佳化器記憶體佔用。更具體地說,每個 DDP 程序中的最佳化器例項不再為所有引數建立每個引數的狀態,而是僅為所有模型引數的一個分片保留最佳化器狀態。最佳化器 step() 函式僅更新其分片中的引數,然後將其更新後的引數廣播到所有其他對等 DDP 程序,以確保所有模型副本仍然處於相同狀態。

如何使用 ZeroRedundancyOptimizer#

下面的程式碼演示瞭如何使用 ZeroRedundancyOptimizer。大部分程式碼與 分散式資料並行注意事項 中介紹的簡單 DDP 示例類似。主要區別在於 example 函式中的 if-else 子句,該子句包裝了最佳化器構造,並在 ZeroRedundancyOptimizerAdam 最佳化器之間切換。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

輸出如下所示。當使用 Adam 啟用 ZeroRedundancyOptimizer 時,最佳化器 step() 的峰值記憶體消耗是標準 Adam 記憶體消耗的一半。這符合我們的預期,因為我們將 Adam 最佳化器狀態分片到了兩個程序中。輸出還表明,使用 ZeroRedundancyOptimizer 時,模型引數在一次迭代後仍然具有相同的值(與未使用 ZeroRedundancyOptimizer 相比,引數總和是相同的)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875