vLLMUpdaterV2¶
- class torchrl.collectors.llm.vLLMUpdaterV2(vllm_engine: RLvLLMEngine)[源]¶
使用 RLvLLMEngine 介面的簡化 vLLM 權重更新器。
此更新器可與任何實現 RLvLLMEngine 介面的 vLLM 引擎配合使用,自動提取配置並透過引擎自己的方法處理權重更新。
- 引數:
vllm_engine – 實現 RLvLLMEngine 介面的 vLLM 引擎。
注意
可以透過
torchrl.collectors.llm.vLLMUpdater和 v2=True 來建立此類。- property collector: Any | None¶
接收器的收集器或容器。
如果容器超出範圍或未設定,則返回None。
- property collectors: list[Any] | None¶
收集器或接收者容器。
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None¶
可選的類方法,用於從策略建立權重更新器例項。
子類可以實現此方法以提供基於策略的自定義初始化邏輯。如果實現,將在收集器中初始化權重更新器時呼叫此方法,然後再回退到預設建構函式。
- 引數:
policy (TensorDictModuleBase) – 要從中建立權重更新器的策略。
- 返回:
- 權重更新器的例項,或者如果策略無法建立例項則為 None。
無法用於建立例項的例項。
- 返回型別:
WeightUpdaterBase | None
- classmethod get_model_metadata(model) dict[str, tuple[torch.dtype, torch.Size]][源]¶
從模型中獲取模型元資料。
- 引數:
model – 具有 state_dict() 方法的模型(例如,TransformersWrapper)
- 返回:
引數名稱到 (dtype, shape) 元組的對映
- 返回型別:
dict
- increment_version()¶
增加策略版本。
- init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None) None[源]¶
初始化權重更新器。
- 引數:
model_metadata – 可選的模型元資料。如果未提供,則使用引擎的元資料。
- property post_hooks: list[collections.abc.Callable[[], None]]¶
註冊到權重更新器的後置鉤子列表。
- push_weights(weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase)[源]¶
將權重推送到 vLLM 引擎。
- 引數:
weights – (name, tensor) 對的迭代器或 TensorDictBase
- push_weights_from_transformers(transformers_model)[源]¶
從 transformers 模型推送權重。
- 引數:
transformers_model – Transformers PreTrainedModel 或 TorchRL 包裝器
- push_weights_from_transformers_optimized(transformers_model, batch_size=50)[源]¶
push_weights_from_transformers 的最佳化版本,支援 GPU 預載入。
此方法提供了多項最佳化:1. 在傳輸前將所有權重預載入到 GPU。2. 可選地批處理權重以實現更好的記憶體管理。3. 在可能的情況下使用非阻塞傳輸。
- 引數:
transformers_model – Transformers PreTrainedModel 或 TorchRL 包裝器
batch_size – 每次傳輸的權重數量(0 = 不分批)
- register_post_hook(hook: Callable[[], None])¶
註冊一個後置鉤子,在權重更新後呼叫。
- 引數:
hook (Callable[[], None]) – 要註冊的後置鉤子。