UnaryTransform¶
- class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[原始碼]¶
對指定的輸入應用一元操作。
- 引數:
in_keys (NestedKey 序列) – 一元操作輸入的鍵。
out_keys (NestedKey 序列) – 一元操作輸出的鍵。
in_keys_inv (NestedKey 序列, 可選) – 在反向呼叫期間一元操作輸入的鍵。
out_keys_inv (NestedKey 序列, 可選) – 在反向呼叫期間一元操作輸出的鍵。
- 關鍵字引數:
fn (Callable[[Any], Tensor | TensorDictBase]) – 用作一元操作的函式。如果它接受非張量輸入,則還必須接受
None。inv_fn (Callable[[Any], Any], 可選) – 在反向呼叫期間用作一元操作的函式。如果它接受非張量輸入,則還必須接受
None。可以省略,在這種情況下將使用fn進行反向對映。use_raw_nontensor (bool, 可選) – 如果為
False,則在呼叫fn之前,從NonTensorData/NonTensorStack輸入中提取資料。如果為True,則直接將原始NonTensorData/NonTensorStack輸入提供給fn,後者必須支援這些輸入。預設為False。
示例
>>> from torchrl.envs import GymEnv, UnaryTransform >>> env = GymEnv("Pendulum-v1") >>> env = env.append_transform( ... UnaryTransform( ... in_keys=["observation"], ... out_keys=["observation_trsf"], ... fn=lambda tensor: str(tensor.numpy().tobytes()))) >>> env.observation_spec Composite( observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), observation_trsf: NonTensor( shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None), device=None, shape=torch.Size([])) >>> env.rollout(3) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x..., batch_size=torch.Size([3]), device=None), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\..., batch_size=torch.Size([3]), device=None), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> env.check_env_specs() [torchrl][INFO] check_env_specs succeeded!
- transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[原始碼]¶
轉換動作規範,使結果規範與變換對映匹配。
- 引數:
action_spec (TensorSpec) – 變換前的規範
- 返回:
轉換後的預期規範
- transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]¶
變換 done spec,使結果 spec 與變換對映匹配。
- 引數:
done_spec (TensorSpec) – 變換前的 spec
- 返回:
轉換後的預期規範
- transform_input_spec(input_spec: Composite) Composite[原始碼]¶
轉換輸入規範,使結果規範與轉換對映匹配。
- 引數:
input_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範
- transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]¶
轉換觀察規範,使結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範
- transform_output_spec(output_spec: Composite) Composite[原始碼]¶
轉換輸出規範,使結果規範與轉換對映匹配。
此方法通常應保持不變。更改應透過
transform_observation_spec()、transform_reward_spec()和transform_full_done_spec()實現。 :param output_spec: 轉換前的 spec :type output_spec: TensorSpec- 返回:
轉換後的預期規範
- transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]¶
轉換獎勵的 spec,使其與變換對映匹配。
- 引數:
reward_spec (TensorSpec) – 變換前的 spec
- 返回:
轉換後的預期規範
- transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[原始碼]¶
轉換狀態規範,使結果規範與變換對映匹配。
- 引數:
state_spec (TensorSpec) – 變換前的規範
- 返回:
轉換後的預期規範