評價此頁

torch.kthvalue#

torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)#

返回一個命名元組 (values, indices),其中 valuesinput 張量在指定維度 dim 上每行的第 k 小的元素。而 indices 是找到的每個元素的索引位置。

如果未指定 `dim`,則選擇 `input` 的最後一個維度。

如果 keepdimTrue,則 valuesindices 張量的大小與 input 相同,只是在 dim 維度的大小為 1。否則,dim 會被壓縮(參見 torch.squeeze()),導致 valuesindices 張量比 input 張量少一個維度。

注意

input 是 CUDA 張量且存在多個有效的第 k 小值時,此函式可能會不確定地返回其中任何一個的 indices

引數
  • input (Tensor) – 輸入張量。

  • k (int) – 第 k 小元素中的 k 值

  • dim (int, optional) – 查詢第 k 小值所在的維度

  • keepdim (bool, optional) – 輸出張量是否保留 dim。預設為 False

關鍵字引數

out (tuple, optional) – 可選地提供用於作為輸出緩衝區的 (Tensor, LongTensor) 輸出元組

示例

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.kthvalue(x, 4)
torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3))

>>> x=torch.arange(1.,7.).resize_(2,3)
>>> x
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.]])
>>> torch.kthvalue(x, 2, 0, True)
torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))