快捷方式

MaskedCategorical

class torchrl.modules.MaskedCategorical(logits: torch.Tensor | None = None, probs: torch.Tensor | None = None, *, mask: torch.Tensor | None = None, indices: torch.Tensor | None = None, neg_inf: float = - inf, padding_value: int | None = None, use_cross_entropy: bool = True, padding_side: str = 'left')[source]

MaskedCategorical 分佈。

參考: https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

引數:
  • logits (torch.Tensor) – 事件的對數機率(未歸一化)

  • probs (torch.Tensor) – 事件機率。如果提供了機率,則被掩碼的項的機率將被歸零,並且機率將在其最後一個維度上重新歸一化。

關鍵字引數:
  • mask (torch.Tensor) – 一個布林掩碼,形狀與 logits/probs 相同,其中 False 條目是被掩碼的項。或者,如果 sparse_mask 為 True,它代表分佈中有效索引的列表。與 indices 互斥。

  • indices (torch.Tensor) – 一個密集索引張量,表示必須考慮哪些動作。與 mask 互斥。

  • neg_inf (float, optional) – 分配給無效(超出掩碼)索引的對數機率值。預設為 -inf。

  • padding_value – 掩碼張量中的填充值。當 sparse_mask == True 時,將忽略 padding_value。

  • use_cross_entropy (bool, optional) – 為了更快地計算對數機率,可以使用 cross_entropy 損失函式。預設為 True

  • padding_side (str, optional) – 填充的側邊。預設為 "left"

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4) / 100  # almost equal probabilities
>>> mask = torch.tensor([True, False, True, True])
>>> dist = MaskedCategorical(logits=logits, mask=mask)
>>> sample = dist.sample((10,))
>>> print(sample)  # no `1` in the sample
tensor([2, 3, 0, 2, 2, 0, 2, 0, 2, 2])
>>> print(dist.log_prob(sample))
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
        -1.1203, -1.1203])
>>> print(dist.log_prob(torch.ones_like(sample)))
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
>>> # with probabilities
>>> prob = torch.ones(10)
>>> prob = prob / prob.sum()
>>> mask = torch.tensor([False] + 9 * [True])  # first outcome is masked
>>> dist = MaskedCategorical(probs=prob, mask=mask)
>>> print(dist.log_prob(torch.arange(10)))
tensor([   -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
        -2.1972, -2.1972])
entropy()[source]

計算分佈的熵。

對於帶掩碼的分佈,我們只考慮有效(未掩碼)結果的熵。無效結果的機率為零,不計入熵。

log_prob(value: Tensor) Tensor[source]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (Tensor) –

property padding_value

分佈掩碼的填充值。

如果未設定填充值,將從 logits 中推斷。

sample(sample_shape: torch.Size | Sequence[int] | None = None) Tensor[source]

生成 sample_shape 形狀的樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的樣本批次。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源