評價此頁

torch.nn.functional.one_hot#

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor#

接收形狀為 (*) 的索引值的 LongTensor,並返回形狀為 (*, num_classes) 的張量。該張量在最後一個維度上,除了索引值對應的位置為 1 外,其餘位置均為 0。

另請參閱 Wikipedia 上的獨熱編碼

引數
  • tensor (LongTensor) – 任意形狀的類別值。

  • num_classes (int, optional) – 總類別數。如果設定為 -1,則類別數將根據輸入張量中最大的類別值推斷(最大類別值 + 1)。預設為 -1

返回

返回一個 LongTensor,它比輸入張量多一個維度,在該維度上,輸入張量指示的索引位置為 1,其餘位置為 0。

示例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])