評價此頁

GaussianNLLLoss#

class torch.nn.modules.loss.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[原始碼]#

高斯負對數似然損失。

目標被視為來自高斯分佈的樣本,其期望和方差由神經網路預測。對於一個被建模為具有期望張量 input 和正方差張量 var 的高斯分佈的目標張量 target,損失為:

loss=12(log(max(var, eps))+(inputtarget)2max(var, eps))+const.\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

其中 eps 用於提高穩定性。預設情況下,損失函式的常數項會被省略,除非 fullTrue。如果 var 的大小與 input 不同(由於同方差假設),則其最終維度必須為 1,或者維度少一個(其他所有尺寸相同)才能正確廣播。

引數
  • full (bool, optional) – 在損失計算中包含常數項。預設為 False

  • eps (float, optional) – 用於截斷 var 的值(見下文註釋),以提高穩定性。預設為 1e-6。

  • reduction (str, optional) – 指定應用於輸出的歸約方式:'none' | 'mean' | 'sum''none':不進行歸約;'mean':輸出為所有批次成員損失的平均值;'sum':輸出為所有批次成員損失的總和。預設為 'mean'

形狀
  • 輸入:(N,)(N, *)()(*),其中 * 表示任意數量的附加維度。

  • 目標:(N,)(N, *)()(*),形狀與輸入相同,或形狀與輸入相同但有一個維度為 1(允許廣播)。

  • 方差:(N,)(N, *)()(*),形狀與輸入相同,或形狀與輸入相同但有一個維度為 1,或形狀與輸入相同但維度少一個(允許廣播),或為一個標量值。

  • 輸出:如果 reduction'mean'(預設)或 'sum',則輸出為標量。如果 reduction'none',則輸出形狀為 (N,)(N, *),與輸入形狀相同。

示例

>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True)  # heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 1, requires_grad=True)  # homoscedastic
>>> output = loss(input, target, var)
>>> output.backward()

注意

var 的截斷對於自動微分(autograd)是忽略的,因此梯度不受其影響。

參考

Nix, D. A. and Weigend, A. S., “Estimating the mean and variance of the target probability distribution”, Proceedings of 1994 IEEE International Conference on Neural Networks (ICNN’94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138。

forward(input, target, var)[原始碼]#

執行前向傳播。

返回型別

張量