評價此頁

DistributedDataParallel#

class torch.nn.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, init_sync=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision=None, device_mesh=None, skip_all_reduce_unused_params=False)[source]#

在模組級別實現基於 torch.distributed 的分散式資料並行。

此容器透過在每個模型副本之間同步梯度來實現資料並行。用於同步的裝置透過輸入的 process_group 指定,預設情況下為整個通訊組。請注意,DistributedDataParallel 不會對輸入進行分塊或分片到參與的 GPU 上;使用者有責任定義如何做到這一點,例如透過使用 DistributedSampler

另請參閱:基礎知識使用 nn.parallel.DistributedDataParallel 而非 multiprocessing 或 nn.DataParallel。與 torch.nn.DataParallel 中的輸入限制相同。

建立此類需要 torch.distributed 已被初始化,透過呼叫 torch.distributed.init_process_group()

DistributedDataParallel 在單節點多 GPU 資料並行訓練中,已被證明比 torch.nn.DataParallel 快得多。

要在具有 N 個 GPU 的主機上使用 DistributedDataParallel,您應該啟動 N 個程序,確保每個程序專門處理 0 到 N-1 的單個 GPU。這可以透過為每個程序設定 CUDA_VISIBLE_DEVICES 來實現,或者透過為 GPU 呼叫以下 API 來實現:

>>> torch.cuda.set_device(i)

或者為 加速器 呼叫統一 API:

>>> torch.accelerator.set_device_index(i)

其中 i 從 0 到 N-1。在每個程序中,您應該參考以下內容來構建此模組:

>>> if torch.accelerator.is_available():
>>>     device_type = torch.accelerator.current_accelerator().type
>>>     vendor_backend = torch.distributed.get_default_backend_for_device(device_type)
>>>
>>> torch.distributed.init_process_group(
>>>     backend=vendor_backend, world_size=N, init_method='...'
>>> )
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)

或者您可以使用最新的初始化 API:

>>> torch.distributed.init_process_group(device_id=i)

為了在每個節點上啟動多個程序,您可以使用 torch.distributed.launchtorch.multiprocessing.spawn

注意

有關分散式訓練所有功能的簡要介紹,請參閱 PyTorch 分散式概述

注意

DistributedDataParallel 可以與 torch.distributed.optim.ZeroRedundancyOptimizer 結合使用,以減少每個副本的最佳化器狀態記憶體佔用。有關更多詳細資訊,請參閱 ZeroRedundancyOptimizer 教程

注意

nccl 後端是目前使用 GPU 時最快且強烈推薦的後端。這適用於單節點和多節點分散式訓練。

注意

此模組還支援混合精度分散式訓練。這意味著您的模型可以具有不同型別的引數,例如 fp16fp32 的混合型別,這些混合型別引數的梯度歸約將正常工作。

注意

如果您在一個程序中使用 torch.save 來儲存檢查點模組,並在其他程序中使用 torch.load 來恢復它,請確保為每個程序正確配置了 map_location。如果沒有 map_locationtorch.load 會將模組恢復到模組被儲存的裝置上。

注意

當一個模型在 M 個節點上以 batch=N 進行訓練時,如果損失是跨批次中的例項求和(而非通常的平均),則梯度將是單節點以 batch=M*N 訓練的同等模型梯度的 M 倍小(因為不同節點之間的梯度被平均了)。當您想要獲得與本地訓練對等體在數學上等效的訓練過程時,您應該考慮這一點。但在大多數情況下,您可以將 DistributedDataParallel 包裝的模型、DataParallel 包裝的模型以及單 GPU 上的普通模型視為相同(例如,使用相同的學習率來獲得等效的批次大小)。

注意

引數永遠不會在程序之間廣播。該模組會在梯度上執行 all-reduce 步驟,並假定它們將在所有程序中以相同的方式被最佳化器修改。在每次迭代中,緩衝區(例如 BatchNorm 統計量)會從 rank 0 程序中的模組廣播到系統中的所有其他副本。

注意

