快捷方式

TorchRL 配置系統

TorchRL 提供了一個強大的配置系統,該系統構建在 Hydra 之上,使您可以輕鬆配置和執行強化學習實驗。該系統使用基於資料類的結構化配置,這些配置可以進行組合、覆蓋和擴充套件。

使用配置系統的優點包括: - 快速輕鬆上手:提供您的任務,讓系統處理其餘部分 - 一次性概覽可用選項及其預設值:python sota-implementations/ppo_trainer/train.py --help 將顯示所有可用選項及其預設值 - 易於覆蓋和擴充套件:您可以覆蓋配置檔案中的任何選項,也可以使用自己的自定義配置擴充套件配置檔案 - 易於共享和復現:您可以與他人共享配置檔案,他們只需執行相同的命令即可復現您的結果。 - 易於版本控制:您可以輕鬆地對配置檔案進行版本控制。

快速入門示例

讓我們從一個建立 Gym 環境的簡單示例開始。這是一個最小的配置檔案

# config.yaml
defaults:
  - env@training_env: gym

training_env:
  env_name: CartPole-v1

此配置有兩個主要部分

1. defaults **部分**

defaults 部分告訴 Hydra 要包含哪些配置組。在這種情況下

  • env@training_env: gym 意味著“使用 'env' 組中的 'gym' 配置來作為 'training_env' 目標”

這等同於包含一個預定義的 Gym 環境配置,該配置設定了正確的 target 類和預設引數。

2. 配置覆蓋

training_env 部分允許您覆蓋或指定所選配置的引數

  • env_name: CartPole-v1 設定了特定的環境名稱

配置類別和組

TorchRL 使用 @ 語法將配置組織成多個類別,以實現目標化配置

  • env@<target>:環境配置(Gym、DMControl、Brax 等)以及批處理環境

  • transform@<target>:轉換配置(觀察/獎勵處理)

  • model@<target>:模型配置(策略和價值網路)

  • network@<target>:神經網路配置(MLP、ConvNet)

  • collector@<target>:資料收集配置

  • replay_buffer@<target>:回放緩衝區配置

  • storage@<target>:儲存後端配置

  • sampler@<target>:取樣策略配置

  • writer@<target>:寫入器策略配置

  • trainer@<target>:訓練迴圈配置

  • optimizer@<target>:最佳化器配置

  • loss@<target>:損失函式配置

  • logger@<target>:日誌記錄配置

@<target> 語法允許您將配置分配到配置結構中的特定位置。

更復雜的示例:帶轉換的並行環境

這是一個更復雜的示例,它建立了一個並行環境,併為每個工作程序應用多個轉換

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: transformed_env
  - env@training_env.create_env_fn.base_env: gym
  - transform@training_env.create_env_fn.transform: compose
  - transform@transform0: noop_reset
  - transform@transform1: step_counter

# Transform configurations
transform0:
  noops: 30
  random: true

transform1:
  max_steps: 200
  step_count_key: "step_count"

# Environment configuration
training_env:
  num_workers: 4
  create_env_fn:
    base_env:
      env_name: Pendulum-v1
    transform:
      transforms:
        - ${transform0}
        - ${transform1}
    _partial_: true

此配置建立的內容

此配置構建了一個**具有 4 個工作程序的並行環境**,其中每個工作程序執行一個**應用了兩個轉換的 Pendulum-v1 環境**

  1. 並行環境結構: - batched_env 建立一個執行多個環境例項的並行環境 - num_workers: 4 表示 4 個並行環境程序

  2. 單個環境構建(為 4 個工作程序中的每個程序重複): - **基本環境**:gym 配合 env_name: Pendulum-v1 建立一個 Pendulum 環境 - **轉換層 1**:noop_reset 在每集開始時執行 30 次隨機 no-op 動作 - **轉換層 2**:step_counter 將每集限制為 200 步並跟蹤步數 - **轉換組合**:compose 將兩個轉換組合成一個單一的轉換

  3. 最終結果:4 個並行的 Pendulum 環境,每個環境具有: - 隨機 no-op 重置(開始時 0-30 次動作) - 最大每集 200 步 - 步數計數功能

關鍵配置概念

  1. 巢狀目標env@training_env.create_env_fn.base_env: gym 將 gym 配置深度放置在結構中

  2. 函式工廠_partial_: true 建立一個可以呼叫多次(每個工作程序一次)的函式

  3. 轉換組合:多個轉換被組合並應用於每個環境例項

  4. 變數插值${transform0}${transform1} 引用單獨定義的轉換配置

