torch.argwhere#
- torch.argwhere(input) Tensor#
返回一個包含
input中所有非零元素的索引的張量。結果中的每一行包含input中一個非零元素的索引。結果按字典序排序,最後一個索引變化最快(C 風格)。如果
input有 個維度,那麼得到的索引張量out的大小為 ,其中 是input張量中非零元素的總數。注意
此函式類似於 NumPy 的 argwhere。
當
input位於 CUDA 上時,此函式會導致主機-裝置同步。- 引數
{input} –
示例
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])