快捷方式

OneHotCategorical

class torchrl.modules.OneHotCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]

獨熱(One-hot)分類分佈。

此類行為與 torch.distributions.Categorical 完全相同,但它讀取和生成離散張量的一次熱編碼。

引數:
  • logits (torch.Tensor) – 事件的對數機率(未歸一化)

  • probs (torch.Tensor) – 事件的機率

  • grad_method (ReparamGradientStrategy, optional) –

    用於收集重引數化樣本的策略。ReparamGradientStrategy.PassThrough 將使用 softmax 值對數機率作為樣本梯度的代理來計算樣本梯度。

    使用 softmax 值作為樣本梯度的代理來計算樣本梯度。

    ReparamGradientStrategy.RelaxedOneHot 將使用 torch.distributions.RelaxedOneHot 從分佈中取樣。

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
entropy()[source]

返回分佈的熵,按 batch_shape 批次計算。

返回:

形狀為 batch_shape 的張量。

log_prob(value: torch.Tensor) torch.Tensor[source]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (Tensor) –

property mode: Tensor

返回分佈的眾數。

rsample(sample_shape: torch.Size | Sequence = None) torch.Tensor[source]

生成 sample_shape 形狀的重引數化樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的重引數化樣本批次。

sample(sample_shape: torch.Size | Sequence | None = None) torch.Tensor[source]

生成 sample_shape 形狀的樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的樣本批次。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源