評價此頁

DataParallel#

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]#

在模組級別實現資料並行。

這個容器透過在批次維度上分塊來將輸入分割到指定的裝置上,從而實現給定 module 的並行應用(其他物件將每個裝置複製一次)。在前向傳播中,模組會在每個裝置上覆制,每個副本處理一部分輸入。在後向傳播過程中,每個副本的梯度將被累加到原始模組中。

批次大小應大於使用的 GPU 數量。

警告

建議使用 DistributedDataParallel 進行多 GPU 訓練,而不是此類,即使只有一個節點。請參見: 使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel分散式資料並行

允許將任意位置引數和關鍵字引數傳遞給 DataParallel,但某些型別會得到特殊處理。張量將在指定的維度(預設為 0)上被**散佈**。元組、列表和字典型別將被淺複製。其他型別將在不同執行緒之間共享,如果它們在模型的正向傳播中被寫入,可能會損壞。

並行化的 module 在執行此 DataParallel 模組之前,必須將其引數和緩衝區放在 device_ids[0] 上。

警告

在每次正向傳播中,module 會在每個裝置上**複製**,因此在 forward 中對執行模組的任何更新都將丟失。例如,如果 module 有一個在每次 forward 中遞增的計數器屬性,它將始終保持初始值,因為更新是在副本上完成的,而這些副本在 forward 之後就會被銷燬。然而,DataParallel 保證 device[0] 上的副本的引數和緩衝區將與基礎並行化的 module 共享儲存。因此,對 device[0] 上的引數或緩衝區的**原地**更新將被記錄。例如,BatchNorm2dspectral_norm() 依賴於此行為來更新緩衝區。

警告

module 及其子模組上定義的正向和反向鉤子將被呼叫 len(device_ids) 次,每次使用位於特定裝置上的輸入。特別地,僅保證鉤子相對於對應裝置上的操作以正確的順序執行。例如,不能保證透過 register_forward_pre_hook() 設定的鉤子會在 所有 len(device_ids)forward() 呼叫之前執行,但保證每個這樣的鉤子會在對應裝置上對應的 forward() 呼叫之前執行。

警告

moduleforward() 中返回一個標量(即 0 維張量)時,此包裝器將返回一個長度等於資料並行使用的裝置數量的向量,其中包含每個裝置的結果。

注意

在使用 Module(包裝在 DataParallel 中)的 pack sequence -> recurrent network -> unpack sequence 模式時存在一個細微之處。詳情請參閱 FAQ 中的 我的迴圈網路無法與資料並行配合使用 部分。

引數
  • module (Module) – 要並行化的模組

  • device_ids (list of int or torch.device) – CUDA 裝置(預設:所有裝置)

  • output_device (int or torch.device) – 輸出的裝置位置(預設:device_ids[0])

變數

module (Module) – 要並行化的模組

示例

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU