評價此頁
torchrun">

簡介 || 什麼是 DDP || 單節點多 GPU 訓練 || 容錯 || 多節點訓練 || minGPT 訓練

使用 torchrun 進行容錯分散式訓練#

創建於:2022 年 9 月 27 日 | 最後更新:2024 年 11 月 12 日 | 最後驗證:2024 年 11 月 5 日

作者: Suraj Subramanian

您將學到什麼
  • 使用 torchrun 啟動多 GPU 訓練任務

  • 儲存和載入訓練任務的快照

  • 為平穩重啟構建訓練指令碼

GitHub 上檢視本教程使用的程式碼

先決條件
  • DDP 的高階概述

  • 熟悉DDP 程式碼

  • 具有多個 GPU 的機器(本教程使用 AWS p3.8xlarge 例項)

  • 已安裝支援 CUDA 的 PyTorch

請跟隨下面的影片或在 youtube 上觀看。

在分散式訓練中,單個程序的失敗可能會中斷整個訓練任務。由於這裡的故障敏感性可能更高,因此使您的訓練指令碼具有魯棒性尤為重要。您可能還希望您的訓練任務是彈性的,例如,計算資源可以在任務過程中動態地加入和離開。

PyTorch 提供了一個名為 torchrun 的實用程式,它提供了容錯和彈性訓練功能。當發生故障時,torchrun 會記錄錯誤並嘗試從訓練任務的最後一個儲存的“快照”自動重啟所有程序。

快照儲存的內容不僅僅是模型狀態;它還可以包含有關已執行的 epoch 數量、最佳化器狀態或訓練任務連續性所需的任何其他有狀態屬性的詳細資訊。

為什麼使用 torchrun#

torchrun 處理分散式訓練的細節,因此您無需這樣做。例如,

  • 您無需設定環境變數或顯式傳遞 rankworld_sizetorchrun 會分配這些以及其他幾個 環境變數

  • 無需在指令碼中呼叫 mp.spawn;您只需要一個通用的 main() 入口點,然後使用 torchrun 啟動指令碼。這樣,同一個指令碼就可以在非分散式、單節點和多節點設定中執行。

  • 從最後一個儲存的訓練快照平穩地重啟訓練。

平穩重啟#

為了實現平穩重啟,您應該像這樣構建您的訓練指令碼

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

如果發生故障,torchrun 將終止所有程序並重新啟動它們。每個程序入口點首先載入並初始化最後一個儲存的快照,然後從那裡繼續訓練。因此,在任何故障發生時,您只會丟失自最後一個儲存的快照以來的訓練進度。

在彈性訓練中,每當發生任何成員資格更改(新增或刪除節點)時,torchrun 將終止並在可用裝置上生成程序。擁有此結構可確保您的訓練任務能夠繼續進行,而無需手動干預。

multigpu.pymultigpu_torchrun.py 的差異

程序組初始化#

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

使用 torchrun 提供的環境變數#

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

儲存和載入快照#

定期將所有相關資訊儲存在快照中,使我們的訓練任務能夠在中斷後無縫恢復。

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

在 Trainer 建構函式中載入快照#

當恢復中斷的訓練任務時,您的指令碼將首先嚐試載入快照以繼續訓練。

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

恢復訓練#

訓練可以從最後一個執行的 epoch 繼續,而無需從頭開始。

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

執行指令碼#

就像執行非多程序指令碼一樣呼叫您的入口點函式;torchrun 會自動生成程序。

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

進一步閱讀#