評價此頁

MultiheadAttention#

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源]#

允許模型聯合關注來自不同表示子空間的資訊。

此 MultiheadAttention 層實現了 Attention Is All You Need 論文中描述的原始架構。該層旨在作為基礎理解的參考實現,因此其功能相對於較新的架構而言僅限於有限的功能。鑑於 Transformer 類架構的快速創新步伐,我們建議您探索此 教程,以從核心構建塊構建高效層,或使用 PyTorch 生態系統中的更高階庫。

多頭注意力 (Multi-Head Attention) 定義為:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

其中 headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

nn.MultiheadAttention 將在可能的情況下使用 scaled_dot_product_attention() 的最佳化實現。

除了支援新的 scaled_dot_product_attention() 函式外,為了加速推理,MHA 還將使用支援 Nested Tensors 的 fastpath 推理,前提是:

  • 計算自注意力(即 querykeyvalue 是同一個張量)。

  • 輸入是批處理的(3D),並且 batch_first==True

  • 自動梯度被停用(使用 torch.inference_modetorch.no_grad)或者沒有張量引數 requires_grad

  • 訓練被停用(使用 .eval()

  • add_bias_kvFalse

  • add_zero_attnFalse

  • kdimvdim 等於 embed_dim

  • 如果傳遞了 NestedTensor,則不傳遞 key_padding_maskattn_mask

  • autocast 已停用。

如果使用了最佳化的推理 fastpath 實現,則可以為 query/key/value 傳遞 NestedTensor,以比使用 padding mask 更有效地表示 padding。在這種情況下,將返回 NestedTensor,並且可以預期額外的加速與輸入中 padding 的比例成正比。

引數
  • embed_dim – 模型的總維度。

  • num_heads – 並行注意力頭的數量。請注意,embed_dim 將被分割到 num_heads 中(即每個頭的維度為 embed_dim // num_heads)。

  • dropoutattn_output_weights 上的 Dropout 機率。預設值:0.0(無 dropout)。

  • bias – 如果指定,則向輸入/輸出投影層新增偏置。預設值:True

  • add_bias_kv – 如果指定,則向 key 和 value 序列在 dim=0 處新增偏置。預設值:False

  • add_zero_attn – 如果指定,則向 key 和 value 序列在 dim=1 處新增新的零批次。預設值:False

  • kdim – key 的總特徵數。預設值:None(使用 kdim=embed_dim)。

  • vdim – value 的總特徵數。預設值:None(使用 vdim=embed_dim)。

  • batch_first – 如果為 True,則輸入和輸出張量為 (batch, seq, feature)。預設值:False (seq, batch, feature)。

示例

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源]#

使用 query、key 和 value 嵌入計算注意力輸出。

支援用於 padding、mask 和注意力權重的可選引數。

引數
  • query (Tensor) – shape 為 (L,Eq)(L, E_q) 的未批處理輸入,(L,N,Eq)(L, N, E_q)batch_first=False 時,或 (N,L,Eq)(N, L, E_q)batch_first=True 時,其中 LL 是目標序列長度,NN 是批次大小,EqE_q 是查詢嵌入維度 embed_dim。查詢與鍵值對進行比較以生成輸出。更多細節請參閱“Attention Is All You Need”。

  • key (Tensor) – shape 為 (S,Ek)(S, E_k) 的未批處理輸入,(S,N,Ek)(S, N, E_k)batch_first=False 時,或 (N,S,Ek)(N, S, E_k)batch_first=True 時,其中 SS 是源序列長度,NN 是批次大小,EkE_k 是鍵嵌入維度 kdim。更多細節請參閱“Attention Is All You Need”。

  • value (Tensor) – shape 為 (S,Ev)(S, E_v) 的未批處理輸入,(S,N,Ev)(S, N, E_v)batch_first=False 時,或 (N,S,Ev)(N, S, E_v)batch_first=True 時,其中 SS 是源序列長度,NN 是批次大小,EvE_v 是值嵌入維度 vdim。更多細節請參閱“Attention Is All You Need”。

  • key_padding_mask (Optional[Tensor]) – 如果指定,則為 shape 為 (N,S)(N, S) 的掩碼,指示 key 中哪些元素在注意力計算中被忽略(即被視為“padding”)。對於未批處理的 query,shape 應為 (S)(S)。支援二進位制和浮點掩碼。對於二進位制掩碼,True 值表示相應的 key 值在注意力計算中將被忽略。對於浮點掩碼,它將直接加到相應的 key 值上。

  • need_weights (bool) – 如果指定,則除了 attn_outputs 之外,還返回 attn_output_weights。將 need_weights=False 以使用最佳化的 scaled_dot_product_attention 併為 MHA 獲得最佳效能。預設值:True

  • attn_mask (Optional[Tensor]) – 如果指定,則為 2D 或 3D 掩碼,可防止注意力關注某些位置。必須為 shape (L,S)(L, S)(Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S),其中 NN 是批次大小,LL 是目標序列長度,SS 是源序列長度。2D 掩碼將廣播到整個批次,而 3D 掩碼則允許每個批次條目都有不同的掩碼。支援二進位制和浮點掩碼。對於二進位制掩碼,True 值表示不允許關注相應的位置。對於浮點掩碼,掩碼值將加到注意力權重上。如果同時提供了 attn_mask 和 key_padding_mask,它們的型別應匹配。

  • average_attn_weights (bool) – 如果為 true,表示返回的 attn_weights 應在各頭之間取平均值。否則,attn_weights 按頭分開提供。請注意,此標誌僅在 need_weights=True 時生效。預設值:True(即平均各頭權重)。

  • is_causal (bool) – 如果指定,則將因果掩碼作為注意力掩碼應用。預設值:False。警告:is_causal 提供了一個提示,即 attn_mask 是因果掩碼。提供錯誤的提示可能導致執行錯誤,包括向前和向後相容性問題。

返回型別

tuple[torch.Tensor, Optional[torch.Tensor]]

輸出
  • attn_output - shape 為 (L,E)(L, E) 的注意力輸出,當輸入未批處理時;(L,N,E)(L, N, E)batch_first=False 時;或 (N,L,E)(N, L, E)batch_first=True 時,其中 LL 是目標序列長度,NN 是批次大小,EE 是嵌入維度 embed_dim

  • attn_output_weights - 僅當 need_weights=True 時返回。如果 average_attn_weights=True,則返回平均後的注意力權重,shape 為 (L,S)(L, S),當輸入未批處理時,或 (N,L,S)(N, L, S),當 batch_first=False 時,其中 NN 是批次大小,LL 是目標序列長度,SS 是源序列長度。如果 average_attn_weights=False,則返回各頭的注意力權重,shape 為 (num_heads,L,S)(\text{num\_heads}, L, S),當輸入未批處理時,或 (N,num_heads,L,S)(N, \text{num\_heads}, L, S),當 batch_first=False 時。

注意

batch_first 引數對於未批處理的輸入將被忽略。

merge_masks(attn_mask, key_padding_mask, query)[源]#

確定掩碼型別並根據需要合併掩碼。

如果只提供一個掩碼,則返回該掩碼和對應的掩碼型別。如果同時提供兩個掩碼,它們將被擴充套件到 shape (batch_size, num_heads, seq_len, seq_len),並透過邏輯 or 合併,並返回掩碼型別 2::param attn_mask: attention mask,shape 為 (seq_len, seq_len),掩碼型別 0 :param key_padding_mask: padding mask,shape 為 (batch_size, seq_len),掩碼型別 1 :param query: query embeddings,shape 為 (batch_size, seq_len, embed_dim)

返回

merged mask mask_type: 合併後的掩碼型別(0、1 或 2)。

返回型別

merged_mask