RMSNorm#
- class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source]#
對輸入的小批次應用均方根層歸一化。
此層實現了論文 Root Mean Square Layer Normalization 中描述的操作。
RMS 是在最後
D個維度上計算的,其中D是normalized_shape的維度。例如,如果normalized_shape是(3, 5)(一個二維形狀),則 RMS 是在輸入的最後 2 個維度上計算的。- 引數
normalized_shape (int 或 list 或 torch.Size) –
input shape from an expected input of size
If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
eps (Optional[float]) – 用於數值穩定性的分母的加法值。預設值:
torch.finfo(x.dtype).epselementwise_affine (bool) – 一個布林值,當設定為
True時,此模組具有可學習的逐元素仿射引數,並初始化為 1(用於權重)。預設值:True。
- 形狀
輸入:
輸出:(與輸入形狀相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)