快捷方式

DataCollectorBase

class torchrl.collectors.DataCollectorBase[原始碼]

資料收集器的基類。

async_shutdown(timeout: float | None = None, close_env: bool = True) None[原始碼]

當收集器透過 start 方法非同步啟動時,關閉收集器。

引數

timeout (float, optional): 等待收集器關閉的最長時間。 close_env (bool, optional): 如果為 True,收集器將關閉包含的環境。

預設為 True

另請參閱

start()

init_updater(*args, **kwargs)[原始碼]

使用自定義引數初始化權重更新器。

此方法將引數傳遞給權重更新器的 init 方法。如果未設定權重更新器,則此方法無效。

引數:
  • *args – 用於權重更新器初始化的位置引數

  • **kwargs – 用於權重更新器初始化的關鍵字引數

pause()[原始碼]

上下文管理器,如果收集器正在自由執行,則暫停收集器。

start()[原始碼]

啟動收集器以進行非同步資料收集。

此方法啟動後臺資料收集,允許資料收集和訓練解耦。

收集的資料通常儲存在收集器初始化期間傳入的經驗回放緩衝區中。

注意

呼叫此方法後,在使用完畢後務必使用 async_shutdown() 關閉收集器以釋放資源。

警告

由於其解耦的性質,非同步資料收集可能會顯著影響訓練效能。在使用此模式之前,請確保瞭解其對您特定演算法的影響。

丟擲:

NotImplementedError – 如果子類未實現。

update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None[原始碼]

更新資料收集器的策略權重,支援本地和遠端執行上下文。

此方法確保資料收集器使用的策略權重與最新的訓練權重同步。它支援本地和遠端權重更新,具體取決於資料收集器的配置。本地(下載)更新在遠端(上傳)更新之前執行,以便可以將權重從伺服器傳輸到子工作器。

引數:
  • policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的權重。可以是: - TensorDictModuleBase:將提取其權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:包含權重的常規 dict - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重

  • worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的識別符號。當收集器關聯多個工作器時,此項很重要。

丟擲:

TypeError – 如果提供了 worker_ids 但未配置 weight_updater

注意

使用者應擴充套件 WeightUpdaterBase 類來定製特定用例的權重更新邏輯。不應覆蓋此方法。

另請參閱

LocalWeightsUpdaterBaseRemoteWeightsUpdaterBase()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源