torch.functional.lu#
- torch.functional.lu(*args, **kwargs)[source]#
計算矩陣或矩陣批次的 LU 分解
A。返回一個包含A的 LU 分解和透視(pivots)的元組。如果pivot設定為True,則進行部分主元法(partial pivoting)。警告
torch.lu()已棄用,推薦使用torch.linalg.lu_factor()和torch.linalg.lu_factor_ex()。torch.lu()將在 PyTorch 的未來版本中移除。LU, pivots, info = torch.lu(A, compute_pivots)應替換為LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)應替換為LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
注意
返回的批次中每個矩陣的置換矩陣(permutation matrix)表示為一個長度為
min(A.shape[-2], A.shape[-1])的 1-indexed 向量。pivots[i] == j表示在演算法的第i步中,第i行與第j-1行進行了置換。帶
pivot=False的 LU 分解不適用於 CPU,嘗試這樣做將引發錯誤。但是,帶pivot=False的 LU 分解可用於 CUDA。如果
get_infos設定為True,此函式不會檢查分解是否成功,因為分解的狀態已包含在返回元組的第三個元素中。對於 CUDA 裝置上大小小於或等於 32 的方形矩陣批次,由於 MAGMA 庫中的一個 bug(請參見 magma issue 13),奇異矩陣的 LU 分解會被重複執行。
可以使用
torch.lu_unpack()推匯出L、U和P。
警告
該函式的梯度僅在
A是滿秩矩陣時才是有限的。這是因為 LU 分解僅在滿秩矩陣上是可微的。此外,如果A接近於非滿秩矩陣,則梯度在數值上是不穩定的,因為它依賴於 和 的計算。- 引數
A (Tensor) – 要分解的張量,大小為
pivot (bool, optional) – 是否要計算帶部分主元法的 LU 分解,還是常規的 LU 分解。
pivot= False 在 CPU 上不受支援。預設為 True。get_infos (bool, optional) – 如果設定為
True,則返回一個 IntTensor。預設為Falseout (tuple, optional) – 可選的輸出元組。如果
get_infos為True,則元組中的元素為 Tensor、IntTensor 和 IntTensor。如果get_infos為False,則元組中的元素為 Tensor 和 IntTensor。預設為None
- 返回
一個包含以下內容的張量元組:
factorization (Tensor): 分解結果,大小為
pivots (IntTensor): 透視(pivots)結果,大小為 。
pivots儲存了所有的中間行交換。最終的置換perm可以透過對初始的 個元素的單位置換perm執行swap(perm[i], perm[pivots[i] - 1])來重構(對於i = 0, ..., pivots.size(-1) - 1),這本質上就是torch.lu_unpack()的作用。infos (IntTensor, optional): 如果
get_infos為True,則這是一個大小為 的張量,其中非零值表示矩陣或每個小批次的分解是否成功或失敗。
- 返回型別
(Tensor, IntTensor, IntTensor (optional))
示例
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.lu(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) >>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!