快捷方式

TrajCounter

class torchrl.envs.transforms.TrajCounter(out_key: NestedKey = 'traj_count', *, repeats: int | None = None)[原始碼]

全域性軌跡計數器轉換。

TrajCounter 可用於計算任何 TorchRL 環境中軌跡的數量(即呼叫 reset 的次數)。此轉換將在單個節點內的多個程序中工作(請參閱下面的注意事項)。單個轉換隻能計算與單個完成狀態相關的軌跡,但只要其字首與計數器鍵的字首匹配,就接受巢狀的完成狀態。

引數:

out_key (NestedKey, 可選) – 軌跡計數器的條目名稱。預設為 "traj_count"

示例

>>> from torchrl.envs import GymEnv, StepCounter, TrajCounter
>>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = env.append_transform(TrajCounter())
>>> r = env.rollout(18, break_when_any_done=False)  # 18 // 6 = 3 trajectories
>>> r["next", "traj_count"]
tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])

注意

可以透過多種方式在工作程序之間共享軌跡計數器,但這通常會涉及將環境包裝在 EnvCreator 中。否則,在轉換的序列化過程中可能會出錯。計數器將在工作程序之間共享,這意味著在任何給定時間,都保證不會有兩個環境共享相同的軌跡計數(每個(步驟計數,軌跡計數)對都是唯一的)。以下是共享 TrajCounter 物件在程序之間的有效方法的示例。

>>> # Option 1: Create the trajectory counter outside the environment.
>>> #  This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents.
>>> t = TrajCounter()
>>> def make_env(max_steps=4, t=t):
...     # See CountingEnv in torchrl.test.mocking_classes
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone())
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> penv = ParallelEnv(
...     2,
...     [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)],
...     mp_start_method="spawn",
... )
>>> # Option 2: Create the transform within the constructor.
>>> #  In this scenario, we still need to tell each sub-env what kwarg has to be used.
>>> #  Both EnvCreator and ParallelEnv offer that possibility.
>>> def make_env(max_steps=4):
...     t = TrajCounter()
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t)
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> make_env_c0 = EnvCreator(make_env)
>>> # Create a variant of the env with different kwargs
>>> make_env_c1 = make_env_c0.make_variant(max_steps=5)
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c1],
...     mp_start_method="spawn",
... )
>>> # Alternatively, pass the kwargs to the ParallelEnv
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c0],
...     create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}],
...     mp_start_method="spawn",
... )
forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[原始碼]

state_dict 將引數和緩衝區複製到此模組及其後代中。

如果 strictTrue,則 state_dict 的鍵必須與此模組的 state_dict() 函式返回的鍵完全匹配。

警告

如果 assignTrue,則必須在呼叫 load_state_dict 之後建立最佳化器,除非 get_swap_module_params_on_conversion()True

引數:
  • state_dict (dict) – 包含引數和持久 buffer 的字典。

  • strict (bool, 可選) – 是否嚴格強制 state_dict 中的鍵與此模組的 state_dict() 函式返回的鍵匹配。預設值:True

  • assign (bool, optional) – 當設定為 False 時,將保留當前模組中張量的屬性;當設定為 True 時,將保留 state_dict 中張量的屬性。唯一的例外是 Parameterrequires_grad 欄位,此時將保留模組的值。預設值:False

返回:

  • missing_keys 是一個包含此模組期望但

    在提供的 state_dict 中缺失的任何鍵的字串列表。

  • unexpected_keys 是一個字串列表,包含此模組

    不期望但在提供的 state_dict 中存在的鍵。

返回型別:

NamedTuple,包含 missing_keysunexpected_keys 欄位。

注意

如果引數或緩衝區被註冊為 None 並且其對應的鍵存在於 state_dict 中,load_state_dict() 將引發 RuntimeError

state_dict(*args, destination=None, prefix='', keep_vars=False)[原始碼]

返回一個字典,其中包含對模組整個狀態的引用。

引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為 None 的引數和緩衝區不包含在內。

注意

返回的物件是淺複製。它包含對模組引數和緩衝區的引用。

警告

當前 state_dict() 還接受 destinationprefixkeep_vars 的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

引數:
  • destination (dict, optional) – 如果提供,模組的狀態將更新到 dict 中,並返回相同的物件。否則,將建立一個 OrderedDict 並返回。預設為 None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 預設情況下,state dict 中返回的 Tensors 會從 autograd 中分離。如果設定為 True,則不會執行分離。預設為 False

返回:

包含模組整體狀態的字典

返回型別:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
transform_observation_spec(observation_spec: Composite) Composite[原始碼]

轉換觀察規範,使結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源