快捷方式

RayLLMCollector

class torchrl.collectors.llm.RayLLMCollector(env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int, total_dialog_turns: int = - 1, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, ray_init_config: dict[str, Any] | None = None, remote_config: dict[str, Any] | None = None, track_policy_version: bool | PolicyVersion = False, sync_iter: bool = True, verbose: bool = False, num_cpus: int | None = None, num_gpus: int | None = None)[source]

一個輕量級的 Ray 實現的 LLM Collector,可以遠端擴充套件和取樣。

引數:

env (EnvBaseEnvBase 建構函式) – 用於資料收集的環境。

關鍵字引數:
  • policy (Callable[[TensorDictBase], TensorDictBase]) – 用於資料收集的策略。

  • policy_factory (Callable[[], Callable], optional) – 一個返回策略例項的可呼叫物件。這與 policy 引數互斥。

  • dialog_turns_per_batch (int) – 一個關鍵字引數,表示批次中的總元素數量。

  • total_dialog_turns (int) – 一個關鍵字引數,表示收集器在其生命週期內返回的總對話輪次數。

  • yield_only_last_steps (bool, optional) – 是否生成軌跡的每一步,還是隻生成最後(完成)的步驟。

  • yield_completed_trajectories (bool, optional) – 是生成具有給定步數的 rollout 批次,還是生成單個、完整的軌跡。

  • postproc (Callable, optional) – 一個後處理轉換。

  • async_envs (bool, optional) – 如果為 True,環境將非同步執行。

  • replay_buffer (ReplayBuffer, optional) – 如果提供,收集器將不會生成 tensordicts,而是填充緩衝區。

  • reset_at_each_iter (bool, optional) – 如果為 True,環境將在每次迭代時重置。

  • flatten_data (bool, optional) – 如果為 True,收集器將在返回前展平收集到的資料。

  • weight_updater (WeightUpdaterBase建構函式, optional) – WeightUpdaterBase 例項或其子類,負責在遠端推理工作器上更新策略權重。

  • ray_init_config (dict[str, Any], optional) – 傳遞給 ray.init() 的關鍵字引數。

  • remote_config (dict[str, Any], optional) – 傳遞給 cls.as_remote() 的關鍵字引數。

  • num_cpus (int, optional) – Actor 的 CPU 數量。預設為 None (從 remote_config 獲取)。

  • num_gpus (int, optional) – Actor 的 GPU 數量。預設為 None (從 remote_config 獲取)。

  • sync_iter (bool, optional) –

    如果為 True,收集器生成的專案將被同步到本地程序。如果為 False,收集器將在生成之間收集下一批資料。這在透過 start() 方法收集資料時無效。例如

    >>> collector = RayLLMCollector(..., sync_iter=True)
    >>> for data in collector:  # blocking
    ...     # expensive operation - collector is idle
    >>> collector = RayLLMCollector(..., sync_iter=False)
    >>> for data in collector:  # non-blocking
    ...     # expensive operation - collector is collecting data
    

    這在某種程度上等同於使用 MultiSyncDataCollector (sync_iter=True) 或 MultiAsyncDataCollector (sync_iter=False) 。預設為 True

  • verbose (bool, optional) – 如果為 True,收集器將列印進度資訊。預設為 False

classmethod as_remote(remote_config: dict[str, Any] | None = None)

建立一個遠端 ray 類的例項。

引數:
  • cls (Python Class) – 要遠端例項化的類。

  • remote_config (dict) – 為此類保留的 CPU 核心數量。

返回:

一個建立 ray 遠端類例項的函式。

async_shutdown(timeout=None)[source]

非同步關閉收集器。

property dialog_turns_per_batch: int

每個批次的對話輪次數。

get_policy_model()

獲取策略模型。

RayLLMCollector 使用此方法來獲取用於權重更新的遠端 LLM 例項。

返回:

策略模型例項

get_policy_version() str | int | None

獲取當前策略版本。

此方法用於支援 Ray actor 中的遠端呼叫,因為屬性無法透過 Ray 的 RPC 機制直接訪問。

返回:

當前版本號(整數)或 UUID(字串),如果版本跟蹤已停用則為 None。

increment_version()[source]

增加策略版本。

init_updater(*args, **kwargs)[source]

使用自定義引數初始化權重更新器。

此方法呼叫遠端收集器上的 init_updater。

引數:
  • *args – 用於權重更新器初始化的位置引數

  • **kwargs – 用於權重更新器初始化的關鍵字引數

is_initialized() bool

檢查收集器是否已初始化並準備好。

返回:

如果收集器已初始化並準備好收集資料,則為 True。

返回型別:

布林值

iterator() Iterator[TensorDictBase]

迭代 DataCollector。

Yields: 包含軌跡 (塊) 的 TensorDictBase 物件

load_state_dict(state_dict: OrderedDict, **kwargs) None

在環境和策略上載入 state_dict。

引數:

state_dict (OrderedDict) – 包含 “policy_state_dict”"env_state_dict" 欄位的有序字典。

next() None[source]

從收集器獲取下一批資料。

返回:

None,因為資料直接寫入回放緩衝區。

pause()

上下文管理器,如果收集器正在自由執行,則暫停收集器。

property policy_version: str | int | None

策略的當前版本。

返回:

當前版本號(整數)或 UUID(字串),如果版本跟蹤已停用則為 None。

reset(index=None, **kwargs) None

將環境重置到新的初始狀態。

property rollout: Callable[[], TensorDictBase]

返回 rollout 函式。

set_seed(seed: int, static_seed: bool = False) int

設定 DataCollector 中儲存的環境的種子。

引數:
  • seed (int) – 用於環境的種子整數。

  • static_seed (bool, optional) – 如果 True,種子不會遞增。預設為 False

返回:

輸出種子。當 DataCollector 包含多個環境時,這很有用,因為種子會為每個環境遞增。結果種子是最後一個環境的種子。

示例

>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = ParallelEnv(6, env_fn)
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
>>> out_seed = collector.set_seed(1)  # out_seed = 6
shutdown()[source]

關閉收集器。

start()[source]

在後臺執行緒中啟動收集器。

state_dict() OrderedDict

返回資料收集器的本地 state_dict(環境和策略)。

返回:

包含 "policy_state_dict"“env_state_dict” 欄位的有序字典。

property total_dialog_turns

要收集的總對話輪次數。

update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, **kwargs)[source]

在遠端工作器上更新策略權重。

引數:
  • policy_or_weights – 要更新的權重。可以是: - TensorDictModuleBase:一個將提取其權重的策略模組 - TensorDictBase:一個包含權重的 TensorDict - dict:一個包含權重的常規 dict - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重。

  • worker_ids – 要更新的工作器。如果為 None,則更新所有工作器。

property weight_updater: WeightUpdaterBase

權重更新器例項。

我們可以傳遞權重更新器,因為它是無狀態的,因此是可序列化的。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源