注意
跳轉到末尾 下載完整的示例程式碼。
模型最佳化入門¶
注意
要在 notebook 中執行本教程,請在開頭新增一個安裝單元格,其中包含:
!pip install tensordict !pip install torchrl
在 TorchRL 中,我們嘗試按照 PyTorch 的慣例來處理最佳化,使用專用的損失模組,這些模組的唯一目的是最佳化模型。這種方法有效地將策略的執行與其訓練分離開來,使我們能夠設計出與傳統監督學習示例中找到的訓練迴圈相似的訓練迴圈。
因此,典型的訓練迴圈如下所示:
..code - block::Python
>>> for i in range(n_collections): ... data = get_next_batch(env, policy) ... for j in range(n_optim): ... loss = loss_fn(data) ... loss.backward() ... optim.step()
在本簡短教程中,您將獲得對損失模組的簡要概述。由於基本用法的 API 通常很簡單,因此本教程將保持簡短。
強化學習目標函式¶
在強化學習中,創新通常涉及探索最佳化策略的新方法(即新演算法),而不是像其他領域那樣專注於新架構。在 TorchRL 中,這些演算法被封裝在損失模組中。損失模組協調您演算法的各個元件,併產生一組損失值,這些值可以被反向傳播以訓練相應的元件。
在本教程中,我們將以一個流行的離策略演算法為例,DDPG。
要構建損失模組,唯一需要的是定義為 :class:`~tensordict.nn.TensorDictModule` 的網路集。大多數情況下,其中一個模組將是策略。也可能需要其他輔助網路,例如 Q 值網路或某種形式的評估器。讓我們看看實際情況:DDPG 需要一個從觀察空間到動作空間的確定性對映,以及一個預測狀態-動作對值的價值網路。DDPG 損失將嘗試找到能夠輸出最大化給定狀態價值的動作的策略引數。
為了構建損失,我們需要 Actor 和 Value 網路。如果它們是按照 DDPG 的預期構建的,那麼它們就是我們獲得可訓練損失模組所需的一切。
from torchrl.envs import GymEnv
env = GymEnv("Pendulum-v1")
from torchrl.modules import Actor, MLP, ValueOperator
from torchrl.objectives import DDPGLoss
n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]
actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32]))
value_net = ValueOperator(
MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]),
in_keys=["observation", "action"],
)
ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net)
僅此而已!我們的損失模組現在可以使用來自環境的資料執行(我們省略了探索、儲存和其他功能,以專注於損失功能)。
rollout = env.rollout(max_steps=100, policy=actor)
loss_vals = ddpg_loss(rollout)
print(loss_vals)
LossModule 的輸出¶
如您所見,我們從損失中收到的值不是單個標量,而是一個包含多個損失的字典。
原因很簡單:因為可能一次會訓練多個網路,並且由於某些使用者可能希望將每個模組的最佳化分成不同的步驟,因此 TorchRL 的目標將返回包含各種損失元件的字典。
這種格式還允許我們將元資料與損失值一起傳遞。通常,我們確保只有損失值是可微分的,因此您可以簡單地對字典中的值求和以獲得總損失。如果您想確保您完全控制正在發生的事情,您可以只對鍵以 "loss_" 字首開頭的條目求和。
total_loss = 0
for key, val in loss_vals.items():
if key.startswith("loss_"):
total_loss += val
訓練 LossModule¶
鑑於以上所有內容,訓練模組與任何其他訓練迴圈中的操作並沒有太大區別。因為它封裝了模組,所以獲取可訓練引數列表的最簡單方法是查詢 parameters() 方法。
我們將需要一個最佳化器(或者如果您選擇,每個模組一個最佳化器)。
以下專案通常會在您的訓練迴圈中找到:
optim.step()
optim.zero_grad()
進一步考慮:目標引數¶
另一個需要考慮的重要方面是 DDPG 等離策略演算法中的目標引數。目標引數通常代表引數隨時間延遲或平滑的版本,並且它們在策略訓練期間的價值估計中起著至關重要的作用。與使用當前價值網路引數配置相比,使用目標引數進行策略訓練通常效率更高。通常,目標引數的管理由損失模組處理,讓使用者無需直接關心。但是,使用者仍有責任根據具體要求及時更新這些值。TorchRL 提供了幾個更新器,即 HardUpdate 和 SoftUpdate,它們可以輕鬆例項化,而無需深入瞭解損失模組的底層機制。
from torchrl.objectives import SoftUpdate
updater = SoftUpdate(ddpg_loss, eps=0.99)
在您的訓練迴圈中,您需要在每個最佳化步驟或每個收集步驟中更新目標引數。
updater.step()
以上是您入門所需的有關損失模組的所有知識!
要進一步探索該主題,請檢視:
《損失模組參考頁面》;
《編寫 DDPG 損失教程》;