獲取可用選項

要探索所有可用的配置及其引數,可以使用 --help 標誌與任何 TorchRL 指令碼結合使用

python sota-implementations/ppo_trainer/train.py --help

這將顯示所有配置組及其選項,方便您發現可用的內容。它應該會打印出類似如下的內容


完整的訓練示例

這是一個用於 PPO 訓練的完整配置

defaults:
  - env@training_env: batched_env
  - env@training_env.create_env_fn: gym
  - model@models.policy_model: tanh_normal
  - model@models.value_model: value
  - network@networks.policy_network: mlp
  - network@networks.value_network: mlp
  - collector: sync
  - replay_buffer: base
  - storage: tensor
  - sampler: without_replacement
  - writer: round_robin
  - trainer: ppo
  - optimizer: adam
  - loss: ppo
  - logger: wandb

# Network configurations
networks:
  policy_network:
    out_features: 2
    in_features: 4
    num_cells: [128, 128]

  value_network:
    out_features: 1
    in_features: 4
    num_calls: [128, 128]

# Model configurations
models:
  policy_model:
    network: ${networks.policy_network}
    in_keys: ["observation"]
    out_keys: ["action"]

  value_model:
    network: ${networks.value_network}
    in_keys: ["observation"]
    out_keys: ["state_value"]

# Environment
training_env:
  num_workers: 2
  create_env_fn:
    env_name: CartPole-v1
    _partial_: true

# Training components
trainer:
  collector: ${collector}
  optimizer: ${optimizer}
  loss_module: ${loss}
  logger: ${logger}
  total_frames: 100000

collector:
  create_env_fn: ${training_env}
  policy: ${models.policy_model}
  frames_per_batch: 1024

optimizer:
  lr: 0.001

loss:
  actor_network: ${models.policy_model}
  critic_network: ${models.value_model}

logger:
  exp_name: my_experiment

執行實驗

基本用法

# Use default configuration
python sota-implementations/ppo_trainer/train.py

# Override specific parameters
python sota-implementations/ppo_trainer/train.py optimizer.lr=0.0001

# Change environment
python sota-implementations/ppo_trainer/train.py training_env.create_env_fn.env_name=Pendulum-v1

# Use different collector
python sota-implementations/ppo_trainer/train.py collector=async

超引數搜尋

# Sweep over learning rates
python sota-implementations/ppo_trainer/train.py --multirun optimizer.lr=0.0001,0.001,0.01

# Multiple parameter sweep
python sota-implementations/ppo_trainer/train.py --multirun \
  optimizer.lr=0.0001,0.001 \
  training_env.num_workers=2,4,8

自定義配置檔案

# Use custom config file
python sota-implementations/ppo_trainer/train.py --config-name my_custom_config

配置儲存實現細節

在底層,TorchRL 使用 Hydra 的 ConfigStore 來註冊所有配置類。這提供了型別安全、驗證和 IDE 支援。當您匯入 configs 模組時,註冊會自動發生。

from hydra.core.config_store import ConfigStore
from torchrl.trainers.algorithms.configs import *

cs = ConfigStore.instance()

# Environments
cs.store(group="env", name="gym", node=GymEnvConfig)
cs.store(group="env", name="batched_env", node=BatchedEnvConfig)

# Models
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
# ... and many more

可用的配置類

基類

ConfigBase()

所有配置類的抽象基類。

環境配置

EnvConfig([_partial_])

環境的基類配置。

BatchedEnvConfig(_partial_, create_env_fn, ...)

批處理環境的配置。

TransformedEnvConfig([_partial_, base_env, ...])

轉換環境的配置。

環境庫配置

EnvLibsConfig([_partial_])

環境庫的基類配置。

GymEnvConfig([_partial_, env_name, ...])

GymEnv 環境的配置。

DMControlEnvConfig([_partial_, env_name, ...])

DMControlEnv 環境的配置。

BraxEnvConfig([_partial_, env_name, ...])

BraxEnv 環境的配置。

HabitatEnvConfig([_partial_, env_name, ...])

HabitatEnv 環境的配置。

IsaacGymEnvConfig([_partial_, env_name, ...])

IsaacGymEnv 環境的配置。

JumanjiEnvConfig([_partial_, env_name, ...])

JumanjiEnv 環境的配置。

MeltingpotEnvConfig([_partial_, env_name, ...])

