PrioritizedSampler¶
- class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[原始碼]¶
經驗回放的優先採樣器。
此取樣器實現了優先經驗回放 (PER),如“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” 所述 (https://arxiv.org/abs/1511.05952)。
核心思想:PER 不再從經驗回放緩衝區中統一取樣經驗,而是根據其“重要性”(通常由其時序差 (TD) 誤差的大小衡量)的機率來取樣經驗。這種優先排序可以透過關注最具資訊量的經驗來加速學習。
工作原理:1. 每個經驗根據其 TD 誤差分配優先順序:\(p_i = |\delta_i| + \epsilon\) 2. 取樣機率計算如下:\(P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}\) 3. 重要性取樣權重用於糾正偏差:\(w_i = (N \cdot P(i))^{-\beta}\)
- 引數:
max_capacity (int) – 緩衝區的最大容量。
alpha (
float) – 指數 \(\alpha\) 決定了優先排序的程度。 - \(\alpha = 0\):統一取樣(無優先排序) - \(\alpha = 1\):基於 TD 誤差大小的完全優先排序 - 典型值:0.4-0.7,用於平衡優先排序 - 較高的 \(\alpha\) 值意味著對高誤差經驗的優先排序更激進beta (
float) – 重要性取樣的負指數 \(\beta\)。 - \(\beta\) 控制對優先排序引入的偏差的糾正 - \(\beta = 0\):無糾正(偏向於高優先順序樣本) - \(\beta = 1\):完全糾正(無偏但可能不穩定) - 典型值:在訓練期間從 0.4-0.6 開始,然後退火到 1.0 - 訓練早期較低的 \(\beta\) 值提供穩定性,後期較高的 \(\beta\) 值減少偏差eps (
float, 可選) – 新增到優先順序的微小常數,以確保沒有任何經驗的優先順序為零。這可以防止某些經驗永遠不被取樣。預設為 1e-8。reduction (str, 可選) – 用於多維 tensordicts(即儲存的軌跡)的縮減方法。可以是 "max"、"min"、"median" 或 "mean" 之一。
max_priority_within_buffer (bool, 可選) – 如果為
True,則在緩衝區內跟蹤最大優先順序。如果為False,則最大優先順序會跟蹤自採樣器例項化以來的最大值。
引數指南: - :math:`alpha` (alpha):控制優先排序高誤差經驗的程度
0.4-0.7:學習速度和穩定性之間的良好平衡
1.0:最大優先排序(可能不穩定)
0.0:統一取樣(無優先排序益處)
- :math:`beta` (beta):控制重要性取樣糾正
從 0.4-0.6 開始進行訓練以獲得穩定性
在訓練過程中退火到 1.0 以減少偏差
較低的值 = 更穩定但有偏差
較高的值 = 偏差較小但可能不穩定
- :math:`epsilon`:防止優先順序為零的小常數
1e-8:良好的預設值
太小:可能導致數值問題
太大:會降低優先排序的效果
示例
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> rb.add(data_0) >>> rb.add(data_1) >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) >>> sample, info = rb.sample(10, return_info=True) >>> print(sample) TensorDict( fields={ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=cpu, is_shared=False) >>> print(info) {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
注意
使用
TensorDictReplayBuffer可以平滑更新優先順序的過程>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = TDRB( ... storage=LazyTensorStorage(10), ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), ... priority_key="priority", # This kwarg isn't present in regular RBs ... ) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> data = torch.stack([data_0, data_1]) >>> rb.extend(data) >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor >>> sample, info = rb.sample(10, return_info=True) >>> print(sample['index']) # The index is packed with the tensordict tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
- update_priority(index: int | torch.Tensor, priority: float | torch.Tensor, *, storage: TensorStorage | None = None) None[原始碼]¶
更新由索引指向的資料的優先順序。
- 引數:
index (int 或 torch.Tensor) – 要更新優先順序的索引。
priority (Number 或 torch.Tensor) – 索引元素的新的優先順序。
- 關鍵字引數:
storage (Storage, 可選) – 一個用於將 Nd 索引大小對映到 sum_tree 和 min_tree 的一維大小的儲存。只有當
index.ndim > 2時才需要。