• 文件 >
  • 使用回放緩衝區
快捷方式

使用回放緩衝區

作者Vincent Moens

回放緩衝區是任何強化學習或控制演算法的核心組成部分。監督學習方法通常的特點是訓練迴圈,其中資料從靜態資料集中隨機抽取,並依次輸入模型和損失函式。在強化學習中,情況通常略有不同:資料是透過模型收集的,然後臨時儲存在動態結構(經驗回放緩衝區)中,該結構作為損失模組的資料集。

與往常一樣,緩衝區使用的上下文會極大地影響其構建方式:有些人可能希望儲存軌跡,而另一些人則希望儲存單個轉換。在某些情況下,特定的取樣策略可能更可取:某些項可能比其他項具有更高的優先順序,或者以替換或不替換的方式進行取樣可能很重要。計算因素也可能發揮作用,例如緩衝區的大小可能超過可用的 RAM 儲存。

因此,TorchRL 的回放緩衝區是完全可組合的:雖然它們“開箱即用”,需要最少的精力來構建,但它們也支援許多自定義,例如儲存型別、取樣策略或資料轉換。

在本教程中,您將學習

基礎:構建一個標準的replay buffer

TorchRL 的回放緩衝區設計旨在優先考慮模組化、可組合性、效率和簡潔性。例如,建立一個基本的回放緩衝區是一個簡單的過程,如下面的示例所示

import gc

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

預設情況下,此回放緩衝區的尺寸為 1000。讓我們透過使用 extend() 方法填充我們的緩衝區來檢查這一點

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))

我們使用了 extend() 方法,該方法設計用於一次新增多個項。如果傳遞給 extend 的物件具有多個維度,則其第一個維度將被視為在緩衝區中拆分為單獨的元素。

這本質上意味著,在將多維張量或 tensordict 新增到緩衝區時,緩衝區在計算其記憶體中的元素數量時只考慮第一個維度。如果傳遞的物件不可迭代,則會丟擲異常。

要一次新增一個專案,應改用 add() 方法。

自定義儲存

我們看到緩衝區被限制為我們傳遞給它的前 1000 個元素。要更改大小,我們需要自定義我們的儲存。

TorchRL 提供三種類型的儲存

  • ListStorage 將元素獨立地儲存在列表中。它支援任何資料型別,但這種靈活性以效率為代價;

  • LazyTensorStorage 將張量資料結構連續儲存。它自然地與 TensorDict(或 tensorclass)物件配合使用。儲存是按張量連續儲存的,這意味著取樣效率將高於使用列表時,但隱含的限制是傳遞給它的任何資料都必須具有與用於例項化緩衝區的第一個批次資料相同的基本屬性(如形狀和 dtype)。傳遞不符合此要求的資料將引發異常或導致某些未定義行為。

  • LazyMemmapStorage 的工作方式類似於 LazyTensorStorage,因為它也是惰性的(即,它期望第一個批次資料用於例項化),並且它要求每個儲存的批次具有匹配的形狀和 dtype 的資料。這種儲存的獨特之處在於它指向磁碟檔案(或使用檔案系統儲存),這意味著它可以支援非常大的資料集,同時仍然以連續的方式訪問資料。

讓我們看看如何使用這些儲存中的每一種

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

帶有列表儲存的緩衝區可以儲存任何型別的資料(但我們必須更改 collate_fn,因為預設值需要數值資料)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))

由於 ListStorage 的假設最少,因此它是 TorchRL 中的預設儲存。

一個 LazyTensorStorage 可以連續儲存資料。在處理中等大小的複雜但不改變的資料結構時,這應該是首選選項。

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

讓我們建立一個包含 2 個儲存張量的大小為 torch.Size([3]) 的資料批次。

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)

第一次呼叫 extend() 將例項化儲存。資料的第一個維度被解綁為單獨的資料點。

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")

讓我們從緩衝區取樣並列印資料。

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])

一個 LazyMemmapStorage 以相同的方式建立。我們也可以自定義磁碟上的儲存位置。

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = ReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir)
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    print(
        "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
    )
    print(
        "the ('b', 'c') tensor is stored in",
        buffer_lazymemmap._storage._storage["b", "c"].filename,
    )
    sample = buffer_lazytensor.sample(5)
    print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
    del buffer_lazymemmap

與 TensorDict 整合

張量位置遵循與包含它們的 TensorDict 相同的結構:這使得在訓練期間輕鬆儲存和載入緩衝區。

要充分發揮 TensorDict 作為資料載體的潛力,可以使用 TensorDictReplayBuffer 類。它的主要優點之一是它能夠處理取樣資料的組織,以及可能需要的任何附加資訊(例如取樣索引)。

它可以以與標準 ReplayBuffer 相同的方式構建,並且通常可以互換使用。

from torchrl.data import TensorDictReplayBuffer

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    sample = buffer_lazymemmap.sample()
    print("sample:", sample)
    del buffer_lazymemmap

我們的取樣現在有一個額外的 "index" 鍵,它表示取樣了哪些索引。讓我們看看這些索引。

print(sample["index"])

與 tensorclass 整合

