快捷方式

vLLMUpdaterV2

class torchrl.collectors.llm.vLLMUpdaterV2(vllm_engine: RLvLLMEngine)[源]

使用 RLvLLMEngine 介面的簡化 vLLM 權重更新器。

此更新器可與任何實現 RLvLLMEngine 介面的 vLLM 引擎配合使用,自動提取配置並透過引擎自己的方法處理權重更新。

引數:

vllm_engine – 實現 RLvLLMEngine 介面的 vLLM 引擎。

注意

可以透過 torchrl.collectors.llm.vLLMUpdaterv2=True 來建立此類。

all_worker_ids()[源]

返回工作 ID 列表。

property collector: Any | None

接收器的收集器或容器。

如果容器超出範圍或未設定,則返回None

property collectors: list[Any] | None

收集器或接收者容器。

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

可選的類方法,用於從策略建立權重更新器例項。

子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。

引數:

policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。

返回:

權重更新器的例項,或者如果策略無法建立例項則為 None。

無法用於建立例項的例項。

返回型別:

WeightUpdaterBase | None

classmethod get_model_metadata(model) dict[str, tuple[torch.dtype, torch.Size]][源]

從模型中獲取模型元資料。

引數:

model – 具有 state_dict() 方法的模型(例如,TransformersWrapper)

返回:

引數名稱到 (dtype, shape) 元組的對映

返回型別:

dict

get_tp_size() int[源]

獲取張量並行大小。

increment_version()

增加策略版本。

init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None) None[源]

初始化權重更新器。

引數:

model_metadata – 可選的模型元資料。如果未提供,則使用引擎的元資料。

property post_hooks: list[collections.abc.Callable[[], None]]

註冊到權重更新器的後置鉤子列表。

push_weights(weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase)[源]

將權重推送到 vLLM 引擎。

引數:

weights – (name, tensor) 對的迭代器或 TensorDictBase

push_weights_from_transformers(transformers_model)[源]

從 transformers 模型推送權重。

引數:

transformers_model – Transformers PreTrainedModel 或 TorchRL 包裝器

push_weights_from_transformers_optimized(transformers_model, batch_size=50)[源]

push_weights_from_transformers 的最佳化版本,支援 GPU 預載入。

此方法提供了多項最佳化:1. 在傳輸前將所有權重預載入到 GPU。2. 可選地批處理權重以實現更好的記憶體管理。3. 在可能的情況下使用非阻塞傳輸。

引數:
  • transformers_model – Transformers PreTrainedModel 或 TorchRL 包裝器

  • batch_size – 每次傳輸的權重數量(0 = 不分批)

register_collector(collector)[源]

註冊一個收集器並設定策略版本增量後置鉤子。

引數:

collector – 要註冊的收集器(DataCollectorBase)

register_post_hook(hook: Callable[[], None])

註冊一個後置鉤子,在權重更新後呼叫。

引數:

hook (Callable[[], None]) – 要註冊的後置鉤子。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源