QValueActor¶
- class torchrl.modules.tensordict_module.QValueActor(*args, **kwargs)[原始碼]¶
一個 Q 值 Actor 類。
此類在輸入模組後附加一個
QValueModule,以便使用動作值來選擇動作。- 引數:
module (nn.Module) – 一個
torch.nn.Module,用於將輸入對映到輸出引數空間。如果提供的類與tensordict.nn.TensorDictModuleBase不相容,它將被包裝在一個tensordict.nn.TensorDictModule中,並使用以下關鍵字引數指定的in_keys。- 關鍵字引數:
in_keys (iterable of str, optional) – 如果提供的類與
tensordict.nn.TensorDictModuleBase不相容,此鍵列表指示需要傳遞給包裝模組的觀測值,以獲取動作值。預設為["observation"]。spec (TensorSpec, optional) – 僅關鍵字引數。輸出 tensor 的 spec。如果 module 輸出多個 tensor,spec 表徵第一個輸出 tensor 的空間。
safe (bool) – 僅關鍵字引數。如果為
True,則輸出值將針對輸入規範進行檢查。由於探索策略或數值溢位/下溢問題,可能會發生域外取樣。如果此值超出範圍,則使用TensorSpec.project方法將其投影回所需空間。預設為False。action_space (str, optional) – 動作空間。必須是
"one-hot"、"mult-one-hot"、"binary"或"categorical"之一。此引數與spec互斥,因為spec條件化了 action_space。action_value_key (str or tuple of str, optional) – 如果輸入模組是
tensordict.nn.TensorDictModuleBase例項,則它必須匹配其輸出鍵之一。否則,此字串表示輸出 tensordict 中動作值條目的名稱。action_mask_key (str or tuple of str, optional) – 表示動作掩碼的輸入鍵。預設為
"None"(相當於沒有掩碼)。
注意
out_keys不能傳遞。如果模組是tensordict.nn.TensorDictModule例項,則 out_keys 將相應更新。對於常規torch.nn.Module例項,將使用三元組["action", action_value_key, "chosen_action_value"]。示例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> # with a regular nn.Module >>> module = nn.Linear(4, 4) >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) >>> # with a TensorDictModule >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5]) >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"]) >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)