評價此頁

torch.take_along_dim#

torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor#

input 中,在給定的 dim 維度上,按照 indices 中的一維索引選擇值。

如果 dim 為 None,則輸入陣列將被視為已展平成一維。

返回沿某個維度索引的函式,例如 torch.argmax()torch.argsort(),可以與此函式配合使用。請參閱下面的示例。

注意

此函式類似於 NumPy 的 take_along_axis。另請參閱 torch.gather()

引數
  • input (Tensor) – 輸入張量。

  • indices (LongTensor) – input 的索引。必須是 long 型別。

  • dim (int, optional) – 選擇的維度。預設值:0

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])