SliceSampler¶
- class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int | None = None, slice_len: int | None = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False)[原始碼]¶
沿第一維度對資料切片進行取樣,給定開始和停止訊號。
此類有放回地取樣子軌跡。無放回版本請參閱
SliceSamplerWithoutReplacement。注意
SliceSampler 檢索軌跡索引可能會很慢。為了加快其執行速度,請優先使用 end_key 而非 traj_key,並考慮以下關鍵字引數:
compile、cache_values和use_gpu。- 關鍵字引數:
num_slices (int) – 要抽樣的切片數量。批次大小必須大於或等於
num_slices引數。與slice_len互斥。slice_len (int) – 要抽樣的切片的長度。批次大小必須大於或等於
slice_len引數,並且可以被其整除。與num_slices互斥。end_key (NestedKey, optional) – 指示軌跡(或回合)結束的鍵。預設為
("next", "done")。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設為
"episode"(在 TorchRL 的資料集中常用)。ends (torch.Tensor, 可選) – 一個包含執行結束訊號的一維布林張量。當
end_key或traj_key獲取成本高昂,或此訊號易於獲得時使用。必須與cache_values=True一起使用,且不能與end_key或traj_key結合使用。如果提供,則假定儲存已滿,並且如果ends張量的最後一個元素為False,則相同的軌跡會跨越結束和開始。trajectories (torch.Tensor, 可選) – 一個包含執行 ID 的一維整數張量。當
end_key或traj_key獲取成本高昂,或此訊號易於獲得時使用。必須與cache_values=True一起使用,且不能與end_key或traj_key結合使用。如果提供,則假定儲存已滿,並且如果軌跡張量的最後一個元素與第一個元素相同,則相同的軌跡會跨越結束和開始。cache_values (bool, 可選) –
與靜態資料集一起使用。將快取軌跡的開始和結束訊號。即使軌跡索引在呼叫
extend時發生更改,也可以安全地使用此選項,因為此操作將清除快取。警告
cache_values=True在以下情況將無法正常工作:取樣器與由另一個緩衝區擴充套件的儲存一起使用。例如:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
cache_values=True在以下情況將無法按預期工作:緩衝區在程序之間共享,一個程序負責寫入,另一個程序負責取樣,因為清除快取只能在本地進行。truncated_key (NestedKey, optional) – 如果不為
None,則此引數指示截斷訊號應寫入輸出資料的位置。這用於告知值估計器提供的軌跡在哪裡中斷。預設為("next", "truncated")。此功能僅適用於TensorDictReplayBuffer例項(否則,截斷鍵將在sample()方法返回的資訊字典中)。strict_length (bool, optional) – 如果為
False,則允許長度小於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。如果為True,則將過濾掉長度不足的軌跡。請注意,這可能導致有效的 batch_size 短於請求的 batch_size!可以使用split_trajectories()來拆分軌跡。預設為True。compile (bool 或 dict of kwargs, optional) – 如果為
True,則sample()方法的瓶頸將使用compile()進行編譯。也可以透過此引數將關鍵字引數傳遞給 torch.compile。預設為False。span (bool, int, Tuple[bool | int, bool | int], 可選) – 如果提供,則取樣軌跡將跨越左側和/或右側。這意味著可能提供的元素少於所需的元素。布林值表示每個軌跡至少會取樣一個元素。整數 i 表示每個取樣軌跡至少會收集 slice_len - i 個樣本。使用元組可以精細控制跨越左側(儲存軌跡的開頭)和右側(儲存軌跡的結尾)的跨度。
use_gpu (bool 或 torch.device) – 如果為
True(或傳遞了裝置),則將使用加速器來檢索軌跡的起始索引。當緩衝區內容很大時,這可以顯著加快取樣速度。預設為False。
注意
要恢復儲存中的軌跡分割,
SliceSampler將首先嚐試在儲存中查詢traj_key條目。如果找不到,將使用end_key來重建劇集。注意
當使用 strict_length=False 時,建議使用
split_trajectories()來分割取樣軌跡。但是,如果來自同一劇集的兩個樣本並排放置,這可能會產生不正確的結果。為避免此問題,請考慮以下解決方案之一:使用帶有切片取樣器的
TensorDictReplayBuffer例項>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSampler( ... slice_len=5, traj_key="episode",strict_length=False, ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> print(s) TensorDict( fields={ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([46]), device=cpu, is_shared=False) >>> t = split_trajectories(s, done_key="truncated") >>> print(t["obs"]) tensor([[73, 74, 75, 76, 77], [ 0, 1, 2, 3, 0], [ 0, 1, 2, 3, 0], [41, 42, 43, 44, 45], [ 0, 1, 2, 3, 0], [67, 68, 69, 70, 71], [27, 28, 29, 30, 31], [80, 81, 82, 83, 84], [17, 18, 19, 20, 21], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
使用
SliceSamplerWithoutReplacement>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(4), ... "episode": torch.ones(4),}, ... batch_size=[4] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[75, 76, 77, 78, 79], [ 0, 1, 2, 3, 0]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 0.]])
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSampler >>> torch.manual_seed(0) >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1_000_000), ... sampler=SliceSampler(cache_values=True, num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> print("sample:", sample) >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32)
SliceSampler與大多數 TorchRL 的資料集預設相容示例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSampler >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) >>> for batch in data: ... batch = batch.reshape(num_slices, -1) ... break >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) check that each batch only has one episode: tensor([[19], [14], [ 8], [10], [13], [ 4], [ 2], [ 3], [22], [ 8]])