評價此頁

torch.cuda.comm.gather#

torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source]#

從多個 GPU 裝置收集張量。

引數
  • tensors (Iterable[Tensor]) – 要收集的張量可迭代物件。除 dim 以外的所有維度的張量大小必須匹配。

  • dim (int, optional) – 張量將沿此維度連線。預設值:0

  • destination (torch.device, str, or int, optional) – 輸出裝置。可以是 CPU 或 CUDA。預設值:當前 CUDA 裝置。

  • out (Tensor, optional, keyword-only) – 用於儲存收集結果的張量。其大小必須與 tensors 的大小匹配,除了 dim 維度,該維度的大小必須等於 sum(tensor.size(dim) for tensor in tensors)。可以位於 CPU 或 CUDA 上。

注意

destination 不能與 out 同時指定。

返回

  • 如果指定了 destination,則

    一個位於 destination 裝置的張量,它是將 tensors 沿 dim 連線的結果。

  • 如果指定了 out

    包含 tensors 沿 dim 連線結果的 out 張量。