評價此頁

MultiheadAttention#

class torch.nn.modules.activation.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[原始碼]#

允許模型聯合關注來自不同表示子空間的資訊。

此 MultiheadAttention 層實現了 Attention Is All You Need 論文中描述的原始架構。此層的目的是作為基礎理解的參考實現,因此它僅包含相對於較新架構的有限功能。鑑於 Transformer 類架構的快速創新步伐,我們建議探索此 教程,從核心構建塊構建高效層,或使用 PyTorch 生態系統 中的更高階庫。

多頭注意力定義為

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

其中 headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

nn.MultiheadAttention 在可能的情況下將使用 scaled_dot_product_attention() 的最佳化實現。

除了支援新的 scaled_dot_product_attention() 函式外,為了加速推理,MHA 將使用快速路徑推理,並支援 Nested Tensors,前提是:

  • 正在計算自注意力(即 querykeyvalue 是同一個張量)。

  • 輸入是批處理的(3D),並且 batch_first==True

  • 自動梯度被停用(使用 torch.inference_modetorch.no_grad)或者沒有張量引數 requires_grad

  • 訓練被停用(使用 .eval()

  • add_bias_kvFalse

  • add_zero_attnFalse

  • kdimvdim 等於 embed_dim

  • 如果傳遞了 NestedTensor,則不傳遞 key_padding_maskattn_mask

  • autocast 被停用。

如果正在使用最佳化的推理快速路徑實現,則可以將 NestedTensor 傳遞給 query/key/value,以比使用 padding mask 更有效地表示 padding。在這種情況下,將返回一個 NestedTensor,並且可以預期相對於輸入中 padding 的比例有額外的加速。

引數
  • embed_dim – 模型的總維度。

  • num_heads – 並行注意力頭的數量。請注意,embed_dim 將被分割給 num_heads(即每個頭的維度將是 embed_dim // num_heads)。

  • dropoutattn_output_weights 上的 Dropout 機率。預設為 0.0(無 dropout)。

  • bias – 如果指定,則為輸入/輸出投影層新增偏置。預設為 True

  • add_bias_kv – 如果指定,則在 dim=0 處為 key 和 value 序列新增偏置。預設為 False

  • add_zero_attn – 如果指定,則在 dim=1 處為 key 和 value 序列新增新的零批次。預設為 False

  • kdim – key 的總特徵數。預設為 None(使用 kdim=embed_dim)。

  • vdim – value 的總特徵數。預設為 None(使用 vdim=embed_dim)。

  • batch_first – 如果為 True,則輸入和輸出張量為 (batch, seq, feature)。預設為 False(seq, batch, feature)。

示例

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[原始碼]#

使用 query、key 和 value 嵌入計算注意力輸出。

支援 padding、mask 和注意力權重的可選引數。

引數
  • query (Tensor) – shape 為 (L,Eq)(L, E_q) 的未批處理輸入,當 batch_first=False 時 shape 為 (L,N,Eq)(L, N, E_q) 或當 batch_first=True 時 shape 為 (N,L,Eq)(N, L, E_q),其中 LL 是目標序列長度,NN 是批次大小,EqE_q 是 query 嵌入維度 embed_dim。Query 與 key-value 對進行比較以生成輸出。更多細節請參見“Attention Is All You Need”。

  • key (Tensor) – shape 為 (S,Ek)(S, E_k) 的未批處理輸入,當 batch_first=False 時 shape 為 (S,N,Ek)(S, N, E_k) 或當 batch_first=True 時 shape 為 (N,S,Ek)(N, S, E_k),其中 SS 是源序列長度,NN 是批次大小,EkE_k 是 key 嵌入維度 kdim。更多細節請參見“Attention Is All You Need”。

  • value (Tensor) – shape 為 (S,Ev)(S, E_v) 的未批處理輸入,當 batch_first=False 時 shape 為 (S,N,Ev)(S, N, E_v) 或當 batch_first=True 時 shape 為 (N,S,Ev)(N, S, E_v),其中 SS 是源序列長度,NN 是批次大小,EvE_v 是 value 嵌入維度 vdim。更多細節請參見“Attention Is All You Need”。

  • key_padding_mask (Optional[Tensor]) – 如果指定,則為 shape 為 (N,S)(N, S) 的掩碼,指示 key 中哪些元素應被忽略(即被視為“padding”)。對於未批處理的 query,shape 應為 (S)(S)。支援二進位制和浮點掩碼。對於二進位制掩碼,True 值表示相應的 key 值將被忽略。對於浮點掩碼,它將直接新增到相應的 key 值中。

  • need_weights (bool) – 如果指定,除了 attn_outputs 外,還將返回 attn_output_weights。將 need_weights=False 以使用最佳化的 scaled_dot_product_attention 併為 MHA 獲得最佳效能。預設為 True

  • attn_mask (Optional[Tensor]) – 如果指定,則為 2D 或 3D 掩碼,可防止注意力集中到某些位置。必須是 shape (L,S)(L, S)(Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S),其中 NN 是批次大小,LL 是目標序列長度,SS 是源序列長度。2D 掩碼將廣播到整個批次,而 3D 掩碼允許批次中的每個條目都有不同的掩碼。支援二進位制和浮點掩碼。對於二進位制掩碼,True 值表示相應的位置不允許注意力。對於浮點掩碼,掩碼值將加到注意力權重上。如果同時提供了 attn_mask 和 key_padding_mask,它們的型別應匹配。

  • average_attn_weights (bool) – 如果為 True,則表示返回的 attn_weights 應在 heads 之間取平均值。否則,attn_weights 將按 head 分別提供。請注意,此標誌僅在 need_weights=True 時生效。預設為 True(即在 heads 之間平均權重)。

  • is_causal (bool) – 如果指定,則將因果掩碼作為注意力掩碼應用。預設為 False。警告:is_causal 提供了一個 attn_mask 是因果掩碼的提示。提供不正確的提示可能導致執行錯誤,包括向前和向後相容性。

返回型別

tuple[torch.Tensor, Optional[torch.Tensor]]

輸出
  • attn_output - 注意力輸出,shape 為 (L,E)(L, E)(當輸入未批處理時),(L,N,E)(L, N, E)(當 batch_first=False 時)或 (N,L,E)(N, L, E)(當 batch_first=True 時),其中 LL 是目標序列長度,NN 是批次大小,EE 是嵌入維度 embed_dim

  • attn_output_weights – 僅當 need_weights=True 時返回。如果 average_attn_weights=True,則返回在 heads 之間平均後的注意力權重,shape 為 (L,S)(L, S)(當輸入未批處理時)或 (N,L,S)(N, L, S)(當 batch_first=False 時),其中 NN 是批次大小,LL 是目標序列長度,SS 是源序列長度。如果 average_attn_weights=False,則返回按 head 分佈的注意力權重,shape 為 (num_heads,L,S)(\text{num\_heads}, L, S)(當輸入未批處理時)或 (N,num_heads,L,S)(N, \text{num\_heads}, L, S)(當輸入批處理時)。

注意

batch_first 引數對於未批處理的輸入將被忽略。

merge_masks(attn_mask, key_padding_mask, query)[原始碼]#

確定掩碼型別併合並掩碼(如果需要)。

如果只提供一個掩碼,將返回該掩碼和相應的掩碼型別。如果同時提供兩個掩碼,它們都將被擴充套件到 shape (batch_size, num_heads, seq_len, seq_len),並使用邏輯 or 合併,並將返回掩碼型別 2 :param attn_mask: shape 為 (seq_len, seq_len) 的注意力掩碼,掩碼型別 0 :param key_padding_mask: shape 為 (batch_size, seq_len) 的 padding 掩碼,掩碼型別 1 :param query: shape 為 (batch_size, seq_len, embed_dim) 的 query 嵌入

返回

merged mask mask_type: 合併後的掩碼型別(0、1 或 2)

返回型別

merged_mask