DataParallel#
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]#
在模組級別實現資料並行。
這個容器透過在批次維度上分塊來將輸入分割到指定的裝置上,從而實現給定
module的並行應用(其他物件將每個裝置複製一次)。在前向傳播中,模組會在每個裝置上覆制,每個副本處理一部分輸入。在後向傳播過程中,每個副本的梯度將被累加到原始模組中。批次大小應大於使用的 GPU 數量。
警告
建議使用
DistributedDataParallel進行多 GPU 訓練,而不是此類,即使只有一個節點。請參見: 使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel 和 分散式資料並行。允許將任意位置引數和關鍵字引數傳遞給 DataParallel,但某些型別會得到特殊處理。張量將在指定的維度(預設為 0)上被**散佈**。元組、列表和字典型別將被淺複製。其他型別將在不同執行緒之間共享,如果它們在模型的正向傳播中被寫入,可能會損壞。
並行化的
module在執行此DataParallel模組之前,必須將其引數和緩衝區放在device_ids[0]上。警告
在每次正向傳播中,
module會在每個裝置上**複製**,因此在forward中對執行模組的任何更新都將丟失。例如,如果module有一個在每次forward中遞增的計數器屬性,它將始終保持初始值,因為更新是在副本上完成的,而這些副本在forward之後就會被銷燬。然而,DataParallel保證device[0]上的副本的引數和緩衝區將與基礎並行化的module共享儲存。因此,對device[0]上的引數或緩衝區的**原地**更新將被記錄。例如,BatchNorm2d和spectral_norm()依賴於此行為來更新緩衝區。警告
在
module及其子模組上定義的正向和反向鉤子將被呼叫len(device_ids)次,每次使用位於特定裝置上的輸入。特別地,僅保證鉤子相對於對應裝置上的操作以正確的順序執行。例如,不能保證透過register_forward_pre_hook()設定的鉤子會在 所有len(device_ids)次forward()呼叫之前執行,但保證每個這樣的鉤子會在對應裝置上對應的forward()呼叫之前執行。警告
當
module在forward()中返回一個標量(即 0 維張量)時,此包裝器將返回一個長度等於資料並行使用的裝置數量的向量,其中包含每個裝置的結果。注意
在使用
Module(包裝在DataParallel中)的pack sequence -> recurrent network -> unpack sequence模式時存在一個細微之處。詳情請參閱 FAQ 中的 我的迴圈網路無法與資料並行配合使用 部分。- 引數
module (Module) – 要並行化的模組
device_ids (list of int or torch.device) – CUDA 裝置(預設:所有裝置)
output_device (int or torch.device) – 輸出的裝置位置(預設:device_ids[0])
- 變數
module (Module) – 要並行化的模組
示例
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU