快捷方式

ConditionalSkip

class torchrl.envs.transforms.ConditionalSkip(cond: Callable[[TensorDict], bool | torch.Tensor])[原始碼]

當滿足特定條件時,跳過環境中步驟的轉換。

此轉換將 `cond(tensordict)` 的結果寫入傳遞給 `TransformedEnv.base_env._step` 方法的 tensordict 的 “_step” 條目中。如果 `base_env` 不是批處理鎖定的(一般來說,它是無狀態的),則 tensordict 將被縮減到需要透過步驟的元素。如果它是批處理鎖定的(一般來說,它是狀態化的),如果 “_step” 中的任何值不是 `True`,則完全跳過該步驟。否則,可以信任環境會相應地處理 “_step” 訊號。

注意

跳過操作也會影響修改環境輸出的轉換,即,如果滿足條件,任何將在 `step()` 返回的 tensordict 上執行的轉換都將被跳過。如果此效果不理想,可以透過將轉換後的環境包裝在另一個轉換後的環境中來緩解,因為跳過操作只會影響 `ConditionalSkip` 轉換的第一級父級。請參閱下面的示例。

引數:

cond (Callable[[TensorDictBase], bool | torch.Tensor]) – 一個可呼叫物件,用於輸入 tensordict,它檢查下一個環境步驟是否必須被跳過(`True` = 跳過,`False` = 執行 `env.step`)。

示例

>>> import torch
>>>
>>> from torchrl.envs import GymEnv
>>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose
>>>
>>> torch.manual_seed(0)
>>>
>>> base_env = TransformedEnv(
...     GymEnv("Pendulum-v1"),
...     StepCounter(step_count_key="inner_count"),
... )
>>> middle_env = TransformedEnv(
...     base_env,
...     Compose(
...         StepCounter(step_count_key="middle_count"),
...         ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1),
...     ),
...     auto_unwrap=False)  # makes sure that transformed envs are properly wrapped
>>> env = TransformedEnv(
...     middle_env,
...     StepCounter(step_count_key="step_count"),
...     auto_unwrap=False)
>>> env.set_seed(0)
>>>
>>> r = env.rollout(10)
>>> print(r["observation"])
tensor([[-0.9670, -0.2546, -0.9669],
        [-0.9802, -0.1981, -1.1601],
        [-0.9802, -0.1981, -1.1601],
        [-0.9926, -0.1214, -1.5556],
        [-0.9926, -0.1214, -1.5556],
        [-0.9994, -0.0335, -1.7622],
        [-0.9994, -0.0335, -1.7622],
        [-0.9984,  0.0561, -1.7933],
        [-0.9984,  0.0561, -1.7933],
        [-0.9895,  0.1445, -1.7779]])
>>> print(r["inner_count"])
tensor([[0],
        [1],
        [1],
        [2],
        [2],
        [3],
        [3],
        [4],
        [4],
        [5]])
>>> print(r["middle_count"])
tensor([[0],
        [1],
        [1],
        [2],
        [2],
        [3],
        [3],
        [4],
        [4],
        [5]])
>>> print(r["step_count"])
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]])
forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源