torch.nn.utils.convert_conv2d_weight_memory_format#
- torch.nn.utils.convert_conv2d_weight_memory_format(module, memory_format)[原始碼]#
將
nn.Conv2d.weight的memory_format轉換為指定的memory_format。此轉換會遞迴應用於巢狀的
nn.Module,包括module本身。請注意,它僅更改 memory_format,而不改變每個維度的語義。此函式用於促進計算採用 NHWC 核心,這可以為計算能力 >= 7.0 的 CUDA 裝置上的 fp16 資料提供可觀的加速。注意
呼叫
model.to(memory_format=torch.channels_last)比實用函式convert_conv2d_weight_memory_format更激進。任何具有 4D 權重的層都會受到model.to的影響,但這些層不一定受益於轉換為指定的memory_format。一個我們可以確定的地方是,cuDNN 中卷積的 NHWC (channels_last) 轉換,因為在 NHWC 中運行卷積是有益的,即使在必須對輸入張量進行排列的情況下。因此,我們的策略是僅將卷積的權重轉換為 channels_last。這可以確保:1. 使用快速卷積核心,其優勢可能超過排列的開銷(如果輸入格式不同)。2. 不會對不受益於 memory_format 轉換的層應用不必要的排列。
最佳情況是,卷積層之間的層是 channels last 相容的。輸入張量在遇到第一個卷積層時會被排列成 channels last,並保持在該記憶體格式。因此,後續的卷積層將不需要排列其輸入張量。
如果 channels last 不相容的層位於卷積層之間,我們需要將輸入張量排列回連續格式(contiguous format)以供該層使用。輸入張量將以連續格式透過剩餘的層,並在遇到另一個卷積層時被排列成 channels last。將該排列傳播到更早的層是沒有意義的,因為大多數層對
memory_format都相當不敏感。當 PyTorch 支援排列的融合時,這個說法可能會改變,因為可能存在比卷積層前立即融合排列更好的位置。
- 引數
module (nn.Module) –
nn.Conv2d&nn.ConvTranspose2d或容器nn.Modulememory_format (memory_format) – 使用者指定的
memory_format,例如torch.channels_last或torch.contiguous_format
- 返回
具有更新的
nn.Conv2d的原始模組- 返回型別
_M
示例
>>> input = torch.randint( ... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda" ... ) >>> model = nn.Sequential( >>> nn.Conv2d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last) >>> model = nn.utils.convert_conv2d_weight_memory_format( ... model, torch.channels_last ... ) >>> out = model(input)