torch.linalg.pinv#
- torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor#
計算矩陣的偽逆(摩爾-彭羅斯逆)。
偽逆可以 代數上定義,但透過 SVD 來理解它在計算上更方便。
支援浮點 (float)、雙精度浮點 (double)、複數浮點 (cfloat) 和複數雙精度浮點 (cdouble) 資料型別。還支援矩陣批處理,如果 `A` 是一個矩陣批處理,則輸出具有相同的批處理維度。
如果
hermitian= True,則假設A是厄米特(複數)或對稱(實數)的,但這不會在內部進行檢查。相反,在計算中僅使用矩陣的下三角部分。奇異值(或特徵值的範數,當
hermitian= True 時)低於 的閾值將被視為零並在計算中被忽略,其中 是最大的奇異值(或特徵值)。如果未指定
rtol,並且A是維度為 (m, n) 的矩陣,則相對容差設定為 ,其中 是A的 dtype 的 epsilon 值(參見finfo)。如果未指定rtol且atol被指定為大於零,則rtol將被設定為零。如果
atol或rtol是一個torch.Tensor,則其形狀必須可廣播到A的奇異值形狀,這些奇異值由torch.linalg.svd()返回。注意
如果
hermitian= False,此函式使用torch.linalg.svd();如果hermitian= True,則使用torch.linalg.eigh()。對於 CUDA 輸入,此函式會將該裝置與 CPU 同步。注意
如果可能,請考慮使用
torch.linalg.lstsq()將偽逆乘以左側,因為torch.linalg.lstsq(A, B).solution == A.pinv() @ B
如果可能,始終優先使用
lstsq(),因為它比顯式計算偽逆更快且數值更穩定。注意
此函式有一個與 NumPy 相容的版本 linalg.pinv(A, rcond, hermitian=False)。但是,使用位置引數
rcond已棄用,建議使用rtol。警告
此函式內部使用
torch.linalg.svd()(或在hermitian= True 時使用torch.linalg.eigh()),因此其導數與這些函式導數存在相同的問題。有關更多詳細資訊,請參閱torch.linalg.svd()和torch.linalg.eigh()中的警告。- 引數
- 關鍵字引數
示例
>>> A = torch.randn(3, 5) >>> A tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) >>> torch.linalg.pinv(A) tensor([[ 0.0600, -0.1933, -0.2090], [-0.0903, -0.0817, -0.4752], [-0.7124, -0.1631, -0.2272], [ 0.1356, 0.3933, -0.5023], [-0.0308, -0.1725, -0.5216]]) >>> A = torch.randn(2, 6, 3) >>> Apinv = torch.linalg.pinv(A) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(8.5633e-07) >>> A = torch.randn(3, 3, dtype=torch.complex64) >>> A = A + A.T.conj() # creates a Hermitian matrix >>> Apinv = torch.linalg.pinv(A, hermitian=True) >>> torch.dist(Apinv @ A, torch.eye(3)) tensor(1.0830e-06)