torch.Tensor.scatter_add_#
- Tensor.scatter_add_(dim, index, src) Tensor#
類似於
scatter_(),將src張量中的所有值加到index張量指定的self的索引處。對於src中的每個值,它被加到self的一個索引中,該索引透過其在src中的索引(對於dimension != dim)和index中的相應值(對於dimension = dim)來指定。對於一個 3-D 張量,
self的更新方式為self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
self,index和src必須具有相同的維度數。此外,對於所有維度d,要求index.size(d) <= src.size(d),並且對於所有維度d != dim,要求index.size(d) <= self.size(d)。請注意,index和src不會廣播。當index為空時,我們始終返回原始張量,而不進行進一步的錯誤檢查。注意
當在 CUDA 裝置上使用張量時,此操作可能行為不確定。有關更多資訊,請參閱 隨機性。
注意
反向傳播僅對
src.shape == index.shape進行了實現。- 引數
示例
>>> src = torch.ones((2, 5)) >>> index = torch.tensor([[0, 1, 2, 0, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[1., 0., 0., 1., 1.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[2., 0., 0., 1., 1.], [0., 2., 0., 0., 0.], [0., 0., 2., 1., 1.]])