BCEWithLogitsLoss#
- class torch.nn.modules.loss.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)[source]#
此損失函式將 Sigmoid 層和 BCELoss 結合到一個類中。此版本比單獨使用 Sigmoid 再使用 BCELoss 更具數值穩定性,因為它透過將這兩個操作合併到一個層中,利用了 log-sum-exp 技巧來提高數值穩定性。
未約簡的(即
reduction設定為'none')損失可以描述為其中 是批次大小。如果
reduction不是'none'(預設為'mean'),則:這用於衡量例如自動編碼器中重構的誤差。請注意,目標 t[i] 應該是介於 0 和 1 之間的數字。
透過為正樣本新增權重,可以權衡召回率和精確率。在多標籤分類的情況下,損失可以描述為:
其中 是類索引(對於多標籤二分類,;對於單標籤二分類,), 是批次中樣本的編號, 是類 的正樣本權重。
增加召回率, 增加精確率。
例如,如果一個數據集包含一個類別的 100 個正樣本和 300 個負樣本,那麼該類的
pos_weight應等於 . 損失函式的作用將如同資料集包含 個正樣本一樣。示例
>>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 >>> output = torch.full([10, 64], 1.5) # A prediction (logit) >>> pos_weight = torch.ones([64]) # All weights are equal to 1 >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) >>> criterion(output, target) # -log(sigmoid(1.5)) tensor(0.20...)
在上述示例中,
pos_weight張量中的元素對應於多標籤二分類場景中的 64 個不同類別。pos_weight中的每個元素旨在根據相應類別的負樣本和正樣本之間的不平衡來調整損失函式。這種方法在類別不平衡程度不同的資料集中很有用,可確保損失計算準確地考慮每個類別的分佈。- 引數
weight (Tensor, optional) – 為每個批次元素的手動重縮放權重。如果提供,則必須是大小為 nbatch 的 Tensor。
size_average (bool, optional) – 已棄用 (參見
reduction)。預設情況下,損失值在批次中的每個損失元素上取平均值。請注意,對於某些損失,每個樣本有多個元素。如果欄位size_average設定為False,則損失值在每個小批次中而是求和。當reduce為False時忽略。預設值:Truereduce (bool, optional) – 已棄用 (參見
reduction)。預設情況下,損失值在每個小批次中根據size_average對觀測值進行平均或求和。當reduce為False時,返回每個批次元素的損失值,並忽略size_average。預設值:Truereduction (str, optional) – 指定應用於輸出的歸約方法:
'none'|'mean'|'sum'。'none':不應用歸約;'mean':輸出的總和將除以輸出中的元素數量;'sum':將對輸出進行求和。注意:size_average和reduce正在被棄用,在此期間,指定這兩個引數中的任何一個都將覆蓋reduction。 預設值:'mean'pos_weight (Tensor, optional) – 正樣本的權重,將與目標張量廣播。必須是一個在類別維度上大小等於類別數量的張量。請密切注意 PyTorch 的廣播語義,以實現期望的操作。對於大小為 [B, C, H, W] 的目標(其中 B 是批次大小),大小為 [B, C, H, W] 的 pos_weight 將為批次中的每個元素應用不同的 pos_weight,或者大小為 [C, H, W] 的 pos_weight 將為批次中的所有元素應用相同的 pos_weight。要為 2D 多類目標 [C, H, W] 沿所有空間維度應用相同的正權重,請使用:[C, 1, 1]。預設值:
None
- 形狀
輸入: ,其中 表示任意數量的維度。
目標:,與輸入形狀相同。
輸出:標量。如果
reduction為'none',則 ,與輸入形狀相同。
示例
>>> loss = nn.BCEWithLogitsLoss() >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> output = loss(input, target) >>> output.backward()