評價此頁

KLDivLoss#

class torch.nn.modules.loss.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)[原始碼]#

Kullback-Leibler 散度損失。

對於形狀相同的張量 ypred, ytruey_{\text{pred}},\ y_{\text{true}}, 其中 ypredy_{\text{pred}}input,而 ytruey_{\text{true}}target,我們定義**逐點 KL 散度**為

L(ypred, ytrue)=ytruelogytrueypred=ytrue(logytruelogypred)L(y_{\text{pred}},\ y_{\text{true}}) = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})ytrue(logytruelogypred)

為了避免在計算此數量時出現下溢問題,此損失函式期望 input 引數在對數空間中。如果 log_target= True,那麼 target 引數也可以以對數空間形式提供。

總結來說,此函式大致等同於計算

if not log_target:  # default
    loss_pointwise = target * (target.log() - input)
else:
    loss_pointwise = target.exp() * (target - input)

然後根據 reduction 引數對結果進行約簡,如下所示:

if reduction == "mean":  # default
    loss = loss_pointwise.mean()
elif reduction == "batchmean":  # mathematically correct
    loss = loss_pointwise.sum() / input.size(0)
elif reduction == "sum":
    loss = loss_pointwise.sum()
else:  # reduction == "none"
    loss = loss_pointwise

注意

與 PyTorch 中的所有其他損失函式一樣,此函式期望第一個引數 input 是模型的輸出(例如,神經網路),第二個引數 target 是資料集中的觀測值。這與標準的數學符號 KL(P  Q)KL(P\ ||\ Q) 不同,其中 PP 表示觀測值的分佈,而 QQ 表示模型。

警告

reduction= “mean” 不返回真實的 KL 散度值,請使用 reduction= “batchmean”,這與數學定義一致。

引數
  • size_average (bool, optional) – 已棄用(參見 reduction)。預設情況下,損失按批次中的每個損失元素進行平均。請注意,對於某些損失,每個樣本有多個元素。如果將 size_average 欄位設定為 False,則損失會按小批次進行求和。當 reduceFalse 時忽略。預設值:True

  • reduce (bool, optional) – 已棄用(參見 reduction)。預設情況下,損失會根據 size_average 的值,在每個小批次上按觀測值進行平均或求和。當 reduceFalse 時,則返回每個批次元素的損失,並忽略 size_average。預設值:True

  • reduction (str, optional) – 指定要應用於輸出的約簡方式。預設值:“mean”

  • log_target (bool, optional) – 指定 target 是否為對數空間。預設值:False

形狀
  • 輸入: ()(*),其中 * 表示任意數量的維度。

  • 目標:()(*),與輸入形狀相同。

  • 輸出:預設情況下為標量。如果 reduction‘none’,則 ()(*),形狀與輸入相同。

示例

>>> kl_loss = nn.KLDivLoss(reduction="batchmean")
>>> # input should be a distribution in the log space
>>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
>>> # Sample a batch of distributions. Usually this would come from the dataset
>>> target = F.softmax(torch.rand(3, 5), dim=1)
>>> output = kl_loss(input, target)
>>>
>>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
>>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
>>> output = kl_loss(input, log_target)
forward(input, target)[原始碼]#

執行前向傳播。

返回型別

張量