DistributionalQValueActor¶
- class torchrl.modules.tensordict_module.DistributionalQValueActor(*args, **kwargs)[原始碼]¶
一個分散式的 DQN Actor 類。
此類在輸入模組之後附加一個
QValueModule,以便動作值用於選擇一個動作。- 引數:
module (nn.Module) – 一個
torch.nn.Module,用於將輸入對映到輸出引數空間。如果模組不是torchrl.modules.DistributionalDQNnet型別,DistributionalQValueActor將確保沿維度-2對動作值張量應用 log-softmax 操作。這可以透過關閉make_log_softmax關鍵字引數來停用。- 關鍵字引數:
in_keys (iterable of str, optional) – 要從輸入 tensordict 中讀取並傳遞給 module 的鍵。如果包含多個元素,則值將按 in_keys 可迭代物件給出的順序傳遞。預設為
["observation"]。spec (TensorSpec, optional) – 僅關鍵字引數。輸出 tensor 的 spec。如果 module 輸出多個 tensor,spec 表徵第一個輸出 tensor 的空間。
safe (bool) – 僅關鍵字引數。如果為
True,則會針對輸入規範檢查輸出值。由於探索策略或數值下溢/溢位問題,可能會發生域外取樣。如果此值越界,則使用TensorSpec.project方法將其投影回所需空間。預設為False。var_nums (int, optional) – 如果
action_space = "mult-one-hot",此值表示每個動作分量的基數。support (torch.Tensor) – 動作值的支援集。
action_space (str, optional) – 動作空間。必須是
"one-hot"、"mult-one-hot"、"binary"或"categorical"之一。此引數與spec互斥,因為spec條件化了 action_space。make_log_softmax (bool, optional) – 如果為
True,並且模組不是torchrl.modules.DistributionalDQNnet型別,則沿動作值張量的維度 -2 應用 log-softmax 操作。action_value_key (str 或 tuple of str, optional) – 如果輸入模組是
tensordict.nn.TensorDictModuleBase例項,則它必須與其輸出鍵之一匹配。否則,此字串表示輸出 tensordict 中動作值條目的名稱。action_mask_key (str or tuple of str, optional) – 表示動作掩碼的輸入鍵。預設為
"None"(相當於沒有掩碼)。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> module = MLP(out_features=(nbins, 4), depth=2) >>> # let us make sure that the output is a log-softmax >>> module = TensorDictSequential( ... TensorDictModule(module, ["observation"], ["action_value"]), ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]), ... ) >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor( ... module=module, ... spec=action_spec, ... support=torch.arange(nbins)) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)