GroupNorm#
- class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[原始碼]#
對輸入的小批次應用組歸一化。
該層實現了論文 Group Normalization 中描述的操作。
輸入通道被分成
num_groups組,每組包含num_channels / num_groups個通道。num_channels必須能被num_groups整除。均值和標準差分別在每組內計算。 和 是可學習的、大小為num_channels的逐通道仿射變換引數向量(如果affine為True)。方差透過有偏估計量計算,等同於 torch.var(input, unbiased=False)。此層在訓練和評估模式下都使用從輸入資料計算的統計量。
- 引數
- 形狀
輸入:,其中
輸出: (與輸入形狀相同)
示例
>>> input = torch.randn(20, 6, 10, 10) >>> # Separate 6 channels into 3 groups >>> m = nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input)