torchrl.collectors 包¶
資料收集器在某種程度上等同於 PyTorch 的資料載入器,除了 (1) 它們從非靜態資料來源收集資料,以及 (2) 資料是使用模型(很可能是正在訓練的模型的一個版本)收集的。
TorchRL 的資料收集器接受兩個主要引數:一個環境(或一組環境建構函式)和一個策略。它們將在定義的步數內迭代地執行一個環境步驟和一個策略查詢,然後將收集到的資料堆疊提供給使用者。當環境達到完成狀態和/或達到預定義的步數後,環境將被重置。
由於資料收集是一個潛在的計算密集型過程,因此適當配置執行超引數至關重要。需要考慮的第一個引數是資料收集應該與最佳化步驟序列發生還是並行發生。SyncDataCollector 類將在訓練工作程序上執行資料收集。MultiSyncDataCollector 將工作負載分配給多個工作程序,並彙總將提供給訓練工作程序的結果。最後,MultiaSyncDataCollector 將在多個工作程序上執行資料收集,並提供它能收集到的第一批結果。此執行將連續不斷地發生,並與網路訓練同時進行:這意味著用於資料收集的策略權重可能略滯後於訓練工作程序上的策略配置。因此,儘管此類收集資料的速度可能最快,但其代價是僅適用於可以非同步收集資料的設定(例如,離策略 RL 或課程 RL)。對於遠端執行的 rollout(MultiSyncDataCollector 或 MultiaSyncDataCollector),有必要使用 collector.update_policy_weights_() 或在建構函式中設定 update_at_each_batch=True 來同步遠端策略的權重與訓練工作程序上的權重。
第二個要考慮的引數(在遠端設定中)是資料收集的裝置以及執行環境和策略操作的裝置。例如,在 CPU 上執行的策略可能比在 CUDA 上執行的策略慢。當多個推理工作程序同時執行時,跨可用裝置分派計算工作負載可能會加快收集速度或避免 OOM 錯誤。最後,批次大小和傳遞裝置(即等待傳遞給收集工作程序的資料的儲存裝置)的選擇也可能影響記憶體管理。控制的關鍵引數是 devices,它控制執行裝置(即策略的裝置),以及 storing_device,它控制在 rollout 期間儲存環境和資料的裝置。一個好的經驗法則是通常使用相同的裝置進行儲存和計算,當僅傳遞 devices 引數時,這是預設行為。
除了這些計算引數外,使用者還可以選擇配置以下引數
max_frames_per_traj:在呼叫
env.reset()之後的幀數frames_per_batch:每次迭代收集器提供的幀數
init_random_frames:隨機步數(呼叫
env.rand_step()的步數)reset_at_each_iter:如果為
True,則在每次批次收集後重置環境。split_trajs:如果為
True,則軌跡將被拆分,並以填充的 tensordict 和一個"mask"鍵的形式提供,該鍵將指向表示有效值的布林掩碼。exploration_type:與策略一起使用的探索策略。
reset_when_done:是否在達到完成狀態時重置環境。
收集器和批次大小¶
由於每個收集器都有其組織內部執行的環境的方式,因此根據收集器的具體情況,資料將具有不同的批次大小。下表總結了收集資料時的情況。
SyncDataCollector |
MultiSyncDataCollector (n=B) |
MultiaSyncDataCollector (n=B) |
|||
|---|---|---|---|---|---|
cat_results |
NA |
“stack” |
0 |
-1 |
NA |
單環境 |
[T] |
[B, T] |
[B*(T//B) |
[B*(T//B)] |
[T] |
批處理環境 (n=P) |
[P, T] |
[B, P, T] |
[B * P, T] |
[P, T * B] |
[P, T] |
在所有這些情況下,最後一個維度(T 表示 time)會進行調整,以使批次大小等於傳遞給收集器的 frames_per_batch 引數。
警告
MultiSyncDataCollector 不應與 cat_results=0 一起使用,因為資料將與批處理環境一起沿著批次維度堆疊,或者對於單環境則沿著時間維度堆疊,這在兩者之間進行切換時可能會引起混淆。cat_results="stack" 是一種更好、更一致的與環境互動的方式,因為它會使每個維度保持獨立,並提供更好的配置、收集器類和其他元件之間的可互換性。
而 MultiSyncDataCollector 有一個對應於正在執行的子收集器數量的維度(B),而 MultiaSyncDataCollector 則沒有。這一點很容易理解,因為 MultiaSyncDataCollector 是基於“先到先得”的原則提供資料批次的,而 MultiSyncDataCollector 則在提供資料之前會從每個子收集器收集資料。
收集器和策略副本¶
當將策略傳遞給收集器時,我們可以選擇策略執行的裝置。這可以用於將策略的訓練版本放在一個裝置上,而將推理版本放在另一個裝置上。例如,如果您有兩個 CUDA 裝置,最好在一個裝置上訓練,並在另一個裝置上執行策略進行推理。如果是這種情況,可以使用 update_policy_weights_() 將引數從一個裝置複製到另一個裝置(如果不需要複製,則此方法無效)。
由於目的是避免顯式呼叫 policy.to(policy_device),因此收集器將在例項化時對策略結構進行深度複製,並將引數放在新裝置上(如果需要)。由於並非所有策略都支援深度複製(例如,使用 CUDA 圖或依賴第三方庫的策略),因此我們儘量限制執行深度複製的情況。以下圖表顯示了何時會發生這種情況。
收集器中的策略複製決策樹。¶
分散式環境中的權重同步¶
在分散式和多程序環境中,確保所有策略例項都與最新的訓練權重同步對於保持效能一致至關重要。API 引入了一種靈活且可擴充套件的機制,用於在不同裝置和程序之間更新策略權重,以適應各種部署場景。
使用 WeightUpdaters 傳送和接收模型權重¶
權重同步過程透過一個專用的擴充套件點進行協調:WeightUpdaterBase。這個基類提供了一個結構化的介面來實現自定義權重更新邏輯,允許使用者根據自己的具體需求定製同步過程。
WeightUpdaterBase 負責將策略權重分發給策略或遠端推理工作程序,以及在必要時從伺服器格式化/收集權重。每個收集器(伺服器或工作程序)都應有一個 WeightUpdaterBase 例項來處理與策略的權重同步。即使是最簡單的收集器也使用 VanillaWeightUpdater 例項來更新策略的 state_dict(假設它是一個 Module 例項)。
擴充套件 Updater 類¶
為了適應不同的用例,API 允許使用者擴充套件 updater 類並進行自定義實現。目標是能夠自定義權重同步策略,同時不修改收集器和策略的實現。這種靈活性在涉及複雜網路架構或專用硬體設定的場景中尤其有益。透過實現這些基類中的抽象方法,使用者可以定義如何檢索、轉換和應用權重,確保與他們現有的基礎設施無縫整合。
用於在推理工作程序上更新遠端策略權重的基類。 |
|
|
用於更新本地策略權重的 |
|
用於跨多個程序或裝置同步策略權重的遠端權重更新器。 |
|
使用 Ray 在遠端工作程序之間同步策略權重的遠端權重更新器。 |
|
使用 RPC 在遠端工作程序之間同步策略權重的遠端權重更新器。 |
|
用於在分散式工作程序之間同步策略權重的遠端權重更新器。 |
收集器與回放緩衝區互操作性¶
在需要從回放緩衝區取樣單個轉換的最簡單場景中,幾乎不需要關注收集器的構建方式。在填充儲存之前,將資料展平作為預處理步驟就足夠了。
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... transform=lambda data: data.reshape(-1))
>>> for data in collector:
... memory.extend(data)
如果需要收集軌跡切片,推薦的方法是建立一個多維緩衝區,並使用 SliceSampler 取樣器類進行取樣。必須確保傳遞給緩衝區的資料形狀正確,並且 time 和 batch 維度清晰分開。實際上,以下配置將起作用。
>>> # Single environment: no need for a multi-dimensional buffer
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # Batched environments: a multi-dim buffer is required
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> env = ParallelEnv(4, make_env)
>>> collector = SyncDataCollector(env, policy, frames_per_batch=N, total_frames=-1)
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + regular env: behaves like a ParallelEnv if cat_results="stack"
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=2),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([make_env] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)
>>> # MultiSyncDataCollector + parallel env: the ndim must be adapted accordingly
>>> memory = ReplayBuffer(
... storage=LazyTensorStorage(N, ndim=3),
... sampler=SliceSampler(num_slices=4, trajectory_key=("collector", "traj_ids"))
... )
>>> collector = MultiSyncDataCollector([ParallelEnv(2, make_env)] * 4,
... policy,
... frames_per_batch=N,
... total_frames=-1,
... cat_results="stack")
>>> for data in collector:
... memory.extend(data)
使用 MultiSyncDataCollector 取樣軌跡的回放緩衝區目前不受完全支援,因為資料批次可能來自任何工作程序,並且在大多數情況下,寫入緩衝區中的連續批次不會來自同一來源(從而中斷了軌跡)。
非同步執行收集器¶
將回放緩衝區傳遞給收集器可以讓我們開始收集,並擺脫收集器的迭代特性。如果你想在後臺執行資料收集器,只需執行 start()。
>>> collector = SyncDataCollector(..., replay_buffer=rb) # pass your replay buffer
>>> collector.start()
>>> # little pause
>>> time.sleep(10)
>>> # Start training
>>> for i in range(optim_steps):
... data = rb.sample() # Sampling from the replay buffer
... # rest of the training loop
單程序收集器(SyncDataCollector)將使用多執行緒執行程序,因此請注意 Python 的 GIL 和相關的多執行緒限制。
另一方面,多程序收集器將允許子程序自己處理緩衝區的填充,從而真正解耦資料收集和訓練。
使用 start() 啟動的資料收集器應使用 async_shutdown() 關閉。
警告
非同步執行收集器可以將收集與訓練解耦,這意味著訓練效能可能因硬體、負載和其他因素而有很大差異(儘管通常預期會提供顯著的速度提升)。請確保您瞭解這可能如何影響您的演算法,以及這是否是合理的做法!(例如,PPO 等 on-policy 演算法不應非同步執行,除非經過適當的基準測試)。
單節點資料收集器¶
資料收集器的基類。 |
|
|
RL 問題的通用資料收集器。 |
|
在單獨的程序中同步執行給定數量的 DataCollectors。 |
|
在單獨的程序中非同步執行給定數量的 DataCollectors。 |
|
在單獨的程序中執行單個 DataCollector。 |
分散式資料收集器¶
TorchRL 提供了一系列分散式資料收集器。這些工具支援多種後端('gloo'、'nccl'、'mpi'(使用 DistributedDataCollector)或 PyTorch RPC(使用 RPCDataCollector))和啟動器('ray'、submitit 或 torch.multiprocessing)。它們可以在同步或非同步模式下,在單節點或跨多個節點上高效使用。
資源:在專用資料夾中查詢這些收集器的示例。
注意
選擇子收集器:所有分散式收集器都支援各種單機收集器。人們可能會想為什麼還要使用 MultiSyncDataCollector 或 ParallelEnv。總的來說,多程序收集器的 IO 開銷比並行環境低,因為並行環境需要在每一步進行通訊。然而,模型規格在反方向上起作用,因為使用並行環境將導致策略(和/或轉換)執行速度更快,因為這些操作將被向量化。
注意
選擇收集器(或並行環境)的裝置:程序間資料共享是透過共享記憶體緩衝區實現的,並行環境和在 CPU 上執行的多程序環境。根據所用機器的功能,這可能比在 GPU 上共享資料(由 CUDA 驅動程式原生支援)慢得令人無法接受。實際上,這意味著在構建並行環境或收集器時使用 device="cpu" 關鍵字引數可能會比在可用時使用 device="cuda" 導致收集速度變慢。
注意
考慮到該庫的許多可選依賴項(例如,Gym、Gymnasium 以及許多其他庫),在多程序/分散式環境中,警告可能會很快變得非常煩人。預設情況下,TorchRL 會在子程序中過濾掉這些警告。如果仍然希望看到這些警告,可以透過設定 torchrl.filter_warnings_subprocess=False 來顯示它們。
|
帶有 torch.distributed 後端的分散式資料收集器。 |
|
基於 RPC 的分散式資料收集器。 |
|
帶有 torch.distributed 後端的分散式同步資料收集器。 |
|
Submitit 的延遲啟動器。 |
|
帶有 Ray 後端的分散式資料收集器。 |
輔助函式¶
|
用於軌跡分離的實用函式。 |