StepCounter¶
- class torchrl.envs.transforms.StepCounter(max_steps: int | None = None, truncated_key: str | None = 'truncated', step_count_key: str | None = 'step_count', update_done: bool = True)[原始碼]¶
從重置開始計數,並可選擇在達到一定步數後將 truncated 狀態設定為
True。"done"狀態也會相應調整(因為 done 是任務完成和提前截斷的析取)。- 引數:
max_steps (int, optional) – 一個正整數,表示在將
truncated_key條目設定為True之前要採取的最大步數。truncated_key (str, optional) – 寫入截斷條目的鍵。預設為
"truncated",資料收集器將其識別為重置訊號。此引數只能是字串(不能是巢狀鍵),因為它將匹配父環境中的每個葉子 done 鍵(例如,如果使用了"truncated"鍵名,則("agent", "done")鍵將伴隨一個("agent", "truncated"))。step_count_key (str, optional) – 寫入步數計數值的鍵。預設為
"step_count"。此引數只能是字串(不能是巢狀鍵),因為它將匹配父環境中的每個葉子 done 鍵(例如,如果使用了"step_count"鍵名,則("agent", "done")鍵將伴隨一個("agent", "step_count"))。update_done (bool, optional) – 如果為
True,則將更新"truncated"級別的"done"布林張量。此訊號表示軌跡已到達其末尾,原因可能是任務已完成("completed"條目為True)或已被截斷("truncated"條目為True)。預設為True。
注意
為了確保與具有多個 done_key(s) 的環境相容,此轉換將為 tensordict 內的每個 done 條目寫入一個 step_count 條目。
示例
>>> import gymnasium >>> from torchrl.envs import GymWrapper >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1")) >>> env = TransformedEnv(base_env, ... StepCounter(max_steps=5)) >>> rollout = env.rollout(100) >>> print(rollout) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), completed: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), completed: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) >>> print(rollout["next", "step_count"]) tensor([[1], [2], [3], [4], [5]])
- forward(tensordict: TensorDictBase) TensorDictBase[原始碼]¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_input_spec(input_spec: Composite) Composite[原始碼]¶
轉換輸入規範,使結果規範與轉換對映匹配。
- 引數:
input_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範
- transform_observation_spec(observation_spec: Composite) Composite[原始碼]¶
轉換觀察規範,使結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範
- transform_output_spec(output_spec: Composite) Composite[原始碼]¶
轉換輸出規範,使結果規範與轉換對映匹配。
此方法通常應保持不變。更改應使用
transform_observation_spec()、transform_reward_spec()和transform_full_done_spec()來實現。 :param output_spec: 轉換前的 spec :type output_spec: TensorSpec- 返回:
轉換後的預期規範