DecisionTransformerInferenceWrapper¶
- class torchrl.modules.tensordict_module.DecisionTransformerInferenceWrapper(*args, **kwargs)[原始碼]¶
Decision Transformer 的推理動作包裝器。
一個專門為 Decision Transformer 設計的包裝器,它將遮蔽輸入 tensordict 序列到推理上下文中。輸出將是一個 TensorDict,其鍵與輸入相同,但只包含預測動作序列的最後一個動作和最後一個剩餘獎勵。
此模組建立並返回 tensordict 的修改副本,即它 **不會** 就地修改 tensordict。
注意
如果動作、觀測或獎勵到目標鍵不是標準的,則應使用方法
set_tensor_keys(),例如:>>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz")
in_keys 是觀測、動作和剩餘獎勵的鍵。out_keys 匹配 in_keys,並新增策略中的任何其他 out_key(例如,分佈的引數或隱藏值)。
- 引數:
policy (TensorDictModule) – 接收觀測併產生動作值的策略模組
- 關鍵字引數:
inference_context (int) – 將在上下文中不被遮蔽的先前動作的數量。例如,對於形狀為 [batch_size, context, obs_dim] 且 context=20 和 inference_context=5 的觀測輸入,上下文的前 15 個條目將被遮蔽。預設為 5。
spec (Optional[TensorSpec]) – 輸入 TensorDict 的規格。如果為 None,將從策略模組推斷。
device (torch.device, optional) – 如果提供,則指定緩衝區/規格放置的裝置。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ( ... ProbabilisticActor, ... TanhDelta, ... DTActor, ... DecisionTransformerInferenceWrapper, ... ) >>> dtactor = DTActor(state_dim=4, action_dim=2, ... transformer_config=DTActor.default_config() ... ) >>> actor_module = TensorDictModule( ... dtactor, ... in_keys=["observation", "action", "return_to_go"], ... out_keys=["param"]) >>> dist_class = TanhDelta >>> dist_kwargs = { ... "low": -1.0, ... "high": 1.0, ... } >>> actor = ProbabilisticActor( ... in_keys=["param"], ... out_keys=["action"], ... module=actor_module, ... distribution_class=dist_class, ... distribution_kwargs=dist_kwargs) >>> inference_actor = DecisionTransformerInferenceWrapper(actor) >>> sequence_length = 20 >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4), ... "action": torch.randn(1, sequence_length, 2), ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,]) >>> result = inference_actor(td) >>> print(result) TensorDict( fields={ action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False), param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False), return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([1]), device=None, is_shared=False)