DdpgCnnActor¶
- class torchrl.modules.DdpgCnnActor(action_dim: int, conv_net_kwargs: dict | None = None, mlp_net_kwargs: dict | None = None, use_avg_pooling: bool = False, device: DEVICE_TYPING | None = None)[原始碼]¶
DDPG 卷積 Actor 類。
在“CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING”中提出,https://arxiv.org/pdf/1509.02971.pdf
DDPG 卷積 Actor 以觀察(對觀察到的畫素進行的一些簡單轉換)為輸入,並從中輸出一個動作向量,以及一個可用於值估計的觀察嵌入。它應該被訓練以最大化 DDPG Q 值網路返回的值。
- 引數:
action_dim (int) – 動作向量的長度。
conv_net_kwargs (dict 或 list of dicts, optional) –
ConvNet 的關鍵字引數。預設為
>>> { ... 'in_features': None, ... "num_cells": [32, 64, 64], ... "kernel_sizes": [8, 4, 3], ... "strides": [4, 2, 1], ... "paddings": [0, 0, 1], ... 'activation_class': torch.nn.ELU, ... 'norm_class': None, ... 'aggregator_class': SquashDims, ... 'aggregator_kwargs': {"ndims_in": 3}, ... 'squeeze_output': True, ... } #
mlp_net_kwargs –
MLP 的關鍵字引數。預設為
>>> { ... 'in_features': None, ... 'out_features': action_dim, ... 'depth': 2, ... 'num_cells': 200, ... 'activation_class': nn.ELU, ... 'bias_last_layer': True, ... }
use_avg_pooling (bool, optional) – 如果為
True,則使用AvgPooling層進行聚合。預設為False。device (torch.device, optional) – 建立模組的裝置。
示例
>>> import torch >>> from torchrl.modules import DdpgCnnActor >>> actor = DdpgCnnActor(action_dim=4) >>> print(actor) DdpgCnnActor( (convnet): ConvNet( (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) (1): ELU(alpha=1.0) (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) (3): ELU(alpha=1.0) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): ELU(alpha=1.0) (6): SquashDims() ) (mlp): MLP( (0): LazyLinear(in_features=0, out_features=200, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=200, out_features=200, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=200, out_features=4, bias=True) ) ) >>> obs = torch.randn(10, 3, 64, 64) >>> action, hidden = actor(obs) >>> print(action.shape) torch.Size([10, 4]) >>> print(hidden.shape) torch.Size([10, 2304])
- forward(observation: Tensor) tuple[torch.Tensor, torch.Tensor][原始碼]¶
定義每次呼叫時執行的計算。
所有子類都應重寫此方法。
注意
儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫
Module例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。