分散式最佳化器#
建立時間: 2021年03月01日 | 最後更新時間: 2025年06月16日
警告
目前使用 CUDA 張量時不支援分散式最佳化器
torch.distributed.optim 暴露了 DistributedOptimizer,它接受一個遠端引數列表(RRef),並在引數所在的 worker 上本地執行最佳化器。分散式最佳化器可以使用任何本地最佳化器 基類 在每個 worker 上應用梯度。
- class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source]#
DistributedOptimizer 接受分佈在不同 worker 上的引數的遠端引用,並在每個引數上本地應用給定的最佳化器。
此類使用
get_gradients()來檢索特定引數的梯度。對
step()的併發呼叫,無論來自同一客戶端還是不同客戶端,都會在每個 worker 上進行序列化處理——因為每個 worker 的最佳化器一次只能處理一組梯度。然而,不能保證整個前向-後向-最佳化器序列一次只為一個客戶端執行。這意味著正在應用的梯度可能不對應於給定 worker 上執行的最新前向傳遞。此外,跨 worker 也沒有保證的順序。DistributedOptimizer 預設情況下使用 TorchScript 建立本地最佳化器,以便在多執行緒訓練(例如分散式模型並行)的情況下,最佳化器更新不會被 Python 全域性直譯器鎖 (GIL) 阻塞。此功能目前對大多數最佳化器都已啟用。您也可以按照 PyTorch 教程中的 示例 為自己的自定義最佳化器啟用 TorchScript 支援。
- 引數
optimizer_class (optim.Optimizer) – 在每個 worker 上例項化的最佳化器類。
params_rref (list[RRef]) – 到本地或遠端引數的 RRefs 列表,用於最佳化。
args – 傳遞給每個 worker 上最佳化器建構函式的引數。
kwargs – 傳遞給每個 worker 上最佳化器建構函式的引數。
- 示例:
>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
- step(context_id)[source]#
執行一次最佳化步驟。
這將呼叫包含要最佳化的引數的每個 worker 上的
torch.optim.Optimizer.step(),並一直阻塞直到所有 worker 返回。提供的context_id將用於檢索包含應應用於引數的梯度的相應context。- 引數
context_id – 我們應該為其執行最佳化器步驟的 autograd 上下文 id。
- class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source]#
包裝任意
torch.optim.Optimizer並執行 post-local SGD。此最佳化器在每一步執行本地最佳化器。預熱階段之後,它在應用本地最佳化器後定期平均引數。- 引數
optim (Optimizer) – 本地最佳化器。
averager (ModelAverager) – 用於執行 post-localSGD 演算法的模型平均器例項。
示例
>>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step()
- load_state_dict(state_dict)[source]#
這與
torch.optim.Optimizer的load_state_dict()相同,但還恢復了模型平均器的步數到提供的state_dict中儲存的值。如果在
state_dict中沒有"step"條目,它將發出警告並將模型平均器的步數初始化為 0。
- state_dict()[source]#
這與
torch.optim.Optimizer的state_dict()相同,但增加了一個額外的條目來記錄模型平均器的步數到檢查點,以確保重新載入時不會再次導致不必要的預熱。
- class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]#
包裝任意
optim.Optimizer並將其狀態分片到組內的 ranks 中。共享方式描述如 ZeRO 所述。
每個 rank 的本地最佳化器例項僅負責更新大約
1 / world_size的引數,因此只需要保留1 / world_size的最佳化器狀態。在本地更新引數後,每個 rank 將其引數廣播給所有其他對等節點,以使所有模型副本保持相同的狀態。ZeroRedundancyOptimizer可以與torch.nn.parallel.DistributedDataParallel結合使用,以減少每個 rank 的峰值記憶體消耗。ZeroRedundancyOptimizer使用排序貪心演算法在每個 rank 上打包一定數量的引數。每個引數屬於一個 rank,並且不跨 rank 分割。分割槽是任意的,可能不匹配引數的註冊或使用順序。- 引數
params (
Iterable) – 一個Iterable,包含torch.Tensor或dict,它們將跨 rank 進行分片。- 關鍵字引數
optimizer_class (
torch.nn.Optimizer) – 本地最佳化器的類。process_group (
ProcessGroup, 可選) –torch.distributedProcessGroup(預設值:由torch.distributed.init_process_group()初始化的dist.group.WORLD)。parameters_as_bucket_view (bool, 可選) – 如果為
True,則引數會被打包到 bucket 中以加快通訊速度,並且param.data欄位指向不同偏移量的 bucket 檢視;如果為False,則每個單獨的引數都會單獨通訊,並且每個params.data保持不變(預設值:False)。overlap_with_ddp (bool, 可選) – 如果為
True,則step()與DistributedDataParallel的梯度同步重疊;這需要 (1) 要麼為optimizer_class引數提供一個函式式最佳化器,要麼提供一個具有函式式等價物的最佳化器,並且 (2) 註冊一個從ddp_zero_hook.py中的函式之一構造的 DDP 通訊鉤子;引數將被打包到與DistributedDataParallel匹配的 bucket 中,這意味著parameters_as_bucket_view引數將被忽略。如果為False,則step()在後向傳播之後(正常情況下)獨立執行。(預設值:False)**defaults – 任何尾隨引數,它們會被原樣轉發給本地最佳化器。
示例
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
警告
目前,
ZeroRedundancyOptimizer要求所有傳入的引數都是相同的密集型別。警告
如果您傳遞
overlap_with_ddp=True,請注意以下幾點:鑑於當前實現中的DistributedDataParallel與ZeroRedundancyOptimizer重疊的方式,前兩到三次訓練迭代在最佳化器步驟中不執行引數更新,具體取決於static_graph=False或static_graph=True,分別是。這是因為需要關於DistributedDataParallel使用的梯度分桶策略的資訊,如果static_graph=False,則該資訊在第二次前向傳遞時確定;如果static_graph=True,則在第三次前向傳遞時確定。為了彌補這一點,一種選擇是預置虛擬輸入。警告
ZeroRedundancyOptimizer 是實驗性的,可能會發生變化。
- add_param_group(param_group)[source]#
將一個引數組新增到
Optimizer的param_groups中。這在微調預訓練網路時可能很有用,因為在訓練過程中,可以使凍結層可訓練並將其新增到
Optimizer中。- 引數
param_group (dict) – 指定要最佳化的引數以及特定於組的最佳化選項。
警告
此方法處理更新所有分片上的分片,但需要在所有 rank 上呼叫。在部分 rank 上呼叫此方法將導致訓練掛起,因為通訊原語是根據所管理的引數呼叫的,並且期望所有 rank 都參與同一組引數。
- consolidate_state_dict(to=0)[source]#
在目標 rank 上合併一個
state_dict列表(每個 rank 一個)。- 引數
to (int) – 接收最佳化器狀態的 rank(預設值:0)。
- 引發
RuntimeError – 如果
overlap_with_ddp=True且在ZeroRedundancyOptimizer例項完全初始化(一旦DistributedDataParallel梯度 bucket 重建後)之前呼叫此方法。
警告
這需要在所有 rank 上呼叫。
- join_hook(**kwargs)[source]#
返回 ZeRO join 鉤子。
它透過在最佳化器步驟中對集體通訊進行影子處理,從而支援在不均勻輸入上的訓練。
在呼叫此鉤子之前,必須正確設定梯度。
- 引數
kwargs (dict) – 一個
dict,包含任何用於在執行時修改 join 鉤子行為的關鍵字引數;共享相同 join 上下文管理器所有Joinable例項將接收相同的kwargs值。
此鉤子不支援任何關鍵字引數;即
kwargs未使用。
- load_state_dict(state_dict)[source]#
從輸入的
state_dict載入與給定 rank 相關的狀態,並根據需要更新本地最佳化器。- 引數
state_dict (dict) – 最佳化器狀態;應為呼叫
state_dict()返回的物件。- 引發
RuntimeError – 如果
overlap_with_ddp=True且在ZeroRedundancyOptimizer例項完全初始化(一旦DistributedDataParallel梯度 bucket 重建後)之前呼叫此方法。
- state_dict()[source]#
返回此 rank 已知的最後一個全域性最佳化器狀態。
- 引發
RuntimeError – 如果
overlap_with_ddp=True且在ZeroRedundancyOptimizer例項完全初始化(一旦DistributedDataParallel梯度 bucket 重建後)之前呼叫此方法;或者如果在呼叫consolidate_state_dict()之前呼叫此方法。- 返回型別