快捷方式

RPCWeightUpdater

class torchrl.collectors.distributed.RPCWeightUpdater(collector_infos, collector_class, collector_rrefs, policy_weights: TensorDictBase, num_workers: int)[原始碼]

一個遠端權重更新器,用於使用 RPC 在遠端工作者之間同步策略權重。

The RPCWeightUpdater class provides a mechanism for updating the weights of a policy across remote inference workers using RPC. It is designed to work with the RPCDataCollector to ensure that each worker receives the latest policy weights. This class is typically used in distributed data collection scenarios where remote workers are managed via RPC and need to be kept in sync with the central policy weights。

引數:
  • collector_infos – 關於 collector 的資訊,用於 RPC 通訊。

  • collector_class – 所使用的 collector 的類。

  • collector_rrefs – collector 的遠端引用。

  • policy_weights (TensorDictBase) – 需要分發給工作者的策略的當前權重。

  • num_workers (int) – 將接收更新策略權重的遠端工作者的數量。

update_weights()

使用 RPC 更新指定或所有遠端工作者的權重。

all_worker_ids()[原始碼]

返回所有工作者識別符號的列表(在此類中未實現)。

_sync_weights_with_worker()[原始碼]

將伺服器權重與特定工作者同步(未實現)。

_get_server_weights()[原始碼]

從伺服器檢索最新權重(未實現)。

_maybe_map_weights()[原始碼]

分發前可選地對映伺服器權重(未實現)。

注意

該類假定伺服器權重可以直接應用於遠端工作者,而無需任何額外的處理。如果您的用例需要更復雜的權重對映或同步邏輯,請考慮擴充套件 WeightUpdaterBase 並實現自定義實現。

另請參閱

WeightUpdaterBaseRPCDataCollector

all_worker_ids() list[int] | list[torch.device][原始碼]

獲取所有工作程序 ID 的列表。

預設返回 None。子類應覆蓋以返回實際的工作程序 ID。

返回:

工作程序 ID 列表或 None。

返回型別:

list[int] | list[torch.device] | None

property collector: Any | None

接收器的收集器或容器。

如果容器超出範圍或未設定,則返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可選的類方法,用於從策略建立權重更新器例項。

子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。

引數:

policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。

返回:

權重更新器的例項,或者如果策略無法建立例項則為 None。

無法用於建立例項的例項。

返回型別:

WeightUpdaterBase | None

increment_version()

增加策略版本。

init(*args, **kwargs)

使用自定義引數初始化權重更新器。

子類可以覆蓋此方法以處理自定義初始化。預設情況下,這是一個無操作。

引數:
  • *args – 初始化位置引數

  • **kwargs – 初始化關鍵字引數

property post_hooks: list[collections.abc.Callable[[], None]]

註冊到權重更新器的後置鉤子列表。

push_weights(weights: TensorDictBase | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[原始碼]

更新策略的權重,或在指定/所有遠端工作程序上更新。

引數:
  • policy_or_weights – 從中獲取權重的來源。可以是: - TensorDictModuleBase:將提取權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:一個包含權重的普通字典 - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。

  • worker_ids – 要更新的工作程序的可選列表。

返回:無。

register_collector(collector)

在更新器中註冊一個收集器。

註冊後,更新器將不再接受另一個收集器。

引數:

collector (DataCollectorBase) – 要註冊的 collector。

register_post_hook(hook: Callable[[], None])

註冊一個後置鉤子,在權重更新後呼叫。

引數:

hook (Callable[[], None]) – 要註冊的後置鉤子。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源