評價此頁

TransformerDecoder#

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[原始碼]#

TransformerDecoder 是 N 個解碼器層的堆疊。

此 TransformerDecoder 層實現了 Attention Is All You Need 論文中描述的原始架構。本層的目的是作為基礎理解的參考實現,因此與較新的 Transformer 架構相比,它只包含有限的功能。鑑於類 Transformer 架構的快速創新,我們建議您探索此 教程,以從核心構建塊構建高效層,或者使用 PyTorch 生態系統 的更高級別庫。

警告

TransformerDecoder 中的所有層都使用相同的引數進行初始化。建議在建立 TransformerDecoder 例項後手動初始化各層。

引數
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 類的例項(必需)。

  • num_layers (int) – 解碼器中子解碼器層的數量(必需)。

  • norm (Optional[Module]) – 層歸一化元件(可選)。

示例

>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[原始碼]#

依次透過解碼器層傳遞輸入(和掩碼)。

引數
  • tgt (Tensor) – 到解碼器的序列(必需)。

  • memory (Tensor) – 來自編碼器最後一層的序列(必需)。

  • tgt_mask (Optional[Tensor]) – tgt 序列的掩碼(可選)。

  • memory_mask (Optional[Tensor]) – memory 序列的掩碼(可選)。

  • tgt_key_padding_mask (Optional[Tensor]) – 每個批次的 tgt 鍵的掩碼(可選)。

  • memory_key_padding_mask (Optional[Tensor]) – 每個批次的 memory 鍵的掩碼(可選)。

  • tgt_is_causal (Optional[bool]) – 如果指定,則將因果掩碼應用於 tgt mask。預設為 None;嘗試檢測因果掩碼。警告:tgt_is_causal 提供了一個提示,表明 tgt_mask 是因果掩碼。提供不正確的提示可能導致執行不正確,包括前向和後向相容性。

  • memory_is_causal (bool) – 如果指定,則將因果掩碼應用為 memory mask。預設值:False。警告:memory_is_causal 提供了一個提示,即 memory_mask 是因果掩碼。提供錯誤的提示可能導致執行不正確,包括向前和向後相容性。

返回型別

張量

形狀

請參閱Transformer中的文件。