評價此頁

分散式最佳化器#

建立時間: 2021年03月01日 | 最後更新時間: 2025年06月16日

警告

目前使用 CUDA 張量時不支援分散式最佳化器

torch.distributed.optim 暴露了 DistributedOptimizer,它接受一個遠端引數列表(RRef),並在引數所在的 worker 上本地執行最佳化器。分散式最佳化器可以使用任何本地最佳化器 基類 在每個 worker 上應用梯度。

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source]#

DistributedOptimizer 接受分佈在不同 worker 上的引數的遠端引用,並在每個引數上本地應用給定的最佳化器。

此類使用 get_gradients() 來檢索特定引數的梯度。

step() 的併發呼叫,無論來自同一客戶端還是不同客戶端,都會在每個 worker 上進行序列化處理——因為每個 worker 的最佳化器一次只能處理一組梯度。然而,不能保證整個前向-後向-最佳化器序列一次只為一個客戶端執行。這意味著正在應用的梯度可能不對應於給定 worker 上執行的最新前向傳遞。此外,跨 worker 也沒有保證的順序。

DistributedOptimizer 預設情況下使用 TorchScript 建立本地最佳化器,以便在多執行緒訓練(例如分散式模型並行)的情況下,最佳化器更新不會被 Python 全域性直譯器鎖 (GIL) 阻塞。此功能目前對大多數最佳化器都已啟用。您也可以按照 PyTorch 教程中的 示例 為自己的自定義最佳化器啟用 TorchScript 支援。

引數
  • optimizer_class (optim.Optimizer) – 在每個 worker 上例項化的最佳化器類。

  • params_rref (list[RRef]) – 到本地或遠端引數的 RRefs 列表,用於最佳化。

  • args – 傳遞給每個 worker 上最佳化器建構函式的引數。

  • kwargs – 傳遞給每個 worker 上最佳化器建構函式的引數。

示例:
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)
step(context_id)[source]#

執行一次最佳化步驟。

這將呼叫包含要最佳化的引數的每個 worker 上的 torch.optim.Optimizer.step(),並一直阻塞直到所有 worker 返回。提供的 context_id 將用於檢索包含應應用於引數的梯度的相應 context

引數

context_id – 我們應該為其執行最佳化器步驟的 autograd 上下文 id。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source]#

包裝任意 torch.optim.Optimizer 並執行 post-local SGD。此最佳化器在每一步執行本地最佳化器。預熱階段之後,它在應用本地最佳化器後定期平均引數。

引數
  • optim (Optimizer) – 本地最佳化器。

  • averager (ModelAverager) – 用於執行 post-localSGD 演算法的模型平均器例項。

示例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState,
>>>   post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim,
>>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()
load_state_dict(state_dict)[source]#

這與 torch.optim.Optimizerload_state_dict() 相同,但還恢復了模型平均器的步數到提供的 state_dict 中儲存的值。

如果在 state_dict 中沒有 "step" 條目,它將發出警告並將模型平均器的步數初始化為 0。

state_dict()[source]#

這與 torch.optim.Optimizerstate_dict() 相同,但增加了一個額外的條目來記錄模型平均器的步數到檢查點,以確保重新載入時不會再次導致不必要的預熱。

step()[source]#

執行一次最佳化步驟(引數更新)。

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]#

包裝任意 optim.Optimizer 並將其狀態分片到組內的 ranks 中。

共享方式描述如 ZeRO 所述。

每個 rank 的本地最佳化器例項僅負責更新大約 1 / world_size 的引數,因此只需要保留 1 / world_size 的最佳化器狀態。在本地更新引數後,每個 rank 將其引數廣播給所有其他對等節點,以使所有模型副本保持相同的狀態。ZeroRedundancyOptimizer 可以與 torch.nn.parallel.DistributedDataParallel 結合使用,以減少每個 rank 的峰值記憶體消耗。

ZeroRedundancyOptimizer 使用排序貪心演算法在每個 rank 上打包一定數量的引數。每個引數屬於一個 rank,並且不跨 rank 分割。分割槽是任意的,可能不匹配引數的註冊或使用順序。

引數

params (Iterable) – 一個 Iterable,包含 torch.Tensordict,它們將跨 rank 進行分片。