MeltingpotEnv 環境的配置。

MOGymEnvConfig([_partial_, env_name, ...])

MOGymEnv 環境的配置。

MultiThreadedEnvConfig([_partial_, ...])

MultiThreadedEnv 環境的配置。

OpenMLEnvConfig([_partial_, env_name, ...])

OpenMLEnv 環境的配置。

OpenSpielEnvConfig([_partial_, env_name, ...])

OpenSpielEnv 環境的配置。

PettingZooEnvConfig([_partial_, env_name, ...])

PettingZooEnv 環境的配置。

RoboHiveEnvConfig([_partial_, env_name, ...])

RoboHiveEnv 環境的配置。

SMACv2EnvConfig([_partial_, env_name, ...])

SMACv2Env 環境的配置。

UnityMLAgentsEnvConfig([_partial_, ...])

UnityMLAgentsEnv 環境的配置。

VmasEnvConfig([_partial_, env_name, ...])

VmasEnv 環境的配置。

模型和網路配置

ModelConfig([_partial_, in_keys, out_keys])

配置模型的父類。

NetworkConfig([_partial_])

配置網路的父類。

MLPConfig(_partial_, in_features, ...)

配置多層感知機的類。

ConvNetConfig(_partial_, in_features, depth, ...)

配置卷積網路的類。

TensorDictModuleConfig([_partial_, in_keys, ...])

配置 TensorDictModule 的類。

TanhNormalModelConfig([_partial_, in_keys, ...])

配置 TanhNormal 模型的類。

ValueModelConfig([_partial_, in_keys, ...])

配置價值模型的類。

轉換配置

TransformConfig()

轉換的基類配置。

ComposeConfig([transforms, _target_])

Compose 轉換的配置。

NoopResetEnvConfig([noops, random, _target_])

NoopResetEnv 轉換的配置。

StepCounterConfig([max_steps, ...])

StepCounter 轉換的配置。

DoubleToFloatConfig([in_keys, out_keys, ...])

DoubleToFloat 轉換的配置。

ToTensorImageConfig([from_int, unsqueeze, ...])

ToTensorImage 轉換的配置。

ClipTransformConfig([in_keys, out_keys, ...])

ClipTransform 的配置。

ResizeConfig([w, h, interpolation, in_keys, ...])

Resize 轉換的配置。

CenterCropConfig([height, width, in_keys, ...])

CenterCrop 轉換的配置。

CropConfig([top, left, height, width, ...])

Crop 轉換的配置。

FlattenObservationConfig([in_keys, ...])

FlattenObservation 轉換的配置。

GrayScaleConfig([in_keys, out_keys, _target_])

GrayScale 轉換的配置。

ObservationNormConfig([loc, scale, in_keys, ...])

ObservationNorm 轉換的配置。

CatFramesConfig([N, dim, in_keys, out_keys, ...])

CatFrames 轉換的配置。

RewardClippingConfig([clamp_min, clamp_max, ...])

RewardClipping 轉換的配置。

RewardScalingConfig([loc, scale, in_keys, ...])

RewardScaling 轉換的配置。

BinarizeRewardConfig([in_keys, out_keys, ...])

BinarizeReward 轉換的配置。

TargetReturnConfig([target_return, mode, ...])

TargetReturn 轉換的配置。

VecNormConfig([in_keys, out_keys, decay, ...])

VecNorm 轉換的配置。

FrameSkipTransformConfig([frame_skip, ...])

FrameSkipTransform 的配置。

DeviceCastTransformConfig([device, in_keys, ...])

DeviceCastTransform 的配置。

DTypeCastTransformConfig([dtype, in_keys, ...])

DTypeCastTransform 的配置。

UnsqueezeTransformConfig([dim, in_keys, ...])

UnsqueezeTransform 的配置。

SqueezeTransformConfig([dim, in_keys, ...])

SqueezeTransform 的配置。

PermuteTransformConfig([dims, in_keys, ...])

PermuteTransform 的配置。

CatTensorsConfig([dim, in_keys, out_keys, ...])

CatTensors 轉換的配置。

StackConfig([dim, in_keys, out_keys, _target_])

Stack 轉換的配置。

DiscreteActionProjectionConfig([...])

DiscreteActionProjection 轉換的配置。

TensorDictPrimerConfig([primer_spec, ...])

TensorDictPrimer 轉換的配置。

PinMemoryTransformConfig([in_keys, ...])

