GaussianNLLLoss#
- class torch.nn.modules.loss.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼]#
高斯負對數似然損失。
目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個被建模為具有期望張量
input和正方差張量var的高斯分佈的目標張量target,損失為:其中
eps用於提高穩定性。預設情況下,損失函式的常數項會被省略,除非full為True。如果var的大小與input不同(由於同方差假設),則其最終維度必須為 1,或者維度少一個(其他所有尺寸相同)才能正確廣播。- 引數
- 形狀
輸入: 或 ,其中 表示任意數量的附加維度。
目標: 或 ,形狀與輸入相同,或形狀與輸入相同但有一個維度為 1(允許廣播)。
方差: 或 ,形狀與輸入相同,或形狀與輸入相同但有一個維度為 1,或形狀與輸入相同但維度少一個(允許廣播),或為一個標量值。
輸出:如果
reduction為'mean'(預設)或'sum',則輸出為標量。如果reduction為'none',則輸出形狀為 ,與輸入形狀相同。
示例
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic >>> output = loss(input, target, var) >>> output.backward()
注意
對
var的截斷對於自動微分(autograd)是忽略的,因此梯度不受其影響。- 參考
Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138。