快捷方式

torchrl.trainers 包

trainer 包提供了用於編寫可重用訓練指令碼的實用程式。核心思想是使用一個實現巢狀迴圈的 trainer,其中外層迴圈執行資料收集步驟,內層迴圈執行最佳化步驟。我們認為這適合多種 RL 訓練方案,例如線上策略、離線策略、基於模型和無模型解決方案、離線 RL 等。更具體的情況,例如元 RL 演算法可能具有截然不同的訓練方案。

trainer.train() 方法可以概括如下:

Trainer 迴圈
        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

        >>> for batch in collector:
        ...     batch = self._process_batch_hook(batch)  # "batch_process"
        ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
        ...     self._pre_optim_hook()  # "pre_optim_steps"
        ...     for j in range(self.optim_steps_per_batch):
        ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
        ...         losses = self.loss_module(sub_batch)
        ...         self._post_loss_hook(sub_batch)  # "post_loss"
        ...         self.optimizer.step()
        ...         self.optimizer.zero_grad()
        ...         self._post_optim_hook()  # "post_optim"
        ...         self._post_optim_log(sub_batch)  # "post_optim_log"
        ...     self._post_steps_hook()  # "post_steps"
        ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

There are 10 hooks that can be used in a trainer loop:

     >>> for batch in collector:
     ...     batch = self._process_batch_hook(batch)  # "batch_process"
     ...     self._pre_steps_log_hook(batch)  # "pre_steps_log"
     ...     self._pre_optim_hook()  # "pre_optim_steps"
     ...     for j in range(self.optim_steps_per_batch):
     ...         sub_batch = self._process_optim_batch_hook(batch)  # "process_optim_batch"
     ...         losses = self.loss_module(sub_batch)
     ...         self._post_loss_hook(sub_batch)  # "post_loss"
     ...         self.optimizer.step()
     ...         self.optimizer.zero_grad()
     ...         self._post_optim_hook()  # "post_optim"
     ...         self._post_optim_log(sub_batch)  # "post_optim_log"
     ...     self._post_steps_hook()  # "post_steps"
     ...     self._post_steps_log_hook(batch)  #  "post_steps_log"

trainer 迴圈中有 10 個鉤子:"batch_process""pre_optim_steps""process_optim_batch""post_loss""post_steps""post_optim""pre_steps_log""post_steps_log""post_optim_log""optimizer"。它們在註釋中指示了它們的應用位置。鉤子可分為三類:**資料處理**("batch_process""process_optim_batch")、**日誌記錄**("pre_steps_log""post_optim_log""post_steps_log")以及**操作**鉤子("pre_optim_steps""post_loss""post_optim""post_steps")。

  • **資料處理**鉤子會更新一個數據的 tensordict。鉤子的 __call__ 方法應該接受一個 TensorDict 物件作為輸入,並根據某些策略對其進行更新。這類鉤子的示例如回放緩衝區擴充套件(ReplayBufferTrainer.extend)、資料歸一化(包括歸一化常數更新)、資料子取樣(:class:~torchrl.trainers.BatchSubSampler)等。

  • **日誌記錄**鉤子接受一個以 TensorDict 形式呈現的資料批次,並將從中檢索到的資訊寫入日誌記錄器。示例包括 LogValidationReward 鉤子、獎勵日誌記錄器(LogScalar)等。鉤子應返回一個字典(或 None 值),其中包含要記錄的資料。鍵 "log_pbar" 保留給布林值,指示記錄的值是否應顯示在訓練日誌上列印的進度條上。

  • **操作**鉤子是執行模型、資料收集器、目標網路更新等特定操作的鉤子。例如,使用 UpdateWeights 同步收集器的權重或使用 ReplayBufferTrainer.update_priority 更新回放緩衝區的優先順序是操作鉤子的示例。它們與資料無關(不需要 TensorDict 輸入),只需要在每次迭代(或每 N 次迭代)時執行一次。

TorchRL 提供的鉤子通常繼承自一個共同的抽象類 TrainerHookBase,並都實現了三個基本方法:用於檢查點的 state_dictload_state_dict 方法,以及一個用於將鉤子註冊到 trainer 中預設值的 register 方法。此方法接受 trainer 和模組名稱作為輸入。例如,以下日誌鉤子每呼叫 10 次 "post_optim_log" 就會執行一次:

