torch.nn.functional.gumbel_softmax#
- torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[原始碼]#
從 Gumbel-Softmax 分佈 (連結 1 連結 2) 中取樣,並可選擇性地離散化。
- 引數
- 返回
從 Gumbel-Softmax 分佈中取樣的張量,其形狀與 logits 相同。如果
hard=True,則返回的樣本將是 one-hot 的,否則它們將是機率分佈,在 dim 維度上求和為 1。- 返回型別
注意
此函數出於歷史原因而保留,將來可能會從 nn.Functional 中移除。
注意
對於 hard 的主要技巧是執行 y_hard - y_soft.detach() + y_soft
這實現了兩個目的:- 使輸出值精確為 one-hot(因為我們新增然後減去 y_soft 值)- 使梯度等於 y_soft 的梯度(因為我們去除了所有其他梯度)
- 示例:
>>> logits = torch.randn(20, 32) >>> # Sample soft categorical using reparametrization trick: >>> F.gumbel_softmax(logits, tau=1, hard=False) >>> # Sample hard categorical using "Straight-through" trick: >>> F.gumbel_softmax(logits, tau=1, hard=True)