torch.nn.utils.init.skip_init#
- torch.nn.utils.init.skip_init(module_cls, *args, **kwargs)[原始碼]#
給定一個模組類物件和引數/關鍵字引數,在不初始化引數/緩衝區的情況下例項化模組。
當初始化過程緩慢或者將執行自定義初始化時,此功能很有用,可以使預設初始化變得不必要。由於此函式的實現方式,存在一些注意事項:
1. 模組在其建構函式中必須接受一個 device 引數,該引數將被傳遞給建構函式中建立的任何引數或緩衝區。
2. 模組在其建構函式中不得對引數執行除初始化(例如來自
torch.nn.init的函式)之外的任何計算。如果滿足這些條件,則可以例項化模組,使其引數/緩衝區值未被初始化,如同使用
torch.empty()建立的一樣。- 引數
module_cls – 類物件;應為
torch.nn.Module的子類args – 傳遞給模組建構函式的引數
kwargs – 傳遞給模組建構函式的關鍵字引數
- 返回
例項化後的模組,具有未初始化的引數/緩衝區
示例
>>> import torch >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) >>> m.weight Parameter containing: tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], requires_grad=True) >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) >>> m2.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True)