• 文件 >
  • 開始資料收集和儲存
快捷方式

使用資料收集和儲存入門

作者Vincent Moens

注意

要在 notebook 中執行本教程,請在開頭新增一個安裝單元格,其中包含:

!pip install tensordict
!pip install torchrl
import tempfile

沒有資料就沒有學習。在監督學習中,使用者習慣於使用 DataLoader 等工具將資料整合到其訓練迴圈中。DataLoader 是可迭代物件,它們為您提供將用於訓練模型的資料。

TorchRL 在資料載入問題上的方法類似,儘管它在 RL 庫生態系統中卻出奇地獨特。TorchRL 的資料載入器被稱為 DataCollectors。大多數時候,資料收集並不僅限於原始資料的收集,因為資料需要在臨時緩衝區(或對線上策略演算法的等效結構)中儲存,然後才能被 損失模組 消耗。本教程將探討這兩個類。

資料收集器

這裡討論的主要資料收集器是 SyncDataCollector,這是本檔的重點。從根本上說,收集器是一個簡單的類,負責在環境中執行您的策略、在必要時重置環境,並提供預定義大小的批次。與 環境教程 中演示的 rollout() 方法不同,收集器在連續的資料批次之間不會重置。因此,兩個連續的資料批次可能包含來自同一軌跡的元素。

您需要傳遞給收集器的基本引數是您想要收集的批次大小(frames_per_batch)、迭代器的長度(可能無限)、策略和環境。為簡單起見,我們將在示例中使用一個模擬的隨機策略。

import torch

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy

torch.manual_seed(0)

env = GymEnv("CartPole-v1")
env.set_seed(0)

policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

我們現在期望我們的收集器將提供大小為 200 的批次,無論收集過程中發生什麼。換句話說,我們這個批次中可能包含多個軌跡!total_frames 表示收集器應該有多長。值為 -1 將產生一個永不結束的收集器。

讓我們迭代收集器,瞭解一下這些資料是什麼樣的

for data in collector:
    print(data)
    break

正如您所見,我們的資料被添加了一些收集器特定的元資料,這些元資料被分組在一個 "collector" 子 Tensordict 中,我們在 環境回放 中沒有看到這些。這對於跟蹤軌跡 ID 非常有用。在下面的列表中,每個專案都標記了對應轉換所屬的軌跡編號。

print(data["collector", "traj_ids"])

資料收集器在編寫最先進的演算法時非常有用,因為效能通常透過特定技術在與環境互動的給定次數內解決問題的能力來衡量(收集器中的 total_frames 引數)。因此,我們示例中的大多數訓練迴圈如下所示:

..code - block::Python

>>> for data in collector:
...     # your algorithm here

回放緩衝區

既然我們已經探討了如何收集資料,我們想知道如何儲存它。在 RL 中,典型的情況是資料被收集、臨時儲存,並在一段時間後根據某種啟發式方法清除:先進先出或其他。典型的虛擬碼如下:

..code - block::Python

>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() # etc

TorchRL 中儲存資料的父類被稱為 ReplayBuffer。TorchRL 的回放緩衝區是可組合的:您可以編輯儲存型別、取樣技術、寫入啟發式方法或應用於它們的轉換。我們將在專門的深入教程中介紹這些花哨的功能。通用的回放緩衝區只需要知道要使用哪種儲存。通常,我們推薦使用 TensorStorage 子類,它在大多數情況下都能正常工作。在本教程中,我們將使用 LazyMemmapStorage,它具有兩個優點:首先,“懶惰”,您無需提前顯式告訴它您的資料外觀。其次,它使用 MemoryMappedTensor 作為後端,以高效的方式將資料儲存到磁碟。您唯一需要知道的是您想要的緩衝區大小。

from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name

buffer = ReplayBuffer(
    storage=LazyMemmapStorage(max_size=1000, scratch_dir=buffer_scratch_dir)
)

可以透過 add()(單個元素)或 extend()(多個元素)方法來填充緩衝區。使用我們剛剛收集的資料,我們一次性初始化和填充緩衝區。

indices = buffer.extend(data)

我們可以檢查緩衝區現在擁有的元素數量與我們從收集器獲得的數量相同。

assert len(buffer) == collector.frames_per_batch

唯一需要了解的是如何從緩衝區收集資料。自然,這依賴於 sample() 方法。因為我們沒有指定取樣必須無重複進行,所以無法保證從緩衝區收集的樣本是唯一的。

sample = buffer.sample(batch_size=30)
print(sample)

再次,我們的樣本看起來與我們從收集器收集的資料完全一樣!

後續步驟

  • 您可以檢視其他多程序收集器,例如 MultiSyncDataCollectorMultiaSyncDataCollector

  • TorchRL 還提供分散式收集器,如果您有多個節點用於推理。請在 API 參考 中檢視它們。

  • 檢視專門的 回放緩衝區教程 以瞭解構建緩衝區時可用的選項,或者 API 參考,其中詳細介紹了所有功能。回放緩衝區擁有無數的功能,例如多執行緒取樣、優先經驗回放等等……

  • 為了簡單起見,我們省略了回放緩衝區可迭代的能力。您可以自己嘗試:構建一個緩衝區並在建構函式中指定其批次大小,然後嘗試迭代它。這相當於在一個迴圈中呼叫 rb.sample()

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源