ConditionalPolicySwitch¶
- class torchrl.envs.transforms.ConditionalPolicySwitch(policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool])[原始碼]¶
一個根據指定條件有條件地在策略之間切換的轉換。
此轉換會評估環境 step 方法返回的資料的條件。如果滿足條件,它會將指定的策略應用於資料。否則,資料將按原樣返回。這對於需要根據特定標準應用不同策略的場景非常有用,例如在遊戲中輪流進行。
- 引數:
policy (Callable[[TensorDictBase], TensorDictBase]) – 滿足條件時要應用的策略。這應該是一個可呼叫的物件,它接受一個 TensorDictBase 並返回一個 TensorDictBase。
condition (Callable[[TensorDictBase], bool]) – 一個可呼叫的物件,它接受一個 TensorDictBase 並返回一個布林值或一個張量,指示是否應應用該策略。
警告
此轉換必須有父環境。
注意
理想情況下,它應該是堆疊中的最後一個轉換。如果策略需要轉換後的資料(例如影像),並且此轉換在此類轉換之前應用,則策略將收不到轉換後的資料。
示例
>>> import torch >>> from tensordict.nn import TensorDictModule as Mod >>> >>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter >>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than >>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main >>> # policy which does a left action. >>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True) >>> >>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"]) >>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"]) >>> >>> cond = lambda td: td.get("observation")[..., 0] >= 0 >>> >>> env = base_env.append_transform( ... Compose( ... # We use two step counters to show that one counts the global steps, whereas the other ... # only counts the steps where the main policy is executed ... StepCounter(step_count_key="step_count_total"), ... ConditionalPolicySwitch(condition=cond, policy=policy_switch), ... StepCounter(step_count_key="step_count_main"), ... ) ... ) >>> >>> env.set_seed(0) >>> torch.manual_seed(0) >>> >>> r = env.rollout(100, policy=policy) >>> print("action", r["action"]) action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) >>> print("obs", r["observation"]) obs tensor([[ 0.0322, -0.1540, 0.0111, 0.3190], [ 0.0299, -0.1544, 0.0181, 0.3280], [ 0.0276, -0.1550, 0.0255, 0.3414], [ 0.0253, -0.1558, 0.0334, 0.3596], [ 0.0230, -0.1569, 0.0422, 0.3828], [ 0.0206, -0.1582, 0.0519, 0.4117], [ 0.0181, -0.1598, 0.0629, 0.4469], [ 0.0156, -0.1617, 0.0753, 0.4891], [ 0.0130, -0.1639, 0.0895, 0.5394], [ 0.0104, -0.1665, 0.1058, 0.5987], [ 0.0076, -0.1696, 0.1246, 0.6685], [ 0.0047, -0.1732, 0.1463, 0.7504], [ 0.0016, -0.1774, 0.1715, 0.8459], [-0.0020, 0.0150, 0.1884, 0.6117], [-0.0017, 0.2071, 0.2006, 0.3838]]) >>> print("obs'", r["next", "observation"]) obs' tensor([[ 0.0299, -0.1544, 0.0181, 0.3280], [ 0.0276, -0.1550, 0.0255, 0.3414], [ 0.0253, -0.1558, 0.0334, 0.3596], [ 0.0230, -0.1569, 0.0422, 0.3828], [ 0.0206, -0.1582, 0.0519, 0.4117], [ 0.0181, -0.1598, 0.0629, 0.4469], [ 0.0156, -0.1617, 0.0753, 0.4891], [ 0.0130, -0.1639, 0.0895, 0.5394], [ 0.0104, -0.1665, 0.1058, 0.5987], [ 0.0076, -0.1696, 0.1246, 0.6685], [ 0.0047, -0.1732, 0.1463, 0.7504], [ 0.0016, -0.1774, 0.1715, 0.8459], [-0.0020, 0.0150, 0.1884, 0.6117], [-0.0017, 0.2071, 0.2006, 0.3838], [ 0.0105, 0.2015, 0.2115, 0.5110]]) >>> print("total step count", r["step_count_total"].squeeze()) total step count tensor([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27]) >>> print("total step with main policy", r["step_count_main"].squeeze()) total step with main policy tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
- forward(tensordict: TensorDictBase) Any[原始碼]¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.