評價此頁

torch.Tensor.masked_scatter_#

Tensor.masked_scatter_(mask, source)#

source 中的元素複製到 self 張量中 mask 為 True 的位置。從 source 的位置 0 開始,按順序將 source 中的元素逐個複製到 self 中,每當 mask 為 True 時進行一次複製。 mask 的形狀必須 可以廣播 到底層張量的形狀。 source 的元素數量應至少等於 mask 中 True 的數量。

引數
  • mask (BoolTensor) – 布林掩碼

  • source (Tensor) – 要從中複製的張量

注意

mask 操作的是 self 張量,而不是給定的 source 張量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor(
...     [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
...     dtype=torch.bool,
... )
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])