PrioritizedSliceSampler¶
- class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: torch.dtype = torch.float32, reduction: str = 'max', *, num_slices: int | None = None, slice_len: int | None = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False)[原始碼]¶
使用優先採樣,根據開始和停止訊號,取樣資料沿第一個維度的切片。
此類結合了軌跡取樣和優先經驗回放 (PER),如“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” 中所述(https://arxiv.org/abs/1511.05952)。
核心思想:此取樣器不均勻地取樣軌跡切片,而是根據軌跡切片中轉換的重要性來優先選擇軌跡的起始點。這使得學習能夠專注於軌跡中最具資訊量的部分。
工作原理:1. 每個轉換根據其 TD 誤差被分配一個優先順序:\(p_i = |\\delta_i| + \\epsilon\) 2. 軌跡起始點以以下機率取樣:\(P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}\) 3. 重要性取樣權重用於糾正偏差:\(w_i = (N \\cdot P(i))^{-\beta}\) 4. 從取樣到的起始點提取完整的軌跡切片。
有關更多資訊,請參閱
SliceSampler和PrioritizedSampler。警告
PrioritizedSliceSampler 將檢視單個轉換的優先順序,並相應地對起始點進行取樣。這意味著優先順序較低的轉換也可能出現在樣本中,如果它們緊隨另一個優先順序較高的轉換之後;而優先順序很高但接近軌跡末尾的轉換,如果不能用作起始點,則可能永遠不會被取樣。目前,使用者有責任使用
update_priority()來聚合軌跡項的優先順序。- 引數:
max_capacity (int) – 緩衝區的最大容量。
alpha (
float) – 指數 \(\alpha\) 決定了優先順序的程度。 - \(\alpha = 0\):軌跡起始點的均勻取樣 - \(\alpha = 1\):基於起始點處 TD 誤差幅度的完全優先順序 - 典型值:0.4-0.7 以實現平衡的優先順序 - 較高的 \(\alpha\) 意味著對高誤差軌跡區域的優先順序更高。beta (
float) – 重要性取樣的負指數 \(\beta\)。 - \(\beta\) 控制對優先順序引入的偏差的校正 - \(\beta = 0\):無校正(偏向高優先順序軌跡區域) - \(\beta = 1\):完全校正(無偏差但可能不穩定) - 典型值:開始時為 0.4-0.6,訓練期間逐漸退火至 1.0 - 訓練早期較低的 \(\beta\) 可提供穩定性,後期較高的 \(\beta\) 可減少偏差。eps (
float, optional) – 新增到優先順序的微小常數,以確保沒有轉換的優先順序為零。這可以防止軌跡區域永遠不被取樣。預設為 1e-8。reduction (str, optional) – 多維 tensordicts(即儲存的軌跡)的縮減方法。可以是“max”、“min”、“median”或“mean”之一。
引數指南: - :math:`alpha` (alpha):控制對高誤差軌跡區域的優先順序程度
0.4-0.7:學習速度和穩定性之間的良好平衡
1.0:最大優先順序(可能不穩定)
0.0:均勻取樣(無優先順序優勢)
- :math:`beta` (beta):控制重要性取樣校正
訓練初期設定為 0.4-0.6 以獲得穩定性
訓練過程中退火至 1.0 以減少偏差
較低的值 = 更穩定但有偏差
較高的值 = 偏差較小但可能不穩定
- :math:`\epsilon`:防止優先順序為零的小常數
1e-8:良好的預設值
太小:可能導致數值問題
太大:降低優先順序效果
- 關鍵字引數:
num_slices (int) – 要抽樣的切片數量。批次大小必須大於或等於
num_slices引數。與slice_len互斥。slice_len (int) – 要抽樣的切片的長度。批次大小必須大於或等於
slice_len引數,並且可以被其整除。與num_slices互斥。end_key (NestedKey, optional) – 指示軌跡(或回合)結束的鍵。預設為
("next", "done")。traj_key (NestedKey, optional) – 指示軌跡的鍵。預設為
"episode"(在 TorchRL 的資料集中常用)。ends (torch.Tensor, optional) – 一個一維布林張量,包含執行結束訊號。當
end_key或traj_key的獲取成本很高,或者該訊號很容易獲得時使用。必須與cache_values=True一起使用,並且不能與end_key或traj_key結合使用。trajectories (torch.Tensor, optional) – 一個一維整數張量,包含執行 ID。當
end_key或traj_key的獲取成本很高,或者該訊號很容易獲得時使用。必須與cache_values=True一起使用,並且不能與end_key或traj_key結合使用。cache_values (bool, optional) –
用於靜態資料集。將快取軌跡的開始和結束訊號。即使在呼叫
extend期間軌跡索引發生更改,也可以安全地使用此選項,因為此操作將清除快取。警告
cache_values=True在取樣器與由另一個緩衝區擴充套件的儲存一起使用時將無法正常工作。例如:>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
cache_values=True在緩衝區由多個程序共享,一個程序負責寫入而另一個程序負責取樣時,將無法按預期工作,因為清除快取只能在本地進行。truncated_key (NestedKey, optional) – 如果不為
None,則此引數指示截斷訊號應寫入輸出資料的位置。這用於告知值估計器提供的軌跡在哪裡中斷。預設為("next", "truncated")。此功能僅適用於TensorDictReplayBuffer例項(否則,截斷鍵將在sample()方法返回的資訊字典中)。strict_length (bool, optional) – 如果為
False,則允許長度小於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。如果為True,則將過濾掉長度不足的軌跡。請注意,這可能導致有效的 batch_size 短於請求的 batch_size!可以使用split_trajectories()來拆分軌跡。預設為True。compile (bool 或 dict of kwargs, optional) – 如果為
True,則sample()方法的瓶頸將使用compile()進行編譯。也可以透過此引數將關鍵字引數傳遞給 torch.compile。預設為False。span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,則取樣的軌跡將跨越左側和/或右側。這意味著可能提供的元素少於所需元素。布林值表示每個軌跡至少採樣一個元素。整數 i 表示為每個取樣的軌跡收集的樣本至少為 slice_len - i。使用元組可以精細控制左側(儲存軌跡的開頭)和右側(儲存軌跡的結尾)的跨度。
max_priority_within_buffer (bool, optional) – 如果為
True,則在緩衝區內跟蹤最大優先順序。如果為False,則最大優先順序跟蹤自採樣器例項化以來的最大值。預設為False。
示例
>>> import torch >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler >>> from tensordict import TensorDict >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6) >>> data = TensorDict( ... { ... "observation": torch.randn(9,16), ... "action": torch.randn(9, 1), ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long), ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long), ... ("next", "observation"): torch.randn(9,16), ... ("next", "reward"): torch.randn(9,1), ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1), ... }, ... batch_size=[9], ... ) >>> rb.extend(data) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 1, 1] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 1, 2] >>> print("weight", info["_weight"].tolist()) weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) >>> rb.update_priority(torch.arange(0,9,1), priority=priority) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 2, 2] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 0, 1] >>> print("weight", info["_weight"].tolist()) weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
- update_priority(index: int | torch.Tensor, priority: float | torch.Tensor, *, storage: TensorStorage | None = None) None¶
更新由索引指向的資料的優先順序。
- 引數:
index (int 或 torch.Tensor) – 要更新優先順序的索引。
priority (Number 或 torch.Tensor) – 索引元素的新的優先順序。
- 關鍵字引數:
storage (Storage, optional) – 一個儲存,用於將 N 維索引大小對映到 sum_tree 和 min_tree 的一維大小。僅在
index.ndim > 2時需要。