快捷方式

TruncatedNormal

class torchrl.modules.TruncatedNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: torch.Tensor | float = 5.0, low: torch.Tensor | float = - 1.0, high: torch.Tensor | float = 1.0, tanh_loc: bool = False)[原始碼]

實現了一個帶有位置縮放的截斷正態分佈。

位置縮放可以防止位置“離”0“太遠”,這最終會導致數值不穩定的樣本和差的梯度計算(例如,梯度爆炸)。在實踐中,位置是根據以下公式計算的:

\[loc = tanh(loc / upscale) * upscale.\]

透過關閉 tanh_loc 引數(見下文)可以停用此行為。

引數:
  • loc (torch.Tensor) – 正態分佈位置引數

  • scale (torch.Tensor) – 正態分佈 sigma 引數(方差的平方根)

  • upscaletorch.Tensor數字可選)–

    公式中的“a”縮放因子

    \[loc = tanh(loc / upscale) * upscale.\]

    預設為 5.0

  • lowtorch.Tensor數字可選)–分佈的最小值。預設為 -1.0;

  • hightorch.Tensor數字可選)–分佈的最大值。預設為 1.0;

  • tanh_locbool可選)–如果為 True,則使用上述公式進行位置縮放,否則保留原始值。預設為 False

log_prob(value, **kwargs)[原始碼]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (Tensor) –

property mode

返回分佈的眾數。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源