TanhNormal¶
- class torchrl.modules.TanhNormal(loc: torch.Tensor, scale: torch.Tensor, upscale: torch.Tensor | Number = 5.0, low: torch.Tensor | Number = - 1.0, high: torch.Tensor | Number = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True)[原始碼]¶
實現帶位置縮放的 TanhNormal 分佈。
位置縮放可防止位置在應用
TanhTransform時“離 0”太遠,但這最終會導致取樣不穩定和梯度計算不良(例如,梯度爆炸)。實際上,在位置縮放的情況下,位置根據以下公式計算:\[loc = tanh(loc / upscale) * upscale.\]- 引數:
loc (torch.Tensor) – 正態分佈位置引數
scale (torch.Tensor) – 正態分佈 sigma 引數(方差的平方根)
upscale (torch.Tensor 或 數字) –
公式中的“a”縮放因子
\[loc = tanh(loc / upscale) * upscale.\]low (torch.Tensor 或 數字, 可選) – 分佈的最小值。預設為 -1.0;
high (torch.Tensor 或 數字, 可選) – 分佈的最大值。預設為 1.0;
event_dims (int, 可選) – 描述動作的維度數。預設為 1。將
event_dims設定為0將導致日誌機率與輸入形狀相同,設定為1將對最後一個維度求和,設定為2將對最後兩個維度求和,依此類推。tanh_loc (bool, 可選) – 如果為
True,則使用上述公式進行位置縮放,否則保留原始值。預設為False;safe_tanh (bool, 可選) – 如果為
True,則 Tanh 變換會“安全地”進行,以避免數值溢位。這目前會與torch.compile()發生衝突。
- property mean¶
返回分佈的均值。
- property mode¶
返回分佈的眾數。
- property support¶
返回一個
Constraint物件,表示此分佈的支援域。