評價此頁

torch.nn.utils.parametrizations.orthogonal#

torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[source]#

將矩陣或矩陣批次應用正交或酉引數化。

K\mathbb{K}R\mathbb{R}C\mathbb{C},引數化矩陣 QKm×nQ \in \mathbb{K}^{m \times n} 是**正交**的,具體定義如下:

QHQ=Inif mnQQH=Imif m<n\begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*}

其中 QHQ^{\text{H}}QQ 為複數時的共軛轉置,為實數時的轉置,而 In\mathrm{I}_nn 維單位矩陣。通俗地說,當 mnm \geq n 時,QQ 的列是正交的,否則行是正交的。

如果張量有多個維度,我們將其視為形狀為 (…, m, n) 的矩陣批次。

矩陣 QQ 可以透過三種不同的 orthogonal_map 來引數化,這些對映作用於原始張量:

  • "matrix_exp"/"cayley"matrix_exp() Q=exp(A)Q = \exp(A)Cayley 變換 Q=(In+A/2)(InA/2)1Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1} 作用於一個斜對稱矩陣 AA 以得到一個正交矩陣。

  • "householder":計算 Householder 反射的乘積(householder_product())。

"matrix_exp"/"cayley" 通常比 "householder" 能更快地使引數化權重收斂,但對於非常瘦或非常寬的矩陣計算速度較慢。

如果 use_trivialization=True(預設值),則引數化實現“動態平凡化框架”,其中一個額外的矩陣 BKn×nB \in \mathbb{K}^{n \times n} 儲存在 module.parametrizations.weight[0].base 下。這有助於引數化層的收斂,但會消耗一些額外的記憶體。請參閱 Trivializations for Gradient-Based Optimization on Manifolds

QQ 的初始值:如果原始張量未引數化且 use_trivialization=True(預設值),則 QQ 的初始值是原始張量本身(如果它已經是正交的,或者在複數情況下是酉的),否則透過 QR 分解進行正交化(參見 torch.linalg.qr())。如果它未引數化且 orthogonal_map="householder",即使 use_trivialization=False,情況也相同。否則,初始值是應用於原始張量的所有已註冊引數化的組合結果。

注意

此函式使用 register_parametrization() 中的引數化功能實現。

引數
  • module (nn.Module) – 要註冊引數化的模組。

  • name (str, optional) – 要使其正交的張量名稱。預設為 "weight"

  • orthogonal_map (str, optional) – 以下之一:“matrix_exp”、“cayley” 或 “householder”。預設為,當矩陣為方陣或複數時為 "matrix_exp",否則為 "householder"

  • use_trivialization (bool, optional) – 是否使用動態平凡化框架。預設為 True

返回

具有已向指定權重註冊正交引數化的原始模組

返回型別

模組

示例

>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
    (weight): ParametrizationList(
    (0): _Orthogonal()
    )
)
)
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)