評價此頁

CTCLoss#

class torch.nn.modules.loss.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[原始碼]#

連線主義時間分類損失。

計算連續(未分割)時間序列與目標序列之間的損失。CTCLoss 對輸入到目標的可能對齊的機率進行求和,生成一個相對於每個輸入節點可微的損失值。輸入到目標的對齊被假定為“多對一”,這限制了目標序列的長度,使其必須 \leq 輸入長度。

引數
  • blank (int, optional) – 空標籤。預設為 00

  • reduction (str, optional) – 指定要應用於輸出的縮減:'none' | 'mean' | 'sum''none':不應用縮減;'mean':輸出的損失將除以目標長度,然後對批次取平均值;'sum':輸出的損失將進行求和。預設為 'mean'

  • zero_infinity (bool, optional) – 是否將無窮損失及其相關梯度歸零。預設為 False。無窮損失主要發生在輸入太短而無法與目標對齊時。

形狀
  • Log_probs:大小為 (T,N,C)(T, N, C)(T,C)(T, C) 的張量,其中 T=輸入長度T = \text{input length}N=批次大小N = \text{batch size},以及 C=類別數(包括空標籤)C = \text{number of classes (including blank)}。輸出的對數機率(例如,透過 torch.nn.functional.log_softmax() 獲得)。

  • Targets:大小為 (N,S)(N, S)(sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 的張量,其中 N=批次大小N = \text{batch size}S=最大目標長度,如果形狀為(N,S)S = \text{max target length, if shape is } (N, S)。它表示目標序列。目標序列中的每個元素都是一個類別索引。目標索引不能是空標籤(預設值為 0)。在 (N,S)(N, S) 形式中,目標被填充到最長序列的長度,然後堆疊。在 (sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 形式中,目標被假定為未填充且在 1 維內連線。

  • Input_lengths:大小為 (N)(N)()() 的元組或張量,其中 N=批次大小N = \text{batch size}。它表示輸入的長度(每個長度必須 T\leq T)。並且長度是為每個序列指定的,以實現掩碼,假定序列被填充到相同的長度。

  • Target_lengths:大小為 (N)(N)()() 的元組或張量,其中 N=批次大小N = \text{batch size}。它表示目標的長度。長度是為每個序列指定的,以實現掩碼,假定序列被填充到相同的長度。如果目標形狀為 (N,S)(N,S),則 target_lengths 實際上是每個目標序列的停止索引 sns_n,使得對於批次中的每個目標 target_n = targets[n,0:s_n]。長度必須每個都 S\leq S。如果目標以 1d 張量的形式給出,它是各個目標的連線,則 target_lengths 必須加起來等於張量的總長度。

  • Output:如果 reduction'mean'(預設)或 'sum',則為標量。如果 reduction'none',則為大小為 (N)(N)(如果輸入是批次的)或 ()()(如果輸入不是批次的)的張量,其中 N=批次大小N = \text{batch size}

示例

>>> # Target are to be padded
>>> T = 50  # Input sequence length
>>> C = 20  # Number of classes (including blank)
>>> N = 16  # Batch size
>>> S = 30  # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(
...     low=S_min,
...     high=S,
...     size=(N,),
...     dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50  # Input sequence length
>>> C = 20  # Number of classes (including blank)
>>> N = 16  # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(
...     low=1,
...     high=C,
...     size=(sum(target_lengths),),
...     dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50  # Input sequence length
>>> C = 20  # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(
...     low=1,
...     high=C,
...     size=(target_lengths,),
...     dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
參考

A. Graves 等人:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks:https://www.cs.toronto.edu/~graves/icml_2006.pdf

注意

要使用 CuDNN,必須滿足以下條件:targets 必須是連線格式,所有 input_lengths 必須為 Tblank=0blank=0target_lengths 256\leq 256,整數引數必須是 torch.int32 型別。

常規實現使用(在 PyTorch 中更常用)的 torch.long 資料型別。

注意

在某些情況下,當使用帶有 CuDNN 的 CUDA 後端時,此運算子可能會選擇非確定性演算法以提高效能。如果您不希望這樣做,可以嘗試透過設定 torch.backends.cudnn.deterministic = True 來使操作確定化(可能會犧牲效能)。有關背景資訊,請參閱關於 可復現性 的說明。

forward(log_probs, targets, input_lengths, target_lengths)[原始碼]#

執行前向傳播。

返回型別

張量