如果您將 DistributedDataParallel分散式 RPC 框架 結合使用,您應該始終使用 torch.distributed.autograd.backward() 來計算梯度,並使用 torch.distributed.optim.DistributedOptimizer 來最佳化引數。

示例

>>> import torch.distributed.autograd as dist_autograd
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> import torch
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>> import torch.distributed.rpc as rpc
>>> from torch.distributed.rpc import RRef
>>>
>>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True)
>>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
>>> ddp_model = DDP(my_model)
>>>
>>> # Setup optimizer
>>> optimizer_params = [rref]
>>> for param in ddp_model.parameters():
>>>     optimizer_params.append(RRef(param))
>>>
>>> dist_optim = DistributedOptimizer(
>>>     optim.SGD,
>>>     optimizer_params,
>>>     lr=0.05,
>>> )
>>>
>>> with dist_autograd.context() as context_id:
>>>     pred = ddp_model(rref.to_here())
>>>     loss = loss_func(pred, target)
>>>     dist_autograd.backward(context_id, [loss])
>>>     dist_optim.step(context_id)

注意

DistributedDataParallel 目前對使用 torch.utils.checkpoint() 的梯度檢查點支援有限。如果檢查點使用 use_reentrant=False(推薦),DDP 將按預期工作,沒有任何限制。然而,如果檢查點使用 use_reentrant=True(預設值),並且模型中沒有未使用的引數,並且每個層最多隻檢查點一次(確保您沒有將 find_unused_parameters=True 傳遞給 DDP),DDP 將按預期工作。我們目前不支援檢查點多次層或檢查點模型中存在未使用的引數的情況。

注意

為了讓非 DDP 模型載入 DDP 模型的 state dict,需要應用 consume_prefix_in_state_dict_if_present() 來剝離 DDP state dict 中的字首“module.”,然後再載入。

警告

建構函式、前向方法以及輸出(或此模組輸出的函式)的微分是分散式同步點。在不同程序可能執行不同程式碼的情況下,請予以考慮。

警告

此模組假定所有引數在建立時都已在模型中註冊。之後不應新增或刪除引數。緩衝區也是如此。

警告

此模組假定每個分散式程序的模型中註冊的所有引數順序相同。該模組本身將按照模型註冊引數的逆序進行梯度 allreduce。換句話說,使用者有責任確保每個分散式程序擁有完全相同的模型,從而擁有完全相同的引數註冊順序。

警告

此模組允許使用非行主序連續步幅的引數。例如,您的模型可能包含一些 torch.memory_formattorch.contiguous_format,而其他引數的格式為 torch.channels_last。但是,不同程序中對應的引數必須具有相同的步幅。

警告

此模組不適用於 torch.autograd.grad()(即,它僅在梯度要累積到引數的 .grad 屬性中時才有效)。

警告

如果您計劃將此模組與 nccl 後端或 gloo 後端(使用 Infiniband)結合使用,並結合使用多工作程序的 DataLoader,請將多程序啟動方法更改為 forkserver(僅限 Python 3)或 spawn。不幸的是,Gloo(使用 Infiniband)和 NCCL2 不是 fork 安全的,如果您不更改此設定,很可能會遇到死鎖。

警告

您永遠不應嘗試在用 DistributedDataParallel 包裝模型後更改模型的引數。因為,在用 DistributedDataParallel 包裝模型時,DistributedDataParallel 的建構函式將在構造時為模型本身的所有引數註冊額外的梯度歸約函式。如果您之後更改了模型的引數,梯度歸約函式將不再匹配正確的引數集。

警告

DistributedDataParallel分散式 RPC 框架 結合使用是實驗性的,並可能發生更改。

