SafeProbabilisticTensorDictSequential¶
- class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[原始碼]¶
tensordict.nn.ProbabilisticTensorDictSequential的子類,它接受TensorSpec作為引數來控制輸出域。與
TensorDictSequential類似,但強制要求序列中的最後一個模組是ProbabilisticTensorDictModule,並且還公開了get_dist方法來從ProbabilisticTensorDictModule恢復分佈物件。- 引數:
modules (TensorDictModules 的可迭代物件) – 按順序排列的 TensorDictModule 例項序列,以 ProbabilisticTensorDictModule 結尾,將按順序執行。
partial_tolerant (bool, optional) – 如果為
True,則輸入的 tensordict 可能缺少某些輸入鍵。如果是這樣,則只會執行可以根據存在的鍵執行的模組。此外,如果輸入的 tensordict 是 tensordicts 的惰性堆疊,並且 partial_tolerant 為True,並且堆疊不包含所需的鍵,那麼 TensorDictSequential 將掃描子 tensordicts 查詢具有所需鍵的 tensordicts(如果有)。