ReplayBuffer 類及其子類也原生支援 tensorclass 類,這些類可以方便地用於更顯式地編碼資料集。

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)

正如預期的那樣,資料具有正確的類和形狀!

與其它張量結構(PyTrees)整合

TorchRL 的回放緩衝區也支援任何 pytree 資料結構。PyTree 是一個任意深度的巢狀結構,由字典、列表和/或元組構成,其葉子是張量。這意味著我們可以將任何此類樹狀結構儲存在連續記憶體中!可以使用各種儲存:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受此類資料。

這是一個關於此功能如何工作的簡短演示。

from torch.utils._pytree import tree_map

我們在 RAM 上構建我們的回放緩衝區。

rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

使用 pytrees,任何可呼叫物件都可以用作轉換。

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

讓我們檢查一下我們的轉換是否起作用。

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)

從緩衝區取樣和迭代

回放緩衝區支援多種取樣策略。

  • 如果批次大小是固定的,並且可以在構造時定義,則可以將其作為關鍵字引數傳遞給緩衝區;

  • 在批次大小固定的情況下,可以迭代回放緩衝區來收集樣本;

  • 如果批次大小是動態的,則可以在執行時將其傳遞給 sample 方法。

取樣可以使用多執行緒完成,但這與最後一個選項不相容(因為它需要緩衝區提前知道下一個批次的大小)。

讓我們看幾個例子。

固定批次大小

如果在構造過程中傳遞了批次大小,則在取樣時應省略它。

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()

此資料批次的大小是我們想要的(128)。

要啟用多執行緒取樣,只需在構造過程中將一個正整數傳遞給 prefetch 關鍵字引數。這應該會大大加快取樣速度,尤其是在取樣耗時時(例如,在使用優先採樣器時)。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())

迭代固定批次大小的緩衝區

只要預定義了批次大小,我們也可以像使用常規資料載入器一樣迭代緩衝區。

for i, data in enumerate(buffer_lazy):
    if i == 3:
        print(data)
        break

del buffer_lazy

由於我們的取樣技術是完全隨機的並且允許替換,因此所討論的迭代器是無限的。但是,我們可以改用 SamplerWithoutReplacement,它會將我們的緩衝區轉換為有限迭代器。

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我們建立一個足夠大的資料以獲取幾個樣本。

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
    continue
print(f"A total of {_i+1} batches have been collected")

del buffer_lazy

動態批次大小

與我們之前看到的相反,可以省略 batch_size 關鍵字引數,並直接將其傳遞給 sample 方法。

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))

del buffer_lazy

優先回放緩衝區

TorchRL 還提供了 優先回放緩衝區 的介面。這個緩衝區類根據透過資料傳遞的優先順序訊號對資料進行取樣。

雖然此工具相容非 tensordict 資料,但我們鼓勵使用 TensorDict,因為它能夠輕鬆地在緩衝區內外傳遞元資料。

讓我們首先看看在通用情況下如何構建一個優先回放緩衝區。必須手動設定 \(\alpha\)\(\beta\) 超引數。

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

擴充套件回放緩衝區會返回項的索引,我們稍後需要它們來更新優先順序。

indices = rb.extend([1, "foo", None])

取樣器期望每個元素都有一個優先順序。新增到緩衝區時,優先順序設定為預設值 1。一旦計算出優先順序(通常透過損失),就必須在緩衝區中更新它。

這是透過 update_priority() 方法完成的,該方法需要索引和優先順序。我們將資料集中的第二個樣本的人為高優先順序分配給它,以觀察它對取樣效果的影響。

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我們觀察到從緩衝區取樣主要返回第二個樣本("foo")。

sample, info = rb.sample(10, return_info=True)
print(sample)

info 包含項的相對權重以及索引。

print(info)

我們看到使用優先回放緩衝區比使用常規緩衝區需要一系列額外的訓練迴圈步驟。

  • 在收集資料並擴充套件緩衝區後,必須更新項的優先順序;

  • 在計算損失並從損失中獲得“優先順序訊號”後,我們必須再次更新緩衝區中項的優先順序。這需要我們跟蹤索引。

這極大地阻礙了緩衝區的可重用性:如果有人要編寫一個可以建立優先緩衝區和常規緩衝區的訓練指令碼,她必須新增大量的控制流,以確保在僅使用優先緩衝區的情況下,在適當的位置呼叫適當的方法。

讓我們看看如何使用 TensorDict 來改進這一點。我們看到 TensorDictReplayBuffer 返回的資料中添加了其相對儲存索引。我們沒有提到的一項功能是,此類還確保在擴充套件時自動將優先順序訊號解析到優先採樣器。

這些功能的結合在幾個方面簡化了事情:- 擴充套件緩衝區時,優先順序訊號將自動

如果存在,則會解析,並且優先順序將被準確分配;

  • 索引將儲存在取樣到的 tensordicts 中,便於在計算損失後更新優先順序。

  • 在計算損失時,優先順序訊號將註冊到傳遞給損失模組的 tensordict 中,從而可以輕鬆地更新權重。

    ..code - block::Python

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下程式碼說明了這些概念。我們構建一個帶有優先採樣器的回放緩衝區,並在建構函式中指示應從中獲取優先順序訊號的條目。

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

