torch.nn.utils.rnn.pad_packed_sequence#
- torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[原始碼]#
填充一個已打包的可變長度序列批次。
它是
pack_padded_sequence()的逆操作。返回的 Tensor 的資料大小將是
T x B x *(如果batch_first為False) 或B x T x *(如果batch_first為True) ,其中T是最長序列的長度,B是批次大小。示例
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens = [2, 1, 3] >>> packed = pack_padded_sequence( ... seq, lens, batch_first=True, enforce_sorted=False ... ) >>> packed PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) >>> seq_unpacked tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens_unpacked tensor([2, 1, 3])
注意
total_length對於在Module中包裝的DataParallel實現pack sequence -> recurrent network -> unpack sequence模式非常有用。詳情請參見 本 FAQ 部分。- 引數
sequence (PackedSequence) – 要填充的批次
batch_first (bool, optional) – 如果為
True,輸出格式為B x T x *,否則為T x B x *。padding_value (float, optional) – 填充元素的取值。
total_length (int, optional) – 如果不為
None,則輸出將填充到長度為total_length。如果total_length小於sequence中的最大序列長度,此方法將丟擲ValueError。
- 返回
包含填充序列的 Tensor 和包含批次中每個序列長度的 Tensor 的元組。批次元素將按原始傳遞給
pack_padded_sequence或pack_sequence時的順序重新排序。- 返回型別