評價此頁

分散式 Checkpoint (DCP) 入門#

建立日期: 2023 年 10 月 02 日 | 最後更新: 2025 年 07 月 10 日 | 最後驗證: 2024 年 11 月 05 日

作者: Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin

注意

editgithub 上檢視和編輯此教程。

先決條件

在分散式訓練過程中儲存 AI 模型檢查點可能具有挑戰性,因為引數和梯度分佈在不同的訓練器之間,並且可用訓練器的數量可能會在您恢復訓練時發生變化。PyTorch分散式 Checkpoint (DCP) 可以幫助簡化此過程。

在本教程中,我們將展示如何使用 DCP API 和一個簡單的 FSDP 包裝模型。

DCP 如何工作#

torch.distributed.checkpoint() 支援從多個程序並行儲存和載入模型。您可以使用此模組在任意數量的程序上並行儲存,然後在載入時跨不同的叢集拓撲進行重新分片。

此外,透過使用 torch.distributed.checkpoint.state_dict() 中的模組,DCP 支援在分散式環境中優雅地處理 state_dict 的生成和載入。這包括管理模型和最佳化器之間的完全限定名稱 (FQN) 對映,以及為 PyTorch 提供的並行性設定預設引數。

DCP 在幾個重要方面與 torch.save()torch.load() 不同:

  • 它為每個檢查點生成多個檔案,每個程序至少有一個檔案。

  • 它在原地操作,意味著模型應首先分配其資料,DCP 使用該儲存空間。

  • DCP 為有狀態物件 (在 torch.distributed.checkpoint.stateful 中正式定義) 提供特殊處理,如果定義了 state_dictload_state_dict 方法,則會自動呼叫它們。

注意

本教程中的程式碼在一個 8-GPU 伺服器上執行,但可以輕鬆推廣到其他環境。

如何使用 DCP#

這裡我們使用一個 toy 模型包裝 FSDP 進行演示。同樣,這些 API 和邏輯也可應用於更大的模型以進行檢查點儲存。

儲存#

現在,讓我們建立一個 toy 模組,用 FSDP 包裝它,用一些 dummy 輸入資料餵給它,然後儲存它。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

請繼續檢查 checkpoint 目錄。您應該看到與下面顯示的相同數量的檢查點檔案。例如,如果您有 8 個裝置,您應該看到 8 個檔案。

Distributed Checkpoint

載入#

儲存後,讓我們建立相同的 FSDP 包裝模型,並將儲存的 state dict 從儲存載入到模型中。您可以以相同的 world size 或不同的 world size 載入。

請注意,在載入之前,您需要呼叫 model.state_dict() 並將其傳遞給 DCP 的 load_state_dict() API。這與 torch.load() fundamentally 不同,因為 torch.load() 僅需要載入前的檢查點路徑。我們需要在載入前使用 state_dict 的原因如下:

  • DCP 使用從 model state_dict 預分配的儲存來從檢查點目錄載入。在載入過程中,傳入的 state_dict 將被原地更新。

  • DCP 在載入前需要模型的 sharding 資訊,以支援重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import fully_shard

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = fully_shard(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想將儲存的檢查點載入到一個非分散式設定的非 FSDP 包裝模型中,例如用於推理,您也可以使用 DCP 來完成。預設情況下,DCP 以單程式多資料 (SPMD) 風格儲存和載入分散式 state_dict。但是,如果未初始化程序組,DCP 會推斷意圖是以“非分散式”風格進行儲存或載入,即完全在當前程序中。

注意

多程式多資料 (MPMD) 的分散式檢查點支援仍在開發中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式#

還有一個尚未提到的缺點是,DCP 儲存的檢查點格式與使用 torch.save 生成的格式本質上是不同的。當用戶希望與習慣 torch.save 格式的使用者共享模型,或者只想為應用程式新增格式靈活性時,這可能會成為一個問題。在這種情況下,我們在 torch.distributed.checkpoint.format_utils 中提供了 format_utils 模組。

為方便使用者,提供了一個命令列實用程式,其格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上面的命令中,modetorch_to_dcpdcp_to_torch 之一。

或者,也提供了方法供希望直接轉換檢查點的使用者使用。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

結論#

總而言之,我們學習瞭如何使用 DCP 的 save()load() API,以及它們與 torch.save()torch.load() 的區別。此外,我們還學習瞭如何使用 get_state_dict()set_state_dict() 來在 state dict 生成和載入過程中自動管理特定於並行性的 FQN 和預設值。

有關更多資訊,請參閱以下內容: