OnlineDTActor¶
- class torchrl.modules.OnlineDTActor(state_dim: int, action_dim: int, transformer_config: dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None)[來源]¶
Online Decision Transformer Actor 類。
用於 Online Decision Transformer 的 Actor 類,用於從高斯分佈中取樣動作,如 “Online Decision Transformer” 中所述。
返回用於從高斯分佈中取樣動作的均值和標準差。
- 引數:
state_dim (int) – 狀態維度。
action_dim (int) – 動作維度。
transformer_config (Dict 或
DecisionTransformer.DTConfig) – GPT2 transformer 的配置。預設為default_config()。device (torch.device, 可選) – 要使用的裝置。預設為 None。
示例
>>> model = OnlineDTActor(state_dim=4, action_dim=2, ... transformer_config=OnlineDTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> mu, std = model(observation, action, return_to_go) >>> mu.shape torch.Size([32, 10, 2]) >>> std.shape torch.Size([32, 10, 2])
- classmethod default_config()[來源]¶
OnlineDTActor的預設配置。
- forward(observation: Tensor, action: Tensor, return_to_go: Tensor) tuple[torch.Tensor, torch.Tensor][來源]¶
定義每次呼叫時執行的計算。
所有子類都應重寫此方法。
注意
儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫
Module例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。