torch.nn.functional.ctc_loss#
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[原始碼]#
計算連線主義時序分類(Connectionist Temporal Classification)損失。
詳情請參閱
CTCLoss。注意
在某些情況下,當在 CUDA 裝置上使用張量並利用 CuDNN 時,此運算元可能會選擇一個非確定性演算法來提高效能。如果這不可取,你可以嘗試將操作設定為確定性的(可能以效能為代價),方法是設定
torch.backends.cudnn.deterministic = True。有關更多資訊,請參閱 可復現性。注意
此操作在使用 CUDA 裝置上的張量時可能會產生非確定性梯度。有關更多資訊,請參閱 可復現性。
- 引數
log_probs (Tensor) – 或 ,其中 C = 字母表中的字元數(包括空白字元),T = 輸入長度,N = 批次大小。輸出的對數機率(例如,透過
torch.nn.functional.log_softmax()獲得)。targets (Tensor) – 或 (sum(target_lengths))。如果 target_lengths 中的所有條目都為零,則可能為空張量。在第二種形式中,目標被假定為已連線。
input_lengths (Tensor) – 或 。輸入的長度(每個都必須 )
target_lengths (Tensor) – 或 。目標的長度
blank (int, optional) – 空白標籤。預設為 。
reduction (str, optional) – 指定應用於輸出的規約:
'none'|'mean'|'sum'。'none':不進行規約,'mean':輸出損失將除以目標長度,然後取批次上的均值,'sum':輸出將被求和。預設為'mean'。zero_infinity (bool, optional) – 是否將無窮損失及其相關梯度歸零。預設為
False。無窮損失主要發生在輸入太短而無法與目標對齊時。
- 返回型別
示例
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()