快捷方式

ConsistentDropout

class torchrl.modules.ConsistentDropout(p: float = 0.5)[原始碼]

實現了一個 Dropout 變體,具有一致性 dropout。

該方法在 “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022) 中提出。

這個 Dropout 變體試圖透過在 rollout 期間快取 dropout 掩碼並在更新階段重用它們來提高訓練穩定性和減少更新方差。

您正在檢視的類獨立於 TorchRL 的其餘 API,並且不需要 tensordict 即可執行。 ConsistentDropoutModuleConsistentDropout 的包裝器,它利用了 TensorDict 的可擴充套件性,透過 生成的 dropout 掩碼 儲存在 transition ``TensorDict 本身中。有關詳細說明和用法示例,請參閱此類。

除此之外,與 PyTorch 的 Dropout 實現相比,概念上的偏差很小。

..note:: TorchRL 的資料收集器在 no_grad() 模式下執行 rollout,但不在 eval 模式下執行,

因此,除非傳遞給收集器的策略處於 eval 模式,否則將應用 dropout 掩碼。

注意

與其他探索模組不同,ConsistentDropoutModule 使用 train/eval 模式以符合 PyTorch 中常規的 Dropout API。 set_exploration_type() 上下文管理器對此模組無效。

引數:

p (float, 可選) – Dropout 機率。預設為 0.5

另請參閱

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[原始碼]

在訓練(rollouts & updates)期間,此呼叫在乘以輸入張量之前,會掩蓋一個全為 1 的張量。

在評估期間,此呼叫將不執行任何操作,僅返回輸入。

引數:

返回: 在訓練模式下返回一個張量和一個對應的掩碼,在評估模式下僅返回一個張量。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源