torch.nn.utils.parametrizations.orthogonal#
- torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[source]#
將矩陣或矩陣批次應用正交或酉引數化。
令 為 或 ,引數化矩陣 是**正交**的,具體定義如下:
其中 是 為複數時的共軛轉置,為實數時的轉置,而 是 n 維單位矩陣。通俗地說,當 時, 的列是正交的,否則行是正交的。
如果張量有多個維度,我們將其視為形狀為 (…, m, n) 的矩陣批次。
矩陣 可以透過三種不同的
orthogonal_map來引數化,這些對映作用於原始張量:"matrix_exp"/"cayley":matrix_exp()和 Cayley 變換 作用於一個斜對稱矩陣 以得到一個正交矩陣。"householder":計算 Householder 反射的乘積(householder_product())。
"matrix_exp"/"cayley"通常比"householder"能更快地使引數化權重收斂,但對於非常瘦或非常寬的矩陣計算速度較慢。如果
use_trivialization=True(預設值),則引數化實現“動態平凡化框架”,其中一個額外的矩陣 儲存在module.parametrizations.weight[0].base下。這有助於引數化層的收斂,但會消耗一些額外的記憶體。請參閱 Trivializations for Gradient-Based Optimization on Manifolds。的初始值:如果原始張量未引數化且
use_trivialization=True(預設值),則 的初始值是原始張量本身(如果它已經是正交的,或者在複數情況下是酉的),否則透過 QR 分解進行正交化(參見torch.linalg.qr())。如果它未引數化且orthogonal_map="householder",即使use_trivialization=False,情況也相同。否則,初始值是應用於原始張量的所有已註冊引數化的組合結果。注意
此函式使用
register_parametrization()中的引數化功能實現。- 引數
- 返回
具有已向指定權重註冊正交引數化的原始模組
- 返回型別
示例
>>> orth_linear = orthogonal(nn.Linear(20, 40)) >>> orth_linear ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _Orthogonal() ) ) ) >>> Q = orth_linear.weight >>> torch.dist(Q.T @ Q, torch.eye(20)) tensor(4.9332e-07)