torch.sparse.sampled_addmm#
- torch.sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) Tensor#
在
input的稀疏模式指定的 位置執行密集矩陣mat1和mat2的矩陣乘法。矩陣input被加到最終結果中。在數學上,這執行了以下操作:
其中 是
input的稀疏模式矩陣,alpha和beta是縮放因子。 在input為非零值的對應位置上值為 1,其他位置上為 0。注意
input必須是稀疏 CSR 張量。mat1和mat2必須是密集張量。- 引數
- 關鍵字引數
beta (Number, optional) –
input的乘數()alpha (Number, optional) – ()的乘數
out (Tensor, optional) – 輸出張量。如果為 None 則忽略。預設為 None。
示例
>>> input = torch.eye(3, device='cuda').to_sparse_csr() >>> mat1 = torch.randn(3, 5, device='cuda') >>> mat2 = torch.randn(5, 3, device='cuda') >>> torch.sparse.sampled_addmm(input, mat1, mat2) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr) >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() tensor([[ 0.2847, 0.0000, 0.0000], [ 0.0000, -0.7805, 0.0000], [ 0.0000, 0.0000, -0.1900]], device='cuda:0') >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr)