MultiheadAttention#
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源]#
允許模型聯合關注來自不同表示子空間的資訊。
此 MultiheadAttention 層實現了 Attention Is All You Need 論文中描述的原始架構。該層旨在作為基礎理解的參考實現,因此其功能相對於較新的架構而言僅限於有限的功能。鑑於 Transformer 類架構的快速創新步伐,我們建議您探索此 教程,以從核心構建塊構建高效層,或使用 PyTorch 生態系統中的更高階庫。
多頭注意力 (Multi-Head Attention) 定義為:
其中 。
nn.MultiheadAttention將在可能的情況下使用scaled_dot_product_attention()的最佳化實現。除了支援新的
scaled_dot_product_attention()函式外,為了加速推理,MHA 還將使用支援 Nested Tensors 的 fastpath 推理,前提是:計算自注意力(即
query、key和value是同一個張量)。輸入是批處理的(3D),並且
batch_first==True。自動梯度被停用(使用
torch.inference_mode或torch.no_grad)或者沒有張量引數requires_grad訓練被停用(使用
.eval())add_bias_kv為False。add_zero_attn為False。kdim和vdim等於embed_dim。如果傳遞了 NestedTensor,則不傳遞
key_padding_mask和attn_mask。autocast 已停用。
如果使用了最佳化的推理 fastpath 實現,則可以為
query/key/value傳遞 NestedTensor,以比使用 padding mask 更有效地表示 padding。在這種情況下,將返回 NestedTensor,並且可以預期額外的加速與輸入中 padding 的比例成正比。- 引數
embed_dim – 模型的總維度。
num_heads – 並行注意力頭的數量。請注意,
embed_dim將被分割到num_heads中(即每個頭的維度為embed_dim // num_heads)。dropout –
attn_output_weights上的 Dropout 機率。預設值:0.0(無 dropout)。bias – 如果指定,則向輸入/輸出投影層新增偏置。預設值:
True。add_bias_kv – 如果指定,則向 key 和 value 序列在 dim=0 處新增偏置。預設值:
False。add_zero_attn – 如果指定,則向 key 和 value 序列在 dim=1 處新增新的零批次。預設值:
False。kdim – key 的總特徵數。預設值:
None(使用kdim=embed_dim)。vdim – value 的總特徵數。預設值:
None(使用vdim=embed_dim)。batch_first – 如果為
True,則輸入和輸出張量為 (batch, seq, feature)。預設值:False(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源]#
使用 query、key 和 value 嵌入計算注意力輸出。
支援用於 padding、mask 和注意力權重的可選引數。
- 引數
query (Tensor) – shape 為 的未批處理輸入, 當
batch_first=False時,或 當batch_first=True時,其中 是目標序列長度, 是批次大小, 是查詢嵌入維度embed_dim。查詢與鍵值對進行比較以生成輸出。更多細節請參閱“Attention Is All You Need”。key (Tensor) – shape 為 的未批處理輸入, 當
batch_first=False時,或 當batch_first=True時,其中 是源序列長度, 是批次大小, 是鍵嵌入維度kdim。更多細節請參閱“Attention Is All You Need”。value (Tensor) – shape 為 的未批處理輸入, 當
batch_first=False時,或 當batch_first=True時,其中 是源序列長度, 是批次大小, 是值嵌入維度vdim。更多細節請參閱“Attention Is All You Need”。key_padding_mask (Optional[Tensor]) – 如果指定,則為 shape 為 的掩碼,指示
key中哪些元素在注意力計算中被忽略(即被視為“padding”)。對於未批處理的 query,shape 應為 。支援二進位制和浮點掩碼。對於二進位制掩碼,True值表示相應的key值在注意力計算中將被忽略。對於浮點掩碼,它將直接加到相應的key值上。need_weights (bool) – 如果指定,則除了
attn_outputs之外,還返回attn_output_weights。將need_weights=False以使用最佳化的scaled_dot_product_attention併為 MHA 獲得最佳效能。預設值:True。attn_mask (Optional[Tensor]) – 如果指定,則為 2D 或 3D 掩碼,可防止注意力關注某些位置。必須為 shape 或 ,其中 是批次大小, 是目標序列長度, 是源序列長度。2D 掩碼將廣播到整個批次,而 3D 掩碼則允許每個批次條目都有不同的掩碼。支援二進位制和浮點掩碼。對於二進位制掩碼,
True值表示不允許關注相應的位置。對於浮點掩碼,掩碼值將加到注意力權重上。如果同時提供了 attn_mask 和 key_padding_mask,它們的型別應匹配。average_attn_weights (bool) – 如果為 true,表示返回的
attn_weights應在各頭之間取平均值。否則,attn_weights按頭分開提供。請注意,此標誌僅在need_weights=True時生效。預設值:True(即平均各頭權重)。is_causal (bool) – 如果指定,則將因果掩碼作為注意力掩碼應用。預設值:
False。警告:is_causal提供了一個提示,即attn_mask是因果掩碼。提供錯誤的提示可能導致執行錯誤,包括向前和向後相容性問題。
- 返回型別
- 輸出
attn_output - shape 為 的注意力輸出,當輸入未批處理時; 當
batch_first=False時;或 當batch_first=True時,其中 是目標序列長度, 是批次大小, 是嵌入維度embed_dim。attn_output_weights - 僅當
need_weights=True時返回。如果average_attn_weights=True,則返回平均後的注意力權重,shape 為 ,當輸入未批處理時,或 ,當batch_first=False時,其中 是批次大小, 是目標序列長度, 是源序列長度。如果average_attn_weights=False,則返回各頭的注意力權重,shape 為 ,當輸入未批處理時,或 ,當batch_first=False時。
注意
batch_first 引數對於未批處理的輸入將被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[源]#
確定掩碼型別並根據需要合併掩碼。
如果只提供一個掩碼,則返回該掩碼和對應的掩碼型別。如果同時提供兩個掩碼,它們將被擴充套件到 shape
(batch_size, num_heads, seq_len, seq_len),並透過邏輯or合併,並返回掩碼型別 2::param attn_mask: attention mask,shape 為(seq_len, seq_len),掩碼型別 0 :param key_padding_mask: padding mask,shape 為(batch_size, seq_len),掩碼型別 1 :param query: query embeddings,shape 為(batch_size, seq_len, embed_dim)- 返回
merged mask mask_type: 合併後的掩碼型別(0、1 或 2)。
- 返回型別
merged_mask