PinMemoryTransform 的配置。

RewardSumConfig([in_keys, out_keys, _target_])

RewardSum 轉換的配置。

ExcludeTransformConfig([exclude_keys, _target_])

ExcludeTransform 的配置。

SelectTransformConfig([include_keys, _target_])

SelectTransform 的配置。

TimeMaxPoolConfig([dim, in_keys, out_keys, ...])

TimeMaxPool 轉換的配置。

RandomCropTensorDictConfig([crop_size, ...])

RandomCropTensorDict 轉換的配置。

InitTrackerConfig([in_keys, out_keys, _target_])

InitTracker 轉換的配置。

RenameTransformConfig([key_mapping, _target_])

RenameTransform 的配置。

Reward2GoTransformConfig([gamma, in_keys, ...])

Reward2GoTransform 的配置。

ActionMaskConfig([mask_key, in_keys, ...])

ActionMask 轉換的配置。

VecGymEnvTransformConfig([in_keys, ...])

VecGymEnvTransform 的配置。

BurnInTransformConfig([burn_in, in_keys, ...])

BurnInTransform 的配置。

SignTransformConfig([in_keys, out_keys, ...])

SignTransform 的配置。

RemoveEmptySpecsConfig([_target_])

RemoveEmptySpecs 轉換的配置。

BatchSizeTransformConfig([batch_size, ...])

BatchSizeTransform 的配置。

AutoResetTransformConfig([replace, ...])

AutoResetTransform 的配置。

ActionDiscretizerConfig([num_intervals, ...])

ActionDiscretizer 轉換的配置。

TrajCounterConfig([out_key, repeats, _target_])

TrajCounter 轉換的配置。

LineariseRewardsConfig([in_keys, out_keys, ...])

LineariseRewards 轉換的配置。

ConditionalSkipConfig([cond, _target_])

ConditionalSkip 轉換的配置。

MultiActionConfig([dim, stack_rewards, ...])

MultiAction 轉換的配置。

TimerConfig([out_keys, time_key, _target_])

Timer 轉換的配置。

ConditionalPolicySwitchConfig([policy, ...])

ConditionalPolicySwitch 轉換的配置。

FiniteTensorDictCheckConfig([in_keys, ...])

FiniteTensorDictCheck 轉換的配置。

UnaryTransformConfig([fn, in_keys, ...])

UnaryTransform 的配置。

HashConfig([in_keys, out_keys, _target_])

Hash 轉換的配置。

TokenizerConfig([vocab_size, in_keys, ...])

Tokenizer 轉換的配置。

EndOfLifeTransformConfig([eol_key, ...])

EndOfLifeTransform 的配置。

MultiStepTransformConfig([n_steps, gamma, ...])

MultiStepTransform 的配置。

KLRewardTransformConfig([in_keys, out_keys, ...])

KLRewardTransform 的配置。

R3MTransformConfig([in_keys, out_keys, ...])

R3MTransform 的配置。

VC1TransformConfig([in_keys, out_keys, ...])

VC1Transform 的配置。

VIPTransformConfig([in_keys, out_keys, ...])

VIPTransform 的配置。

VIPRewardTransformConfig([in_keys, ...])

VIPRewardTransform 的配置。

VecNormV2Config([in_keys, out_keys, decay, ...])

VecNormV2 轉換的配置。

資料收集配置

DataCollectorConfig()

配置資料收集器的父類。

SyncDataCollectorConfig([create_env_fn, ...])

配置同步資料收集器的類。

AsyncDataCollectorConfig(create_env_fn, ...)

非同步資料收集器的配置。

MultiSyncDataCollectorConfig([...])

多同步資料收集器的配置。

MultiaSyncDataCollectorConfig([...])

多非同步資料收集器的配置。

回放緩衝區和儲存配置

ReplayBufferConfig([_partial_, _target_, ...])

通用回放緩衝區的配置。

TensorDictReplayBufferConfig([_partial_, ...])

基於 TensorDict 的回放緩衝區的配置。

RandomSamplerConfig([_target_])

從回放緩衝區進行隨機取樣的配置。

SamplerWithoutReplacementConfig([_target_, ...])

無替換取樣配置。

PrioritizedSamplerConfig([_target_, ...])

從回放緩衝區進行優先採樣的配置。

SliceSamplerConfig([_target_, num_slices, ...])

從回放緩衝區進行切片取樣的配置。

SliceSamplerWithoutReplacementConfig([...])

無替換切片取樣的配置。

