評價此頁

torch.nn.utils.parametrizations.spectral_norm#

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[原始碼]#

對給定模組中的引數應用譜歸一化。

WSN=Wσ(W),σ(W)=maxh:h0Wh2h2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

當應用於向量時,它簡化為

xSN=xx2\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}

譜歸一化透過降低模型的Lipschitz常數來穩定生成對抗網路(GAN)中判別器(critic)的訓練。σ\sigma是利用冪法進行一次迭代來近似的,每次訪問權重時都會進行。如果權重張量的維度大於2,則在冪法中會將其重塑為2D以獲得譜範數。

請參閱 Spectral Normalization for Generative Adversarial Networks

注意

此函式使用 register_parametrization() 中的引數化功能來實現。它是 torch.nn.utils.spectral_norm() 的重新實現。

注意

當註冊此約束時,將估計與最大奇異值相關的奇異向量,而不是隨機取樣。然後,當模組在訓練模式下訪問張量時,會透過進行 n_power_iterations冪法 來更新它們。

注意

如果 _SpectralNorm 模組,即 module.parametrization.weight[idx],在移除時處於訓練模式,它將執行另一次冪迭代。如果您想避免此迭代,請在移除模組之前將其設定為評估模式。

引數
  • module (nn.Module) – 包含的模組

  • name (str, optional) – 權重引數的名稱。預設為 "weight"

  • n_power_iterations (int, optional) – 計算譜範數的冪迭代次數。預設為 1

  • eps (float, optional) – 計算範數時用於數值穩定性的 epsilon。預設為 1e-12

  • dim (int, optional) – 對應於輸出數量的維度。預設為 0,但對於 ConvTranspose{1,2,3}d 型別的模組,它為 1

返回

註冊了新引數化的原始模組,該引數化已應用於指定的權重

返回型別

模組

示例

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)