評價此頁

torch.Tensor.scatter_add_#

Tensor.scatter_add_(dim, index, src) Tensor#

類似於 scatter_(),將 src 張量中的所有值加到 index 張量指定的 self 的索引處。對於 src 中的每個值,它被加到 self 的一個索引中,該索引透過其在 src 中的索引(對於 dimension != dim)和 index 中的相應值(對於 dimension = dim)來指定。

對於一個 3-D 張量,self 的更新方式為

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

self, indexsrc 必須具有相同的維度數。此外,對於所有維度 d,要求 index.size(d) <= src.size(d),並且對於所有維度 d != dim,要求 index.size(d) <= self.size(d)。請注意,indexsrc 不會廣播。當 index 為空時,我們始終返回原始張量,而不進行進一步的錯誤檢查。

注意

當在 CUDA 裝置上使用張量時,此操作可能行為不確定。有關更多資訊,請參閱 隨機性

注意

反向傳播僅對 src.shape == index.shape 進行了實現。

引數
  • dim (int) – 索引的軸

  • index (LongTensor) – 要進行散播和相加的元素的索引,可以是空的,也可以與 src 具有相同的維度。當為空時,操作將保持 self 不變。

  • src (Tensor) – 要進行散播和相加的源元素

示例

>>> src = torch.ones((2, 5))
>>> index = torch.tensor([[0, 1, 2, 0, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
>>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])