ListStorageConfig([_partial_, _target_, ...])

回放緩衝區中基於列表的儲存配置。

TensorStorageConfig([_partial_, _target_, ...])

回放緩衝區中基於張量的儲存配置。

LazyTensorStorageConfig([_partial_, ...])

延遲張量儲存配置。

LazyMemmapStorageConfig([_partial_, ...])

延遲記憶體對映儲存配置。

LazyStackStorageConfig([_partial_, ...])

延遲堆疊儲存配置。

StorageEnsembleConfig([_partial_, _target_, ...])

儲存集合的配置。

RoundRobinWriterConfig([_target_, compilable])

迴圈寫入器的配置,它將資料分發到多個儲存中。

StorageEnsembleWriterConfig([_partial_, ...])

儲存集合寫入器的配置。

訓練和最佳化配置

TrainerConfig()

訓練器的基類配置。

PPOTrainerConfig(collector, total_frames, ...)

PPO(近端策略最佳化)訓練器的配置類。

LossConfig([_partial_])

配置損失的類。

PPOLossConfig([_partial_, actor_network, ...])

配置 PPO 損失的類。

AdamConfig([lr, betas, eps, weight_decay, ...])

Adam 最佳化器的配置。

AdamWConfig([lr, betas, eps, weight_decay, ...])

AdamW 最佳化器的配置。

AdamaxConfig([lr, betas, eps, weight_decay, ...])

Adamax 最佳化器的配置。

AdadeltaConfig([lr, rho, eps, weight_decay, ...])

Adadelta 最佳化器的配置。

AdagradConfig([lr, lr_decay, weight_decay, ...])

Adagrad 最佳化器的配置。

ASGDConfig([lr, lambd, alpha, t0, ...])

ASGD 最佳化器的配置。

LBFGSConfig([lr, max_iter, max_eval, ...])

LBFGS 最佳化器的配置。

LionConfig([lr, betas, weight_decay, ...])

Lion 最佳化器的配置。

NAdamConfig([lr, betas, eps, weight_decay, ...])

NAdam 最佳化器的配置。

RAdamConfig([lr, betas, eps, weight_decay, ...])

RAdam 最佳化器的配置。

RMSpropConfig([lr, alpha, eps, ...])

RMSprop 最佳化器的配置。

RpropConfig([lr, etas, step_sizes, foreach, ...])

Rprop 最佳化器的配置。

SGDConfig([lr, momentum, dampening, ...])

SGD 最佳化器的配置。

SparseAdamConfig([lr, betas, eps, _target_, ...])

SparseAdam 最佳化器的配置。

日誌記錄配置

LoggerConfig()

配置日誌記錄器的類。

WandbLoggerConfig(exp_name[, offline, ...])

配置 Wandb 日誌記錄器的類。

TensorboardLoggerConfig(exp_name[, log_dir, ...])

配置 Tensorboard 日誌記錄器的類。

CSVLoggerConfig(exp_name[, log_dir, ...])

配置 CSV 日誌記錄器的類。

建立自定義配置

您可以透過繼承相應的基類來建立自定義配置類

from dataclasses import dataclass
from torchrl.trainers.algorithms.configs.envs_libs import EnvLibsConfig

@dataclass
class MyCustomEnvConfig(EnvLibsConfig):
    _target_: str = "my_module.MyCustomEnv"
    env_name: str = "MyEnv-v1"
    custom_param: float = 1.0

    def __post_init__(self):
        super().__post_init__()

# Register with ConfigStore
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance()
cs.store(group="env", name="my_custom", node=MyCustomEnvConfig)

最佳實踐

  1. 從小處著手:從基本配置開始,然後逐漸增加複雜性

  2. 使用預設值:利用 defaults 部分來組合配置

  3. 謹慎覆蓋:只覆蓋您需要更改的部分

  4. 驗證配置:測試您的配置是否能正確例項化

  5. 版本控制:將您的配置檔案保留在版本控制下

  6. 使用變數插值:使用 ${variable} 語法來避免重複

未來擴充套件

隨著 TorchRL 新增更多演算法(如 SAC、TD3、DQN),配置系統將擴充套件,包含

  • 新的訓練器配置(例如,SACTrainerConfigTD3TrainerConfig

  • 特定於演算法的損失配置

  • 針對不同演算法的專用收集器配置

  • 附加的環境和模型配置

模組化設計可確保輕鬆整合,同時保持向後相容性。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源