ConsistentDropout¶
- class torchrl.modules.ConsistentDropout(p: float = 0.5)[原始碼]¶
實現了一個
Dropout變體,具有一致性 dropout。該方法在 “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022) 中提出。
這個
Dropout變體試圖透過在 rollout 期間快取 dropout 掩碼並在更新階段重用它們來提高訓練穩定性和減少更新方差。您正在檢視的類獨立於 TorchRL 的其餘 API,並且不需要 tensordict 即可執行。
ConsistentDropoutModule是ConsistentDropout的包裝器,它利用了TensorDict的可擴充套件性,透過 將 生成的 dropout 掩碼 儲存在 transition ``TensorDict本身中。有關詳細說明和用法示例,請參閱此類。除此之外,與 PyTorch 的
Dropout實現相比,概念上的偏差很小。- ..note:: TorchRL 的資料收集器在
no_grad()模式下執行 rollout,但不在 eval 模式下執行, 因此,除非傳遞給收集器的策略處於 eval 模式,否則將應用 dropout 掩碼。
注意
與其他探索模組不同,
ConsistentDropoutModule使用train/eval模式以符合 PyTorch 中常規的 Dropout API。set_exploration_type()上下文管理器對此模組無效。- 引數:
p (
float, 可選) – Dropout 機率。預設為0.5。
另請參閱
MultiSyncDataCollector: 在底層使用_main_async_collector()(SyncDataCollector)
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[原始碼]¶
在訓練(rollouts & updates)期間,此呼叫在乘以輸入張量之前,會掩蓋一個全為 1 的張量。
在評估期間,此呼叫將不執行任何操作,僅返回輸入。
- 引數:
x (torch.Tensor) – 輸入張量。
mask (torch.Tensor, 可選) – dropout 的可選掩碼。
返回: 在訓練模式下返回一個張量和一個對應的掩碼,在評估模式下僅返回一個張量。
- ..note:: TorchRL 的資料收集器在