快捷方式

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[原始碼]

tensordict.nn.ProbabilisticTensorDictSequential 的子類,它接受 TensorSpec 作為引數來控制輸出域。

TensorDictSequential 類似,但強制要求序列中的最後一個模組是 ProbabilisticTensorDictModule,並且還公開了 get_dist 方法來從 ProbabilisticTensorDictModule 恢復分佈物件。

引數:
  • modules (TensorDictModules 的可迭代物件) – 按順序排列的 TensorDictModule 例項序列,以 ProbabilisticTensorDictModule 結尾,將按順序執行。

  • partial_tolerant (bool, optional) – 如果為 True,則輸入的 tensordict 可能缺少某些輸入鍵。如果是這樣,則只會執行可以根據存在的鍵執行的模組。此外,如果輸入的 tensordict 是 tensordicts 的惰性堆疊,並且 partial_tolerant 為 True,並且堆疊不包含所需的鍵,那麼 TensorDictSequential 將掃描子 tensordicts 查詢具有所需鍵的 tensordicts(如果有)。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源