torch.nn.utils.parametrize.register_parametrization#
- torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[原始碼]#
將引數化註冊到一個模組的張量上。
為簡單起見,假設
tensor_name="weight"。當訪問module.weight時,模組將返回引數化版本parametrization(module.weight)。如果原始張量需要梯度,則反向傳播將透過parametrization進行微分,最佳化器將相應地更新張量。模組第一次註冊引數化時,此函式將向模組新增一個型別為
ParametrizationList的屬性parametrizations。張量
weight上的引數化列表將可以在module.parametrizations.weight下訪問。原始張量將可以在
module.parametrizations.weight.original下訪問。可以透過在同一屬性上註冊多個引數化來連線引數化。
註冊的引數化的訓練模式會進行更新,以匹配宿主模組的訓練模式。
引數化引數和緩衝區具有內建快取系統,可以使用上下文管理器
cached()啟用。引數化可以有一個可選的實現方法,簽名如下:
def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
當註冊第一個引數化時,將呼叫此方法處理未引數化的張量,以計算原始張量的初始值。如果未實現此方法,則原始張量就是未引數化的張量。
如果一個張量上註冊的所有引數化都實現了 right_inverse,則可以透過賦值來初始化引數化張量,如下面的示例所示。
第一個引數化可以依賴於多個輸入。這可以透過從
right_inverse返回一個張量元組來實現(參見下方RankOne引數化的示例實現)。在這種情況下,無約束張量也位於
module.parametrizations.weight下,名稱分別為original0、original1,等等。注意
如果 unsafe=False(預設值),則 forward 和 right_inverse 方法都將被呼叫一次,以執行一系列一致性檢查。如果 unsafe=True,則將在張量未引數化時呼叫 right_inverse,否則將不呼叫任何方法。
注意
在大多數情況下,
right_inverse是一個函式,使得forward(right_inverse(X)) == X(參見 右逆)。有時,當引數化不是滿射時,放寬此限制可能是合理的。警告
如果一個引數化依賴於多個輸入,
register_parametrization()將會註冊一些新的引數。如果此類引數化在最佳化器建立後註冊,則需要手動將這些新引數新增到最佳化器中。請參閱torch.Optimizer.add_param_group()。- 引數
- 關鍵字引數
unsafe (bool) – 一個布林標誌,表示引數化是否可能更改張量的 dtype 和形狀。預設值:False 警告:註冊時未檢查引數化的一致性。請自行承擔啟用此標誌的風險。
- 引發
ValueError – 如果模組沒有名為
tensor_name的引數或緩衝區- 返回型別
示例
>>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True
>>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization( ... nn.Linear(4, 4), "weight", RankOne() ... ) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1