評價此頁

torch.topk#

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)#

返回給定 input 張量在給定維度上最大的 k 個元素。

如果未指定 `dim`,則選擇 `input` 的最後一個維度。

如果 largestFalse,則返回最小的 k 個元素。

返回一個包含 valuesindices 的命名元組,其中 valuesinput 張量在給定維度 dim 的每行的最大的 k 個元素的值,indices 是這些元素在原始張量中的索引。

布林選項 sorted 如果為 True,將確保返回的 k 個元素本身是有序的。

注意

在使用 torch.topk 時,相等元素的索引不保證是穩定的,並且在不同的呼叫之間可能會有所不同。

引數
  • input (Tensor) – 輸入張量。

  • k (int) – “top-k”中的 k

  • dim (int, optional) – 要沿其排序的維度

  • largest (bool, optional) – 控制返回最大值還是最小值元素

  • sorted (bool, optional) – 控制是否按排序順序返回元素

關鍵字引數

out (tuple, optional) – 可選的 (Tensor, LongTensor) 輸出元組,用作輸出緩衝區

示例

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))