快捷方式

vLLMUpdater

class torchrl.collectors.llm.vLLMUpdater(*args, v2=False, **kwargs)[原始碼]

將權重發送到 vLLM 工作節點的類。

此類負責在訓練策略和 vLLM 推理工作節點之間同步權重。它支援本地 vLLM 例項和遠端 Ray Actor。

引數:
  • master_address (str, optional) – 分散式訓練的主地址。預設為 localhost。

  • master_port (int, optional) – 分散式訓練的主埠。如果為 None,則會自動分配。

  • model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional) – 模型元資料,將引數名稱對映到它們的 dtype 和 shape。如果未提供,將從策略中提取。

  • vllm_tp_size (int, optional) – vLLM 的張量並行大小。預設為 1。

  • v2 (bool, optional) – 如果為 True,則返回 vLLMUpdaterV2 例項。這是一個實驗性功能,提供了與 AsyncVLLM 引擎更好的整合。使用 v2=True 時,必須提供 vllm_engine 引數而不是上述引數。預設為 False。

init()[原始碼]

使用模型元資料初始化更新器並初始化組。

_sync_weights_with_worker()[原始碼]

與 vLLM 工作節點同步權重。

_get_server_weights()[原始碼]

未使用 - 必須直接傳遞權重。

_maybe_map_weights()[原始碼]

無需對映。

all_worker_ids()[原始碼]

返回 [0],因為我們只有一個工作節點。

注意

此類假定策略是一個可由 vLLM 載入的 Transformers 模型。策略必須具有 state_dict() 方法,該方法返回模型權重。

警告

v2=True 選項是實驗性的,在未來的版本中可能會有向後不相容的更改。但是,它通常被認為是與 AsyncVLLM 引擎配合使用的更好選擇,並提供更好的效能和可靠性。

all_worker_ids() list[int][原始碼]

返回 [0],因為我們只有一個工作節點。

property collector: Any | None

接收器的收集器或容器。

如果容器超出範圍或未設定,則返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可選的類方法,用於從策略建立權重更新器例項。

子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。

引數:

policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。

返回:

權重更新器的例項,或者如果策略無法建立例項則為 None。

無法用於建立例項的例項。

返回型別:

WeightUpdaterBase | None

classmethod get_model_metadata(model: TensorDictModuleBase) dict[str, tuple[torch.dtype, torch.Size]][原始碼]

從模型中獲取模型元資料。

引數:

model (TensorDictModuleBase) – 要從中獲取元資料的模型。必須是 TransformersWrapper 或同等模型。

返回:

模型元資料。

返回型別:

dict[str, tuple[torch.dtype, torch.Size]]

increment_version()

增加策略版本。

init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) None[原始碼]

使用模型元資料初始化更新器並初始化組。

引數:

model_metadata (dict[str, tuple[torch.dtype, torch.Size]]) – 模型元資料,將引數名稱對映到它們的 dtype 和 shape。

property post_hooks: list[collections.abc.Callable[[], None]]

註冊到權重更新器的後置鉤子列表。

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

更新策略的權重,或在指定/所有遠端工作程序上更新。

引數:
  • policy_or_weights – 從中獲取權重的來源。可以是: - TensorDictModuleBase:將提取權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:一個包含權重的普通字典 - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。

  • worker_ids – 要更新的工作程序的可選列表。

返回:無。

register_collector(collector: DataCollectorBase)[原始碼]

在更新器中註冊一個收集器。

註冊後,更新器將不再接受另一個收集器。

引數:

collector (DataCollectorBase) – 要註冊的 collector。

register_post_hook(hook: Callable[[], None])

註冊一個後置鉤子,在權重更新後呼叫。

引數:

hook (Callable[[], None]) – 要註冊的後置鉤子。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源