評價此頁

torch.take#

torch.take(input, index) Tensor#

返回一個新張量,其中包含 input 張量在給定索引處的元素。輸入張量被視為 1-D 張量。結果的形狀與索引的形狀相同。

引數
  • input (Tensor) – 輸入張量。

  • index (LongTensor) – 張量的索引

示例

>>> src = torch.tensor([[4, 3, 5],
...                     [6, 7, 8]])
>>> torch.take(src, torch.tensor([0, 2, 5]))
tensor([ 4,  5,  8])