快捷方式

VecNorm

class torchrl.envs.transforms.VecNorm(*args, **kwargs)[source]

用於 torchrl 環境的移動平均歸一化層。

警告

此類將棄用,取而代之的是 VecNormV2,並將在 v0.10 中被該類替換。您可以透過使用 new_api 引數或從 torchrl.envs 匯入 VecNormV2 類來適應這些更改。

VecNorm 跟蹤資料集的摘要統計資訊,以便對其進行即時標準化。如果該轉換處於“eval”模式,則不會更新執行統計資訊。

如果多個程序執行相似的環境,則可以傳遞一個放置在共享記憶體中的 TensorDictBase 例項:如果是這樣,每次查詢歸一化層時,它都會更新共享相同引用的所有程序的值。

要在推理時使用 VecNorm 並避免使用新觀察值更新值,應將此層替換為 to_observation_norm()。這將提供一個靜態版本的 VecNorm,在源轉換更新時不會更新。要獲取 VecNorm 層的凍結副本,請參閱 frozen_copy()

引數:
  • in_keys (sequence of NestedKey, optional) – 要更新的鍵。預設為 [“observation”, “reward”]

  • out_keys (sequence of NestedKey, optional) – 目標鍵。預設為 in_keys

  • shared_td (TensorDictBase, optional) – 一個共享的 tensordict,包含轉換的鍵。

  • lock (mp.Lock) – 一個鎖,用於防止程序之間發生競爭條件。預設為 None(在初始化期間建立鎖)。

  • decay (number, optional) – 移動平均的衰減率。預設為 0.99

  • eps (number, optional) – 執行標準差的下界(用於數值下溢)。預設為 1e-4。

  • shapes (List[torch.Size], optional) – 如果提供,則表示每個 in_keys 的形狀。其長度必須與 in_keys 的長度匹配。每個形狀都必須匹配相應條目的最後一個維度。如果不是,則將條目的特徵維度(即不屬於 tensordict 批次大小的所有維度)視為特徵維度。

  • new_api (bool or None, optional) – 如果 True,則將返回 VecNormV2 的例項。如果未傳遞,將引發警告。預設為 False

示例

>>> from torchrl.envs.libs.gym import GymEnv
>>> t = VecNorm(decay=0.9)
>>> env = GymEnv("Pendulum-v0")
>>> env = TransformedEnv(env, t)
>>> tds = []
>>> for _ in range(1000):
...     td = env.rand_step()
...     if td.get("done"):
...         _ = env.reset()
...     tds += [td]
>>> tds = torch.stack(tds, 0)
>>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all())
tensor(True)
>>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all())
tensor(True)
static build_td_for_shared_vecnorm(env: EnvBase, keys: Sequence[str] | None = None, memmap: bool = False) TensorDictBase[source]

為跨程序歸一化建立共享的 tensordict。

引數:
  • env (EnvBase) – 用於建立 tensordict 的示例環境

  • keys (sequence of NestedKey, optional) – 需要歸一化的鍵。預設為 [“next”, “reward”]

  • memmap (bool) – 如果 True,則生成的 tensordict 將被轉換為記憶體對映(使用 memmap_())。否則,tensordict 將被放置在共享記憶體中。

返回:

一個共享記憶體,用於傳送到每個程序。

示例

>>> from torch import multiprocessing as mp
>>> queue = mp.Queue()
>>> env = make_env()
>>> td_shared = VecNorm.build_td_for_shared_vecnorm(env,
...     ["next", "reward"])
>>> assert td_shared.is_shared()
>>> queue.put(td_shared)
>>> # on workers
>>> v = VecNorm(shared_td=queue.get())
>>> env = TransformedEnv(make_env(), v)
forward(next_tensordict: TensorDictBase) TensorDictBase

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
freeze() VecNorm[source]

凍結 VecNorm,避免在呼叫時更新統計資訊。

請參閱 unfreeze()

frozen_copy() VecNorm[source]

返回一個 Transform 的副本,該副本會跟蹤統計資訊但不會更新它們。

get_extra_state() OrderedDict[source]

返回要包含在模組 state_dict 中的任何額外狀態。

如果您的模組需要儲存額外狀態,請實現此函式以及相應的 set_extra_state()。在構建模組的 state_dict() 時呼叫此函式。

注意,為了保證 state_dict 的序列化工作正常,額外狀態應該是可被 pickle 的。我們僅為 Tensors 的序列化提供向後相容性保證;其他物件的序列化形式若發生變化,可能導致向後相容性中斷。

返回:

要儲存在模組 state_dict 中的任何額外狀態

返回型別:

物件

property loc

返回一個用於仿射變換的 loc 的 TensorDict。

property scale

返回一個用於仿射變換的 scale 的 TensorDict。

set_extra_state(state: OrderedDict) None[source]

設定載入的 state_dict 中包含的額外狀態。

此函式從 load_state_dict() 呼叫,以處理 state_dict 中的任何額外狀態。如果您的模組需要在其 state_dict 中儲存額外狀態,請實現此函式以及相應的 get_extra_state()

引數:

state (dict) – 來自 state_dict 的額外狀態

property standard_normal: bool

locscale 提供的仿射變換是否遵循標準正態方程。

類似於 ObservationNorm 的 standard_normal 屬性。

始終返回 True

to_observation_norm() Compose | ObservationNorm[source]

將 VecNorm 轉換為一個可以在推理時使用的 ObservationNorm 類。

可以使用 state_dict() API 更新 ObservationNorm 層。

示例

>>> from torchrl.envs import GymEnv, VecNorm
>>> vecnorm = VecNorm(in_keys=["observation"])
>>> train_env = GymEnv("CartPole-v1", device=None).append_transform(
...     vecnorm)
>>>
>>> r = train_env.rollout(4)
>>>
>>> eval_env = GymEnv("CartPole-v1").append_transform(
...     vecnorm.to_observation_norm())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
>>>
>>> r = train_env.rollout(4)
>>> # Update entries with state_dict
>>> eval_env.transform.load_state_dict(
...     vecnorm.to_observation_norm().state_dict())
>>> print(eval_env.transform.loc, eval_env.transform.scale)
transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換觀察規範,使結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

unfreeze() VecNorm[source]

解凍 VecNorm。

請參閱 freeze()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源