讓我們選擇一個與儲存索引成比例的優先順序訊號。

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

較高的索引應該更頻繁地出現。

from matplotlib import pyplot as plt

fig = plt.hist(sample["index"].numpy())
plt.show()

在處理完我們的樣本後,我們使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新優先順序鍵。為了展示其工作原理,讓我們反轉取樣項的優先順序。

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

現在,較高的索引應該不太頻繁地出現。

sample = rb.sample()

fig = plt.hist(sample["index"].numpy())
plt.show()

使用轉換

儲存在回放緩衝區中的資料可能未準備好呈現給損失模組。在某些情況下,收集器生成的資料可能過於龐大,無法按原樣儲存。例如,將影像從 uint8 轉換為浮點張量,或在使用決策轉換器時連線連續幀。

透過將適當的轉換附加到緩衝區,可以進出緩衝區處理資料。以下是一些示例。

儲存原始影像

uint8 型別張量在記憶體佔用方面遠小於我們通常輸入模型的浮點張量。因此,儲存原始影像可能很有用。以下指令碼展示瞭如何構建一個只返回原始影像但使用轉換後的影像進行推理的收集器,以及如何在回放緩衝區中回收這些轉換。

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

讓我們看看一個 rollouts。

print(env.rollout(3))

我們剛剛建立了一個產生畫素的環境。這些影像經過處理以輸入策略。我們希望儲存原始影像,而不是它們的轉換。為此,我們將向收集器附加一個轉換,以選擇我們希望出現的鍵。

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

讓我們看看一個數據批次,並控制 "pixels_trsf" 鍵已被丟棄。

for data in collector:
    print(data)
    break

collector.shutdown()

我們使用與環境相同的轉換建立了一個回放緩衝區。但是,有一個細節需要解決:在沒有環境的情況下使用的轉換對資料結構是不可知的。將轉換附加到環境時,"next" 巢狀 tensordict 中的資料首先被轉換,然後在 rollouts 執行期間複製到根目錄。使用靜態資料時,情況並非如此。但是,我們的資料帶有巢狀的“next”tensordict,如果不是明確指示它來處理,我們的轉換將會忽略它。我們手動將這些鍵新增到轉換中。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)

我們可以檢查 sample 方法是否會看到轉換後的影像重新出現。

print(rb.sample())

更復雜的示例:使用 CatFrames

CatFrames 轉換器透過時間展開觀測值,建立 n 步過去的事件記憶,使模型能夠考慮過去的事件(在 POMDP 或迴圈策略(如決策轉換器)的情況下)。儲存這些連線的幀可能會消耗大量記憶體。當 n 步視窗在訓練和推理期間需要不同(通常更長)時,這也會有問題。我們透過在兩個階段分別執行 CatFrames 轉換來解決這個問題。

from torchrl.envs import CatFrames, UnsqueezeTransform

我們為返回基於畫素的觀測值的環境建立了一組標準的轉換。

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break

collector.shutdown()

緩衝區轉換看起來與環境轉換非常相似,但具有額外的 ("next", ...) 鍵,如前所述。

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)

讓我們從緩衝區取樣一個批次。轉換後的畫素鍵的形狀應該從最後一個維度開始的第四個維度長度為 4。

s = rb.sample(1)  # the buffer has only one element
print(s)

經過一些處理(排除未使用的鍵等)後,我們看到線上生成的資料和離線生成的資料匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

儲存軌跡

在許多情況下,從緩衝區訪問軌跡而不是簡單轉換是可取的。TorchRL 提供了多種實現此目的的方法。

首選方法是沿著緩衝區的第一個維度儲存軌跡,並使用 SliceSampler 對這些資料批次進行取樣。此類只需瞭解您資料結構的一些資訊即可完成其工作(請注意,目前它僅與 tensordict 結構化資料相容):切片數量或其長度以及有關情節分隔位置的資訊(例如,回想一下,使用 DataCollector 時,軌跡 ID 儲存在 ("collector", "traj_ids") 中)。在此簡單示例中,我們構建了一個包含 4 個連續短軌跡的資料,並從中取樣了 4 個切片,每個切片長度為 2(因為批次大小為 8,8 個項 // 4 個切片 = 2 個時間步)。我們還標記了這些步驟。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

gc.collect()

結論

我們已經瞭解瞭如何在 TorchRL 中使用回放緩衝區,從最簡單的用法到更高階的用法,其中資料需要以特定方式進行轉換或儲存。您現在應該能夠

  • 建立回放緩衝區,自定義其儲存、取樣器和轉換;

  • 為您的問題選擇最佳的儲存型別(列表、記憶體或基於磁碟);

  • 最小化緩衝區的記憶體佔用。

後續步驟

  • 檢視資料 API 參考,瞭解 TorchRL 中的離線資料集,這些資料集基於我們的回放緩衝區 API;

  • 檢視其他取樣器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或其他寫入器,例如 TensorDictMaxValueWriter

  • 文件 中檢視如何檢查回放緩衝區點。

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源