VanillaWeightUpdater¶
- class torchrl.collectors.VanillaWeightUpdater(*, weight_getter: Callable[[], TensorDictBase] | None = None, policy_weights: TensorDictBase)[source]¶
一個簡單的
WeightUpdaterBase實現,用於更新本地策略的權重。“VanillaWeightSender”類提供了一種透過直接從指定源獲取權重來更新本地策略權重的基本機制。它通常用於權重更新邏輯簡單且不需要任何複雜對映或轉換的場景。
當未提供自定義權重發送器時,此類的 SyncDataCollector 預設使用它。
另請參閱
- 關鍵字引數:
weight_getter (Callable[[], TensorDictBase], optional) – 一個返回伺服器權重的可呼叫物件。如果未提供,則必須將權重直接傳遞給
update_weights()。policy_weights (TensorDictBase) – 一個 TensorDictBase,其中包含要就地更新的策略權重。
- all_worker_ids() list[int] | list[torch.device] | None¶
獲取所有工作程序 ID 的列表。
預設返回 None。子類應覆蓋以返回實際的工作程序 ID。
- 返回:
工作程序 ID 列表或 None。
- 返回型別:
list[int] | list[torch.device] | None
- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出範圍或未設定,則返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[source]¶
從策略建立 VanillaWeightUpdater 例項。
此方法建立一個權重更新器,該更新器將直接使用策略的 state dict 來更新其權重。
- 引數:
policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。
- 返回:
- 已配置為更新的權重更新器例項
策略的權重。
- 返回型別:
- increment_version()¶
增加策略版本。
- init(*args, **kwargs)¶
使用自定義引數初始化權重更新器。
子類可以覆蓋此方法以處理自定義初始化。預設情況下,這是一個無操作。
- 引數:
*args – 初始化位置引數
**kwargs – 初始化關鍵字引數
- property post_hooks: list[collections.abc.Callable[[], None]]¶
註冊到權重更新器的後置鉤子列表。
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
更新策略的權重,或在指定/所有遠端工作程序上更新。
- 引數:
policy_or_weights – 從中獲取權重的來源。可以是: - TensorDictModuleBase:將提取權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:一個包含權重的普通字典 - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。
worker_ids – 要更新的工作程序的可選列表。
返回:無。
- register_collector(collector)¶
在更新器中註冊一個收集器。
註冊後,更新器將不再接受另一個收集器。
- 引數:
collector (DataCollectorBase) – 要註冊的 collector。
- register_post_hook(hook: Callable[[], None])¶
註冊一個後置鉤子,在權重更新後呼叫。
- 引數:
hook (Callable[[], None]) – 要註冊的後置鉤子。