>>> class LoggingHook(TrainerHookBase):
...     def __init__(self):
...         self.counter = 0
...
...     def register(self, trainer, name):
...         trainer.register_module(self, "logging_hook")
...         trainer.register_op("post_optim_log", self)
...
...     def save_dict(self):
...         return {"counter": self.counter}
...
...     def load_state_dict(self, state_dict):
...         self.counter = state_dict["counter"]
...
...     def __call__(self, batch):
...         if self.counter % 10 == 0:
...             self.counter += 1
...             out = {"some_value": batch["some_value"].item(), "log_pbar": False}
...         else:
...             out = None
...         self.counter += 1
...         return out

檢查點

trainer 類和鉤子支援檢查點,可以透過 torchsnapshot 後端或常規 torch 後端實現。這可以透過全域性變數 CKPT_BACKEND 來控制。

$ CKPT_BACKEND=torchsnapshot python script.py

CKPT_BACKEND 預設為 torch。torchsnapshot 相對於 pytorch 的優勢在於它是一個更靈活的 API,支援分散式檢查點,並且允許使用者將儲存在磁碟上的檔案中的張量載入到具有物理儲存的張量中(這是 pytorch 目前不支援的)。例如,這允許將張量載入到否則無法放入記憶體的回放緩衝區中,或者從回放緩衝區載入。

在構建 trainer 時,可以提供檢查點要寫入的路徑。對於 torchsnapshot 後端,需要一個目錄路徑,而 torch 後端需要一個檔案路徑(通常是 .pt 檔案)。

>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
...     collector=collector,
...     total_frames=total_frames,
...     frame_skip=frame_skip,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)

Trainer.train() 方法可用於執行上述迴圈及其所有鉤子,儘管僅使用 Trainer 類來實現其檢查點功能也是完全有效的用法。

Trainer 和鉤子

BatchSubSampler(batch_size[, sub_traj_len, ...])

線上 RL SOTA 實現的資料子取樣器。

ClearCudaCache(interval)

按給定間隔清除 CUDA 快取。

CountFramesLog(*args, **kwargs)

幀計數器鉤子。

LogScalar([key, logname, log_pbar, ...])

用於批次中任何張量值的通用標量日誌記錄器鉤子。

OptimizerHook(optimizer[, loss_components])

為一或多個損失元件新增最佳化器。

LogValidationReward(*, record_interval, ...)

用於 Trainer 的記錄器鉤子。

ReplayBufferTrainer(replay_buffer[, ...])

回放緩衝區鉤子提供程式。

RewardNormalizer([decay, scale, eps, ...])

獎勵歸一化器鉤子。

SelectKeys(keys)

選擇 TensorDict 批次中的鍵。

Trainer(*args, **kwargs)

通用 Trainer 類。

TrainerHookBase()

torchrl Trainer 類的抽象鉤子類。

UpdateWeights(collector, update_weights_interval)

收集器權重更新鉤子類。

特定於演算法的 Trainer(實驗性)

警告

以下 Trainer 是實驗性/原型功能。API 可能在未來版本中發生更改。請報告任何問題或反饋,以幫助改進這些實現!

TorchRL 提供高階、特定於演算法的 Trainer,它們將模組化元件組合成完整的訓練解決方案,具有合理的預設值和全面的配置選項。

PPOTrainer(*args, **kwargs)

PPO(Proximal Policy Optimization)Trainer 實現。

PPOTrainer

PPOTrainer 提供了一個完整的 PPO 訓練解決方案,具有可配置的預設值和基於 Hydra 的全面配置系統。

主要特性

  • 完整的訓練流程,包括環境設定、資料收集和最佳化

  • 使用資料類和 Hydra 的廣泛配置系統

  • 內建的獎勵、動作和訓練統計資料日誌記錄

  • 基於現有 TorchRL 元件的模組化設計

  • **最少程式碼**:僅用約 20 行程式碼即可完成 SOTA 實現!

警告

這是一項實驗性功能。API 可能在未來版本中發生更改。我們歡迎反饋和貢獻,以幫助改進此實現!

快速入門 - 命令列介面

# Basic usage - train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py

自定義配置

# Override specific parameters via command line
python sota-implementations/ppo_trainer/train.py \
    trainer.total_frames=2000000 \
    training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
    networks.policy_network.num_cells=[256,256] \
    optimizer.lr=0.0003