引數
  • module (Module) – 要並行化的模組

  • device_ids (list of int or torch.device) –

    CUDA 裝置。1) 對於單裝置模組,device_ids 必須包含一個裝置 ID,該 ID 代表該程序對應的輸入模組所在的唯一 CUDA 裝置。或者,device_ids 也可以是 None。2) 對於多裝置模組和 CPU 模組,device_ids 必須是 None

    當兩種情況下的 device_ids 均為 None 時,前向傳遞的輸入資料和實際模組都必須放置在正確的裝置上。(預設:對於單裝置模組為 None

  • output_device (int or torch.device) – 單裝置 CUDA 模組的輸出裝置位置。對於多裝置模組和 CPU 模組,它必須是 None,並且模組本身決定輸出位置。(預設:單裝置模組為 device_ids[0]

  • broadcast_buffers (bool) – 在 forward 函式開始時同步(廣播)模組緩衝區的標誌。(預設:True

  • init_sync (bool) – 初始化時是否同步以驗證引數形狀並廣播引數和緩衝區。警告:如果設定為 False,使用者必須自行確保所有 rank 上的權重是相同的。(預設:True

  • process_group – 用於分散式資料 all-reduction 的程序組。如果為 None,則使用預設程序組,該程序組由 torch.distributed.init_process_group() 建立。(預設:None

  • bucket_cap_mbDistributedDataParallel 將引數分桶到多個桶中,以便每個桶的梯度歸約可以與後向計算重疊。bucket_cap_mb 控制以兆位元組 (MiB) 為單位的桶大小。如果為 None,則使用預設大小 25 MiB。(預設:None

  • find_unused_parameters (bool) – 從包裝模組的 forward 函式返回值中包含的所有張量開始遍歷 autograd 圖。未作為此圖一部分接收梯度的引數將被預先標記為準備好進行歸約。此外,可能已在包裝模組的 forward 函式中使用但未包含在損失計算中因此也不會接收梯度的引數將被預先標記為準備好進行歸約。(預設:False

  • check_reduction – 此引數已棄用。

  • gradient_as_bucket_view (bool) – 設定為 True 時,梯度將是檢視,指向 allreduce 通訊桶的不同偏移量。這可以減少峰值記憶體使用量,節省的記憶體大小等於總梯度大小。此外,它避免了在梯度和 allreduce 通訊桶之間複製的開銷。當梯度是檢視時,不能在梯度上呼叫 detach_()。如果遇到此類錯誤,請參考 torch/optim/optimizer.py 中的 zero_grad() 函式作為解決方案。請注意,梯度將在第一次迭代後成為檢視,因此峰值記憶體節省應在第一次迭代後進行檢查。

  • static_graph (bool) –

    設定為 True 時,DDP 知道訓練圖是靜態的。靜態圖意味著 1) 使用和未使用的引數集在整個訓練迴圈中不會改變;在這種情況下,使用者設定 find_unused_parameters = True 與否無關緊要。2) 圖的訓練方式在整個訓練迴圈中不會改變(即,沒有依賴於迭代的控制流)。當 static_graph 設定為 True 時,DDP 將支援過去無法支援的場景:1) 可重入後向傳播。2) 啟用檢查點多次。3) 模型具有未使用的引數時進行啟用檢查點。4) 模型引數存在於前向函式之外。5) 當存在未使用的引數時,可能提高效能,因為當 static_graph 設定為 True 時,DDP 不會在每次迭代中搜索圖來檢測未使用的引數。要檢查您是否可以將 static_graph 設定為 True,一種方法是檢查您之前的模型訓練結束時的 ddp 日誌資料,如果 ddp_logging_data.get("can_set_static_graph") == True,那麼您很可能也可以設定 static_graph = True

    示例:
    >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
    >>> # Training loop
    >>> ...
    >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
    >>> static_graph = ddp_logging_data.get("can_set_static_graph")
    

  • delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter) – 一個命名引數列表,當 param_to_hook_all_reduce 中指定的引數的梯度就緒時,它們的 all reduce 將被延遲。DDP 縮減器將忽略此引數中指定的命名引數,因此 DDP 的其他引數不適用於它們。

  • param_to_hook_all_reduce (torch.nn.Parameter) – 一個引數,用於掛鉤 delay_all_reduce_named_params 中指定的引數的延遲 all reduce。

  • skip_all_reduce_unused_params – 設定為 True 時,DDP 將跳過未使用的引數的歸約。這要求未使用的引數在整個訓練過程中在所有 rank 上保持不變。如果此條件不滿足,可能會導致不同步並導致訓練掛起。

