torch.nn.functional.one_hot#
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor#
接收形狀為
(*)的索引值的 LongTensor,並返回形狀為(*, num_classes)的張量。該張量在最後一個維度上,除了索引值對應的位置為 1 外,其餘位置均為 0。另請參閱 Wikipedia 上的獨熱編碼。
- 引數
tensor (LongTensor) – 任意形狀的類別值。
num_classes (int, optional) – 總類別數。如果設定為 -1,則類別數將根據輸入張量中最大的類別值推斷(最大類別值 + 1)。預設為 -1
- 返回
返回一個 LongTensor,它比輸入張量多一個維度,在該維度上,輸入張量指示的索引位置為 1,其餘位置為 0。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])