SliceSamplerWithoutReplacement¶
- class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: int | None = None, slice_len: int | None = None, drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, use_gpu: bool | torch.device = False)[原始碼]¶
在不重複抽樣的情況下,給定開始和停止訊號,沿著第一個維度對資料切片進行抽樣。
在此上下文中,
不重複抽樣意味著在計數器自動重置之前,同一個元素(不是軌跡)不會被抽取兩次。然而,在單個樣本中,只會出現一個給定軌跡的切片(請參閱下面的示例)。此類應與靜態回放緩衝區一起使用,或在兩個回放緩衝區擴充套件之間使用。擴充套件回放緩衝區將重置取樣器,目前不允許連續不重複抽樣。
注意
SliceSamplerWithoutReplacement 檢索軌跡索引可能會很慢。為了加速其執行,請優先使用 end_key 而不是 traj_key,並考慮以下關鍵字引數:
compile、cache_values和use_gpu。- 關鍵字引數:
drop_last (bool, optional) – 如果為
True,則將刪除最後一個不完整的樣本(如果有)。如果為False,則將保留最後一個樣本。預設為False。num_slices (int) – 要抽樣的切片數量。批次大小必須大於或等於
num_slices引數。與slice_len互斥。slice_len (int) – 要抽樣的切片的長度。批次大小必須大於或等於
slice_len引數,並且可以被其整除。與num_slices互斥。end_key (NestedKey, optional) – 指示軌跡(或回合)結束的鍵。預設為
("next", "done")。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設為
"episode"(在 TorchRL 的資料集中常用)。ends (torch.Tensor, optional) – 一個包含執行結束訊號的一維布林張量。當
end_key或traj_key獲取成本高昂,或該訊號易於獲得時使用。必須與cache_values=True一起使用,並且不能與end_key或traj_key結合使用。trajectories (torch.Tensor, optional) – 一個包含執行 ID 的一維整數張量。當
end_key或traj_key獲取成本高昂,或該訊號易於獲得時使用。必須與cache_values=True一起使用,並且不能與end_key或traj_key結合使用。truncated_key (NestedKey, optional) – 如果不為
None,則此引數指示截斷訊號應寫入輸出資料的位置。這用於告知值估計器提供的軌跡在哪裡中斷。預設為("next", "truncated")。此功能僅適用於TensorDictReplayBuffer例項(否則,截斷鍵將在sample()方法返回的資訊字典中)。strict_length (bool, optional) – 如果為
False,則允許長度小於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。如果為True,則將過濾掉長度不足的軌跡。請注意,這可能導致有效的 batch_size 短於請求的 batch_size!可以使用split_trajectories()來拆分軌跡。預設為True。shuffle (bool, optional) – 如果為
False,則不打亂軌跡的順序。預設為True。compile (bool 或 dict of kwargs, optional) – 如果為
True,則sample()方法的瓶頸將使用compile()進行編譯。也可以透過此引數將關鍵字引數傳遞給 torch.compile。預設為False。use_gpu (bool 或 torch.device) – 如果為
True(或傳遞了裝置),則將使用加速器來檢索軌跡起始點的索引。當緩衝區內容很大時,這可以顯著加速取樣。預設為False。
注意
要恢復儲存中的軌跡分割,
SliceSamplerWithoutReplacement將首先嚐試在儲存中找到traj_key條目。如果找不到,將使用end_key來重建回合。示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer >>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement >>> >>> rb = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(1000), ... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each ... sampler=SliceSamplerWithoutReplacement(num_slices=10), ... batch_size=320, ... ) >>> episode = torch.zeros(1000, dtype=torch.int) >>> episode[:300] = 1 >>> episode[300:550] = 2 >>> episode[550:700] = 3 >>> episode[700:] = 4 >>> data = TensorDict( ... { ... "episode": episode, ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), ... "act": torch.randn((20,)).expand(1000, 20), ... "other": torch.randn((20, 50)).expand(1000, 20, 50), ... }, [1000] ... ) >>> rb.extend(data) >>> sample = rb.sample() >>> # since we want trajectories of 32 transitions but there are only 4 episodes to >>> # sample from, we only get 4 x 32 = 128 transitions in this batch >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique())
SliceSamplerWithoutReplacement與大多數 TorchRL 的資料集預設相容,並允許使用者以類似資料載入器的方式使用資料集。示例
>>> import torch >>> >>> from torchrl.data.datasets import RobosetExperienceReplay >>> from torchrl.data import SliceSamplerWithoutReplacement >>> >>> torch.manual_seed(0) >>> num_slices = 10 >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] >>> data = RobosetExperienceReplay(dataid, batch_size=320, ... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices)) >>> # the last sample is kept, since drop_last=False by default >>> for i, batch in enumerate(data): ... print(batch.get("episode").unique()) tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24]) tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22]) tensor([ 0, 3, 4, 20, 23])
當請求大量總樣本,但軌跡數量很少且跨度很小時,批次最多隻會包含每個軌跡的一個樣本。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.collectors.utils import split_trajectories >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement >>> >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), ... sampler=SliceSamplerWithoutReplacement( ... slice_len=5, traj_key="episode",strict_length=False ... )) ... >>> ep_1 = TensorDict( ... {"obs": torch.arange(100), ... "episode": torch.zeros(100),}, ... batch_size=[100] ... ) >>> ep_2 = TensorDict( ... {"obs": torch.arange(51), ... "episode": torch.ones(51),}, ... batch_size=[51] ... ) >>> rb.extend(ep_1) >>> rb.extend(ep_2) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[14, 15, 16, 17, 18], [ 3, 4, 5, 6, 7]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]]) >>> >>> s = rb.sample(50) >>> t = split_trajectories(s, trajectory_key="episode") >>> print(t["obs"]) tensor([[ 4, 5, 6, 7, 8], [26, 27, 28, 29, 30]]) >>> print(t["episode"]) tensor([[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]])