評價此頁

torch.functional.lu#

torch.functional.lu(*args, **kwargs)[source]#

計算矩陣或矩陣批次的 LU 分解 A。返回一個包含 A 的 LU 分解和透視(pivots)的元組。如果 pivot 設定為 True,則進行部分主元法(partial pivoting)。

警告

torch.lu() 已棄用,推薦使用 torch.linalg.lu_factor()torch.linalg.lu_factor_ex()torch.lu() 將在 PyTorch 的未來版本中移除。 LU, pivots, info = torch.lu(A, compute_pivots) 應替換為

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 應替換為

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 返回的批次中每個矩陣的置換矩陣(permutation matrix)表示為一個長度為 min(A.shape[-2], A.shape[-1]) 的 1-indexed 向量。 pivots[i] == j 表示在演算法的第 i 步中,第 i 行與第 j-1 行進行了置換。

  • pivot = False 的 LU 分解不適用於 CPU,嘗試這樣做將引發錯誤。但是,帶 pivot = False 的 LU 分解可用於 CUDA。

  • 如果 get_infos 設定為 True,此函式不會檢查分解是否成功,因為分解的狀態已包含在返回元組的第三個元素中。

  • 對於 CUDA 裝置上大小小於或等於 32 的方形矩陣批次,由於 MAGMA 庫中的一個 bug(請參見 magma issue 13),奇異矩陣的 LU 分解會被重複執行。

  • 可以使用 torch.lu_unpack() 推匯出 LUP

警告

該函式的梯度僅在 A 是滿秩矩陣時才是有限的。這是因為 LU 分解僅在滿秩矩陣上是可微的。此外,如果 A 接近於非滿秩矩陣,則梯度在數值上是不穩定的,因為它依賴於 L1L^{-1}U1U^{-1} 的計算。

引數
  • A (Tensor) – 要分解的張量,大小為 (,m,n)(*, m, n)

  • pivot (bool, optional) – 是否要計算帶部分主元法的 LU 分解,還是常規的 LU 分解。pivot= False 在 CPU 上不受支援。預設為 True

  • get_infos (bool, optional) – 如果設定為 True,則返回一個 IntTensor。預設為 False

  • out (tuple, optional) – 可選的輸出元組。如果 get_infosTrue,則元組中的元素為 Tensor、IntTensor 和 IntTensor。如果 get_infosFalse,則元組中的元素為 Tensor 和 IntTensor。預設為 None

返回

一個包含以下內容的張量元組:

  • factorization (Tensor): 分解結果,大小為 (,m,n)(*, m, n)

  • pivots (IntTensor): 透視(pivots)結果,大小為 (,min(m,n))(*, \text{min}(m, n))pivots 儲存了所有的中間行交換。最終的置換 perm 可以透過對初始的 mm 個元素的單位置換 perm 執行 swap(perm[i], perm[pivots[i] - 1]) 來重構(對於 i = 0, ..., pivots.size(-1) - 1),這本質上就是 torch.lu_unpack() 的作用。

  • infos (IntTensor, optional): 如果 get_infosTrue,則這是一個大小為 ()(*) 的張量,其中非零值表示矩陣或每個小批次的分解是否成功或失敗。

返回型別

(Tensor, IntTensor, IntTensor (optional))

示例

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!