DataCollectorBase¶
- class torchrl.collectors.DataCollectorBase[原始碼]¶
資料收集器的基類。
- async_shutdown(timeout: float | None = None, close_env: bool = True) None[原始碼]¶
當收集器透過 start 方法非同步啟動時,關閉收集器。
- 引數
timeout (float, optional): 等待收集器關閉的最長時間。 close_env (bool, optional): 如果為 True,收集器將關閉包含的環境。
預設為 True。
另請參閱
- init_updater(*args, **kwargs)[原始碼]¶
使用自定義引數初始化權重更新器。
此方法將引數傳遞給權重更新器的 init 方法。如果未設定權重更新器,則此方法無效。
- 引數:
*args – 用於權重更新器初始化的位置引數
**kwargs – 用於權重更新器初始化的關鍵字引數
- start()[原始碼]¶
啟動收集器以進行非同步資料收集。
此方法啟動後臺資料收集,允許資料收集和訓練解耦。
收集的資料通常儲存在收集器初始化期間傳入的經驗回放緩衝區中。
注意
呼叫此方法後,在使用完畢後務必使用
async_shutdown()關閉收集器以釋放資源。警告
由於其解耦的性質,非同步資料收集可能會顯著影響訓練效能。在使用此模式之前,請確保瞭解其對您特定演算法的影響。
- 丟擲:
NotImplementedError – 如果子類未實現。
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None[原始碼]¶
更新資料收集器的策略權重,支援本地和遠端執行上下文。
此方法確保資料收集器使用的策略權重與最新的訓練權重同步。它支援本地和遠端權重更新,具體取決於資料收集器的配置。本地(下載)更新在遠端(上傳)更新之前執行,以便可以將權重從伺服器傳輸到子工作器。
- 引數:
policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的權重。可以是: - TensorDictModuleBase:將提取其權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:包含權重的常規 dict - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的識別符號。當收集器關聯多個工作器時,此項很重要。
- 丟擲:
TypeError – 如果提供了 worker_ids 但未配置 weight_updater。
注意
使用者應擴充套件 WeightUpdaterBase 類來定製特定用例的權重更新邏輯。不應覆蓋此方法。
另請參閱
LocalWeightsUpdaterBase和RemoteWeightsUpdaterBase()。