CTCLoss#
- class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[原始碼]#
連線主義時間分類損失。
計算連續(非分段)時間序列與目標序列之間的損失。CTCLoss 對輸入到目標的可能對齊的機率進行求和,生成一個相對於每個輸入節點可微分的損失值。輸入到目標的對齊被假定為“多對一”,這限制了目標序列的長度,使其必須 輸入長度。
- 引數
- 形狀
Log_probs:大小為 或 的張量,其中 ,,以及 。輸出的對數機率(例如,透過
torch.nn.functional.log_softmax()獲得)。Targets:大小為 或 的張量,其中 ,且 。它表示目標序列。目標序列中的每個元素都是一個類別索引。並且目標索引不能是空白(預設值為 0)。在 形式中,目標會被填充到最長序列的長度,然後堆疊。在 形式中,目標被假定為未填充的,並在 1 維內連線。
Input_lengths:大小為 或 的元組或張量,其中 。它表示輸入的長度(每個必須)。長度是為每個序列指定的,以便在序列被填充到相等長度的假設下實現掩碼。
Target_lengths:大小為 或 的元組或張量,其中 。它表示目標的長度。長度是為每個序列指定的,以便在序列被填充到相等長度的假設下實現掩碼。如果目標形狀為 ,則 target_lengths 實際上是每個目標序列的停止索引 ,使得
target_n = targets[n,0:s_n]對批次中的每個目標都成立。長度必須分別。如果目標以 1D 張量形式給出,該張量是各個目標的串聯,則 target_lengths 必須加起來等於張量的總長度。Output:如果
reduction是'mean'(預設)或'sum',則為標量。如果reduction是'none',則為 (如果輸入是批處理的)或 (如果輸入不是批處理的),其中 。
示例
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint( ... low=S_min, ... high=S, ... size=(N,), ... dtype=torch.long, ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint( ... low=1, ... high=C, ... size=(sum(target_lengths),), ... dtype=torch.long, ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded and unbatched (effectively N=1) >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> >>> # Initialize random batch of input vectors, for *size = (T,C) >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() >>> input_lengths = torch.tensor(T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) >>> target = torch.randint( ... low=1, ... high=C, ... size=(target_lengths,), ... dtype=torch.long, ... ) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 參考
A. Graves et al.: Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks: https://www.cs.toronto.edu/~graves/icml_2006.pdf
注意
為了使用 CuDNN,必須滿足以下條件:
targets必須是串聯格式,所有input_lengths必須是 T。,target_lengths,整數引數必須是torch.int32型別。常規實現使用的是(在 PyTorch 中更常見的)torch.long 型別。
注意
在某些情況下,當使用帶有 CuDNN 的 CUDA 後端時,此運算子可能會選擇非確定性演算法以提高效能。如果您不希望這樣做,可以嘗試透過設定
torch.backends.cudnn.deterministic = True來使操作確定化(可能會犧牲效能)。有關背景資訊,請參閱關於 可復現性 的說明。