TruncatedNormal¶
- class torchrl.modules.TruncatedNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: torch.Tensor | float = 5.0, low: torch.Tensor | float = - 1.0, high: torch.Tensor | float = 1.0, tanh_loc: bool = False)[原始碼]¶
實現了一個帶有位置縮放的截斷正態分佈。
位置縮放可以防止位置“離”0“太遠”,這最終會導致數值不穩定的樣本和差的梯度計算(例如,梯度爆炸)。在實踐中,位置是根據以下公式計算的:
\[loc = tanh(loc / upscale) * upscale.\]透過關閉 tanh_loc 引數(見下文)可以停用此行為。
- 引數:
loc (torch.Tensor) – 正態分佈位置引數
scale (torch.Tensor) – 正態分佈 sigma 引數(方差的平方根)
upscale (torch.Tensor 或 數字,可選)–
公式中的“a”縮放因子
\[loc = tanh(loc / upscale) * upscale.\]預設為 5.0
low (torch.Tensor 或 數字,可選)–分佈的最小值。預設為 -1.0;
high (torch.Tensor 或 數字,可選)–分佈的最大值。預設為 1.0;
tanh_loc (bool,可選)–如果為
True,則使用上述公式進行位置縮放,否則保留原始值。預設為False;
- property mode¶
返回分佈的眾數。