快捷方式

RemoteTransformersWrapper

class torchrl.modules.llm.RemoteTransformersWrapper(model, max_concurrency: int = 16, validate_model: bool = True, actor_name: Optional[str] = None, num_gpus: int = 1, num_cpus: int = 1, **kwargs)[原始碼]

一個遠端 Ray actor 包裝器,用於 TransformersWrapper,提供了一個簡化的介面。

此類將 TransformersWrapper 例項包裝為 Ray actor,允許遠端執行,同時提供了一個不需要顯式 remote()get() 呼叫的乾淨介面。

引數:
  • model (str) – 要包裝的 Hugging Face Transformers 模型。必須是一個字串(模型名稱或路徑),它將被傳遞給 transformers.AutoModelForCausalLM.from_pretrained。Transformers 模型不可序列化,因此僅支援模型名稱/路徑。

  • max_concurrency (int, optional) – 到遠端 actor 的併發呼叫最大數量。預設為 16。

  • validate_model (bool, optional) – 是否驗證模型。預設為 True。

  • num_gpus (int, optional) – 要使用的 GPU 數量。預設為 0。

  • num_cpus (int, optional) – 要使用的 CPU 數量。預設為 0。

  • **kwargs – 所有其他引數將直接傳遞給 TransformersWrapper。

示例

>>> import ray
>>> from torchrl.modules.llm.policies import RemoteTransformersWrapper
>>>
>>> # Initialize Ray if not already done
>>> if not ray.is_initialized():
...     ray.init()
>>>
>>> # Create remote wrapper
>>> remote_wrapper = RemoteTransformersWrapper(
...     model="gpt2",
...     input_mode="history",
...     generate=True,
...     generate_kwargs={"max_new_tokens": 50}
... )
>>>
>>> # Use like a regular wrapper (no remote/get calls needed)
>>> result = remote_wrapper(tensordict_input)
>>> print(result["text"].response)
property batching

批處理是否已啟用。

cleanup_batching()[原始碼]

清理批處理資源。

property collector

與模組關聯的 collector。

property device

用於計算的裝置。

property dist_params_keys

分佈引數的鍵。

property dist_sample_keys

分佈樣本的鍵。

property generate

文字生成是否啟用。

get_batching_state()[原始碼]

獲取當前的批處理狀態。

get_dist(tensordict, **kwargs)[原始碼]

使用可選的掩碼獲取分佈(從 logits/log-probs)。

get_dist_with_prompt_mask(tensordict, **kwargs)[原始碼]

獲取僅包含響應 token(排除提示)的分佈。

get_new_version(**kwargs)[原始碼]

獲取具有更改引數的新版本包裝器。

property in_keys

輸入鍵。

property inplace

是否使用原地操作。

property layout

輸出張量使用的佈局。

log_prob(data, **kwargs)[原始碼]

計算對數機率。

property log_prob_keys

對數機率的鍵。

property log_probs_key

對數機率輸出的鍵。

property masks_key

掩碼輸出的鍵。

property num_samples

要生成的樣本數量。

property out_keys

輸出鍵。

property pad_output

輸出序列是否填充。

property text_key

文字輸出的鍵。

property tokens_key

token 輸出的鍵。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源