評價此頁

torch.bmm#

torch.bmm(input, mat2, out_dtype=None, *, out=None) Tensor#

對儲存在 inputmat2 中的矩陣進行批處理矩陣-矩陣乘積。

inputmat2 必須是 3 維張量,每個張量包含相同數量的矩陣。

如果 input 是一個 (b×n×m)(b \times n \times m) 張量,mat2 是一個 (b×m×p)(b \times m \times p) 張量,則 out 將是一個 (b×n×p)(b \times n \times p) 張量。

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

此運算子支援TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度

注意

此函式不執行 廣播。有關廣播矩陣乘積,請參閱 torch.matmul()

引數
  • input (Tensor) – 要相乘的第一個矩陣批次

  • mat2 (Tensor) – 要相乘的第二個矩陣批次

  • out_dtype (dtype, optional) – 輸出張量的資料型別。僅在 CUDA 上支援,並且當輸入資料型別為 torch.float16/torch.bfloat16 時,支援 torch.float32。

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])