Patching Batch Norm#
創建於:2023年01月03日 | 最後更新於:2025年06月11日
發生了什麼?#
Batch Norm 需要對 running_mean 和 running_var 進行原地更新,其大小與輸入相同。Functorch 不支援對接受批處理張量的常規張量進行原地更新(即不允許 regular.add_(batched))。因此,當對單個模組的輸入批次進行 vmap 時,我們會遇到此錯誤。
如何解決#
最受支援的方法之一是將 BatchNorm 切換為 GroupNorm。選項 1 和 2 支援這一點。
所有這些選項都假設您不需要 running stats。如果您正在使用一個模組,這意味著假設您不會在評估模式下使用 batch norm。如果您有在評估模式下使用 running batch norm 和 vmap 的用例,請提交一個 issue。
選項 1:更改 BatchNorm#
如果您想更改為 GroupNorm,請將所有 BatchNorm 替換為:
BatchNorm2d(C, G, track_running_stats=False)
這裡的 C 與原始 BatchNorm 中的 C 相同。G 是將 C 分割成的組數。因此,C % G == 0,作為回退,您可以將 C == G,這意味著每個通道將單獨處理。
如果您必須使用 BatchNorm 並且您自己構建了模組,您可以更改模組以不使用 running stats。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats 標誌設定為 False。
BatchNorm2d(64, track_running_stats=False)
選項 2:torchvision 引數#
一些 torchvision 模型,如 resnet 和 regnet,可以接受 norm_layer 引數。這些引數通常預設為 BatchNorm2d,如果它們被預設設定的話。
相反,您可以將其設定為 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
這裡,再次強調,c % g == 0,所以作為回退,請將 g = c。
如果您一定要使用 BatchNorm,請確保使用不使用 running stats 的版本。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
選項 3:functorch 的 patching#
functorch 添加了一些功能,允許快速、原地修改模組以不使用 running stats。更改 norm 層更易出錯,因此我們未提供此選項。如果您有一個網路,並且希望 BatchNorm 不使用 running stats,您可以執行 replace_all_batch_norm_modules_ 以原地修改模組,使其不使用 running stats。
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
選項 4:eval 模式#
在 eval 模式下執行時,running_mean 和 running_var 不會更新。因此,vmap 可以支援此模式。
model.eval()
vmap(model)(x)
model.train()