快捷方式

ProbabilisticTensorDictSequential

class tensordict.nn.ProbabilisticTensorDictSequential(*args, **kwargs)

一系列包含至少一個 ProbabilisticTensorDictModuleTensorDictModules

此類擴充套件了 TensorDictSequential,通常用一系列模組配置,其中最後一個模組是 ProbabilisticTensorDictModule 的例項。但是,它也支援將一個或多箇中間模組配置為 ProbabilisticTensorDictModule 的例項,而最後一個模組可能是也可能不是機率性的。在所有情況下,它都會公開 get_dist() 方法,以從序列中的 ProbabilisticTensorDictModule 例項恢復分佈物件。

多個機率性模組可以共存於單個 ProbabilisticTensorDictSequential 中。如果 return_compositeFalse(預設值),則只有最後一個模組會產生分佈,而其他模組將作為常規 TensorDictModule 例項執行。但是,如果 ProbabilisticTensorDictModule 不是序列中的最後一個模組且 return_composite=False,則在嘗試查詢該模組時會引發 ValueError。如果 return_composite=True,則所有中間的 ProbabilisticTensorDictModule 例項都將貢獻給一個 CompositeDistribution 例項。

結果的對數機率將是條件機率,如果樣本是相互依賴的:每當

\[Z = F(X, Y)\]

那麼 Z 的對數機率將是

\[log(p(z | x, y))\]
引數:

*modules (sequenceOrderedDict of TensorDictModuleBaseProbabilisticTensorDictModule) – 一個有序的 TensorDictModule 例項序列,通常以 ProbabilisticTensorDictModule 結尾,將按順序執行。模組可以是 TensorDictModuleBase 的例項或任何符合此簽名的其他函式。請注意,如果使用非 TensorDictModuleBase 可呼叫物件,則其輸入和輸出鍵不會被跟蹤,因此不會影響 TensorDictSequential 的 in_keysout_keys 屬性。

關鍵字引數:
  • partial_tolerant (bool, optional) – 如果為 True,則輸入 tensordict 可以缺少一些輸入鍵。如果是這樣,將僅執行在存在鍵的情況下可以執行的模組。此外,如果輸入 tensordict 是 tensordicts 的惰性堆疊,並且 partial_tolerant 為 True,並且堆疊不包含所需的鍵,則 TensorDictSequential 將掃描子 tensordicts 以查詢具有所需鍵的那些(如果有)。預設為 False

  • return_composite (bool, optional) – 如果為 True 並且找到多個 ProbabilisticTensorDictModuleProbabilisticTensorDictSequential 例項,則將使用 CompositeDistribution 例項。否則,只有最後一個模組將用於構建分佈。如果 return_compositeFalse 且上述任一條件均不滿足,則會引發錯誤。預設為 True,只要存在一個以上的機率模組或最後一個模組不是機率性的。

  • inplace (bool, optional) – 如果為 True,則輸入 tensordict 將被就地修改。如果為 False,則會建立一個新的空的 TensorDict 例項。如果為 “empty”,則改用 input.empty()(即,輸出保留型別、裝置和批次大小)。預設為 None(依賴於子模組)。

丟擲:

示例

>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq
>>> import torch
>>> # Typical usage: a single distribution is computed last in the sequence
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq,         ...     TensorDictModule as Mod
>>> torch.manual_seed(0)
>>>
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions are ignored when return_composite=False
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=False,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
Normal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> print(module.log_prob(td))
tensor([-0.9189, -0.9189, -0.9189])
>>> # Intermediate distributions produce a CompositeDistribution when return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["loc2"]),
...     Prob(in_keys={"loc": "loc2"}, out_keys=["sample1"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     return_composite=True,
... )
>>> input = TensorDict(x=torch.ones(3))
>>> td = module(input.copy())
>>> print(td)
TensorDict(
    fields={
        loc2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3])), 'sample1': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample1_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # Even a single intermediate distribution is wrapped in a CompositeDistribution when