關鍵字引數
  • optimizer_class (torch.nn.Optimizer) – 本地最佳化器的類。

  • process_group (ProcessGroup, 可選) – torch.distributed ProcessGroup(預設值:由 torch.distributed.init_process_group()初始化的 dist.group.WORLD)。

  • parameters_as_bucket_view (bool, 可選) – 如果為 True,則引數會被打包到 bucket 中以加快通訊速度,並且 param.data 欄位指向不同偏移量的 bucket 檢視;如果為 False,則每個單獨的引數都會單獨通訊,並且每個 params.data 保持不變(預設值:False)。

  • overlap_with_ddp (bool, 可選) – 如果為 True,則 step()DistributedDataParallel 的梯度同步重疊;這需要 (1) 要麼為 optimizer_class 引數提供一個函式式最佳化器,要麼提供一個具有函式式等價物的最佳化器,並且 (2) 註冊一個從 ddp_zero_hook.py 中的函式之一構造的 DDP 通訊鉤子;引數將被打包到與 DistributedDataParallel 匹配的 bucket 中,這意味著 parameters_as_bucket_view 引數將被忽略。如果為 False,則 step() 在後向傳播之後(正常情況下)獨立執行。(預設值:False

  • **defaults – 任何尾隨引數,它們會被原樣轉發給本地最佳化器。

示例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

目前,ZeroRedundancyOptimizer 要求所有傳入的引數都是相同的密集型別。

警告

如果您傳遞 overlap_with_ddp=True,請注意以下幾點:鑑於當前實現中的 DistributedDataParallelZeroRedundancyOptimizer 重疊的方式,前兩到三次訓練迭代在最佳化器步驟中不執行引數更新,具體取決於 static_graph=Falsestatic_graph=True,分別是。這是因為需要關於 DistributedDataParallel 使用的梯度分桶策略的資訊,如果 static_graph=False,則該資訊在第二次前向傳遞時確定;如果 static_graph=True,則在第三次前向傳遞時確定。為了彌補這一點,一種選擇是預置虛擬輸入。

警告

ZeroRedundancyOptimizer 是實驗性的,可能會發生變化。

add_param_group(param_group)[source]#

將一個引數組新增到 Optimizerparam_groups 中。

這在微調預訓練網路時可能很有用,因為在訓練過程中,可以使凍結層可訓練並將其新增到 Optimizer 中。

引數

param_group (dict) – 指定要最佳化的引數以及特定於組的最佳化選項。

警告

此方法處理更新所有分片上的分片,但需要在所有 rank 上呼叫。在部分 rank 上呼叫此方法將導致訓練掛起,因為通訊原語是根據所管理的引數呼叫的,並且期望所有 rank 都參與同一組引數。

consolidate_state_dict(to=0)[source]#

在目標 rank 上合併一個 state_dict 列表(每個 rank 一個)。

引數

to (int) – 接收最佳化器狀態的 rank(預設值:0)。

引發

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 例項完全初始化(一旦 DistributedDataParallel 梯度 bucket 重建後)之前呼叫此方法。

警告

這需要在所有 rank 上呼叫。

property join_device: device#

返回預設裝置。

join_hook(**kwargs)[source]#

返回 ZeRO join 鉤子。

它透過在最佳化器步驟中對集體通訊進行影子處理,從而支援在不均勻輸入上的訓練。

在呼叫此鉤子之前,必須正確設定梯度。

引數

kwargs (dict) – 一個 dict,包含任何用於在執行時修改 join 鉤子行為的關鍵字引數;共享相同 join 上下文管理器所有 Joinable 例項將接收相同的 kwargs 值。

此鉤子不支援任何關鍵字引數;即 kwargs 未使用。

property join_process_group: Any#

返回程序組。

load_state_dict(state_dict)[source]#

從輸入的 state_dict 載入與給定 rank 相關的狀態,並根據需要更新本地最佳化器。

引數

state_dict (dict) – 最佳化器狀態;應為呼叫 state_dict() 返回的物件。

引發

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 例項完全初始化(一旦 DistributedDataParallel 梯度 bucket 重建後)之前呼叫此方法。

state_dict()[source]#

返回此 rank 已知的最後一個全域性最佳化器狀態。

引發

RuntimeError – 如果 overlap_with_ddp=True 且在 ZeroRedundancyOptimizer 例項完全初始化(一旦 DistributedDataParallel 梯度 bucket 重建後)之前呼叫此方法;或者如果在呼叫 consolidate_state_dict() 之前呼叫此方法。

返回型別

dict[str, Any]

step(closure=None, **kwargs)[source]#

執行一次最佳化步驟並跨所有 rank 同步引數。

引數

closure (Callable) – 一個重新評估模型並返回損失的閉包;對大多數最佳化器來說是可選的。

返回

可選損失,取決於底層本地最佳化器。

返回型別

Optional[float]

注意

任何額外的引數都會原樣傳遞給基礎最佳化器。