快捷方式

tensordict.nn.distributions.CompositeDistribution

class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, log_prob_key: Optional[NestedKey] = None, entropy_key: Optional[NestedKey] = None)

一個複合分佈,使用 TensorDict 介面將多個分佈組合在一起。

此類允許對一組分佈執行諸如 log_prob_compositeentropy_compositecdficdfrsamplesample 等操作,並返回一個 TensorDict。輸入的 TensorDict 可能會被就地修改。

引數:
  • params (TensorDictBase) – 一個巢狀的鍵-張量對映,其中根條目對應於樣本名稱,葉子是分佈引數。條目名稱必須與 distribution_map 中指定的名稱匹配。

  • distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指定要使用的分佈型別。分佈的名稱應與 TensorDict 中的樣本名稱匹配。

關鍵字引數:
  • name_map (Dict[NestedKey, NestedKey], optional) – 一個對映,指定每個樣本應寫入的位置。如果未提供,將使用來自 distribution_map 的鍵名。

  • extra_kwargs (Dict[NestedKey, Dict], optional) – 用於構造分佈的額外關鍵字引數的字典。

  • log_prob_key (NestedKey, optional) –

    將儲存聚合對數機率的鍵。預設為 ‘sample_log_prob’

    注意

    如果 tensordict.nn.probabilistic.composite_lp_aggregate() 返回 False,則對數機率將寫入 (“path”, “to”, “leaf”, “<sample_name>_log_prob”),其中 (“path”, “to”, “leaf”, “<sample_name>”) 是與被取樣葉子張量對應的 NestedKey。在這種情況下,將忽略 log_prob_key 引數。

  • entropy_key (NestedKey, optional) –

    將儲存熵的鍵。預設為 ‘entropy’

    注意

    如果 tensordict.nn.probabilistic.composite_lp_aggregate() 返回 False,則熵將寫入 (“path”, “to”, “leaf”, “<sample_name>_entropy”),其中 (“path”, “to”, “leaf”, “<sample_name>”) 是與被取樣葉子張量對應的 NestedKey。在這種情況下,將忽略 entropy_key 引數。

注意

包含引數(params)的輸入 TensorDict 的批處理大小決定了分佈的批處理形狀。例如,呼叫 log_prob 後生成的 “sample_log_prob” 條目的形狀將是引數的形狀加上任何額外的批處理維度。

另請參閱

ProbabilisticTensorDictModuleProbabilisticTensorDictSequential 來了解如何將此類作為模型的一部分使用。

另請參閱

set_composite_lp_aggregate 用於控制對數機率的聚合。

示例

>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> with set_composite_lp_aggregate(False):
...     sample = dist.log_prob(sample)
...     print(sample)
TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源