使用 Join 上下文管理器進行不均勻輸入分散式訓練#
創建於:2021 年 8 月 4 日 | 最後更新:2025 年 9 月 3 日 | 最後驗證:2024 年 11 月 5 日
作者:Andrew Gu
注意
在 github 上檢視和編輯本教程。
注意
Join 作為一項原型功能在 PyTorch 1.10 中引入。此 API 可能會發生更改。
在本教程中,您將看到
Join 上下文管理器的概述。
一個如何將上下文管理器與
DistributedDataParallel結合使用的示例。一個如何將上下文管理器與
DistributedDataParallel和ZeroRedundancyOptimizer結合使用的示例。一個傳遞關鍵字引數給上下文管理器的示例。
深入瞭解 Join 上下文管理器的工作原理。
一個展示如何使玩具類與上下文管理器相容的示例。
要求#
PyTorch 1.10+
什麼是 Join?#
在 使用分散式資料並行入門 - 基本用例 中,您瞭解了使用 DistributedDataParallel 進行資料並行訓練的通用框架。這會在每次反向傳播時隱式排程 all-reduces 操作以同步不同 rank 之間的梯度。此類 集體通訊 需要程序組中的所有 rank 參與,因此如果某個 rank 的輸入較少,其他 rank 將會掛起或報錯(取決於後端)。更廣泛地說,這個問題會存在於任何執行每迭代同步集體通訊的類中。
Join 是一個上下文管理器,用於包裝每個 rank 的訓練迴圈,以支援不均勻輸入進行訓練。該上下文管理器允許儘早耗盡輸入的 rank(即,儘早加入)來“隱藏”尚未加入的 rank 所執行的集體通訊。通訊的隱藏方式由 hooks 指定。
將 Join 與 DistributedDataParallel 結合使用#
PyTorch 的 DistributedDataParallel 可與 Join 上下文管理器開箱即用。以下是一個示例用法:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
這將產生以下輸出(來自 rank 0 和 rank 1 的 print() 語句的順序可能任意):
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
注意
在此通用 Join 上下文管理器引入之前,DistributedDataParallel 提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等同於使用 with model.join():。現有 DistributedDataParallel.join() 的一個限制是它不允許多個參與類,例如 DistributedDataParallel 和 ZeroRedundancyOptimizer 同時使用。
將 Join 與 DistributedDataParallel 和 ZeroRedundancyOptimizer 結合使用#
Join 上下文管理器不僅可以與單個類一起工作,還可以與多個類一起工作。PyTorch 的 ZeroRedundancyOptimizer 也與上下文管理器相容,因此在此,我們考察如何修改之前的示例以同時使用 DistributedDataParallel 和 ZeroRedundancyOptimizer。
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
這將產生與之前相同的輸出。值得注意的更改是額外將 ZeroRedundancyOptimizer 例項傳遞給了 Join()。
傳遞關鍵字引數#
類可以在執行時提供關鍵字引數來修改其在上下文管理器中的行為。例如,DistributedDataParallel 提供了一個 divide_by_initial_world_size 引數,該引數決定梯度是除以初始 world size 還是除以有效 world size(即非加入 rank 的數量)。此類關鍵字引數可以直接傳遞給上下文管理器。
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
警告
傳遞給上下文管理器的關鍵字引數會在所有參與類之間共享。這不應成為限制,因為我們不期望出現多個 Joinable 需要相同引數的不同設定的情況。儘管如此,這一點還是值得注意的。
Join 是如何工作的?#
現在我們已經看到了一些使用 Join 上下文管理器的初步示例,讓我們深入瞭解它是如何工作的。這將幫助您更深入地瞭解它提供的全部功能,併為您準備好使自己的自定義類相容。在這裡,我們將介紹 Join 類以及支援類 Joinable 和 JoinHook。
Joinable#
首先,與 Join 上下文管理器相容的類必須繼承自抽象基類 Joinable。特別是,一個 Joinable 必須實現:
join_hook(self, **kwargs) -> JoinHook
這會返回 Joinable 的 JoinHook 例項,該例項決定已加入的程序應如何“隱藏” Joinable 在每次迭代中執行的集體通訊。
join_device(self) -> torch.device
這會返回一個裝置,供 Join 上下文管理器用於執行集體通訊,例如 torch.device("cuda:0") 或 torch.device("cpu")。
join_process_group(self) -> ProcessGroup
這會返回一個程序組,供 Join 上下文管理器用於執行集體通訊。
特別是,join_device 和 join_process_group 是必需的屬性,以確保上下文管理器能夠排程已加入和未加入程序之間的集體通訊。一種用法是使用 all-reduce 在每次迭代中計算未加入程序的數量。另一種用法是實現 throw_on_early_termination=True 所需的機制,我們將在下文進一步解釋。
DistributedDataParallel 和 ZeroRedundancyOptimizer 已經繼承了 Joinable 並實現了上述方法,這就是為什麼我們可以在之前的示例中直接使用它們。
Joinable 類應確保呼叫 Joinable 建構函式,因為它會初始化一個 JoinConfig 例項,該例項由上下文管理器內部使用以確保正確性。這將作為欄位 _join_config 儲存在每個 Joinable 中。
JoinHook#
接下來,我們分解 JoinHook 類。一個 JoinHook 為上下文管理器提供了兩個入口點:
main_hook(self) -> None
此 hook 在存在尚未加入的 rank 時,被每個已加入的 rank 反覆呼叫。它旨在“隱藏” Joinable 在每個訓練迭代中執行的集體通訊(例如,一次前向傳播、反向傳播和最佳化器步進)。
post_hook(self, is_last_joiner: bool) -> None
所有 rank 加入後,此 hook 會被呼叫一次。它會接收一個額外的布林引數 is_last_joiner,該引數指示該 rank 是否是最後加入的 rank 之一。該引數可能對同步有用。
為了給出這些 hook 可能樣子的具體示例,提供的 ZeroRedundancyOptimizer main hook 正常執行一次最佳化器步進,因為已加入的 rank 仍然負責更新和同步其引數分片;而提供的 DistributedDataParallel post-hook 則將最終更新的模型從最後一個加入的 rank 之一廣播出去,以確保它在所有 rank 中都相同。
Join#
最後,讓我們看看這些如何融入 Join 類本身。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
正如我們在前面的示例中所見,建構函式接受參與訓練迴圈的 Joinable 物件列表。這些應該是每個迭代中執行集體通訊的類。
enable 是一個布林值,如果知道不會出現不均勻輸入,則可以將其設定為 False。在這種情況下,上下文管理器將變得空泛,類似於 contextlib.nullcontext()。這也會停用參與的 Joinable 中的 join 相關計算。
throw_on_early_termination 是一個布林值,如果設定為 True,則會在檢測到不均勻輸入時讓每個 rank 引發一個異常。這對於不符合上下文管理器要求的情況很有用,最常見的情況是當 DistributedDataParallel 與具有 SyncBatchNorm 層的模型一起使用時,存在來自不同類的集體通訊可能任意交織。在這種情況下,應將此引數設定為 True,以便應用程式邏輯可以捕獲異常並確定如何繼續。
核心邏輯發生在
__exit__()方法中,該方法在存在未加入的 rank 時迴圈,呼叫每個Joinable的 main hook,然後一旦所有 rank 都加入,就會呼叫它們的 post hook。main hook 和 post hook 都按照傳遞Joinable的順序進行迭代。上下文管理器要求未加入的程序傳送心跳訊號。因此,每個
Joinable類在其每次迭代的集體通訊之前都應呼叫Join.notify_join_context()。上下文管理器將確保只有傳遞的第一個Joinable實際傳送心跳。
警告
如上所述關於 throw_on_early_termination,Join 上下文管理器與某些類的組合不相容。Joinable 的 JoinHook 必須是可序列化的,因為每個 hook 在繼續執行下一個之前都會完全執行。換句話說,兩個 hook 不能重疊。此外,目前 main hook 和 post hook 都以相同的確定性順序進行迭代。如果這似乎是一個主要限制,我們可能會修改 API 以允許可自定義的排序。
使玩具類與 Join 相容#
由於上一節介紹了一些概念,讓我們透過一個玩具示例在實踐中看看它們。在這裡,我們將實現一個類,該類計算在自己的 rank 加入之前,所有 rank 中看到的輸入數量。這應該提供一個基本的想法,說明如何使您自己的類與 Join 上下文管理器相容。
具體來說,以下程式碼讓每個 rank 打印出 (1) 在其加入之前所有 rank 中看到的輸入數量,以及 (2) 所有 rank 中的總輸入數量。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
由於 rank 0 看到 5 個輸入,rank 1 看到 6 個輸入,這將產生輸出:
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
一些需要強調的關鍵點
一個
Counter例項每迭代執行一次 all-reduce,因此 main hook 也執行一次 all-reduce 來“隱藏”它。Counter類在其__call__()方法的開頭呼叫Join.notify_join_context(),因為這是在其每次迭代的集體通訊(即其 all-reduce)之前的位置。is_last_joiner引數用於確定 post-hooks 中的廣播源。我們將
sync_max_count關鍵字引數傳遞給上下文管理器,然後該引數會被轉發到Counter的 join hook。