torch.baddbmm#
- torch.baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) Tensor#
執行
batch1和batch2中矩陣的批次矩陣乘法。input被加到最終結果中。batch1和batch2必須是 3 維張量,每個張量包含相同數量的矩陣。如果
batch1是一個 張量,batch2是一個 張量,那麼input必須與一個 張量 可廣播,並且out將是一個 張量。alpha和beta的含義與torch.addbmm()中使用的縮放因子相同。如果
beta為 0,則input的內容將被忽略,並且其中的 nan 和 inf 不會被傳播。對於 FloatTensor 或 DoubleTensor 型別的輸入,引數
beta和alpha必須是實數,否則它們應該是整數。此運算子支援TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對反向傳播使用不同精度。
- 引數
- 關鍵字引數
beta (Number, optional) –
input的乘數()alpha (Number, optional) – 的乘數()
out (Tensor, optional) – 輸出張量。
示例
>>> M = torch.randn(10, 3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])