LazyModuleMixin#
- class torch.nn.modules.lazy.LazyModuleMixin(*args, **kwargs)[原始碼]#
一種用於延遲初始化引數的模組的混合類,也稱為“延遲模組”。
延遲初始化引數的模組,即“延遲模組”,從其 forward 方法的第一個輸入推斷其引數的形狀。在第一次 forward 之前,它們包含不應被訪問或使用的
torch.nn.UninitializedParameter,之後它們包含常規的torch.nn.Parameter。延遲模組很方便,因為它們不需要計算某些模組引數,例如典型torch.nn.Linear的in_features引數。構造完成後,帶有延遲模組的網路應首先轉換為所需的 dtype 並放置在預期的裝置上。這是因為延遲模組僅執行形狀推斷,因此通常的 dtype 和裝置放置行為適用。然後,延遲模組應執行“試執行”以初始化模組中的所有元件。這些“試執行”會將正確大小、dtype 和裝置的輸入透過網路傳送到其每個延遲模組。之後,網路就可以像平常一樣使用。
>>> class LazyMLP(torch.nn.Module): ... def __init__(self) -> None: ... super().__init__() ... self.fc1 = torch.nn.LazyLinear(10) ... self.relu1 = torch.nn.ReLU() ... self.fc2 = torch.nn.LazyLinear(1) ... self.relu2 = torch.nn.ReLU() ... ... def forward(self, input): ... x = self.relu1(self.fc1(input)) ... y = self.relu2(self.fc2(x)) ... return y >>> # constructs a network with lazy modules >>> lazy_mlp = LazyMLP() >>> # transforms the network's device and dtype >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' >>> lazy_mlp = lazy_mlp.cuda() >>> lazy_mlp LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) (relu1): ReLU() (fc2): LazyLinear(in_features=0, out_features=1, bias=True) (relu2): ReLU() ) >>> # performs a dry run to initialize the network's lazy modules >>> lazy_mlp(torch.ones(10, 10).cuda()) >>> # after initialization, LazyLinear modules become regular Linear modules >>> lazy_mlp LazyMLP( (fc1): Linear(in_features=10, out_features=10, bias=True) (relu1): ReLU() (fc2): Linear(in_features=10, out_features=1, bias=True) (relu2): ReLU() ) >>> # attaches an optimizer, since parameters can now be used as usual >>> optim = torch.optim.SGD(lazy_mlp.parameters(), lr=0.01)
使用延遲模組時的最後一點注意事項是,網路引數的初始化順序可能會發生變化,因為延遲模組總是在其他模組之後初始化。例如,如果上面定義的 LazyMLP 類首先有一個
torch.nn.LazyLinear模組,然後是第二個常規torch.nn.Linear模組,那麼第二個模組將在構造時初始化,而第一個模組將在第一次試執行時初始化。這可能會導致使用延遲模組的網路引數的初始化方式與不使用延遲模組的網路的引數不同,因為引數初始化的順序(通常取決於有狀態的隨機數生成器)是不同的。有關更多詳細資訊,請參閱 可復現性。延遲模組可以像其他模組一樣使用 state dict 進行序列化。例如
>>> lazy_mlp = LazyMLP() >>> # The state dict shows the uninitialized parameters >>> lazy_mlp.state_dict() OrderedDict({'fc1.weight': <UninitializedParameter>, 'fc1.bias': <UninitializedParameter>, 'fc2.weight': <UninitializedParameter>, 'fc2.bias': <UninitializedParameter>})
延遲模組可以載入常規的
torch.nn.Parameter(即,您可以序列化/反序列化已初始化的 LazyModule,它們將保持已初始化狀態)>>> full_mlp = LazyMLP() >>> # Dry run to initialize another module >>> full_mlp.forward(torch.ones(10, 1)) >>> # Load an initialized state into a lazy module >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) >>> # The state dict now holds valid values >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', tensor([[-0.3837], [ 0.0907], [ 0.6708], [-0.5223], [-0.9028], [ 0.2851], [-0.4537], [ 0.6813], [ 0.5766], [-0.8678]])), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, 0.2479, 0.1091]])), ('fc2.bias', tensor([0.0019]))])
但是請注意,如果在載入狀態時已初始化引數,則在執行“試執行”時不會替換載入的引數。這可以防止在不同上下文中重用已初始化的模組。