快捷方式

DTypeCastTransform

class torchrl.envs.transforms.DTypeCastTransform(dtype_in: torch.dtype, dtype_out: torch.dtype, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[原始碼]

將一個 dtype 轉換為另一個 dtype,針對選定的鍵。

根據構造時是否提供 in_keysin_keys_inv,類的行為將有所不同

  • 如果提供了鍵,則只有這些條目將從 dtype_in 轉換為 dtype_out 條目;

  • 如果未提供鍵且該物件位於轉換的環境登錄檔中,則將具有 dtype_in 資料型別的輸入和輸出規範將分別用作 in_keys_inv / in_keys。

  • 如果未提供鍵且物件在沒有環境的情況下使用,則 forward / inverse 傳遞將掃描輸入 tensordict 中所有 dtype_in 值,並將它們對映到 dtype_out 張量。對於大型資料結構,這可能會影響效能,因為這種掃描並非沒有成本。將被轉換的鍵不會被快取。請注意,在這種情況下,不能傳遞 out_keys(或 out_keys_inv),因為無法精確預測處理鍵的順序。

引數:
  • dtype_in (torch.dtype) – 輸入資料型別(來自環境)。

  • dtype_out (torch.dtype) – 輸出資料型別(用於模型訓練)。

  • in_keys (NestedKey 的序列, 可選) – 要轉換為 dtype_outdtype_in 鍵列表,這些鍵將在暴露給外部物件和函式之前進行轉換。

  • out_keys (NestedKey 的序列, 可選) – 目標鍵列表。如果未提供,則預設為 in_keys

  • in_keys_inv (NestedKey 的序列, 可選) – 要轉換為 dtype_indtype_out 鍵列表,這些鍵將在傳遞給包含的基礎環境或儲存之前進行轉換。

  • out_keys_inv (NestedKey 的序列, 可選) – 逆轉換的目標鍵列表。如果未提供,則預設為 in_keys_inv

示例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"])
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float64

在“自動”模式下,所有 float64 條目都會被轉換

示例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DTypeCastTransform(torch.double, torch.float)
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float32

在不指定轉換鍵的情況下構造環境時,也遵循相同的行為規則

示例

>>> class MyEnv(EnvBase):
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64))
...         self.action_spec = Unbounded((), dtype=torch.float64)
...         self.reward_spec = Unbounded((1,), dtype=torch.float64)
...         self.done_spec = Unbounded((1,), dtype=torch.bool)
...     def _reset(self, data=None):
...         return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
...     def _step(self, data):
...         assert data["action"].dtype == torch.float64
...         reward = self.reward_spec.rand()
...         done = torch.zeros((1,), dtype=torch.bool)
...         obs = self.observation_spec.rand()
...         assert reward.dtype == torch.float64
...         assert obs["obs"].dtype == torch.float64
...         return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
...     def _set_seed(self, seed) -> None:
...         pass
>>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float))
>>> assert env.action_spec.dtype == torch.float32
>>> assert env.observation_spec["obs"].dtype == torch.float32
>>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> assert env.transform.in_keys == ["obs", "reward"]
>>> assert env.transform.in_keys_inv == ["action"]
forward(tensordict: TensorDictBase = None) TensorDictBase[原始碼]

讀取輸入 tensordict,並對選定的鍵應用轉換。

transform_input_spec(input_spec: TensorSpec) TensorSpec[原始碼]

轉換輸入規範,使結果規範與轉換對映匹配。

引數:

input_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_observation_spec(observation_spec)[原始碼]

轉換觀察規範,使結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_output_spec(output_spec: Composite) Composite[原始碼]

轉換輸出規範,使結果規範與轉換對映匹配。

此方法通常應保持不變。應使用 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 來實現更改。 :param output_spec: 轉換前的規範 :type output_spec: TensorSpec

返回:

轉換後的預期規範

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源