Trainer¶
- class torchrl.trainers.Trainer(*args, **kwargs)[source]¶
一個通用的 Trainer 類。
Trainer 負責收集資料和訓練模型。為了使該類儘可能通用,Trainer 不會構建任何特定操作:所有操作都必須在訓練迴圈的特定點進行鉤接。
要構建一個 Trainer,需要一個可迭代的資料來源(一個
collector),一個損失模組和一個最佳化器。- 引數:
collector (Sequence[TensorDictBase]) – 一個返回 TensorDict 格式資料批次的迭代器,形狀為 [batch x time steps]。
total_frames (int) – 訓練期間要收集的總幀數。
loss_module (LossModule) – 一個讀取 TensorDict 批次(可能從回放緩衝區取樣)並返回損失 TensorDict 的模組,其中每個鍵都指向不同的損失元件。
optimizer (optim.Optimizer) – 一個用於訓練模型引數的最佳化器。
logger (Logger, optional) – 一個將處理日誌記錄的 Logger。
optim_steps_per_batch (int, optional) – 每個資料收集批次的最佳化步數。Trainer 的工作原理如下:主迴圈收集資料批次(epoch loop),子迴圈(training loop)在兩次資料收集之間執行模型更新。如果為 None,則 trainer 將使用 worker 的數量作為最佳化步數。
clip_grad_norm (bool, optional) – 如果為 True,則梯度將根據模型引數的總範數進行裁剪。如果為 False,則所有偏導數都將被限制在 (-clip_norm, clip_norm) 範圍內。預設為
True。clip_norm (Number, optional) – 用於裁剪梯度的值。預設為 None(無裁剪範數)。
progress_bar (bool, optional) – 如果為 True,將使用 tqdm 顯示進度條。如果未安裝 tqdm,則此選項無效。預設為
Trueseed (int, optional) – 將用於 collector、pytorch 和 numpy 的種子。預設為
None。save_trainer_interval (int, optional) – Trainer 儲存到磁碟的頻率,以幀數計。預設為 10000。
log_interval (int, optional) – 值記錄的頻率,以幀數計。預設為 10000。
save_trainer_file (path, optional) – 儲存 trainer 的路徑。預設為 None(不儲存)