變數

module (Module) – 要並行化的模組。

示例

>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net = torch.nn.parallel.DistributedDataParallel(model)
join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)[source]#

用於在 DDP 中處理程序間輸入不均的訓練的上下文管理器。

此上下文管理器將跟蹤已加入的 DDP 程序,並透過插入集體通訊操作來“映象”前向和後向傳播,以匹配已加入的 DDP 程序建立的操作。這將確保每個集體呼叫都有一個已加入的 DDP 程序的相應呼叫,從而防止在處理程序間輸入不均的訓練時出現的掛起或錯誤。或者,如果指定 throw_on_early_termination 標誌為 True,所有訓練器將在一個 rank 用完輸入時丟擲錯誤,從而允許根據應用程式邏輯捕獲和處理這些錯誤。

一旦所有 DDP 程序都已加入,上下文管理器將把最後一個加入程序的模型廣播到所有程序,以確保模型在所有程序中都相同(這由 DDP 保證)。

要使用此功能來實現在程序間輸入不均的訓練,只需將此上下文管理器包裝在您的訓練迴圈周圍即可。無需對模型或資料載入進行進一步修改。

警告

如果模型或訓練迴圈被此上下文管理器包裝,並且具有額外的分散式集體操作(例如模型前向傳播中的 SyncBatchNorm),則必須啟用 throw_on_early_termination 標誌。這是因為此上下文管理器不知道非 DDP 集體通訊。此標誌將導致所有 rank 在任何一個 rank 用完輸入時丟擲異常,從而允許跨所有 rank 捕獲和恢復這些錯誤。

引數
  • divide_by_initial_world_size (bool) – 如果為 True,則將梯度除以 DDP 啟動時的初始 world_size。如果為 False,則將梯度除以有效世界大小(即尚未用完輸入的 rank 數量),這會在 allreduce 期間完成。將 divide_by_initial_world_size=True 設定為確保每個輸入樣本(包括不均的輸入)在對全域性梯度的貢獻方面具有相等的權重。這是透過即使在遇到不均的輸入時也始終將梯度除以初始 world_size 來實現的。如果您將其設定為 False,我們將梯度除以剩餘的節點數。這確保了與使用較小的 world_size 進行訓練的對等性,儘管這也意味著不均的輸入對全域性梯度的貢獻會更大。通常,您希望將其設定為 True,用於訓練作業的最後幾個輸入不均的情況。在極端情況下,當輸入數量存在很大差異時,將此設定為 False 可能會提供更好的結果。

  • enable (bool) – 是否啟用不均輸入檢測。在您知道輸入在參與程序之間均勻的情況下,傳入 enable=False 來停用。預設值為 True

  • throw_on_early_termination (bool) – 當至少有一個 rank 用完輸入時,是丟擲錯誤還是繼續訓練。如果為 True,將在第一個 rank 到達資料末尾時丟擲。如果為 False,將繼續以較小的有效世界大小進行訓練,直到所有 rank 都加入。請注意,如果指定了此標誌,則將忽略 divide_by_initial_world_size 標誌。預設值為 False

示例

>>> import torch
>>> import torch.distributed as dist
>>> import os
>>> import torch.multiprocessing as mp
>>> import torch.nn as nn
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     torch.cuda.set_device(rank)
>>>     model = nn.Linear(1, 1, bias=False).to(rank)
>>>     model = torch.nn.parallel.DistributedDataParallel(
>>>         model, device_ids=[rank], output_device=rank
>>>     )
>>>     # Rank 1 gets one more input than rank 0.
>>>     inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
>>>     with model.join():
>>>         for _ in range(5):
>>>             for inp in inputs:
>>>                 loss = model(inp).sum()
>>>                 loss.backward()
>>>     # Without the join() API, the below synchronization will hang
>>>     # blocking for rank 1's allreduce to complete.
>>>     torch.cuda.synchronize(device=rank)
join_hook(**kwargs)[source]#

