評價此頁

torch.argwhere#

torch.argwhere(input) Tensor#

返回一個包含 input 中所有非零元素的索引的張量。結果中的每一行包含 input 中一個非零元素的索引。結果按字典序排序,最後一個索引變化最快(C 風格)。

如果 inputnn 個維度,那麼得到的索引張量 out 的大小為 (z×n)(z \times n),其中 zzinput 張量中非零元素的總數。

注意

此函式類似於 NumPy 的 argwhere

input 位於 CUDA 上時,此函式會導致主機-裝置同步。

引數

{input}

示例

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])