LLMCollector¶
- class torchrl.collectors.llm.LLMCollector(env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int | None = None, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, total_dialog_turns: int = - 1, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, queue: Any | None = None, track_policy_version: bool | PolicyVersion = False, verbose: bool = False)[原始碼]¶
SyncDataCollector 的簡化版本,用於 LLM 推理。
- 引數:
env (EnvBase 或 EnvBase 建構函式) – 用於資料收集的環境。
- 關鍵字引數:
policy (Callable[[TensorDictBase], TensorDictBase]) – 用於資料收集的策略。
policy_factory (Callable[[], Callable], optional) –
一個可呼叫物件,它返回一個策略例項。這與 policy 引數互斥。
注意
policy_factory 在策略無法序列化時非常有用。
dialog_turns_per_batch (int, optional) – 一個僅關鍵字引數,表示批次中的總元素數。除了 yield_completed_trajectories=True 的情況外,始終需要。
total_dialog_turns (int) – 一個僅關鍵字引數,表示收集器在其生命週期內返回的總步數。-1 表示永不結束(直到關閉)。預設為 -1。
yield_completed_trajectories (bool, optional) –
是否生成具有給定步數的滾滾批次(yield_completed_trajectories=False,預設)或單個、已完成的軌跡(yield_completed_trajectories=True)。預設為 False,除非 yield_only_last_steps=True,此時它不能為 False。
警告
如果環境的 done 狀態未正確設定,這可能導致收集器永遠不產生任何資料。
yield_only_last_steps (bool, optional) –
是否生成軌跡的每一步,還是僅生成最後(已完成)的步驟。如果為 True,一次只生成(或寫入緩衝區)一個軌跡。
警告
如果環境的 done 狀態未正確設定,這可能導致收集器永遠不產生任何資料。
postproc (Callable, optional) – 一個後處理轉換,例如
Transform或MultiStep例項。預設為None。async_envs (bool, optional) – 如果為
True,環境將非同步執行。如果環境是AsyncEnvPool例項,則預設為 True。replay_buffer (ReplayBuffer, optional) – 如果提供,collector 將不會產生 tensordicts,而是填充 buffer。預設為
None。reset_at_each_iter (bool, optional) – 如果為
True,環境將在每次迭代時重置。flatten_data (bool, optional) – 如果為
True,收集器將在返回前展平收集到的資料。實際上,這意味著如果使用批次大小為 (B,) 的環境並執行 T 步,flatten_data=True 將呈現形狀為 (B*T,) 的資料,而 flatten_data=False 將不呈現形狀為 (B, T) 的資料。如果提供了 replay_buffer,則預設為 True,否則預設為 False。weight_updater (WeightUpdaterBase 或 建構函式, optional) –
WeightUpdaterBase或其子類的例項,負責在遠端推理工作器上更新策略權重。在SyncDataCollector中通常不使用此引數,因為它在單程序環境中執行。如果更新器需要序列化,請考慮使用建構函式。track_policy_version (bool 或 PolicyVersion, optional) – 如果為
True,收集器將跟蹤策略的版本。這將由PolicyVersion轉換器進行中介,該轉換器將被新增到環境中。或者,也可以傳遞一個PolicyVersion例項,用於跟蹤策略版本。預設為 False。verbose (bool, optional) – 如果為
True,收集器將列印進度資訊。預設為 False。
示例
>>> import vllm >>> from torchrl.modules import vLLMWrapper >>> from pytorch.rl.test.mocking_classes import DummyStrDataLoader >>> from torchrl.envs import LLMEnv >>> llm_model = vllm.LLM("gpt2") >>> tokenizer = llm_model.get_tokenizer() >>> tokenizer.pad_token = tokenizer.eos_token >>> policy = vLLMWrapper(llm_model) >>> dataloader = DummyStrDataLoader(1) >>> env = LLMEnv.from_dataloader( ... dataloader=dataloader, ... tokenizer=tokenizer, ... from_text=True, ... batch_size=1, ... group_repeats=True, ... ) >>> collector = LLMCollector( ... env=env, ... policy_factory=lambda: policy, ... dialog_turns_per_batch=env.batch_size[0], ... total_dialog_turns=3, ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break LazyStackedTensorDict( fields={ attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), collector: LazyStackedTensorDict( fields={ traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1), done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), text: NonTensorStack( [['plsgqejeyd']], batch_size=torch.Size([1, 1]), device=None), text_response: NonTensorStack( [['ec.n.n.n.tjbjz3perwhz']], batch_size=torch.Size([1, 1]), device=None), tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1) >>> del collector
- classmethod as_remote(remote_config: dict[str, Any] | None = None)¶
建立一個遠端 ray 類的例項。
- 引數:
cls (Python Class) – 要遠端例項化的類。
remote_config (dict) – 為此類保留的 CPU 核心數量。
- 返回:
一個建立 ray 遠端類例項的函式。
- async_shutdown(timeout: float | None = None, close_env: bool = True) None¶
結束 ray.init() 在非同步執行期間啟動的程序。
- property dialog_turns_per_batch: int¶
到 frames_per_batch 的別名。
- get_policy_version() str | int | None[原始碼]¶
獲取當前策略版本。
此方法是為了支援 Ray actor 中的遠端呼叫而存在的,因為屬性不能直接透過 Ray 的 RPC 機制訪問。
- 返回:
當前版本號(int 或 str UUID),或在停用版本跟蹤時為 None。
- init_updater(*args, **kwargs)¶
使用自定義引數初始化權重更新器。
此方法將引數傳遞給權重更新器的 init 方法。如果未設定權重更新器,則此方法無效。
- 引數:
*args – 用於權重更新器初始化的位置引數
**kwargs – 用於權重更新器初始化的關鍵字引數
- iterator() Iterator[TensorDictBase]¶
迭代 DataCollector。
Yields: 包含軌跡 (塊) 的 TensorDictBase 物件
- load_state_dict(state_dict: OrderedDict, **kwargs) None¶
在環境和策略上載入 state_dict。
- 引數:
state_dict (OrderedDict) – 包含 “policy_state_dict” 和
"env_state_dict"欄位的有序字典。
- pause()¶
上下文管理器,如果收集器正在自由執行,則暫停收集器。
- property policy_version: str | int | None¶
當前策略版本。
- reset(index=None, **kwargs) None¶
將環境重置到新的初始狀態。
- property rollout: Callable[[], TensorDictBase]¶
使用提供的策略在環境中計算一次滾滾。
- 返回:
包含已計算滾滾的 TensorDictBase。
- set_seed(seed: int, static_seed: bool = False) int¶
設定 DataCollector 中儲存的環境的種子。
- 引數:
seed (int) – 用於環境的種子整數。
static_seed (bool, optional) – 如果
True,種子不會遞增。預設為 False
- 返回:
輸出種子。當 DataCollector 包含多個環境時,這很有用,因為種子會為每個環境遞增。結果種子是最後一個環境的種子。
示例
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6
- shutdown(timeout: float | None = None, close_env: bool = True) None¶
關閉所有工作器和/或關閉本地環境。
- 引數:
timeout (float, optional) – 關閉工作器之間管道的超時時間。對此類無效。
close_env (bool, optional) – 是否關閉環境。預設為 True。
- start()¶
在單獨的執行緒中啟動收集器以進行非同步資料收集。
收集到的資料儲存在提供的回放緩衝區中。當您想將資料收集與訓練解耦時,此方法非常有用,它允許您的訓練迴圈獨立於資料收集過程執行。
- 丟擲:
RuntimeError – 如果在收集器初始化期間未定義回放緩衝區。
示例
>>> import time >>> from functools import partial >>> >>> import tqdm >>> >>> from torchrl.collectors import SyncDataCollector, RandomPolicy >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py >>> >>> # Set the gym backend to gymnasium >>> set_gym_backend("gymnasium").set() >>> >>> if __name__ == "__main__": ... # Create a random policy for the Pong environment ... env = GymEnv("ALE/Pong-v5") ... policy = RandomPolicy(env.action_spec) ... ... # Initialize a shared replay buffer ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) ... ... # Create a synchronous data collector ... collector = SyncDataCollector( ... env, ... policy=policy, ... replay_buffer=rb, ... frames_per_batch=256, ... total_frames=-1, ... ) ... ... # Progress bar to track the number of collected frames ... pbar = tqdm.tqdm(total=100_000) ... ... # Start the collector asynchronously ... collector.start() ... ... # Track the write count of the replay buffer ... prec_wc = 0 ... while True: ... wc = rb.write_count ... c = wc - prec_wc ... prec_wc = wc ... ... # Update the progress bar ... pbar.update(c) ... pbar.set_description(f"Write Count: {rb.write_count}") ... ... # Check the write count every 0.5 seconds ... time.sleep(0.5) ... ... # Stop when the desired number of frames is reached ... if rb.write_count . 100_000: ... break ... ... # Shut down the collector ... collector.async_shutdown()
- state_dict() OrderedDict¶
返回資料收集器的本地 state_dict(環境和策略)。
- 返回:
包含
"policy_state_dict"和 “env_state_dict” 欄位的有序字典。
- update_policy_weights_(policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None¶
更新資料收集器的策略權重,支援本地和遠端執行上下文。
此方法確保資料收集器使用的策略權重與最新的訓練權重同步。它支援本地和遠端權重更新,具體取決於資料收集器的配置。本地(下載)更新在遠端(上傳)更新之前執行,以便可以將權重從伺服器傳輸到子工作器。
- 引數:
policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None) – 要更新的權重。可以是: - TensorDictModuleBase:將提取其權重的策略模組 - TensorDictBase:包含權重的 TensorDict - dict:包含權重的常規 dict - None:將嘗試使用 _get_server_weights() 從伺服器獲取權重
worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional) – 需要更新的工作器的識別符號。當收集器關聯多個工作器時,此項很重要。
- 丟擲:
TypeError – 如果提供了 worker_ids 但未配置 weight_updater。
注意
使用者應擴充套件 WeightUpdaterBase 類來定製特定用例的權重更新邏輯。不應覆蓋此方法。
另請參閱
LocalWeightsUpdaterBase和RemoteWeightsUpdaterBase()。