評價此頁

torch.nn.utils.memory_format.convert_conv3d_weight_memory_format#

torch.nn.utils.memory_format.convert_conv3d_weight_memory_format(module, memory_format)[source]#

nn.Conv3d.weightmemory_format 轉換為指定的 memory_format。此轉換會遞迴地應用於巢狀的 nn.Module,包括 module 本身。請注意,它僅更改 memory_format,而不改變每個維度的語義。此函式用於促進計算採用 NHWC 核心,這在計算能力 >= 7.0 的 CUDA 裝置上對 fp16 資料提供了可觀的加速。

注意

呼叫 model.to(memory_format=torch.channels_last_3d) 比工具函式 convert_conv3d_weight_memory_format 更具侵略性。任何具有 4d 權重的層都會受到 model.to 的影響,而這些層不一定受益於轉換為指定的 memory_format。我們確信的一點是,cuDNN 中對卷積進行 NDHWC (channels_last_3d) 轉換是有益的,因為它有利於以 NDHWC 格式運行卷積,即使在必須對輸入張量應用置換的情況下也是如此。

因此,我們的策略是僅將卷積的權重轉換為 channels_last_3d。這確保了:1. 將使用快速卷積核心,其優勢可能超過置換的開銷(如果輸入格式不同)。2. 不會對不受益於 memory_format 轉換的層應用不必要的置換。

最佳情況是,卷積層之間的層是 channels last 相容的。輸入張量在遇到第一個卷積層時會被排列成 channels last,並保持在該記憶體格式。因此,後續的卷積層將不需要排列其輸入張量。

如果 channels last 不相容的層位於卷積層之間,我們需要將輸入張量排列回連續格式(contiguous format)以供該層使用。輸入張量將以連續格式透過剩餘的層,並在遇到另一個卷積層時被排列成 channels last。將該排列傳播到更早的層是沒有意義的,因為大多數層對 memory_format 都相當不敏感。

當 PyTorch 支援排列的融合時,這個說法可能會改變,因為可能存在比卷積層前立即融合排列更好的位置。

引數
  • module (nn.Module) – nn.Conv3d & nn.ConvTranspose3d 或容器 nn.Module

  • memory_format (memory_format) – 使用者指定的 memory_format,例如 torch.channels_lasttorch.contiguous_format

返回

帶有更新後的 nn.Conv3d 的原始模組

返回型別

_M

示例

>>> input = torch.randint(
...     1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>>     nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(
...     model, torch.channels_last_3d
... )
>>> out = model(input)