CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: Sequence[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[原始碼]¶
將多個鍵連線成一個張量。
這在多個鍵描述一個狀態(例如,“observation_position”和“observation_velocity”)時特別有用。
- 引數:
in_keys (NestedKey 序列) – 要連線的鍵。如果為 None(或未提供),則在轉換首次使用時將從父環境中檢索鍵。只有當設定了父環境時,此行為才會起作用。
out_key (NestedKey) – 結果張量的鍵。
dim (int, optional) – 連線將發生的維度。預設為
-1。
- 關鍵字引數:
del_keys (bool, optional) – 如果為
True,則輸入值將在連線後被刪除。預設為True。unsqueeze_if_oor (bool, optional) – 如果為
True,CatTensor 將檢查要連線的張量是否存在指定的維度。如果不存在,張量將在該維度上進行 unsqueeze。預設為False。sort (bool, optional) – 如果為
True,則將在轉換中對鍵進行排序。否則,將優先使用使用者提供的順序。預設為True。
示例
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(next_tensordict: TensorDictBase) TensorDictBase¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]¶
轉換觀察規範,使結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範