評價此頁

Patching Batch Norm#

創建於:2023年01月03日 | 最後更新於:2025年06月11日

發生了什麼?#

Batch Norm 需要對 running_mean 和 running_var 進行原地更新,其大小與輸入相同。Functorch 不支援對接受批處理張量的常規張量進行原地更新(即不允許 regular.add_(batched))。因此,當對單個模組的輸入批次進行 vmap 時,我們會遇到此錯誤。

如何解決#

最受支援的方法之一是將 BatchNorm 切換為 GroupNorm。選項 1 和 2 支援這一點。

所有這些選項都假設您不需要 running stats。如果您正在使用一個模組,這意味著假設您不會在評估模式下使用 batch norm。如果您有在評估模式下使用 running batch norm 和 vmap 的用例,請提交一個 issue。

選項 1:更改 BatchNorm#

如果您想更改為 GroupNorm,請將所有 BatchNorm 替換為:

BatchNorm2d(C, G, track_running_stats=False)

這裡的 C 與原始 BatchNorm 中的 C 相同。G 是將 C 分割成的組數。因此,C % G == 0,作為回退,您可以將 C == G,這意味著每個通道將單獨處理。

如果您必須使用 BatchNorm 並且您自己構建了模組,您可以更改模組以不使用 running stats。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats 標誌設定為 False。

BatchNorm2d(64, track_running_stats=False)

選項 2:torchvision 引數#

一些 torchvision 模型,如 resnet 和 regnet,可以接受 norm_layer 引數。這些引數通常預設為 BatchNorm2d,如果它們被預設設定的話。

相反,您可以將其設定為 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

這裡,再次強調,c % g == 0,所以作為回退,請將 g = c

如果您一定要使用 BatchNorm,請確保使用不使用 running stats 的版本。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

選項 3:functorch 的 patching#

functorch 添加了一些功能,允許快速、原地修改模組以不使用 running stats。更改 norm 層更易出錯,因此我們未提供此選項。如果您有一個網路,並且希望 BatchNorm 不使用 running stats,您可以執行 replace_all_batch_norm_modules_ 以原地修改模組,使其不使用 running stats。

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

選項 4:eval 模式#

在 eval 模式下執行時,running_mean 和 running_var 不會更新。因此,vmap 可以支援此模式。

model.eval()
vmap(model)(x)
model.train()