評價此頁

torch.triangular_solve#

torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)#

求解具有方陣(上三角或下三角可逆矩陣)AA 和多個右側項 bb 的方程組。

用符號表示,它求解 AX=bAX = b,並假設 AA 是方上三角(如果 upper= False 則為方下三角)並且對角線上沒有零。

torch.triangular_solve(b, A) 可以接受 2D 輸入 b, A 或批次 2D 矩陣的輸入。如果輸入是批次的,則返回批次的輸出 X

如果 A 的對角線包含零或接近零的元素,並且 unitriangular= False(預設值),或者如果輸入矩陣條件較差,結果可能包含 NaN

支援 float, double, cfloat 和 cdouble 資料型別的輸入。

警告

torch.triangular_solve() 已棄用,推薦使用 torch.linalg.solve_triangular(),並且將在未來的 PyTorch 版本中移除。torch.linalg.solve_triangular() 的引數順序已顛倒,並且不返回輸入之一的副本。

X = torch.triangular_solve(B, A).solution 應替換為

X = torch.linalg.solve_triangular(A, B)
引數
  • b (Tensor) – 多個右側項,大小為 (,m,k)(*, m, k),其中 * 是零個或多個批次維度

  • A (Tensor) – 輸入的三角係數矩陣,大小為 (,m,m)(*, m, m),其中 * 是零個或多個批次維度

  • upper (bool, optional) – AA 是上三角還是下三角。預設為 True

  • transpose (bool, optional) – 求解 op(A)X = b,其中當此標誌為 Trueop(A) = A^T,當此標誌為 Falseop(A) = A。預設為 False

  • unitriangular (bool, optional) – AA 是否為單位三角矩陣。如果為 True,則假定 AA 的對角線元素為 1 並且不從 AA 中引用。預設為 False

關鍵字引數

out ((Tensor, Tensor), optional) – 用於寫入輸出的兩個張量的元組。如果為 None 則忽略。預設為 None

返回

一個命名元組 (solution, cloned_coefficient),其中 cloned_coefficientAA 的克隆,而 solution 是方程 AX=bAX = b(或根據關鍵字引數的方程變體)的解 XX

示例

>>> A = torch.randn(2, 2).triu()
>>> A
tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]])
>>> b = torch.randn(2, 3)
>>> b
tensor([[-0.0210,  2.3513, -1.5492],
        [ 1.5429,  0.7403, -1.0243]])
>>> torch.triangular_solve(b, A)
torch.return_types.triangular_solve(
solution=tensor([[ 1.7841,  2.9046, -2.5405],
        [ 1.9320,  0.9270, -1.2826]]),
cloned_coefficient=tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]]))