torch.linalg.solve#
- torch.linalg.solve(A, B, *, left=True, out=None) Tensor#
計算具有唯一解的方線性方程組的解。
設 為 或 ,該函式計算 的解,該解是與 associated的**線性方程組**,其定義如下:
如果
left= False,則該函式返回矩陣 ,用於求解以下方程組:當且僅當 是 可逆矩陣 時,該線性方程組才有一個解。本函式假設 是可逆的。
支援float、double、cfloat和cdouble資料型別的輸入。也支援矩陣的批次,如果輸入是矩陣的批次,則輸出具有相同的批次維度。
設 * 為零個或多個批處理維度,
如果
A的形狀為 (*, n, n),而B的形狀為 (*, n)(向量批)或形狀 (*, n, k)(矩陣批或“多個右側值”),則此函式返回形狀分別為 (*, n) 或 (*, n, k) 的 X。否則,如果
A的形狀為 (*, n, n),而B的形狀為 (n,) 或 (n, k),則B將被廣播(broadcast)以具有 (*, n) 或 (*, n, k) 的形狀。然後,此函式返回由由此產生的線性方程組批次組成的解。
注意
此函式以比單獨執行計算更快、更數值穩定的方式計算 X =
A.inverse() @B。注意
可以透過傳遞
A和B的轉置,並轉置此函式返回的輸出來計算 線性方程組的解。注意
A可以是未批處理的 torch.sparse_csr_tensor,但僅限於 left=True。注意
當輸入位於 CUDA 裝置上時,此函式會使該裝置與 CPU 同步。有關不進行同步的該函式版本,請參閱
torch.linalg.solve_ex()。另請參閱
torch.linalg.solve_triangular()計算具有唯一解的三角線性方程組的解。- 引數
- 關鍵字引數
- 引發
RuntimeError – 如果
A矩陣不可逆,或批處理A中的任何矩陣不可逆。
示例
>>> A = torch.randn(3, 3) >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) >>> torch.allclose(A @ x, b) True >>> A = torch.randn(2, 3, 3) >>> B = torch.randn(2, 3, 4) >>> X = torch.linalg.solve(A, B) >>> X.shape torch.Size([2, 3, 4]) >>> torch.allclose(A @ X, B) True >>> A = torch.randn(2, 3, 3) >>> b = torch.randn(3, 1) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) >>> x.shape torch.Size([2, 3, 1]) >>> torch.allclose(A @ x, b) True >>> b = torch.randn(3) >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) >>> x.shape torch.Size([2, 3]) >>> Ax = A @ x.unsqueeze(-1) >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) True