快捷方式

PixelRenderTransform

torchrl.record.PixelRenderTransform(out_keys: list[NestedKey] = None, preproc: Callable[[np.ndarray | torch.Tensor], np.ndarray | torch.Tensor] = None, as_non_tensor: bool | None = None, render_method: str = 'render', pass_tensordict: bool = False, **kwargs) None[原始碼]

一個呼叫父環境的 render 方法並將畫素觀察註冊到 tensordict 中的轉換。

此轉換提供了一種替代方法,用於在例項化提供渲染功能的 RL 環境時使用 from_pixels 語法糖,尤其是在例項化環境成本很高或 from_pixels 未實現的情況下。它可以用於單個環境或批處理環境。

引數:
  • out_keys (List[NestedKey] or Nested) – 用於註冊畫素觀察值的鍵列表。

  • preproc (Callable, optional) – 一個預處理函式。可用於重塑觀察值,或應用任何其他使其能夠註冊到輸出資料中的轉換。

  • as_non_tensor (bool, optional) – 如果設定為 True,則資料將作為 NonTensorData 寫入,從而放寬形狀要求。如果未提供,則將根據輸入資料的型別和形狀自動推斷。

  • render_method (str, optional) – 渲染方法的名稱。預設為 "render"

  • pass_tensordict (bool, optional) – 如果設定為 True,則輸入 tensordict 將傳遞給渲染方法。這使得無狀態環境能夠進行渲染。預設為 False

  • **kwargs – 傳遞給渲染函式的其他關鍵字引數(例如 mode="rgb_array")。

示例

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>>     env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>>     env = env.append_transform(PixelRenderTransform())
>>>     return env
>>>
>>> if __name__ == "__main__":
...     logger = CSVLogger("dummy", video_format="mp4")
...
...     env = ParallelEnv(4, EnvCreator(make_env))
...
...     env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
...     env.rollout(3)
...
...     check_env_specs(env)
...
...     r = env.rollout(30)
...     print(env)
...     env.transform.dump()
...     env.close()

當批處理環境 render() 返回單個影像時,也可以使用此轉換。

示例

>>> from torchrl.envs import check_env_specs
>>> from torchrl.envs.libs.vmas import VmasEnv
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> env = VmasEnv(
...     scenario="flocking",
...     num_envs=32,
...     continuous_actions=True,
...     max_steps=200,
...     device="cpu",
...     seed=None,
...     # Scenario kwargs
...     n_agents=5,
... )
>>>
>>> logger = CSVLogger("dummy", video_format="mp4")
>>>
>>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy()))
>>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
>>>
>>> check_env_specs(env)
>>>
>>> r = env.rollout(30)
>>> env.transform[-1].dump()

可以使用 switch() 方法停用此轉換,該方法將開啟渲染(如果已關閉)或關閉渲染(如果已開啟)(也可以傳遞一個引數來控制此行為)。由於轉換是 Module 例項,因此可以使用 apply() 來控制此行為。

>>> def switch(module):
...     if isinstance(module, PixelRenderTransform):
...         module.switch()
>>> env.apply(switch)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源