RemoteTransformersWrapper¶
- class torchrl.modules.llm.RemoteTransformersWrapper(model, max_concurrency: int = 16, validate_model: bool = True, actor_name: Optional[str] = None, num_gpus: int = 1, num_cpus: int = 1, **kwargs)[原始碼]¶
一個遠端 Ray actor 包裝器,用於 TransformersWrapper,提供了一個簡化的介面。
此類將 TransformersWrapper 例項包裝為 Ray actor,允許遠端執行,同時提供了一個不需要顯式 remote() 和 get() 呼叫的乾淨介面。
- 引數:
model (str) – 要包裝的 Hugging Face Transformers 模型。必須是一個字串(模型名稱或路徑),它將被傳遞給 transformers.AutoModelForCausalLM.from_pretrained。Transformers 模型不可序列化,因此僅支援模型名稱/路徑。
max_concurrency (int, optional) – 到遠端 actor 的併發呼叫最大數量。預設為 16。
validate_model (bool, optional) – 是否驗證模型。預設為 True。
num_gpus (int, optional) – 要使用的 GPU 數量。預設為 0。
num_cpus (int, optional) – 要使用的 CPU 數量。預設為 0。
**kwargs – 所有其他引數將直接傳遞給 TransformersWrapper。
示例
>>> import ray >>> from torchrl.modules.llm.policies import RemoteTransformersWrapper >>> >>> # Initialize Ray if not already done >>> if not ray.is_initialized(): ... ray.init() >>> >>> # Create remote wrapper >>> remote_wrapper = RemoteTransformersWrapper( ... model="gpt2", ... input_mode="history", ... generate=True, ... generate_kwargs={"max_new_tokens": 50} ... ) >>> >>> # Use like a regular wrapper (no remote/get calls needed) >>> result = remote_wrapper(tensordict_input) >>> print(result["text"].response)
- property batching¶
批處理是否已啟用。
- property collector¶
與模組關聯的 collector。
- property device¶
用於計算的裝置。
- property dist_params_keys¶
分佈引數的鍵。
- property dist_sample_keys¶
分佈樣本的鍵。
- property generate¶
文字生成是否啟用。
- property in_keys¶
輸入鍵。
- property inplace¶
是否使用原地操作。
- property layout¶
輸出張量使用的佈局。
- property log_prob_keys¶
對數機率的鍵。
- property log_probs_key¶
對數機率輸出的鍵。
- property masks_key¶
掩碼輸出的鍵。
- property num_samples¶
要生成的樣本數量。
- property out_keys¶
輸出鍵。
- property pad_output¶
輸出序列是否填充。
- property text_key¶
文字輸出的鍵。
- property tokens_key¶
token 輸出的鍵。