評價此頁

torch.nn.utils.parametrize.register_parametrization#

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[原始碼]#

將引數化註冊到一個模組的張量上。

為簡單起見,假設 tensor_name="weight"。當訪問 module.weight 時,模組將返回引數化版本 parametrization(module.weight)。如果原始張量需要梯度,則反向傳播將透過 parametrization 進行微分,最佳化器將相應地更新張量。

模組第一次註冊引數化時,此函式將向模組新增一個型別為 ParametrizationList 的屬性 parametrizations

張量 weight 上的引數化列表將可以在 module.parametrizations.weight 下訪問。

原始張量將可以在 module.parametrizations.weight.original 下訪問。

可以透過在同一屬性上註冊多個引數化來連線引數化。

註冊的引數化的訓練模式會進行更新,以匹配宿主模組的訓練模式。

引數化引數和緩衝區具有內建快取系統,可以使用上下文管理器 cached() 啟用。

引數化可以有一個可選的實現方法,簽名如下:

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

當註冊第一個引數化時,將呼叫此方法處理未引數化的張量,以計算原始張量的初始值。如果未實現此方法,則原始張量就是未引數化的張量。

如果一個張量上註冊的所有引數化都實現了 right_inverse,則可以透過賦值來初始化引數化張量,如下面的示例所示。

第一個引數化可以依賴於多個輸入。這可以透過從 right_inverse 返回一個張量元組來實現(參見下方 RankOne 引數化的示例實現)。

在這種情況下,無約束張量也位於 module.parametrizations.weight 下,名稱分別為 original0original1,等等。

注意

如果 unsafe=False(預設值),則 forward 和 right_inverse 方法都將被呼叫一次,以執行一系列一致性檢查。如果 unsafe=True,則將在張量未引數化時呼叫 right_inverse,否則將不呼叫任何方法。

注意

在大多數情況下,right_inverse 是一個函式,使得 forward(right_inverse(X)) == X(參見 右逆)。有時,當引數化不是滿射時,放寬此限制可能是合理的。

警告

如果一個引數化依賴於多個輸入,register_parametrization() 將會註冊一些新的引數。如果此類引數化在最佳化器建立後註冊,則需要手動將這些新引數新增到最佳化器中。請參閱 torch.Optimizer.add_param_group()

引數
  • module (nn.Module) – 要註冊引數化的模組

  • tensor_name (str) – 要註冊引數化的引數或緩衝區名稱

  • parametrization (nn.Module) – 要註冊的引數化

關鍵字引數

unsafe (bool) – 一個布林標誌,表示引數化是否可能更改張量的 dtype 和形狀。預設值:False 警告:註冊時未檢查引數化的一致性。請自行承擔啟用此標誌的風險。

引發

ValueError – 如果模組沒有名為 tensor_name 的引數或緩衝區

返回型別

模組

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T  # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>> # Form a rank 1 matrix multiplying two vectors
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>> # Project Z onto the rank 1 matrices
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>> # Return rescaled singular vectors
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(
...     nn.Linear(4, 4), "weight", RankOne()
... )
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1