快捷方式

make_trainer

torchrl.trainers.helpers.make_trainer(collector: DataCollectorBase, loss_module: LossModule, recorder: EnvBase | None = None, target_net_updater: TargetNetUpdater | None = None, policy_exploration: None | TensorDictModuleWrapper | TensorDictModule = None, replay_buffer: ReplayBuffer | None = None, logger: Logger | None = None, cfg: DictConfig = None) Trainer[原始碼]

根據其組成部分建立 Trainer 例項。

引數:
  • collector (DataCollectorBase) – 要用於收集資料的資料收集器。

  • loss_module (LossModule) – 一個 TorchRL 損失模組

  • recorder (EnvBase, 可選) – 一個記錄器環境。如果為 None,則 trainer 將在不測試策略的情況下訓練策略。

  • target_net_updater (TargetNetUpdater, 可選) – 一個目標網路更新物件。

  • policy_exploration (TDModuleTensorDictModuleWrapper, 可選) – 用於記錄和探索更新的策略(應與學習到的策略同步)。

  • replay_buffer (ReplayBuffer, 可選) – 用於收集資料的一個經驗回放緩衝區。

  • logger (Logger, 可選) – 一個用於日誌記錄的 Logger。

  • cfg (DictConfig, 可選) – 一個包含指令碼引數的 DictConfig。如果為 None,則使用預設引數。

返回:

使用輸入物件構建的 trainer。最佳化器由這個輔助函式使用提供的 cfg 構建。

示例

>>> import torch
>>> import tempfile
>>> from torchrl.trainers.loggers import TensorboardLogger
>>> from torchrl.trainers import Trainer
>>> from torchrl.envs import EnvCreator
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.data import TensorDictReplayBuffer
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
>>> from torchrl.objectives.common import LossModule
>>> from torchrl.objectives.utils import TargetNetUpdater
>>> from torchrl.objectives import DDPGLoss
>>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
>>> env_proof = env_maker()
>>> obs_spec = env_proof.observation_spec
>>> action_spec = env_proof.action_spec
>>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
>>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1)  # for the purpose of testing
>>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
>>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
>>> collector = SyncDataCollector(env_maker, policy, total_frames=100)
>>> loss_module = DDPGLoss(policy, value, gamma=0.99)
>>> recorder = env_proof
>>> target_net_updater = None
>>> policy_exploration = EGreedyWrapper(policy)
>>> replay_buffer = TensorDictReplayBuffer()
>>> dir = tempfile.gettempdir()
>>> logger = TensorboardLogger(exp_name=dir)
>>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
...    replay_buffer, logger)
>>> print(trainer)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源