GaussianNLLLoss#
- class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[source]#
高斯負對數似然損失。
目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個
target張量,建模為具有期望為input張量和正方差var張量的高斯分佈,損失為其中
eps用於保持穩定性。預設情況下,損失函式中的常數項會被省略,除非full為True。如果var的尺寸與input不同(由於同方差假設),為了正確廣播,它必須具有一個尺寸為1的最後一個維度,或者比input少一個維度(所有其他尺寸均相同)。- 引數
- 形狀
輸入: 或 ,其中 表示任意數量的附加維度
目標: 或 ,與輸入形狀相同,或與輸入形狀相同但有一個維度等於1(允許廣播)
方差: 或 ,與輸入形狀相同,或與輸入形狀相同但有一個維度等於1,或與輸入形狀相同但少一個維度(允許廣播),或為標量值
輸出:如果
reduction為'mean'(預設)或'sum',則為標量。如果reduction為'none',則為,與輸入形狀相同
示例
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic >>> output = loss(input, target, var) >>> output.backward()
注意
對
var的裁剪會被自動梯度忽略,因此梯度不受其影響。- 參考
Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138。