注意
轉到末尾 下載完整的示例程式碼。
強化學習(PPO)使用 TorchRL 教程#
建立日期:2023 年 3 月 15 日 | 最後更新:2025 年 9 月 17 日 | 最後驗證:2024 年 11 月 5 日
本教程演示瞭如何使用 PyTorch 和 torchrl 訓練引數化策略網路,以解決 OpenAI-Gym/Farama-Gymnasium 控制庫中的倒立擺任務。
倒立擺#
主要學習內容
如何在 TorchRL 中建立環境,轉換其輸出,並從該環境中收集資料;
如何使用
TensorDict讓您的類之間進行通訊;使用 TorchRL 構建訓練迴圈的基礎知識
如何為策略梯度方法計算優勢訊號;
如何使用機率神經網路建立隨機策略;
如何建立動態回放緩衝區並從中不重複地取樣。
我們將介紹 TorchRL 的六個關鍵元件
如果您在 Google Colab 中執行此程式碼,請確保安裝以下依賴項
!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm
近端策略最佳化(PPO)是一種策略梯度演算法,它收集並直接消耗一批資料,以在存在某些近端約束的情況下最大化預期回報來訓練策略。您可以將其視為 REINFORCE(基礎策略最佳化演算法)的複雜版本。有關更多資訊,請參閱 近端策略最佳化演算法論文。
PPO 通常被認為是一種快速有效的線上、on-policy 強化演算法。TorchRL 提供了一個為您完成所有工作的損失模組,這樣您就可以依賴此實現,專注於解決您的問題,而不是每次想訓練策略時都重新發明輪子。
為了完整起見,這裡簡要概述了損失的計算方法,儘管這由我們的 ClipPPOLoss 模組處理—演算法如下:1. 我們將透過在環境中執行策略一定步數來取樣一批資料。2. 然後,我們將使用裁剪版的 REINFORCE 損失,透過隨機子批次對該批次進行一定次數的最佳化。3. 裁剪將對我們的損失設定一個悲觀的界限:與更高的回報估計相比,更低的回報估計將受到青睞。損失的精確公式是
該損失中有兩個組成部分:在最小運算元第一部分,我們計算了 REINFORCE 損失的加權版本(例如,我們已根據當前策略配置滯後於用於資料收集的配置的事實進行了校正的 REINFORCE 損失)。該最小運算元的第二部分是一個類似的損失,我們在其中裁剪了超出或低於給定閾值對的比例。
此損失確保無論優勢是正數還是負數,都會抑制那些會導致與先前配置發生重大偏移的策略更新。
本教程結構如下
首先,我們將定義一組將在訓練中使用的超引數。
接下來,我們將專注於使用 TorchRL 的包裝器和變換來建立我們的環境或模擬器。
接下來,我們將設計策略網路和價值模型,這對於損失函式是必不可少的。這些模組將用於配置我們的損失模組。
接下來,我們將建立回放緩衝區和資料載入器。
最後,我們將執行訓練迴圈並分析結果。
在本教程中,我們將使用 tensordict 庫。TensorDict 是 TorchRL 的通用語言:它幫助我們抽象出模組讀取和寫入的內容,讓我們更少關注具體的資料描述,而更多關注演算法本身。
import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm
定義超引數#
我們為演算法設定了超引數。根據可用資源,可以選擇在 GPU 或其他裝置上執行策略。frame_skip 將控制單個動作執行多少幀。其餘計算幀數的引數必須針對此值進行校正(因為一個環境步驟實際上將返回 frame_skip 幀)。
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0
資料收集引數#
在收集資料時,我們將能夠透過定義 frames_per_batch 引數來選擇每個批次的大小。我們還將定義允許使用的幀數(例如與模擬器的互動次數)。一般來說,RL 演算法的目標是學會盡快解決任務(就環境互動而言):total_frames 越低越好。
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000
PPO 引數#
在每次資料收集(或批次收集)時,我們將在一定數量的“epoch”上執行最佳化,每次都透過內部訓練迴圈消耗我們剛剛獲取的全部資料。這裡,sub_batch_size 與上面的 frames_per_batch 不同:請記住,我們正在處理來自收集器的“批次資料”,其大小由 frames_per_batch 定義,並且我們將在內部訓練迴圈中將其進一步劃分為更小的子批次。這些子批次的大小由 sub_batch_size 控制。
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10 # optimization steps per batch of data collected
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4
定義環境#
在 RL 中,*環境* 通常是我們對模擬器或控制系統的稱呼。各種庫都提供強化學習的模擬環境,包括 Gymnasium(以前稱為 OpenAI Gym)、DeepMind 控制套件等。作為一個通用庫,TorchRL 的目標是為大量 RL 模擬器提供可互換的介面,讓您輕鬆地將一個環境替換為另一個。例如,使用幾個字元就可以建立一個包裝好的 gym 環境。
Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
這段程式碼有幾點需要注意:首先,我們透過呼叫 GymEnv 包裝器來建立環境。如果傳遞了額外的關鍵字引數,它們將被傳輸到 gym.make 方法,從而涵蓋最常見的環境建立命令。或者,您也可以直接使用 gym.make(env_name, **kwargs) 建立一個 gym 環境,並將其包裝在 GymWrapper 類中。
還有 device 引數:對於 gym,這僅控制儲存輸入動作和觀察狀態的裝置,但執行始終在 CPU 上進行。原因很簡單,gym 不支援裝置上執行,除非另有說明。對於其他庫,我們可以控制執行裝置,並且在可能的情況下,我們會盡量在儲存和執行後端方面保持一致。
轉換#
我們將向環境新增一些變換,以準備好策略的資料。在 Gym 中,這通常透過包裝器實現。TorchRL 採用不同的方法,更類似於其他 PyTorch 領域庫,透過使用變換。要向環境新增變換,只需將其包裝在 TransformedEnv 例項中,並將其變換序列附加到其中。轉換後的環境將繼承被包裝環境的裝置和元資料,並根據其包含的變換序列對其進行轉換。
歸一化#
首先編碼的是一個歸一化變換。經驗法則,最好使資料大致匹配單位高斯分佈:為了實現這一點,我們將執行一定數量的隨機步驟,並計算這些觀察的統計摘要。
我們將附加另外兩個變換:DoubleToFloat 變換會將雙精度條目轉換為單精度數字,以便策略讀取。StepCounter 變換將用於計算環境終止之前的步數。我們將使用此度量作為補充效能度量。
正如我們稍後將看到的,TorchRL 的許多類依賴於 TensorDict 進行通訊。您可以將其視為具有一些額外張量功能的 Python 字典。在實踐中,這意味著我們將處理的許多模組需要被告知要讀取什麼鍵(in_keys)以及在它們將接收的 tensordict 中寫入什麼鍵(out_keys)。通常,如果省略 out_keys,則假定 in_keys 條目將被原地更新。對於我們的變換,我們感興趣的唯一條目是 "observation",我們的變換層將被告知修改此條目,僅此而已。
env = TransformedEnv(
base_env,
Compose(
# normalize observations
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter(),
),
)
您可能已經注意到,我們建立了一個歸一化層,但沒有設定其歸一化引數。要做到這一點,ObservationNorm 可以自動收集我們環境的統計摘要。
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
現在,ObservationNorm 變換已填充了用於歸一化資料的均值和方差。
讓我們對我們的統計摘要的形狀進行一些健全性檢查。
print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])
環境不僅由其模擬器和變換定義,還由一系列元資料定義,這些元資料描述了在執行期間可以預期什麼。出於效率原因,TorchRL 在環境規範方面非常嚴格,但您可以輕鬆檢查您的環境規範是否足夠。在我們的示例中,GymWrapper 和繼承自它的 GymEnv 已經負責為您的環境設定正確的規範,所以您不必擔心這一點。
不過,讓我們透過檢視轉換後的環境的規範來具體看一下。有三個規範需要檢視:observation_spec 定義了在環境中執行動作時可以預期什麼,reward_spec 指示了獎勵域,最後是 input_spec(其中包含 action_spec),它代表了環境執行單個步驟所需的一切。
print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: Composite(
observation: UnboundedContinuous(
shape=torch.Size([11]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None)
reward_spec: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
input_spec: Composite(
full_state_spec: Composite(
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=cpu,
shape=torch.Size([]),
data_cls=None),
full_action_spec: Composite(
action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([]),
data_cls=None),
device=cpu,
shape=torch.Size([]),
data_cls=None)
action_spec (as defined by input_spec): BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
該 check_env_specs() 函式執行一個小的回滾,並將其輸出與環境規範進行比較。如果沒有引發錯誤,我們可以確信規範已正確定義。
2025-10-15 19:17:32,789 [torchrl][INFO] check_env_specs succeeded! [END]
為了好玩,讓我們看看簡單的隨機回滾是什麼樣的。您可以呼叫 env.rollout(n_steps) 並獲取環境輸入和輸出的概覽。動作將自動從動作規範域中抽取,因此您不必擔心設計一個隨機取樣器。
通常,在每一步,RL 環境都會接收一個動作作為輸入,並輸出一個觀察、一個獎勵和一個完成狀態。觀察可能是複合的,這意味著它可能由多個張量組成。這對於 TorchRL 來說不是問題,因為整個觀察集會自動打包到輸出的 TensorDict 中。在執行完給定步數的回滾(例如,一系列環境步驟和隨機動作生成)後,我們將檢索一個 TensorDict 例項,其形狀與此軌跡長度匹配。
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=cpu,
is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])
我們的回滾資料的形狀為 torch.Size([3]),這與我們執行它的步數相符。"next" 條目指向當前步驟之後的資料。在大多數情況下,時間 t 的 "next" 資料與 t+1 的資料匹配,但如果我們使用某些特定變換(例如,多步),則情況可能並非如此。
策略#
PPO 使用隨機策略來處理探索。這意味著我們的神經網路將必須輸出分佈的引數,而不是對應於所採取動作的單個值。
由於資料是連續的,我們使用 Tanh-Normal 分佈來尊重動作空間邊界。TorchRL 提供了這種分佈,我們只需要關心構建一個輸出正確數量引數的神經網路,以便策略可以與之配合(均值和方差)。
這裡唯一的額外難點是將我們的輸出分成兩部分,並將第二部分對映到嚴格正的空間。
我們分三個步驟設計策略
定義一個神經網路
D_obs->2 * D_action。確實,我們的loc(均值)和scale(方差)的維度都是D_action。附加一個
NormalParamExtractor以提取均值和方差(例如,將輸入分成兩部分,並對方差引數應用正變換)。建立一個機率性的
TensorDictModule,它可以生成此分佈並從中取樣。
actor_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
NormalParamExtractor(),
)
為了使策略能夠透過 tensordict 資料載體與環境“交流”,我們將 nn.Module 包裝在 TensorDictModule 中。此類將簡單地讀取其提供的 in_keys,並將輸出原地寫入已註冊的 out_keys。
policy_module = TensorDictModule(
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
現在我們需要從我們的正態分佈的均值和方差構建一個分佈。為此,我們指示 ProbabilisticActor 類從均值和方差引數構建一個 TanhNormal。我們還提供了此分佈的最小值和最大值,這些值是從環境規範中獲取的。
in_keys 的名稱(因此也是上面 TensorDictModule 的 out_keys 的名稱)不能隨意設定,因為 TanhNormal 分佈建構函式將期望 loc 和 scale 關鍵字引數。話雖如此,ProbabilisticActor 還接受 Dict[str, str] 型別的 in_keys,其中鍵值對指示了將用於每個要使用的關鍵字引數的 in_key 字串。
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
價值網路#
價值網路是 PPO 演算法的關鍵組成部分,儘管它不會在推理時使用。此模組將讀取觀察值並返回後續軌跡的折扣回報估計。這允許我們依賴在訓練期間即時學習的效用估計來分攤學習成本。我們的價值網路與策略具有相同的結構,但為簡單起見,我們為其分配了自己的一組引數。
value_net = nn.Sequential(
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(num_cells, device=device),
nn.Tanh(),
nn.LazyLinear(1, device=device),
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
讓我們試試我們的策略和價值模組。如前所述,TensorDictModule 的使用使得可以直接讀取環境的輸出來執行這些模組,因為它們知道要讀取哪些資訊以及在哪裡寫入。
print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
fields={
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
action_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
Running value: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=cpu,
is_shared=False)
資料收集器#
TorchRL 提供了一組 資料收集器類。簡而言之,這些類執行三個操作:重置環境,根據最新觀察計算動作,在環境中執行一步,然後重複最後兩個步驟,直到環境發出停止訊號(或達到完成狀態)。
它們允許您控制每次迭代收集多少幀(透過 frames_per_batch 引數),何時重置環境(透過 max_frames_per_traj 引數),策略應該在哪個 device 上執行,等等。它們還設計為與批處理和多程序環境高效配合。
最簡單的資料收集器是 SyncDataCollector:它是一個迭代器,您可以用來獲取指定長度的資料批次,並在收集完總幀數(total_frames)後停止。其他資料收集器(MultiSyncDataCollector 和 MultiaSyncDataCollector)將在多個程序的計算節點上以同步和非同步方式執行相同的操作。
與之前的策略和環境一樣,資料收集器將返回 TensorDict 例項,其總元素數量將匹配 frames_per_batch。使用 TensorDict 將資料傳遞給訓練迴圈,允許您編寫 100% 忽略回滾內容實際特異性的資料載入管道。
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=False,
device=device,
)
回放緩衝區#
回放緩衝區是離策略 RL 演算法的常見構建模組。在策略環境中,每當收集一批資料時,回放緩衝區就會被重新填充,並且其資料會在一定數量的 epoch 中被重複消耗。
TorchRL 的回放緩衝區是使用通用的容器 ReplayBuffer 構建的,它接受緩衝區元件作為引數:儲存、寫入器、取樣器以及可能的變換。只有儲存(指示回放緩衝區容量)是強制性的。我們還指定了一個無重複取樣器,以避免在一個 epoch 中多次取樣同一個專案。使用回放緩衝區進行 PPO 不是強制性的,我們可以直接從收集的批次中取樣子批次,但使用這些類可以方便我們以可重現的方式構建內部訓練迴圈。
replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(max_size=frames_per_batch),
sampler=SamplerWithoutReplacement(),
)
損失函式#
PPO 損失可以直接從 TorchRL 匯入,以方便地使用 ClipPPOLoss 類。這是使用 PPO 最簡單的方法:它隱藏了 PPO 的數學運算以及與之相關的控制流。
PPO 需要計算一些“優勢估計”。簡而言之,優勢是一個值,它反映了在處理偏差/方差權衡時對回報值的預期。要計算優勢,只需(1)構建優勢模組,該模組利用我們的價值運算子,(2)在每個 epoch 之前將每個資料批次透過它。GAE 模組將使用新的 "advantage" 和 "value_target" 條目更新輸入的 tensordict。"value_target" 是一個無梯度張量,表示價值網路應以輸入觀察值表示的經驗值。這兩者都將由 ClipPPOLoss 用於返回策略和價值損失。
advantage_module = GAE(
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
)
loss_module = ClipPPOLoss(
actor_network=policy_module,
critic_network=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
# these keys match by default but we set this for completeness
critic_coef=1.0,
loss_critic_type="smooth_l1",
)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0
)
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:384: DeprecationWarning:
'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:450: DeprecationWarning:
'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.
訓練迴圈#
現在我們有了編寫訓練迴圈所需的所有元件。步驟包括
收集資料
計算優勢
迴圈遍歷收集的資料以計算損失值
反向傳播
最佳化
重複
重複
重複
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
# we now have a batch of data to work with. Let's learn something from it.
for _ in range(num_epochs):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
advantage_module(tensordict_data)
data_view = tensordict_data.reshape(-1)
replay_buffer.extend(data_view.cpu())
for _ in range(frames_per_batch // sub_batch_size):
subdata = replay_buffer.sample(sub_batch_size)
loss_vals = loss_module(subdata.to(device))
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
# Optimization: backward, grad clipping and optimization step
loss_value.backward()
# this is not strictly mandatory but it's good practice to keep
# your gradient norm bounded
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
optim.step()
optim.zero_grad()
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel())
cum_reward_str = (
f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
)
logs["step_count"].append(tensordict_data["step_count"].max().item())
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
if i % 10 == 0:
# We evaluate the policy once every 10 batches of data.
# Evaluation is rather simple: execute the policy without exploration
# (take the expected value of the action distribution) for a given
# number of steps (1000, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
eval_str = (
f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
f"eval step-count: {logs['eval step_count'][-1]}"
)
del eval_rollout
pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
# We're also using a learning rate scheduler. Like the gradient clipping,
# this is a nice-to-have but nothing necessary for PPO to work.
scheduler.step()
0%| | 0/50000 [00:00<?, ?it/s]
2%|▏ | 1000/50000 [00:03<02:37, 310.41it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.0941 (init= 9.0941), step count (max): 14, lr policy: 0.0003: 2%|▏ | 1000/50000 [00:03<02:37, 310.41it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.0941 (init= 9.0941), step count (max): 14, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:25, 330.59it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1147 (init= 9.0941), step count (max): 14, lr policy: 0.0003: 4%|▍ | 2000/50000 [00:06<02:25, 330.59it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1147 (init= 9.0941), step count (max): 14, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:08<02:18, 339.71it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1557 (init= 9.0941), step count (max): 17, lr policy: 0.0003: 6%|▌ | 3000/50000 [00:08<02:18, 339.71it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1557 (init= 9.0941), step count (max): 17, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:11<02:13, 344.80it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1727 (init= 9.0941), step count (max): 22, lr policy: 0.0003: 8%|▊ | 4000/50000 [00:11<02:13, 344.80it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.1727 (init= 9.0941), step count (max): 22, lr policy: 0.0003: 10%|█ | 5000/50000 [00:14<02:08, 348.84it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2145 (init= 9.0941), step count (max): 27, lr policy: 0.0003: 10%|█ | 5000/50000 [00:14<02:08, 348.84it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2145 (init= 9.0941), step count (max): 27, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:17<02:04, 352.28it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2239 (init= 9.0941), step count (max): 34, lr policy: 0.0003: 12%|█▏ | 6000/50000 [00:17<02:04, 352.28it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2239 (init= 9.0941), step count (max): 34, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:20<02:03, 349.15it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2411 (init= 9.0941), step count (max): 34, lr policy: 0.0003: 14%|█▍ | 7000/50000 [00:20<02:03, 349.15it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2411 (init= 9.0941), step count (max): 34, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:23<01:59, 352.67it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2366 (init= 9.0941), step count (max): 33, lr policy: 0.0003: 16%|█▌ | 8000/50000 [00:23<01:59, 352.67it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2366 (init= 9.0941), step count (max): 33, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:25<01:55, 355.66it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2558 (init= 9.0941), step count (max): 63, lr policy: 0.0003: 18%|█▊ | 9000/50000 [00:25<01:55, 355.66it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2558 (init= 9.0941), step count (max): 63, lr policy: 0.0003: 20%|██ | 10000/50000 [00:28<01:51, 357.59it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2629 (init= 9.0941), step count (max): 63, lr policy: 0.0003: 20%|██ | 10000/50000 [00:28<01:51, 357.59it/s]
eval cumulative reward: 82.8892 (init: 82.8892), eval step-count: 8, average reward= 9.2629 (init= 9.0941), step count (max): 63, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:31<01:48, 359.10it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2628 (init= 9.0941), step count (max): 50, lr policy: 0.0003: 22%|██▏ | 11000/50000 [00:31<01:48, 359.10it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2628 (init= 9.0941), step count (max): 50, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:34<01:45, 358.80it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2694 (init= 9.0941), step count (max): 56, lr policy: 0.0003: 24%|██▍ | 12000/50000 [00:34<01:45, 358.80it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2694 (init= 9.0941), step count (max): 56, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:36<01:42, 359.92it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2515 (init= 9.0941), step count (max): 39, lr policy: 0.0003: 26%|██▌ | 13000/50000 [00:36<01:42, 359.92it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2515 (init= 9.0941), step count (max): 39, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:39<01:41, 355.31it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2663 (init= 9.0941), step count (max): 81, lr policy: 0.0003: 28%|██▊ | 14000/50000 [00:39<01:41, 355.31it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2663 (init= 9.0941), step count (max): 81, lr policy: 0.0003: 30%|███ | 15000/50000 [00:42<01:37, 357.64it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2876 (init= 9.0941), step count (max): 61, lr policy: 0.0002: 30%|███ | 15000/50000 [00:42<01:37, 357.64it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2876 (init= 9.0941), step count (max): 61, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:45<01:34, 359.32it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2780 (init= 9.0941), step count (max): 55, lr policy: 0.0002: 32%|███▏ | 16000/50000 [00:45<01:34, 359.32it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2780 (init= 9.0941), step count (max): 55, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:48<01:31, 360.20it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2695 (init= 9.0941), step count (max): 42, lr policy: 0.0002: 34%|███▍ | 17000/50000 [00:48<01:31, 360.20it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2695 (init= 9.0941), step count (max): 42, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:50<01:28, 360.82it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2753 (init= 9.0941), step count (max): 61, lr policy: 0.0002: 36%|███▌ | 18000/50000 [00:50<01:28, 360.82it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2753 (init= 9.0941), step count (max): 61, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:53<01:25, 361.27it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2847 (init= 9.0941), step count (max): 69, lr policy: 0.0002: 38%|███▊ | 19000/50000 [00:53<01:25, 361.27it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2847 (init= 9.0941), step count (max): 69, lr policy: 0.0002: 40%|████ | 20000/50000 [00:56<01:24, 356.27it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2992 (init= 9.0941), step count (max): 77, lr policy: 0.0002: 40%|████ | 20000/50000 [00:56<01:24, 356.27it/s]
eval cumulative reward: 212.8816 (init: 82.8892), eval step-count: 22, average reward= 9.2992 (init= 9.0941), step count (max): 77, lr policy: 0.0002: 42%|████▏ | 21000/50000 [00:59<01:20, 358.39it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2825 (init= 9.0941), step count (max): 62, lr policy: 0.0002: 42%|████▏ | 21000/50000 [00:59<01:20, 358.39it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2825 (init= 9.0941), step count (max): 62, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:02<01:18, 357.41it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2864 (init= 9.0941), step count (max): 55, lr policy: 0.0002: 44%|████▍ | 22000/50000 [01:02<01:18, 357.41it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2864 (init= 9.0941), step count (max): 55, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:04<01:15, 358.40it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3024 (init= 9.0941), step count (max): 90, lr policy: 0.0002: 46%|████▌ | 23000/50000 [01:04<01:15, 358.40it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3024 (init= 9.0941), step count (max): 90, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:07<01:12, 359.93it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3009 (init= 9.0941), step count (max): 84, lr policy: 0.0002: 48%|████▊ | 24000/50000 [01:07<01:12, 359.93it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3009 (init= 9.0941), step count (max): 84, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:10<01:09, 361.25it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2990 (init= 9.0941), step count (max): 93, lr policy: 0.0002: 50%|█████ | 25000/50000 [01:10<01:09, 361.25it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2990 (init= 9.0941), step count (max): 93, lr policy: 0.0002: 52%|█████▏ | 26000/50000 [01:13<01:06, 361.82it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2913 (init= 9.0941), step count (max): 57, lr policy: 0.0001: 52%|█████▏ | 26000/50000 [01:13<01:06, 361.82it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2913 (init= 9.0941), step count (max): 57, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:15<01:04, 356.54it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2905 (init= 9.0941), step count (max): 51, lr policy: 0.0001: 54%|█████▍ | 27000/50000 [01:15<01:04, 356.54it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2905 (init= 9.0941), step count (max): 51, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:18<01:01, 358.12it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2960 (init= 9.0941), step count (max): 66, lr policy: 0.0001: 56%|█████▌ | 28000/50000 [01:18<01:01, 358.12it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2960 (init= 9.0941), step count (max): 66, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:21<00:58, 358.96it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2957 (init= 9.0941), step count (max): 85, lr policy: 0.0001: 58%|█████▊ | 29000/50000 [01:21<00:58, 358.96it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.2957 (init= 9.0941), step count (max): 85, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:24<00:55, 359.97it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3115 (init= 9.0941), step count (max): 72, lr policy: 0.0001: 60%|██████ | 30000/50000 [01:24<00:55, 359.97it/s]
eval cumulative reward: 343.8970 (init: 82.8892), eval step-count: 36, average reward= 9.3115 (init= 9.0941), step count (max): 72, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:27<00:52, 360.61it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.2996 (init= 9.0941), step count (max): 75, lr policy: 0.0001: 62%|██████▏ | 31000/50000 [01:27<00:52, 360.61it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.2996 (init= 9.0941), step count (max): 75, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:29<00:50, 357.88it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3046 (init= 9.0941), step count (max): 137, lr policy: 0.0001: 64%|██████▍ | 32000/50000 [01:29<00:50, 357.88it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3046 (init= 9.0941), step count (max): 137, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:32<00:48, 353.91it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3000 (init= 9.0941), step count (max): 76, lr policy: 0.0001: 66%|██████▌ | 33000/50000 [01:32<00:48, 353.91it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3000 (init= 9.0941), step count (max): 76, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:35<00:44, 356.73it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.2978 (init= 9.0941), step count (max): 105, lr policy: 0.0001: 68%|██████▊ | 34000/50000 [01:35<00:44, 356.73it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.2978 (init= 9.0941), step count (max): 105, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:38<00:41, 358.22it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3078 (init= 9.0941), step count (max): 95, lr policy: 0.0001: 70%|███████ | 35000/50000 [01:38<00:41, 358.22it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3078 (init= 9.0941), step count (max): 95, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:41<00:38, 359.68it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3004 (init= 9.0941), step count (max): 65, lr policy: 0.0001: 72%|███████▏ | 36000/50000 [01:41<00:38, 359.68it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3004 (init= 9.0941), step count (max): 65, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:43<00:36, 359.95it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3121 (init= 9.0941), step count (max): 105, lr policy: 0.0001: 74%|███████▍ | 37000/50000 [01:43<00:36, 359.95it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3121 (init= 9.0941), step count (max): 105, lr policy: 0.0001: 76%|███████▌ | 38000/50000 [01:46<00:33, 360.34it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3048 (init= 9.0941), step count (max): 95, lr policy: 0.0000: 76%|███████▌ | 38000/50000 [01:46<00:33, 360.34it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3048 (init= 9.0941), step count (max): 95, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:49<00:30, 361.20it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3072 (init= 9.0941), step count (max): 94, lr policy: 0.0000: 78%|███████▊ | 39000/50000 [01:49<00:30, 361.20it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3072 (init= 9.0941), step count (max): 94, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:52<00:28, 356.76it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3090 (init= 9.0941), step count (max): 76, lr policy: 0.0000: 80%|████████ | 40000/50000 [01:52<00:28, 356.76it/s]
eval cumulative reward: 512.9731 (init: 82.8892), eval step-count: 54, average reward= 9.3090 (init= 9.0941), step count (max): 76, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [01:54<00:25, 358.52it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3085 (init= 9.0941), step count (max): 88, lr policy: 0.0000: 82%|████████▏ | 41000/50000 [01:55<00:25, 358.52it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3085 (init= 9.0941), step count (max): 88, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [01:57<00:22, 355.16it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3075 (init= 9.0941), step count (max): 75, lr policy: 0.0000: 84%|████████▍ | 42000/50000 [01:57<00:22, 355.16it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3075 (init= 9.0941), step count (max): 75, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:00<00:19, 357.56it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3102 (init= 9.0941), step count (max): 96, lr policy: 0.0000: 86%|████████▌ | 43000/50000 [02:00<00:19, 357.56it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3102 (init= 9.0941), step count (max): 96, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:03<00:16, 358.51it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3246 (init= 9.0941), step count (max): 163, lr policy: 0.0000: 88%|████████▊ | 44000/50000 [02:03<00:16, 358.51it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3246 (init= 9.0941), step count (max): 163, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:06<00:13, 359.31it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3108 (init= 9.0941), step count (max): 89, lr policy: 0.0000: 90%|█████████ | 45000/50000 [02:06<00:13, 359.31it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3108 (init= 9.0941), step count (max): 89, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:08<00:11, 359.28it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3064 (init= 9.0941), step count (max): 82, lr policy: 0.0000: 92%|█████████▏| 46000/50000 [02:08<00:11, 359.28it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3064 (init= 9.0941), step count (max): 82, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:11<00:08, 353.94it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3119 (init= 9.0941), step count (max): 102, lr policy: 0.0000: 94%|█████████▍| 47000/50000 [02:11<00:08, 353.94it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3119 (init= 9.0941), step count (max): 102, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:14<00:05, 356.89it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3227 (init= 9.0941), step count (max): 140, lr policy: 0.0000: 96%|█████████▌| 48000/50000 [02:14<00:05, 356.89it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3227 (init= 9.0941), step count (max): 140, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:17<00:02, 358.40it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3111 (init= 9.0941), step count (max): 87, lr policy: 0.0000: 98%|█████████▊| 49000/50000 [02:17<00:02, 358.40it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3111 (init= 9.0941), step count (max): 87, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:20<00:00, 359.59it/s]
eval cumulative reward: 718.6249 (init: 82.8892), eval step-count: 76, average reward= 9.3161 (init= 9.0941), step count (max): 86, lr policy: 0.0000: 100%|██████████| 50000/50000 [02:20<00:00, 359.59it/s]
結果#
在達到 1M 步的上限之前,演算法應該已經達到了 1000 步的最大步數,這是軌跡被截斷之前的最大步數。
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()

結論和後續步驟#
在本教程中,我們學習了
如何使用
torchrl建立和自定義環境;如何編寫模型和損失函式;
如何設定典型的訓練迴圈。
如果您想進一步嘗試本教程,可以進行以下修改
從效率的角度來看,我們可以並行執行多個模擬來加快資料收集。有關更多資訊,請檢視
ParallelEnv。從日誌記錄的角度來看,可以在請求渲染後向環境新增一個
torchrl.record.VideoRecorder變換,以獲得倒立擺執行的視覺化渲染。有關更多資訊,請檢視torchrl.record。
指令碼總執行時間: (2 分 22.007 秒)