PixelRenderTransform¶
- torchrl.record.PixelRenderTransform(out_keys: list[NestedKey] = None, preproc: Callable[[np.ndarray | torch.Tensor], np.ndarray | torch.Tensor] = None, as_non_tensor: bool | None = None, render_method: str = 'render', pass_tensordict: bool = False, **kwargs) None[原始碼]¶
一個呼叫父環境的 render 方法並將畫素觀察註冊到 tensordict 中的轉換。
此轉換提供了一種替代方法,用於在例項化提供渲染功能的 RL 環境時使用
from_pixels語法糖,尤其是在例項化環境成本很高或from_pixels未實現的情況下。它可以用於單個環境或批處理環境。- 引數:
out_keys (List[NestedKey] or Nested) – 用於註冊畫素觀察值的鍵列表。
preproc (Callable, optional) – 一個預處理函式。可用於重塑觀察值,或應用任何其他使其能夠註冊到輸出資料中的轉換。
as_non_tensor (bool, optional) – 如果設定為
True,則資料將作為NonTensorData寫入,從而放寬形狀要求。如果未提供,則將根據輸入資料的型別和形狀自動推斷。render_method (str, optional) – 渲染方法的名稱。預設為
"render"。pass_tensordict (bool, optional) – 如果設定為
True,則輸入 tensordict 將傳遞給渲染方法。這使得無狀態環境能夠進行渲染。預設為False。**kwargs – 傳遞給渲染函式的其他關鍵字引數(例如
mode="rgb_array")。
示例
>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator >>> from torchrl.record.loggers import CSVLogger >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder >>> >>> def make_env(): >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") >>> env = env.append_transform(PixelRenderTransform()) >>> return env >>> >>> if __name__ == "__main__": ... logger = CSVLogger("dummy", video_format="mp4") ... ... env = ParallelEnv(4, EnvCreator(make_env)) ... ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) ... env.rollout(3) ... ... check_env_specs(env) ... ... r = env.rollout(30) ... print(env) ... env.transform.dump() ... env.close()
當批處理環境
render()返回單個影像時,也可以使用此轉換。示例
>>> from torchrl.envs import check_env_specs >>> from torchrl.envs.libs.vmas import VmasEnv >>> from torchrl.record.loggers import CSVLogger >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder >>> >>> env = VmasEnv( ... scenario="flocking", ... num_envs=32, ... continuous_actions=True, ... max_steps=200, ... device="cpu", ... seed=None, ... # Scenario kwargs ... n_agents=5, ... ) >>> >>> logger = CSVLogger("dummy", video_format="mp4") >>> >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy())) >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) >>> >>> check_env_specs(env) >>> >>> r = env.rollout(30) >>> env.transform[-1].dump()
可以使用
switch()方法停用此轉換,該方法將開啟渲染(如果已關閉)或關閉渲染(如果已開啟)(也可以傳遞一個引數來控制此行為)。由於轉換是Module例項,因此可以使用apply()來控制此行為。>>> def switch(module): ... if isinstance(module, PixelRenderTransform): ... module.switch() >>> env.apply(switch)