快捷方式

DistributedWeightUpdater

class torchrl.collectors.distributed.DistributedWeightUpdater(store: dict[str, str], policy_weights: TensorDictBase, num_workers: int, sync: bool)[source]

一個遠端權重更新器,用於跨分散式工作節點同步策略權重。

DistributedWeightUpdater 類提供了一種機制,用於跨分散式推理工作節點更新策略的權重。它被設計為與 DistributedDataCollector 一起工作,以確保每個工作節點都能獲得最新的策略權重。此類通常用於分散式資料收集場景,其中需要多個工作節點與中央策略權重保持同步。

引數:
  • store (dict[str, str]) – 一個字典式儲存,用於伺服器與分散式工作節點之間的通訊。

  • policy_weights (TensorDictBase) – 需要分發給工作節點的策略的當前權重。

  • num_workers (int) – 將接收更新的策略權重的分散式工作節點的數量。

  • sync (bool) – 如果為 True,則同步發生(伺服器等待工作節點完成更新後才能重新開始執行)。

update_weights()

更新指定或所有分散式工作節點上的權重。

all_worker_ids()[source]

返回所有工作節點識別符號的列表(此類中未實現)。

_sync_weights_with_worker()[source]

將伺服器權重與特定工作節點同步(未實現)。

_get_server_weights()[source]

從伺服器檢索最新權重(未實現)。

_maybe_map_weights()[source]

在分發前可選地對映伺服器權重(未實現)。

注意

此類假定伺服器權重可以直接應用於分散式工作節點,而無需任何額外處理。如果您的用例需要更復雜的權重對映或同步邏輯,請考慮繼承 WeightUpdaterBase 並進行自定義實現。

丟擲:

RuntimeError – 如果工作節點 rank 小於 1 或從儲存返回的狀態不是“updated”。

all_worker_ids() list[int] | list[torch.device][source]

獲取所有工作程序 ID 的列表。

預設返回 None。子類應覆蓋以返回實際的工作程序 ID。

返回:

工作程序 ID 列表或 None。

返回型別:

list[int] | list[torch.device] | None

property collector: Any | None

接收器的收集器或容器。

如果容器超出範圍或未設定,則返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可選的類方法,用於從策略建立權重更新器例項。

子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。

引數:

policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。

返回:

權重更新器的例項,或者如果策略無法建立例項則為 None。

無法用於建立例項的例項。

返回型別:

WeightUpdaterBase | None

increment_version()

增加策略版本。

init(*args, **kwargs)

使用自定義引數初始化權重更新器。

子類可以覆蓋此方法以處理自定義初始化。預設情況下,這是一個無操作。

引數:
  • *args – 初始化位置引數

  • **kwargs – 初始化關鍵字引數

property post_hooks: list[collections.abc.Callable[[], None]]

註冊到權重更新器的後置鉤子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

更新策略的權重,或在指定/所有遠端工作程序上更新。

引數:
  • policy_or_weights – 從中獲取權重的來源。可以是: - TensorDictModuleBase:將提取權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:一個包含權重的普通字典 - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。

  • worker_ids – 要更新的工作程序的可選列表。

返回:無。

register_collector(collector)

在更新器中註冊一個收集器。

註冊後,更新器將不再接受另一個收集器。

引數:

collector (DataCollectorBase) – 要註冊的 collector。

register_post_hook(hook: Callable[[], None])

註冊一個後置鉤子,在權重更新後呼叫。

引數:

hook (Callable[[], None]) – 要註冊的後置鉤子。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源