訓練指令碼#
創建於:2021年5月04日 | 最後更新於:2023年2月09日
如果你的訓練指令碼使用 torch.distributed.launch 工作,它將繼續與 torchrun 一起工作,但有以下不同:
無需手動傳遞
RANK、WORLD_SIZE、MASTER_ADDR和MASTER_PORT。可以提供
rdzv_backend和rdzv_endpoint。對於大多數使用者,這將設定為c10d(參見 rendezvous)。預設的rdzv_backend建立一個非彈性的 rendezvous,其中rdzv_endpoint包含主地址。請確保你的指令碼中包含
load_checkpoint(path)和save_checkpoint(path)的邏輯。當任意數量的工作程序失敗時,我們將使用相同的程式引數重新啟動所有工作程序,因此你將丟失直到最近一次檢查點之間的所有進度(參見 elastic launch)。use_env標誌已被移除。如果你曾透過解析--local-rank選項來解析本地 rank,則需要從環境變數LOCAL_RANK獲取本地 rank(例如,int(os.environ["LOCAL_RANK"]))。
下面是一個訓練指令碼的說明性示例,該指令碼在每個 epoch 進行檢查點儲存,因此在失敗時丟失的最大進度相當於一個完整的 epoch 訓練。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
有關符合 torchelastic 的訓練指令碼的具體示例,請訪問我們的 示例 頁面。