OneHotCategorical¶
- class torchrl.modules.OneHotCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]¶
獨熱(One-hot)分類分佈。
此類行為與 torch.distributions.Categorical 完全相同,但它讀取和生成離散張量的一次熱編碼。
- 引數:
logits (torch.Tensor) – 事件的對數機率(未歸一化)
probs (torch.Tensor) – 事件的機率
grad_method (ReparamGradientStrategy, optional) –
用於收集重引數化樣本的策略。
ReparamGradientStrategy.PassThrough將使用 softmax 值對數機率作為樣本梯度的代理來計算樣本梯度。使用 softmax 值作為樣本梯度的代理來計算樣本梯度。
ReparamGradientStrategy.RelaxedOneHot將使用torch.distributions.RelaxedOneHot從分佈中取樣。
示例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) >>> dist = OneHotCategorical(logits=logits) >>> print(dist.rsample((3,))) tensor([[1., 0., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.]])
- log_prob(value: torch.Tensor) torch.Tensor[source]¶
返回在 value 處評估的機率密度/質量函式的對數。
- 引數:
value (Tensor) –
- rsample(sample_shape: torch.Size | Sequence = None) torch.Tensor[source]¶
生成 sample_shape 形狀的重引數化樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的重引數化樣本批次。
- sample(sample_shape: torch.Size | Sequence | None = None) torch.Tensor[source]¶
生成 sample_shape 形狀的樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的樣本批次。