torch.Tensor.masked_scatter_#
- Tensor.masked_scatter_(mask, source)#
將
source中的元素複製到self張量中mask為 True 的位置。從source的位置 0 開始,按順序將source中的元素逐個複製到self中,每當mask為 True 時進行一次複製。mask的形狀必須 可以廣播 到底層張量的形狀。source的元素數量應至少等於mask中 True 的數量。- 引數
mask (BoolTensor) – 布林掩碼
source (Tensor) – 要從中複製的張量
注意
mask操作的是self張量,而不是給定的source張量。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor( ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], ... dtype=torch.bool, ... ) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])