torch.mm#
- torch.mm(input, mat2, out_dtype=None, *, out=None) Tensor#
執行矩陣
input和mat2的矩陣乘法。如果
input是一個 張量,mat2是一個 張量,out將是一個 張量。注意
此函式不支援 廣播。有關廣播矩陣乘積,請參見
torch.matmul()。支援具有 strided 和 sparse 佈局的 2 維張量作為輸入,並支援與 strided 輸入進行自動求導。
此操作支援具有 sparse 佈局 的引數。如果提供了
out,則將使用其佈局。否則,結果佈局將根據input的佈局來推斷。警告
稀疏支援是測試版功能,某些佈局/資料型別/裝置組合可能不支援,或可能不支援自動求導。如果您發現缺少功能,請提交功能請求。
此運算子支援TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度。
- 引數
- 關鍵字引數
out (Tensor, optional) – 輸出張量。
示例
>>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]])