快捷方式

DdpgCnnActor

class torchrl.modules.DdpgCnnActor(action_dim: int, conv_net_kwargs: dict | None = None, mlp_net_kwargs: dict | None = None, use_avg_pooling: bool = False, device: DEVICE_TYPING | None = None)[原始碼]

DDPG 卷積 Actor 類。

在“CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING”中提出,https://arxiv.org/pdf/1509.02971.pdf

DDPG 卷積 Actor 以觀察(對觀察到的畫素進行的一些簡單轉換)為輸入,並從中輸出一個動作向量,以及一個可用於值估計的觀察嵌入。它應該被訓練以最大化 DDPG Q 值網路返回的值。

引數:
  • action_dim (int) – 動作向量的長度。

  • conv_net_kwargs (dictlist of dicts, optional) –

    ConvNet 的關鍵字引數。預設為

    >>> {
    ...     'in_features': None,
    ...     "num_cells": [32, 64, 64],
    ...     "kernel_sizes": [8, 4, 3],
    ...     "strides": [4, 2, 1],
    ...     "paddings": [0, 0, 1],
    ...     'activation_class': torch.nn.ELU,
    ...     'norm_class': None,
    ...     'aggregator_class': SquashDims,
    ...     'aggregator_kwargs': {"ndims_in": 3},
    ...     'squeeze_output': True,
    ... }  #
    

  • mlp_net_kwargs

    MLP 的關鍵字引數。預設為

    >>> {
    ...     'in_features': None,
    ...     'out_features': action_dim,
    ...     'depth': 2,
    ...     'num_cells': 200,
    ...     'activation_class': nn.ELU,
    ...     'bias_last_layer': True,
    ... }
    

  • use_avg_pooling (bool, optional) – 如果為 True,則使用 AvgPooling 層進行聚合。預設為 False

  • device (torch.device, optional) – 建立模組的裝置。

示例

>>> import torch
>>> from torchrl.modules import DdpgCnnActor
>>> actor = DdpgCnnActor(action_dim=4)
>>> print(actor)
DdpgCnnActor(
  (convnet): ConvNet(
    (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ELU(alpha=1.0)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ELU(alpha=1.0)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ELU(alpha=1.0)
    (6): SquashDims()
  )
  (mlp): MLP(
    (0): LazyLinear(in_features=0, out_features=200, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=200, out_features=200, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=200, out_features=4, bias=True)
  )
)
>>> obs = torch.randn(10, 3, 64, 64)
>>> action, hidden = actor(obs)
>>> print(action.shape)
torch.Size([10, 4])
>>> print(hidden.shape)
torch.Size([10, 2304])
forward(observation: Tensor) tuple[torch.Tensor, torch.Tensor][原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源