快捷方式

EnvCreator

class torchrl.envs.EnvCreator(create_env_fn: Callable[..., EnvBase], create_env_kwargs: dict | None = None, share_memory: bool = True, **kwargs)[原始碼]

環境建立者類。

EnvCreator 是一個通用的環境建立者類,可以在多程序環境中建立環境時替代 lambda 函式。如果必須在子程序中建立的環境與主程序共享資訊(例如,用於 VecNorm 轉換),EnvCreator 會將 tensordicts 的指標傳遞給共享記憶體中的每個程序,以確保它們全部同步。

引數:
  • create_env_fn (callable) – 一個返回 EnvBase 例項的可呼叫物件。

  • create_env_kwargs (dict, optional) – env 建立者的關鍵字引數。

  • share_memory (bool, optional) – 如果為 False,則從環境中獲得的 tensordict 不會放置在共享記憶體中。

  • **kwargs – 在構造期間傳遞給環境的其他關鍵字引數。

示例

>>> # We create the same environment on 2 processes using VecNorm
>>> # and check that the discounted count of observations matches on
>>> # both workers, even if one has not executed any step
>>> import time
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
>>> from torchrl.envs import EnvCreator
>>> from torch import multiprocessing as mp
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
>>> env_creator = EnvCreator(env_fn)
>>>
>>> def test_env1(env_creator):
...     env = env_creator()
...     tensordict = env.reset()
...     for _ in range(10):
...         env.rand_step(tensordict)
...         if tensordict.get(("next", "done")):
...             tensordict = env.reset(tensordict)
...     print("env 1: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> def test_env2(env_creator):
...     env = env_creator()
...     time.sleep(5)
...     print("env 2: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> if __name__ == "__main__":
...     ps = []
...     p1 = mp.Process(target=test_env1, args=(env_creator,))
...     p1.start()
...     ps.append(p1)
...     p2 = mp.Process(target=test_env2, args=(env_creator,))
...     p2.start()
...     ps.append(p1)
...     for p in ps:
...         p.join()
env 1:  tensor([11.9934])
env 2:  tensor([11.9934])
make_variant(**kwargs) EnvCreator[原始碼]

建立 EnvCreator 的一個變體,指向相同的底層元資料,但在構造期間使用不同的關鍵字引數。

這在共享狀態的轉換(如 TrajCounter)中可能很有用。

示例

>>> from torchrl.envs import GymEnv
>>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
>>> env_creator_cartpole = env_creator_pendulum.make_variant(env_name="CartPole-v1")

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源