快捷方式

SelectKeys

class torchrl.trainers.SelectKeys(keys: Sequence[str])[原始碼]

在 TensorDict 批次中選擇鍵。

引數:

keys (字串的可迭代物件) – 要在 tensordict 中選擇的鍵。

示例

>>> trainer = make_trainer()
>>> key1 = "first key"
>>> key2 = "second key"
>>> td = TensorDict(
...     {
...         key1: torch.randn(3),
...         key2: torch.randn(3),
...     },
...     [],
... )
>>> trainer.register_op("batch_process", SelectKeys([key1]))
>>> td_out = trainer._process_batch_hook(td)
>>> assert key1 in td_out.keys()
>>> assert key2 not in td_out.keys()
register(trainer, name='select_keys') None[原始碼]

Registers the hook in the trainer at a default location.

引數:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

注意

To register the hook at another location than the default, use register_op().

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源