評價此頁

torch.matmul#

torch.matmul(input, other, *, out=None) Tensor#

兩個張量的矩陣乘積。

行為取決於張量的維度,如下所示:

  • 如果兩個張量都是一維的,則返回點積(標量)。

  • 如果兩個引數都是二維的,則返回矩陣-矩陣乘積。

  • 如果第一個引數是一維的,第二個引數是二維的,則在進行矩陣乘法時,其維度前面會加上一個 1。矩陣乘法完成後,會移除新增的維度。

  • 如果第一個引數是二維的,第二個引數是一維的,則返回矩陣-向量乘積。

  • 如果兩個引數至少為一維,並且至少有一個引數是 N 維(N > 2),則返回批處理矩陣乘法。如果第一個引數是一維的,則在進行批處理矩陣乘法時,其維度前面會加上一個 1,並在之後移除。如果第二個引數是一維的,則在進行批處理矩陣乘法時,其維度後面會加上一個 1,並在之後移除。

    每個引數的前 N-2 個維度(批處理維度)將進行廣播(因此必須是可廣播的)。最後的 2 個維度(矩陣維度)按照矩陣-矩陣乘積的方式處理。

    例如,如果 input 是一個 (j×1×n×m)(j \times 1 \times n \times m) 張量,而 other 是一個 (k×m×p)(k \times m \times p) 張量,則批處理維度為 (j×1)(j \times 1)(k)(k),矩陣維度為 (n×m)(n \times m)(m×p)(m \times p)out 將是一個 (j×k×n×p)(j \times k \times n \times p) 張量。

此操作支援具有稀疏佈局的引數。特別是,矩陣-矩陣(兩個引數都是二維的)支援稀疏引數,其限制與torch.mm()相同。

警告

稀疏支援是測試版功能,某些佈局/資料型別/裝置組合可能不支援,或可能不支援自動求導。如果您發現缺少功能,請提交功能請求。

此運算子支援TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度

注意

此函式的 out 引數的一維點積版本不支援。

引數
  • input (Tensor) – 要相乘的第一個張量

  • other (Tensor) – 要相乘的第二個張量

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])