>>> # return_composite=True
>>> module = Seq(
...     Mod(lambda x: x + 1, in_keys=["x"], out_keys=["loc"]),
...     Prob(in_keys=["loc"], out_keys=["sample0"], distribution_class=torch.distributions.Normal,
...          distribution_kwargs={"scale": 1}),
...     Mod(lambda x: x + 1, in_keys=["sample0"], out_keys=["y"]),
...     return_composite=True,
... )
>>> td = module(TensorDict(x=torch.ones(3)))
>>> print(td)
TensorDict(
    fields={
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        sample0: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(module.get_dist(input))
CompositeDistribution({'sample0': Normal(loc: torch.Size([3]), scale: torch.Size([3]))})
>>> print(module.log_prob(td))
TensorDict(
    fields={
        sample0_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
build_dist_from_params(tensordict: TensorDictBase) Distribution

在不評估序列中其他模組的情況下,從輸入引數構造一個分佈。

此方法搜尋序列中的最後一個 ProbabilisticTensorDictModule,並使用它來構建分佈。

引數:

tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。

返回:

構造的分佈物件。

返回型別:

D.Distribution

丟擲:

RuntimeError – 如果序列中未找到 ProbabilisticTensorDictModule

property default_interaction_type

使用迭代啟發式方法返回模組的 default_interaction_type

此屬性按反向順序迭代所有模組,嘗試從任何子模組檢索 default_interaction_type 屬性。返回遇到的第一個非 None 值。如果找不到此類值,則返回預設 interaction_type()

property dist_params_keys: List[NestedKey]

返回指向分佈引數的所有鍵。

property dist_sample_keys: List[NestedKey]

返回指向分佈樣本的所有鍵。

forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs) TensorDictBase

當 tensordict 引數未設定時,kwargs 用於建立 TensorDict 的例項。

get_dist(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) Distribution

返回透過序列傳遞輸入 tensordict 所產生的分佈。

如果 return_compositeFalse(預設值),此方法將僅考慮序列中的最後一個機率模組。

否則,它將返回一個 CompositeDistribution 例項,其中包含所有機率模組的分佈。

引數:
  • tensordict (TensorDictBase) – 輸入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為 None,則將建立一個新的 tensordict。預設為 None

關鍵字引數:

**kwargs – 傳遞給底層模組的其他關鍵字引數。

返回:

產生的分佈物件。

返回型別:

D.Distribution

丟擲:

RuntimeError – 如果序列中未找到機率模組。

注意

return_compositeTrue 時,分佈會根據序列中的先前樣本進行條件化。這意味著,如果一個模組依賴於前一個機率模組的輸出,那麼它的分佈將是條件化的。

get_dist_params(tensordict: TensorDictBase, tensordict_out: Optional[TensorDictBase] = None, **kwargs) tuple[torch.distributions.distribution.Distribution, tensordict.base.TensorDictBase]

返回分佈引數和輸出 tensordict。

此方法執行 ProbabilisticTensorDictSequential 模組的確定性部分以獲取分佈引數。互動型別設定為當前全域性互動型別(如果可用),否則預設為最後一個模組的互動型別。

引數:
  • tensordict (TensorDictBase) – 輸入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為 None,則將建立一個新的 tensordict。預設為 None

關鍵字引數:

**kwargs – 傳遞給模組確定性部分的附加關鍵字引數。

返回:

包含分佈物件和輸出 tensordict 的元組。

返回型別:

tuple[D.Distribution, TensorDictBase]

注意

在執行此方法期間,互動型別將臨時設定為指定值。

log_prob(tensordict, tensordict_out: Optional[TensorDictBase] = None, *, dist: Optional[Distribution] = None, **kwargs) tensordict.base.TensorDictBase | torch.Tensor

返回輸入 tensordict 的對數機率。

如果 self.return_compositeTrue 且分佈為 CompositeDistribution,則此方法將返回整個複合分佈的對數機率。

否則,它將僅考慮序列中的最後一個機率模組。

引數:
  • tensordict (TensorDictBase) – 輸入 tensordict。

  • tensordict_out (TensorDictBase, optional) – 輸出 tensordict。如果為 None,則將建立一個新的 tensordict。預設為 None

關鍵字引數:

dist (torch.distributions.Distribution, optional) – 分佈物件。如果為 None,則將使用 get_dist 計算。預設為 None

返回:

輸入 tensordict 的對數機率。

返回型別:

TensorDictBasetorch.Tensor

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源