評價此頁

TransformerEncoder#

class torch.nn.modules.transformer.TransformerEncoder(encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True)[原始碼]#

TransformerEncoder 是 N 個編碼器層的堆疊。

此 TransformerEncoder 層實現了 Attention Is All You Need 論文中描述的原始架構。此層的目的是作為基礎理解的參考實現,因此相比於較新的 Transformer 架構,它只包含有限的功能。鑑於 Transformer 類架構的快速創新步伐,我們建議探索此 教程,使用核心模組中的構建塊來構建高效的層,或使用 PyTorch 生態系統 中的更高階庫。

警告

TransformerEncoder 中的所有層都使用相同的引數進行初始化。建議在建立 TransformerEncoder 例項後手動初始化層。

引數
  • encoder_layer (TransformerEncoderLayer) – TransformerEncoderLayer() 類的例項(必需)。

  • num_layers (int) – 編碼器中子編碼器層的數量(必需)。

  • norm (Optional[Module]) – 層歸一化元件(可選)。

  • enable_nested_tensor (bool) – 如果為 True,輸入將自動轉換為巢狀張量(並在輸出時轉換回來)。當填充率很高時,這將提高 TransformerEncoder 的整體效能。預設為 True(啟用)。

示例

>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
forward(src, mask=None, src_key_padding_mask=None, is_causal=None)[原始碼]#

依次將輸入透過編碼器層。

引數
  • src (Tensor) – 到編碼器的序列(必需)。

  • mask (Optional[Tensor]) – src 序列的掩碼(可選)。

  • src_key_padding_mask (Optional[Tensor]) – src 鍵的每批掩碼(可選)。

  • is_causal (Optional[bool]) – 如果指定,將 mask 應用為因果掩碼。預設為 None;嘗試檢測因果掩碼。警告:is_causal 提供了一個提示,表明 mask 是因果掩碼。提供錯誤的提示可能導致執行錯誤,包括向前和向後相容性。

返回型別

張量

形狀

請參閱Transformer中的文件。