快捷方式

ActionMask

class torchrl.envs.transforms.ActionMask(action_key: NestedKey = 'action', mask_key: NestedKey = 'action_mask')[source]

一個自適應動作掩碼器。

此轉換器可用於透過掩碼動作規範來確保隨機生成的動作遵守合法動作。它在執行步驟後從 input tensordict 中讀取掩碼,並調整有限動作規範的掩碼。

注意

此轉換器在未配備環境時使用會失敗。

引數:
  • action_key (NestedKey, optional) – 可找到動作張量的鍵。預設為 "action"

  • mask_key (NestedKey, optional) – 可找到動作掩碼的鍵。預設為 "action_mask"

示例

>>> import torch
>>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite
>>> from torchrl.envs.transforms import ActionMask, TransformedEnv
>>> from torchrl.envs.common import EnvBase
>>> class MaskedEnv(EnvBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(*args, **kwargs)
...         self.action_spec = Categorical(4)
...         self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool))
...         self.observation_spec = Composite(obs=Unbounded(3))
...         self.reward_spec = Unbounded(1)
...
...     def _reset(self, tensordict=None):
...         td = self.observation_spec.rand()
...         td.update(torch.ones_like(self.state_spec.rand()))
...         return td
...
...     def _step(self, data):
...         td = self.observation_spec.rand()
...         mask = data.get("action_mask")
...         action = data.get("action")
...         mask = mask.scatter(-1, action.unsqueeze(-1), 0)
...
...         td.set("action_mask", mask)
...         td.set("reward", self.reward_spec.rand())
...         td.set("done", ~mask.any().view(1))
...         return td
...
...     def _set_seed(self, seed) -> None:
...         pass
...
>>> torch.manual_seed(0)
>>> base_env = MaskedEnv()
>>> env = TransformedEnv(base_env, ActionMask())
>>> r = env.rollout(10)
>>> r["action_mask"]
tensor([[ True,  True,  True,  True],
        [ True,  True, False,  True],
        [ True,  True, False, False],
        [ True, False, False, False]])
forward(tensordict: TensorDictBase) TensorDictBase[source]

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源