評價此頁

torch.nn.utils.rnn.pack_padded_sequence#

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[原始碼]#

將包含可變長度填充序列的 Tensor 打包。

input 的形狀可以是 T x B x * (如果 batch_firstFalse) 或 B x T x * (如果 batch_firstTrue),其中 T 是最長序列的長度,B 是批次大小,而 * 是任意數量的維度(包括 0)。

對於未排序的序列,請使用 enforce_sorted = False。如果 enforce_sortedTrue,序列應該按長度遞減排序,即 input[:,0] 應該是最長序列,input[:,B-1] 應該是最短序列。enforce_sorted = True 僅對 ONNX 匯出是必需的。

它是 pad_packed_sequence() 的逆向操作,因此可以使用 pad_packed_sequence() 來恢復 PackedSequence 中打包的底層張量。

注意

此函式接受任何至少有兩個維度的輸入。您可以將其應用於打包標籤,並使用 RNN 的輸出來直接計算損失。可以透過訪問 PackedSequence 物件的 .data 屬性來檢索張量。

引數
  • input (Tensor) – 可變長度序列的填充批次。

  • lengths (Tensorlist(int)) – 每個批次元素的序列長度列表(如果以 tensor 形式提供,則必須在 CPU 上)。

  • batch_first (bool, optional) – 如果為 True,則輸入格式為 B x T x *;否則為 T x B x *。預設為 False

  • enforce_sorted (bool, optional) – 如果為 True,則輸入應包含按長度遞減排序的序列。如果為 False,則輸入將無條件排序。預設為 True

返回型別

PackedSequence

警告

如果 input 張量的維度大於 length 中對應值,則該維度將被截斷。

返回

一個 PackedSequence 物件

返回型別

PackedSequence