評價此頁

通用 Join 上下文管理器#

創建於:2025 年 6 月 6 日 | 最後更新於:2025 年 6 月 6 日

通用 Join 上下文管理器有助於在輸入不均勻的情況下進行分散式訓練。本頁概述了相關類的 API:JoinJoinableJoinHook。有關教程,請參閱 使用 Join 上下文管理器進行帶有不均勻輸入的分散式訓練

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source]#

此類定義了通用的 Join 上下文管理器,它允許在程序加入後呼叫自定義鉤子。

這些鉤子應該模擬未加入程序的集體通訊,以防止掛起和錯誤,並確保演算法的正確性。有關鉤子定義的詳細資訊,請參閱 JoinHook

警告

上下文管理器要求每個參與的 Joinable 在其自身的每個迭代集體通訊之前呼叫 notify_join_context() 方法,以確保正確性。

警告

上下文管理器要求 JoinHook 物件中的所有 process_group 屬性都相同。如果存在多個 JoinHook 物件,則使用第一個物件的 device。程序組和裝置資訊用於檢查未加入的程序,並用於在啟用 throw_on_early_termination 時通知程序丟擲異常,這兩者都使用 all-reduce。

引數
  • joinables (List[Joinable]) – 參與的 Joinable s 的列表;它們的鉤子按給定順序迭代。

  • enable (bool) – 一個啟用不均勻輸入檢測的標誌;將其設定為 False 會停用上下文管理器的功能,僅應在使用者知道輸入不會不均勻時才設定(預設值:True)。

  • throw_on_early_termination (bool) – 一個控制檢測到不均勻輸入時是否丟擲異常的標誌(預設值:False)。

示例

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static notify_join_context(joinable)[source]#

通知 Join 上下文管理器呼叫程序尚未加入。

然後,如果 throw_on_early_termination=True,則檢查是否檢測到不均勻輸入(即是否有程序已加入),並在此情況下丟擲異常。

此方法應從 Joinable 物件在其每個迭代集體通訊的開始處呼叫。例如,這應該在 DistributedDataParallel 的 forward pass 開始時呼叫。

只有傳遞給上下文管理器的第一個 Joinable 物件在此方法中執行集體通訊,而對於其他物件,此方法是空的。

引數

joinable (Joinable) – 呼叫此方法的 Joinable 物件。

返回

如果 joinable 是第一個傳遞給上下文管理器的物件,則為 all-reduce 的非同步工作控制代碼,用於通知上下文管理器該程序尚未加入;否則為 None

class torch.distributed.algorithms.Joinable[source]#

這定義了一個可連線類的抽象基類。

可連線類(繼承自 Joinable)應實現 join_hook(),它返回一個 JoinHook 例項,此外還需要實現 join_device()join_process_group(),它們分別返回裝置和程序組資訊。

abstract property join_device: device#

返回用於執行 Join 上下文管理器所需的集體通訊的裝置。

abstract join_hook(**kwargs)[source]#

為給定的 Joinable 返回一個 JoinHook 例項。

引數

kwargs (dict) – 一個包含任何關鍵字引數的 dict,用於在執行時修改 JoinHook 的行為;共享相同 Join 上下文管理器的所有 Joinable 例項都會收到相同的 kwargs 值。

返回型別

JoinHook

abstract property join_process_group: Any#

返回 Join 上下文管理器本身所需的集體通訊的程序組。

class torch.distributed.algorithms.JoinHook[source]#

這定義了一個 Join Hook,它在 Join 上下文管理器中提供了兩個入口點。

入口點:一個主鉤子,在存在未加入的程序時會反覆呼叫;一個後鉤子,在所有程序都加入後呼叫一次。

要為通用 Join 上下文管理器實現 Join Hook,請定義一個繼承自 JoinHook 的類,並根據需要重寫 main_hook()post_hook()

main_hook()[source]#

在存在未加入程序時呼叫此鉤子,以模擬訓練迭代中的集體通訊。

訓練迭代,即一次前向傳播、後向傳播和最佳化器步驟。

post_hook(is_last_joiner)[source]#

在所有程序都加入後呼叫鉤子。

它會傳入一個額外的 bool 引數 is_last_joiner,指示該 rank 是否是最後加入的 rank 之一。

引數

is_last_joiner (bool) – 如果 rank 是最後加入的 rank 之一,則為 True;否則為 False