快捷方式

TanhNormal

class torchrl.modules.TanhNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: torch.Tensor | Number = 5.0, low: torch.Tensor | Number = - 1.0, high: torch.Tensor | Number = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True)[原始碼]

實現帶位置縮放的 TanhNormal 分佈。

位置縮放可防止位置在應用 TanhTransform 時“離 0”太遠,但這最終會導致取樣不穩定和梯度計算不良(例如,梯度爆炸)。實際上,在位置縮放的情況下,位置根據以下公式計算:

\[loc = tanh(loc / upscale) * upscale.\]
引數:
  • loc (torch.Tensor) – 正態分佈位置引數

  • scale (torch.Tensor) – 正態分佈 sigma 引數(方差的平方根)

  • upscale (torch.Tensor數字) –

    公式中的“a”縮放因子

    \[loc = tanh(loc / upscale) * upscale.\]

  • low (torch.Tensor數字, 可選) – 分佈的最小值。預設為 -1.0;

  • high (torch.Tensor數字, 可選) – 分佈的最大值。預設為 1.0;

  • event_dims (int, 可選) – 描述動作的維度數。預設為 1。將 event_dims 設定為 0 將導致日誌機率與輸入形狀相同,設定為 1 將對最後一個維度求和,設定為 2 將對最後兩個維度求和,依此類推。

  • tanh_loc (bool, 可選) – 如果為 True,則使用上述公式進行位置縮放,否則保留原始值。預設為 False

  • safe_tanh (bool, 可選) – 如果為 True,則 Tanh 變換會“安全地”進行,以避免數值溢位。這目前會與 torch.compile() 發生衝突。

get_mode()[原始碼]

使用 Adam 最佳化器計算模式的估計值。

property mean

返回分佈的均值。

property mode

返回分佈的眾數。

property support

返回一個 Constraint 物件,表示此分佈的支援域。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源