torch.bmm#
- torch.bmm(input, mat2, out_dtype=None, *, out=None) Tensor#
對儲存在
input和mat2中的矩陣進行批處理矩陣-矩陣乘積。input和mat2必須是 3 維張量,每個張量包含相同數量的矩陣。如果
input是一個 張量,mat2是一個 張量,則out將是一個 張量。此運算子支援TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度。
注意
此函式不執行 廣播。有關廣播矩陣乘積,請參閱
torch.matmul()。- 引數
- 關鍵字引數
out (Tensor, optional) – 輸出張量。
示例
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])