快捷方式

WeightUpdaterBase

class torchrl.collectors.WeightUpdaterBase[來源]

用於在推理工作器上更新遠端策略權重的基類。

權重更新器是權重更新方案的核心部分

  • 在葉收集器節點中,它負責將權重發送到策略,這可以很簡單地更新 state-dict,也可以更復雜,如果使用了推理伺服器。

  • 在伺服器收集器節點中,它負責將權重發送到葉收集器。

在收集器中,更新器在 update_policy_weights_() 中被呼叫。

此類的主方法是 _push_weights() 方法,它更新工作器/策略中的策略權重。此方法由 push_weights() 呼叫,該方法也呼叫後置鉤子:只有 _push_weights 應由子類實現。

要擴充套件此類,請實現以下抽象方法

  • _get_server_weights (可選):定義如何從伺服器檢索權重,如果它們未傳遞給

    更新器。僅當權重(控制代碼)未直接傳遞時,才呼叫此方法。

  • _sync_weights_with_worker:定義如何與特定工作器同步權重。

    此方法必須由子類實現。

  • _maybe_map_weights:可選地在分發之前轉換伺服器權重。

    預設情況下,此方法返回的權重不變。

  • all_worker_ids:提供所有工作器識別符號的列表。

    預設返回 None(無工作器 ID)。

  • from_policy(可選類方法):定義如何從策略建立權重更新器的例項。

    如果實現,在初始化收集器中的權重更新器時,將先呼叫此方法,然後再回退到預設建構函式。

變數:

collector – 權重接收器的收集器(或任何容器)。收集器透過 register_collector() 註冊。

push_weights()[來源]

在指定的或所有遠端工作器上更新權重。__call__ 方法是 push_weights 的代理。

register_collector()[來源]

透過弱引用在接收器中註冊收集器(或任何容器)。當更新器註冊時,收集器將自動呼叫此方法。

from_policy()[來源]

可選的類方法,用於從策略建立例項。

後置鉤子
  • register_post_hook:註冊一個後置鉤子,在權重更新後被呼叫。

    後置鉤子必須是一個不接受引數的可呼叫物件。後置鉤子將在權重更新後被呼叫。後置鉤子將在與權重更新器相同的程序中呼叫。後置鉤子將按照註冊後置鉤子的順序呼叫。

另請參閱

update_policy_weights_().

all_worker_ids() list[int] | list[torch.device] | None[來源]

獲取所有工作程序 ID 的列表。

預設返回 None。子類應覆蓋以返回實際的工作程序 ID。

返回:

工作程序 ID 列表或 None。

返回型別:

list[int] | list[torch.device] | None

property collector: Any | None

接收器的收集器或容器。

如果容器超出範圍或未設定,則返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None[來源]

可選的類方法,用於從策略建立權重更新器例項。

子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。

引數:

policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。

返回:

權重更新器的例項,或者如果策略無法建立例項則為 None。

無法用於建立例項的例項。

返回型別:

WeightUpdaterBase | None

increment_version()[來源]

增加策略版本。

init(*args, **kwargs)[來源]

使用自定義引數初始化權重更新器。

子類可以覆蓋此方法以處理自定義初始化。預設情況下,這是一個無操作。

引數:
  • *args – 初始化位置引數

  • **kwargs – 初始化關鍵字引數

property post_hooks: list[collections.abc.Callable[[], None]]

註冊到權重更新器的後置鉤子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)[來源]

更新策略的權重,或在指定/所有遠端工作程序上更新。

引數:
  • policy_or_weights – 從中獲取權重的來源。可以是: - TensorDictModuleBase:將提取權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:一個包含權重的普通字典 - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。

  • worker_ids – 要更新的工作程序的可選列表。

返回:無。

register_collector(collector)[來源]

在更新器中註冊一個收集器。

註冊後,更新器將不再接受另一個收集器。

引數:

collector (DataCollectorBase) – 要註冊的 collector。

register_post_hook(hook: Callable[[], None])[來源]

註冊一個後置鉤子,在權重更新後呼叫。

引數:

hook (Callable[[], None]) – 要註冊的後置鉤子。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源