torch.linalg.multi_dot#
- torch.linalg.multi_dot(tensors, *, out=None)#
高效地透過重新排序乘法來計算兩個或多個矩陣的乘積,以執行最少的算術運算。
支援 float, double, cfloat 和 cdouble 資料型別的輸入。此函式不支援批處理輸入。
在
tensors中的每個張量都必須是 2D 的,除了第一個和最後一個張量,它們可以是 1D 的。如果第一個張量是形狀為 (n,) 的 1D 向量,則將其視為形狀為 (1, n) 的行向量;類似地,如果最後一個張量是形狀為 (n,) 的 1D 向量,則將其視為形狀為 (n, 1) 的列向量。如果第一個和最後一個張量是矩陣,則輸出將是矩陣。但是,如果其中一個是 1D 向量,則輸出將是 1D 向量。
與 numpy.linalg.multi_dot 的區別
與 numpy.linalg.multi_dot 不同,第一個和最後一個張量必須是 1D 或 2D,而 NumPy 允許它們是 nD。
警告
此函式不執行廣播。
注意
此函式透過在計算最佳矩陣乘法順序後鏈式呼叫
torch.mm()來實現。注意
形狀為 (a, b) 和 (b, c) 的兩個矩陣相乘的成本為 a * b * c。給定形狀分別為 (10, 100)、(100, 5) 和 (5, 50) 的矩陣 A、B、C,我們可以計算不同乘法順序的成本如下:
在這種情況下,先計算 A 和 B 的乘積,然後再乘以 C 的速度是後者的 10 倍。
- 引數
tensors (Sequence[Tensor]) – 要相乘的兩個或多個張量。第一個和最後一個張量可以是 1D 或 2D。所有其他張量都必須是 2D。
- 關鍵字引數
out (Tensor, optional) – 輸出張量。如果為 None 則忽略。預設為 None。
示例
>>> from torch.linalg import multi_dot >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) tensor(8) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) tensor([8]) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) tensor([[8]]) >>> A = torch.arange(2 * 3).view(2, 3) >>> B = torch.arange(3 * 2).view(3, 2) >>> C = torch.arange(2 * 2).view(2, 2) >>> multi_dot((A, B, C)) tensor([[ 26, 49], [ 80, 148]])