ActionMask¶
- class torchrl.envs.transforms.ActionMask(action_key: NestedKey = 'action', mask_key: NestedKey = 'action_mask')[source]¶
一個自適應動作掩碼器。
此轉換器可用於透過掩碼動作規範來確保隨機生成的動作遵守合法動作。它在執行步驟後從 input tensordict 中讀取掩碼,並調整有限動作規範的掩碼。
注意
此轉換器在未配備環境時使用會失敗。
- 引數:
action_key (NestedKey, optional) – 可找到動作張量的鍵。預設為
"action"。mask_key (NestedKey, optional) – 可找到動作掩碼的鍵。預設為
"action_mask"。
示例
>>> import torch >>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite >>> from torchrl.envs.transforms import ActionMask, TransformedEnv >>> from torchrl.envs.common import EnvBase >>> class MaskedEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.action_spec = Categorical(4) ... self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) ... self.observation_spec = Composite(obs=Unbounded(3)) ... self.reward_spec = Unbounded(1) ... ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() ... td.update(torch.ones_like(self.state_spec.rand())) ... return td ... ... def _step(self, data): ... td = self.observation_spec.rand() ... mask = data.get("action_mask") ... action = data.get("action") ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) ... ... td.set("action_mask", mask) ... td.set("reward", self.reward_spec.rand()) ... td.set("done", ~mask.any().view(1)) ... return td ... ... def _set_seed(self, seed) -> None: ... pass ... >>> torch.manual_seed(0) >>> base_env = MaskedEnv() >>> env = TransformedEnv(base_env, ActionMask()) >>> r = env.rollout(10) >>> r["action_mask"] tensor([[ True, True, True, True], [ True, True, False, True], [ True, True, False, False], [ True, False, False, False]])
- forward(tensordict: TensorDictBase) TensorDictBase[source]¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.