快捷方式

ConditionalPolicySwitch

class torchrl.envs.transforms.ConditionalPolicySwitch(policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool])[原始碼]

一個根據指定條件有條件地在策略之間切換的轉換。

此轉換會評估環境 step 方法返回的資料的條件。如果滿足條件,它會將指定的策略應用於資料。否則,資料將按原樣返回。這對於需要根據特定標準應用不同策略的場景非常有用,例如在遊戲中輪流進行。

引數:
  • policy (Callable[[TensorDictBase], TensorDictBase]) – 滿足條件時要應用的策略。這應該是一個可呼叫的物件,它接受一個 TensorDictBase 並返回一個 TensorDictBase

  • condition (Callable[[TensorDictBase], bool]) – 一個可呼叫的物件,它接受一個 TensorDictBase 並返回一個布林值或一個張量,指示是否應應用該策略。

警告

此轉換必須有父環境。

注意

理想情況下,它應該是堆疊中的最後一個轉換。如果策略需要轉換後的資料(例如影像),並且此轉換在此類轉換之前應用,則策略將收不到轉換後的資料。

示例

>>> import torch
>>> from tensordict.nn import TensorDictModule as Mod
>>>
>>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter
>>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than
>>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main
>>> # policy which does a left action.
>>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True)
>>>
>>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"])
>>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"])
>>>
>>> cond = lambda td: td.get("observation")[..., 0] >= 0
>>>
>>> env = base_env.append_transform(
...     Compose(
...         # We use two step counters to show that one counts the global steps, whereas the other
...         # only counts the steps where the main policy is executed
...         StepCounter(step_count_key="step_count_total"),
...         ConditionalPolicySwitch(condition=cond, policy=policy_switch),
...         StepCounter(step_count_key="step_count_main"),
...     )
... )
>>>
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>>
>>> r = env.rollout(100, policy=policy)
>>> print("action", r["action"])
action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
>>> print("obs", r["observation"])
obs tensor([[ 0.0322, -0.1540,  0.0111,  0.3190],
        [ 0.0299, -0.1544,  0.0181,  0.3280],
        [ 0.0276, -0.1550,  0.0255,  0.3414],
        [ 0.0253, -0.1558,  0.0334,  0.3596],
        [ 0.0230, -0.1569,  0.0422,  0.3828],
        [ 0.0206, -0.1582,  0.0519,  0.4117],
        [ 0.0181, -0.1598,  0.0629,  0.4469],
        [ 0.0156, -0.1617,  0.0753,  0.4891],
        [ 0.0130, -0.1639,  0.0895,  0.5394],
        [ 0.0104, -0.1665,  0.1058,  0.5987],
        [ 0.0076, -0.1696,  0.1246,  0.6685],
        [ 0.0047, -0.1732,  0.1463,  0.7504],
        [ 0.0016, -0.1774,  0.1715,  0.8459],
        [-0.0020,  0.0150,  0.1884,  0.6117],
        [-0.0017,  0.2071,  0.2006,  0.3838]])
>>> print("obs'", r["next", "observation"])
obs' tensor([[ 0.0299, -0.1544,  0.0181,  0.3280],
        [ 0.0276, -0.1550,  0.0255,  0.3414],
        [ 0.0253, -0.1558,  0.0334,  0.3596],
        [ 0.0230, -0.1569,  0.0422,  0.3828],
        [ 0.0206, -0.1582,  0.0519,  0.4117],
        [ 0.0181, -0.1598,  0.0629,  0.4469],
        [ 0.0156, -0.1617,  0.0753,  0.4891],
        [ 0.0130, -0.1639,  0.0895,  0.5394],
        [ 0.0104, -0.1665,  0.1058,  0.5987],
        [ 0.0076, -0.1696,  0.1246,  0.6685],
        [ 0.0047, -0.1732,  0.1463,  0.7504],
        [ 0.0016, -0.1774,  0.1715,  0.8459],
        [-0.0020,  0.0150,  0.1884,  0.6117],
        [-0.0017,  0.2071,  0.2006,  0.3838],
        [ 0.0105,  0.2015,  0.2115,  0.5110]])
>>> print("total step count", r["step_count_total"].squeeze())
total step count tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27])
>>> print("total step with main policy", r["step_count_main"].squeeze())
total step with main policy tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
forward(tensordict: TensorDictBase) Any[原始碼]

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源