環境切換

# Switch to a different environment and logger
python sota-implementations/ppo_trainer/train.py \
    env=gym \
    training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
    logger=tensorboard

檢視所有選項

# View all available configuration options
python sota-implementations/ppo_trainer/train.py --help

配置組

PPOTrainer 的配置組織成邏輯組。

  • **環境**:env_cfg__env_nameenv_cfg__backendenv_cfg__device

  • **網路**:actor_network__network__num_cellscritic_network__module__num_cells

  • **訓練**:total_framesclip_normnum_epochsoptimizer_cfg__lr

  • **日誌記錄**:log_rewardslog_actionslog_observations

工作示例

sota-implementations/ppo_trainer/ 目錄包含一個完整的、可用的 PPO 實現,它演示了 trainer 系統的簡潔性和強大功能。

import hydra
from torchrl.trainers.algorithms.configs import *

@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
    trainer = hydra.utils.instantiate(cfg.trainer)
    trainer.train()

if __name__ == "__main__":
    main()

完整的 PPO 訓練,在約 20 行程式碼中實現完全可配置!

配置類

PPOTrainer 使用分層配置系統,包含以下主要配置類。

注意

由於使用了現代型別註解語法,該配置系統需要 Python 3.10+。

未來發展

這是計劃中的第一個特定於演算法的 Trainer。未來的版本將包含:

  • 其他演算法:SAC、TD3、DQN、A2C 等

  • 將所有 TorchRL 元件完全整合到配置系統中

  • 增強的配置驗證和錯誤報告

  • 高階 Trainer 的分散式訓練支援

請參閱完整的配置系統文件,瞭解所有可用選項。

Builders

make_collector_offpolicy(make_env, ...[, ...])

返回用於離線策略 SOTA 實現的資料收集器。

make_collector_onpolicy(make_env, ...[, ...])

在線上策略設定中建立收集器。

make_dqn_loss(model, cfg)

構建 DQN 損失模組。

make_replay_buffer(device, cfg)

使用從 ReplayArgsConfig 構建的配置來構建回放緩衝區。

make_target_updater(cfg, loss_module)

構建目標網路權重更新物件。

make_trainer(collector, loss_module[, ...])

給定其組成部分,建立 Trainer 例項。

parallel_env_constructor(cfg, **kwargs)

使用適當的解析器建構函式構建的 argparse.Namespace 返回一個並行環境。

sync_async_collector(env_fns, env_kwargs[, ...])

執行非同步收集器,每個收集器運行同步環境。

sync_sync_collector(env_fns, env_kwargs[, ...])

運行同步收集器,每個收集器運行同步環境。

transformed_env_constructor(cfg[, ...])

使用適當的解析器建構函式構建的 argparse.Namespace 返回一個環境建立器。

Utils

correct_for_frame_skip(cfg)

透過將所有反映幀數的引數除以 frame_skip 來更正輸入 frame_skip 的引數。

get_stats_random_rollout(cfg[, ...])

使用隨機 rollouts 從環境中收集統計資料(loc 和 scale)。

Loggers

Logger(exp_name, log_dir)

日誌記錄器的模板。

csv.CSVLogger(exp_name[, log_dir, ...])

極簡依賴的 CSV 日誌記錄器。

mlflow.MLFlowLogger(exp_name, tracking_uri)

mlflow 日誌記錄器的包裝器。

tensorboard.TensorboardLogger(exp_name[, ...])

Tensorboard 日誌記錄器的包裝器。

wandb.WandbLogger(*args, **kwargs)

wandb 日誌記錄器的包裝器。

get_logger(logger_type, logger_name, ...)

獲取指定 logger_type 的日誌記錄器例項。

generate_exp_name(model_name, experiment_name)

使用 UUID 和當前日期生成指定實驗的 ID(字串)。

Recording utils

Recording utils 詳細介紹請參閱此處

VideoRecorder(logger, tag[, in_keys, skip, ...])

影片錄製器轉換。

TensorDictRecorder(out_file_base[, ...])

TensorDict 錄製器。

PixelRenderTransform([out_keys, preproc, ...])

一個呼叫父環境的 render 方法並將畫素觀察註冊到 tensordict 中的轉換。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源