評價此頁

torch.lu#

torch.lu(*args, **kwargs)[原始碼]#

計算矩陣或矩陣批次的 LU 分解。返回一個包含 A 的 LU 分解和樞軸(pivots)的元組。當 pivot 設定為 True 時,將執行樞軸操作。

警告

torch.lu() 已棄用,推薦使用 torch.linalg.lu_factor()torch.linalg.lu_factor_ex()torch.lu() 將在未來的 PyTorch 版本中被移除。 LU, pivots, info = torch.lu(A, compute_pivots) 應替換為

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 應替換為

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 對於批次中的每個矩陣,返回的置換矩陣表示為一個大小為 min(A.shape[-2], A.shape[-1]) 的 1-索引向量。 pivots[i] == j 表示在演算法的第 i 步中,第 i 行與第 j-1 行進行了置換。

  • pivot = False 的 LU 分解在 CPU 上不可用,嘗試這樣做將丟擲錯誤。然而,帶 pivot = False 的 LU 分解對 CUDA 裝置可用。

  • 如果 get_infos 設定為 True,此函式不會檢查分解是否成功,因為分解的狀態存在於返回元組的第三個元素中。

  • 對於 CUDA 裝置上尺寸小於等於 32 的方矩陣批次,由於 MAGMA 庫中的錯誤(請參閱 magma issue 13),LU 分解會對奇異矩陣重複執行。

  • 可以使用 torch.lu_unpack() 推匯出 LUP

警告

此函式的梯度僅在 A 滿秩時才有限。這是因為 LU 分解僅在滿秩矩陣處可微。此外,如果 A 接近不滿秩,梯度將不穩定,因為它依賴於 L1L^{-1}U1U^{-1} 的計算。

引數
  • A (Tensor) – 要分解的張量,尺寸為 (,m,n)(*, m, n)

  • pivot (bool, optional) – 是否要計算帶部分主元法的 LU 分解,還是常規的 LU 分解。pivot= False 在 CPU 上不受支援。預設為 True

  • get_infos (bool, optional) – 如果設定為 True,則返回一個 info IntTensor。預設為 False

  • out (tuple, optional) – 可選的輸出元組。如果 get_infosTrue,則元組中的元素為 Tensor, IntTensor, 和 IntTensor。如果 get_infosFalse,則元組中的元素為 Tensor, IntTensor。預設為 None

返回

一個包含以下內容的張量元組:

  • factorization (Tensor): 分解後的張量,尺寸為 (,m,n)(*, m, n)

  • pivots (IntTensor): 樞軸(pivots)張量,尺寸為 (,min(m,n))(*, \text{min}(m, n))pivots 儲存了所有中間的行交換。最終的置換 perm 可以透過對 i = 0, ..., pivots.size(-1) - 1 應用 swap(perm[i], perm[pivots[i] - 1]) 來重建,其中 perm 最初是 mm 個元素的單位置換(這基本上就是 torch.lu_unpack() 所做的)。

  • infos (IntTensor, optional): 如果 get_infosTrue,這是一個尺寸為 ()(*) 的張量,其中非零值表示矩陣或每個小批次(minibatch)的分解是否成功。

返回型別

(Tensor, IntTensor, IntTensor (optional))

示例

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!