評價此頁

torch.nn.functional.ctc_loss#

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[原始碼]#

計算連線主義時序分類(Connectionist Temporal Classification)損失。

詳情請參閱 CTCLoss

注意

在某些情況下,當在 CUDA 裝置上使用張量並利用 CuDNN 時,此運算元可能會選擇一個非確定性演算法來提高效能。如果這不可取,你可以嘗試將操作設定為確定性的(可能以效能為代價),方法是設定 torch.backends.cudnn.deterministic = True。有關更多資訊,請參閱 可復現性

注意

此操作在使用 CUDA 裝置上的張量時可能會產生非確定性梯度。有關更多資訊,請參閱 可復現性

引數
  • log_probs (Tensor) – (T,N,C)(T, N, C)(T,C)(T, C),其中 C = 字母表中的字元數(包括空白字元)T = 輸入長度N = 批次大小。輸出的對數機率(例如,透過 torch.nn.functional.log_softmax() 獲得)。

  • targets (Tensor) – (N,S)(N, S)(sum(target_lengths))。如果 target_lengths 中的所有條目都為零,則可能為空張量。在第二種形式中,目標被假定為已連線。

  • input_lengths (Tensor) – (N)(N)()()。輸入的長度(每個都必須 T\leq T)

  • target_lengths (Tensor) – (N)(N)()()。目標的長度

  • blank (int, optional) – 空白標籤。預設為 00

  • reduction (str, optional) – 指定應用於輸出的規約:'none' | 'mean' | 'sum''none':不進行規約,'mean':輸出損失將除以目標長度,然後取批次上的均值,'sum':輸出將被求和。預設為 'mean'

  • zero_infinity (bool, optional) – 是否將無窮損失及其相關梯度歸零。預設為 False。無窮損失主要發生在輸入太短而無法與目標對齊時。

返回型別

張量

示例

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()