torch.argmax#
- torch.argmax(input) LongTensor#
返回
input張量中所有元素最大值的索引。這是
torch.max()返回的第二個值。有關此方法的精確語義,請參閱其文件。注意
如果有多個最大值,則返回第一個最大值的索引。
- 引數
input (Tensor) – 輸入張量。
示例
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a) tensor(0)
- torch.argmax(input, dim, keepdim=False) LongTensor
返回張量在一維上的最大值索引。
這是
torch.max()返回的第二個值。有關此方法的精確語義,請參閱其文件。- 引數
input (Tensor) – 輸入張量。
dim – 要約的維度。
示例
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])