快捷方式

set_composite_lp_aggregate

class tensordict.nn.set_composite_lp_aggregate(mode: bool = True)

控制 CompositeDistribution 的對數機率和熵是否將在單個張量中聚合。

composite_lp_aggregate() 返回 True 時,CompositeDistribution 的對數機率/熵將求和到一個具有根 tensordict 形狀的單個張量中。此行為已被棄用,轉而支援非聚合的對數機率,後者提供更大的靈活性以及稍顯自然的 API(tensordict 樣本,tensordict 對數機率,tensordict 熵)。

composite_lp_aggregate 的值也可以透過環境變數 COMPOSITE_LP_AGGREGATE 進行控制。

示例

>>> _ = torch.manual_seed(0)
>>> from tensordict import TensorDict
>>> from tensordict.nn import CompositeDistribution, set_composite_lp_aggregate
>>> import torch
>>> from torch import distributions as d
>>> params = TensorDict({
...     "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
...     ("nested", "disc"): {"logits": torch.randn(3, 10)}
... }, [3])
>>> dist = CompositeDistribution(params,
...     distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical})
>>> sample = dist.sample((4,))
>>> with set_composite_lp_aggregate(False):
...     lp = dist.log_prob(sample)
...     print(lp)
TensorDict(
    fields={
        cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
>>> with set_composite_lp_aggregate(True):
...     lp = dist.log_prob(sample)
...     print(lp)
tensor([[-2.0886, -1.2155, -0.0414],
        [-2.8973, -5.5165,  2.4402],
        [-0.2806, -1.2799,  3.1733],
        [-3.0407, -4.3593,  0.5763]])

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源