評價此頁

torch.linalg.matrix_power#

torch.linalg.matrix_power(A, n, *, out=None) Tensor#

計算一個整數 n 的方陣的 n 次冪。

支援浮點 (float)、雙精度浮點 (double)、複數浮點 (cfloat) 和複數雙精度浮點 (cdouble) 資料型別。還支援矩陣批處理,如果 `A` 是一個矩陣批處理,則輸出具有相同的批處理維度。

如果 n= 0,它將返回與 A 具有相同形狀的單位矩陣(或批次)。如果 n 為負數,它將返回每個矩陣的逆(如果可逆)的 abs(n) 次冪。

注意

如果可能,請考慮使用 torch.linalg.solve() 來將一個矩陣乘以負冪,因為這比顯式計算 AnA^{-n} 更好。

torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n)  @ B

當可能時,總是優先使用 solve(),因為它比顯式計算 AnA^{-n} 更快且數值更穩定。

另請參閱

torch.linalg.solve() 使用數值穩定的演算法計算 A.inverse() @ B

引數
  • A (Tensor) – 形狀為 (*, m, m) 的張量,其中 * 表示零個或多個批次維度。

  • n (int) – 指數。

關鍵字引數

out (Tensor, optional) – 輸出張量。如果為 None 則忽略。預設為 None

引發

RuntimeError – 如果 n< 0 且矩陣 AA 的矩陣批次中的任何矩陣不可逆。

示例

>>> A = torch.randn(3, 3)
>>> torch.linalg.matrix_power(A, 0)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> torch.linalg.matrix_power(A, 3)
tensor([[ 1.0756,  0.4980,  0.0100],
        [-1.6617,  1.4994, -1.9980],
        [-0.4509,  0.2731,  0.8001]])
>>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2)
tensor([[[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]],
        [[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]]])