EnvCreator¶
- class torchrl.envs.EnvCreator(create_env_fn: Callable[..., EnvBase], create_env_kwargs: dict | None = None, share_memory: bool = True, **kwargs)[原始碼]¶
環境建立者類。
EnvCreator 是一個通用的環境建立者類,可以在多程序環境中建立環境時替代 lambda 函式。如果必須在子程序中建立的環境與主程序共享資訊(例如,用於 VecNorm 轉換),EnvCreator 會將 tensordicts 的指標傳遞給共享記憶體中的每個程序,以確保它們全部同步。
- 引數:
create_env_fn (callable) – 一個返回 EnvBase 例項的可呼叫物件。
create_env_kwargs (dict, optional) – env 建立者的關鍵字引數。
share_memory (bool, optional) – 如果為 False,則從環境中獲得的 tensordict 不會放置在共享記憶體中。
**kwargs – 在構造期間傳遞給環境的其他關鍵字引數。
示例
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations matches on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])
- make_variant(**kwargs) EnvCreator[原始碼]¶
建立 EnvCreator 的一個變體,指向相同的底層元資料,但在構造期間使用不同的關鍵字引數。
這在共享狀態的轉換(如
TrajCounter)中可能很有用。示例
>>> from torchrl.envs import GymEnv >>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1") >>> env_creator_cartpole = env_creator_pendulum.make_variant(env_name="CartPole-v1")