快捷方式

UnaryTransform

class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[原始碼]

對指定的輸入應用一元操作。

引數:
  • in_keys (NestedKey 序列) – 一元操作輸入的鍵。

  • out_keys (NestedKey 序列) – 一元操作輸出的鍵。

  • in_keys_inv (NestedKey 序列, 可選) – 在反向呼叫期間一元操作輸入的鍵。

  • out_keys_inv (NestedKey 序列, 可選) – 在反向呼叫期間一元操作輸出的鍵。

關鍵字引數:
  • fn (Callable[[Any], Tensor | TensorDictBase]) – 用作一元操作的函式。如果它接受非張量輸入,則還必須接受 None

  • inv_fn (Callable[[Any], Any], 可選) – 在反向呼叫期間用作一元操作的函式。如果它接受非張量輸入,則還必須接受 None。可以省略,在這種情況下將使用 fn 進行反向對映。

  • use_raw_nontensor (bool, 可選) – 如果為 False,則在呼叫 fn 之前,從 NonTensorData/NonTensorStack 輸入中提取資料。如果為 True,則直接將原始 NonTensorData/NonTensorStack 輸入提供給 fn,後者必須支援這些輸入。預設為 False

示例

>>> from torchrl.envs import GymEnv, UnaryTransform
>>> env = GymEnv("Pendulum-v1")
>>> env = env.append_transform(
...     UnaryTransform(
...         in_keys=["observation"],
...         out_keys=["observation_trsf"],
...             fn=lambda tensor: str(tensor.numpy().tobytes())))
>>> env.observation_spec
Composite(
    observation: BoundedContinuous(
        shape=torch.Size([3]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    observation_trsf: NonTensor(
        shape=torch.Size([]),
        space=None,
        device=cpu,
        dtype=None,
        domain=None),
    device=None,
    shape=torch.Size([]))
>>> env.rollout(3)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                observation_trsf: NonTensorStack(
                    ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x...,
                    batch_size=torch.Size([3]),
                    device=None),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_trsf: NonTensorStack(
            ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\...,
            batch_size=torch.Size([3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> env.check_env_specs()
[torchrl][INFO] check_env_specs succeeded!
transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[原始碼]

轉換動作規範,使結果規範與變換對映匹配。

引數:

action_spec (TensorSpec) – 變換前的規範

返回:

轉換後的預期規範

transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]

變換 done spec,使結果 spec 與變換對映匹配。

引數:

done_spec (TensorSpec) – 變換前的 spec

返回:

轉換後的預期規範

transform_input_spec(input_spec: Composite) Composite[原始碼]

轉換輸入規範,使結果規範與轉換對映匹配。

引數:

input_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]

轉換觀察規範,使結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_output_spec(output_spec: Composite) Composite[原始碼]

轉換輸出規範,使結果規範與轉換對映匹配。

此方法通常應保持不變。更改應透過 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 實現。 :param output_spec: 轉換前的 spec :type output_spec: TensorSpec

返回:

轉換後的預期規範

transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[原始碼]

轉換獎勵的 spec,使其與變換對映匹配。

引數:

reward_spec (TensorSpec) – 變換前的 spec

返回:

轉換後的預期規範

transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[原始碼]

轉換狀態規範,使結果規範與變換對映匹配。

引數:

state_spec (TensorSpec) – 變換前的規範

返回:

轉換後的預期規範

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源