DDP join hook 透過在前向和後向傳播中映象通訊來實現對不均輸入的訓練。

引數

kwargs (dict) – 一個包含任何關鍵字引數的 dict,用於在執行時修改 join hook 的行為;共享同一 join 上下文管理器的所有 Joinable 例項都會收到相同的 kwargs 值。

此 hook 支援以下關鍵字引數:
divide_by_initial_world_size (bool, optional)

如果為 True,則將梯度除以 DDP 啟動時的初始 world size。如果為 False,則將梯度除以有效 world size(即未加入程序的數量),這意味著不均的輸入對全域性梯度的貢獻更大。通常,如果差異程度較小,應將其設定為 True,但在極端情況下可以設定為 False 以獲得可能更好的結果。預設為 True

no_sync()[source]#

停用 DDP 程序之間梯度同步的上下文管理器。

在此上下文中,梯度將累積在模組變數上,稍後將在退出上下文的第一個前向-後向傳播中同步。

示例

>>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
>>> with ddp.no_sync():
>>>     for input in inputs:
>>>         ddp(input).backward()  # no synchronization, accumulate grads
>>> ddp(another_input).backward()  # synchronize grads

警告

前向傳播應包含在上下文管理器內,否則梯度仍將被同步。

register_comm_hook(state, hook)[source]#

為使用者定義的 DDP 跨多個工作程序的梯度聚合註冊通訊 hook。

此 hook 對於研究人員嘗試新想法非常有用。例如,此 hook 可用於實現 GossipGrad 和梯度壓縮等演算法,這些演算法在執行 Distributed DataParallel 訓練時涉及不同的引數同步通訊策略。

引數
  • state (object) –

    傳遞給 hook,用於在訓練過程中維護任何狀態資訊。示例包括梯度壓縮中的錯誤反饋,GossipGrad 中下一個要通訊的對等方等。

    它由每個工作程序本地儲存,並由工作程序上的所有梯度張量共享。

  • hook (Callable) –

    可呼叫簽名如下:hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]

    當 bucket 準備好時,將呼叫此函式。hook 可以執行所需的任何處理,並返回一個 Future 來指示非同步工作的完成(例如:allreduce)。如果 hook 不執行任何通訊,它仍然必須返回一個已完成的 Future。Future 應包含 grad bucket 張量的新值。一旦 bucket 準備好,c10d 縮減器將呼叫此 hook,並使用從 Future 返回的張量將 grad 複製到各個引數。請注意,Future 的返回型別必須是單個張量。

    我們還提供了一個名為 get_future 的 API 來檢索與 c10d.ProcessGroup.Work 完成相關的 Future。get_future 目前支援 NCCL,也支援 GLOO 和 MPI 上的大多數操作,但不包括點對點操作(send/recv)。

警告

Grad bucket 的張量不會按 world_size 進行預除法。使用者負責在進行 allreduce 等操作時進行 world_size 的除法。

警告

DDP 通訊 hook 只能註冊一次,並且應在呼叫 backward 之前註冊。

警告

hook 返回的 Future 物件應包含一個與 grad bucket 中的張量形狀相同的單個張量。

警告

get_future API 支援 NCCL,以及部分 GLOO 和 MPI 後端(不支援點對點操作,如 send/recv),並將返回一個 torch.futures.Future

示例:

以下是一個 noop hook 的示例,它返回相同的張量。

>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>>     fut = torch.futures.Future()
>>>     fut.set_result(bucket.buffer())
>>>     return fut
>>> ddp.register_comm_hook(state=None, hook=noop)
示例:

以下是一個 Parallel SGD 演算法的示例,其中梯度在 allreduce 之前進行編碼,然後在 allreduce 之後進行解碼。

>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>>     encoded_tensor = encode(bucket.buffer())  # encode gradients
>>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
>>>     # Define the then callback to decode.
>>>     def decode(fut):
>>>         decoded_tensor = decode(fut.value()[0])  # decode gradients
>>>         return decoded_tensor
>>>     return fut.then(decode)
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)