快捷方式

step_mdp

torchrl.envs.step_mdp(tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: NestedKey | list[NestedKey] = 'reward', done_keys: NestedKey | list[NestedKey] = 'done', action_keys: NestedKey | list[NestedKey] = 'action') TensorDictBase[原始碼]

建立一個新的 tensordict,反映輸入 tensordict 的時間步。

給定一個在步進後檢索到的 tensordict,返回 "next" 索引的 tensordict。引數允許精確控制哪些內容應該被保留,哪些內容應該從 "next" 條目中複製。預設行為是:將 observation 條目、獎勵和 done 狀態移動到根目錄,排除當前 action,並保留所有額外的鍵(非 action、非 done、非 reward)。

引數:
  • tensordict (TensorDictBase) – 要重新命名的鍵的 tensordict。

  • next_tensordict (TensorDictBase, 可選) – 目標 tensordict。如果為 None,則建立一個新的 tensordict。

  • keep_other (bool, 可選) – 如果為 True,則會保留所有不以 'next_' 開頭的鍵。預設為 True

  • exclude_reward (bool, 可選) – 如果為 True,則 "reward" 鍵將被從結果 tensordict 中丟棄。如果為 False,它將被從 "next" 條目(如果存在)複製(並替換)。預設為 True

  • exclude_done (bool, 可選) – 如果為 True,則 "done" 鍵將被從結果 tensordict 中丟棄。如果為 False,它將被從 "next" 條目(如果存在)複製(並替換)。預設為 False

  • exclude_action (bool, 可選) – 如果為 True,則 "action" 鍵將被從結果 tensordict 中丟棄。如果為 False,它將被保留在根 tensordict 中(因為它不應出現在 "next" 條目中)。預設為 True

  • reward_keys (NestedKeyNestedKey 列表, 可選) – 寫入獎勵的鍵。預設為“reward”。

  • done_keys (NestedKeyNestedKey 列表, 可選) – 寫入 done 的鍵。預設為“done”。

  • action_keys (NestedKeyNestedKey 列表, 可選) – 寫入 action 的鍵。預設為“action”。

返回:

包含 t+1 步張量的新的 tensordict(或如果提供了 next_tensordict 則為 next_tensordict)。

返回型別:

TensorDictBase

另請參閱

EnvBase.step_mdp() 是此自由函式的基於類的版本。它將嘗試快取鍵值以減少 MDP 步進的開銷。

示例

>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({
...     "done": torch.zeros((), dtype=torch.bool),
...     "reward": torch.zeros(()),
...     "extra": torch.zeros(()),
...     "next": TensorDict({
...         "done": torch.zeros((), dtype=torch.bool),
...         "reward": torch.zeros(()),
...         "obs": torch.zeros(()),
...     }, []),
...     "obs": torch.zeros(()),
...     "action": torch.zeros(()),
... }, [])
>>> print(step_mdp(td))
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_done=True))  # "done" is dropped
TensorDict(
    fields={
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_reward=False))  # "reward" is kept
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, exclude_action=False))  # "action" persists at the root
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(step_mdp(td, keep_other=False))  # "extra" is missing
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False),
        obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

如果獎勵鍵也包含在輸入鍵中(當排除獎勵鍵時),此函式將無法正常工作。這就是為什麼 RewardSum 轉換預設將劇集獎勵註冊到 observation 而不是 reward spec。當使用此函式的快速快取版本(_StepMDP)時,不應觀察到此問題。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源