GroupNorm#
- class torch.nn.modules.normalization.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[source]#
對輸入的小批次應用組歸一化。
此層實現了論文 Group Normalization 中描述的操作。
輸入通道被分成
num_groups組,每組包含num_channels / num_groups個通道。num_channels必須能被num_groups整除。均值和標準差在每個組內單獨計算。 和 是可學習的每個通道仿射變換引數向量,其大小為num_channels,如果affine設定為True。方差透過有偏估計量計算,等同於 torch.var(input, unbiased=False)。此層在訓練和評估模式下都使用從輸入資料計算的統計量。
- 引數
- 形狀
輸入: 其中
輸出: (與輸入形狀相同)
示例
>>> input = torch.randn(20, 6, 10, 10) >>> # Separate 6 channels into 3 groups >>> m = nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input)