torch.nn.utils.parametrizations.spectral_norm#
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[原始碼]#
對給定模組中的引數應用譜歸一化。
當應用於向量時,它簡化為
譜歸一化透過降低模型的Lipschitz常數來穩定生成對抗網路(GAN)中判別器(critic)的訓練。是利用冪法進行一次迭代來近似的,每次訪問權重時都會進行。如果權重張量的維度大於2,則在冪法中會將其重塑為2D以獲得譜範數。
請參閱 Spectral Normalization for Generative Adversarial Networks。
注意
此函式使用
register_parametrization()中的引數化功能來實現。它是torch.nn.utils.spectral_norm()的重新實現。注意
當註冊此約束時,將估計與最大奇異值相關的奇異向量,而不是隨機取樣。然後,當模組在訓練模式下訪問張量時,會透過進行
n_power_iterations次 冪法 來更新它們。注意
如果 _SpectralNorm 模組,即 module.parametrization.weight[idx],在移除時處於訓練模式,它將執行另一次冪迭代。如果您想避免此迭代,請在移除模組之前將其設定為評估模式。
- 引數
- 返回
註冊了新引數化的原始模組,該引數化已應用於指定的權重
- 返回型別
示例
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)