評價此頁

GaussianNLLLoss#

class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[source]#

高斯負對數似然損失。

目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個target張量,建模為具有期望為input張量和正方差var張量的高斯分佈,損失為

loss=12(log(max(var, eps))+(inputtarget)2max(var, eps))+const.\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

其中eps用於保持穩定性。預設情況下,損失函式中的常數項會被省略,除非fullTrue。如果var的尺寸與input不同(由於同方差假設),為了正確廣播,它必須具有一個尺寸為1的最後一個維度,或者比input少一個維度(所有其他尺寸均相同)。

引數
  • full (bool, optional) – 在損失計算中包含常數項。預設為False

  • eps (float, optional) – 用於裁剪var的值(見下文說明),以保持穩定性。預設為1e-6。

  • reduction (str, optional) – 指定應用於輸出的縮減方式:'none' | 'mean' | 'sum''none':不進行縮減;'mean':輸出為所有批次成員損失的平均值;'sum':輸出為所有批次成員損失的總和。預設為'mean'

形狀
  • 輸入:(N,)(N, *)()(*),其中* 表示任意數量的附加維度

  • 目標:(N,)(N, *)()(*),與輸入形狀相同,或與輸入形狀相同但有一個維度等於1(允許廣播)

  • 方差:(N,)(N, *)()(*),與輸入形狀相同,或與輸入形狀相同但有一個維度等於1,或與輸入形狀相同但少一個維度(允許廣播),或為標量值

  • 輸出:如果reduction'mean'(預設)或'sum',則為標量。如果reduction'none',則為(N,)(N, *),與輸入形狀相同

示例

>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

var的裁剪會被自動梯度忽略,因此梯度不受其影響。

參考

Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138。

forward(input, target, var)[source]#

執行前向傳播。

返回型別

張量