評價此頁

torch.mm#

torch.mm(input, mat2, out_dtype=None, *, out=None) Tensor#

執行矩陣 inputmat2 的矩陣乘法。

如果 input 是一個 (n×m)(n \times m) 張量,mat2 是一個 (m×p)(m \times p) 張量,out 將是一個 (n×p)(n \times p) 張量。

注意

此函式不支援 廣播。有關廣播矩陣乘積,請參見 torch.matmul()

支援具有 stridedsparse 佈局的 2 維張量作為輸入,並支援與 strided 輸入進行自動求導。

此操作支援具有 sparse 佈局 的引數。如果提供了 out,則將使用其佈局。否則,結果佈局將根據 input 的佈局來推斷。

警告

稀疏支援是測試版功能,某些佈局/資料型別/裝置組合可能不支援,或可能不支援自動求導。如果您發現缺少功能,請提交功能請求。

此運算子支援TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度

引數
  • input (Tensor) – 第一個要進行矩陣相乘的矩陣

  • mat2 (Tensor) – 第二個要進行矩陣相乘的矩陣

  • out_dtype (dtype, optional) – 輸出張量的資料型別。僅在 CUDA 上支援,並且當輸入資料型別為 torch.float16/torch.bfloat16 時,支援 torch.float32。

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])