torch.linalg.matrix_power#
- torch.linalg.matrix_power(A, n, *, out=None) Tensor#
計算一個整數 n 的方陣的 n 次冪。
支援浮點 (float)、雙精度浮點 (double)、複數浮點 (cfloat) 和複數雙精度浮點 (cdouble) 資料型別。還支援矩陣批處理,如果 `A` 是一個矩陣批處理,則輸出具有相同的批處理維度。
如果
n= 0,它將返回與A具有相同形狀的單位矩陣(或批次)。如果n為負數,它將返回每個矩陣的逆(如果可逆)的 abs(n) 次冪。注意
如果可能,請考慮使用
torch.linalg.solve()來將一個矩陣乘以負冪,因為這比顯式計算 更好。torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n) @ B
當可能時,總是優先使用
solve(),因為它比顯式計算 更快且數值更穩定。另請參閱
torch.linalg.solve()使用數值穩定的演算法計算A.inverse() @B。- 引數
- 關鍵字引數
out (Tensor, optional) – 輸出張量。如果為 None 則忽略。預設為 None。
- 引發
RuntimeError – 如果
n< 0 且矩陣A或A的矩陣批次中的任何矩陣不可逆。
示例
>>> A = torch.randn(3, 3) >>> torch.linalg.matrix_power(A, 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> torch.linalg.matrix_power(A, 3) tensor([[ 1.0756, 0.4980, 0.0100], [-1.6617, 1.4994, -1.9980], [-0.4509, 0.2731, 0.8001]]) >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) tensor([[[ 0.2640, 0.4571, -0.5511], [-1.0163, 0.3491, -1.5292], [-0.4899, 0.0822, 0.2773]], [[ 0.2640, 0.4571, -0.5511], [-1.0163, 0.3491, -1.5292], [-0.4899, 0.0822, 0.2773]]])