通用 Join 上下文管理器#
創建於:2025 年 6 月 6 日 | 最後更新於:2025 年 6 月 6 日
通用 Join 上下文管理器有助於在輸入不均勻的情況下進行分散式訓練。本頁概述了相關類的 API:Join、Joinable 和 JoinHook。有關教程,請參閱 使用 Join 上下文管理器進行帶有不均勻輸入的分散式訓練。
- class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source]#
此類定義了通用的 Join 上下文管理器,它允許在程序加入後呼叫自定義鉤子。
這些鉤子應該模擬未加入程序的集體通訊,以防止掛起和錯誤,並確保演算法的正確性。有關鉤子定義的詳細資訊,請參閱
JoinHook。警告
上下文管理器要求每個參與的
Joinable在其自身的每個迭代集體通訊之前呼叫notify_join_context()方法,以確保正確性。警告
上下文管理器要求
JoinHook物件中的所有process_group屬性都相同。如果存在多個JoinHook物件,則使用第一個物件的device。程序組和裝置資訊用於檢查未加入的程序,並用於在啟用throw_on_early_termination時通知程序丟擲異常,這兩者都使用 all-reduce。- 引數
示例
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring
- class torch.distributed.algorithms.Joinable[source]#
這定義了一個可連線類的抽象基類。
可連線類(繼承自
Joinable)應實現join_hook(),它返回一個JoinHook例項,此外還需要實現join_device()和join_process_group(),它們分別返回裝置和程序組資訊。