MultiheadAttention#
- class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[原始碼]#
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[原始碼]#
- 注意:
有關更多資訊,請參閱
forward()。
- 引數
query (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。
key (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。
value (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。
key_padding_mask (Optional[Tensor]) – 如果提供,指定的鍵中的填充元素將被注意力忽略。當給定二進位制掩碼且值為 True 時,將忽略注意力層上的相應值。
need_weights (bool) – 輸出 attn_output_weights。
attn_mask (Optional[Tensor]) – 2D 或 3D 掩碼,可防止注意力指向特定位置。2D 掩碼將廣播到所有批次,而 3D 掩碼允許為每個批次的條目指定不同的掩碼。
- 返回型別
- 形狀
輸入
query: ,其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True,則為 。key: ,其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True,則為 。value: ,其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True,則為 。key_padding_mask: ,其中 N 是批次大小,S 是源序列長度。如果提供 BoolTensor,值為
True的位置將被忽略,而值為False的位置將保持不變。attn_mask: 2D 掩碼 ,其中 L 是目標序列長度,S 是源序列長度。3D 掩碼 ,其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。attn_mask 確保位置 i 可以注意力到未遮掩的位置。如果提供 BoolTensor,值為
True的位置不允許注意力,而False值將保持不變。如果提供 FloatTensor,它將被新增到注意力權重中。is_causal: 如果指定,則將因果掩碼用作注意力掩碼。與提供 attn_mask 互斥。預設值:
False。average_attn_weights: 如果為 True,則表示返回的
attn_weights應該跨頭平均。否則,attn_weights將按頭單獨提供。請注意,此標誌僅在need_weights=True時有效。預設值:True(即平均跨頭權重)。輸出
attn_output: ,其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果
batch_first為True,則為 。attn_output_weights: 如果
average_attn_weights=True,則返回跨頭平均的注意力權重,形狀為 ,其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。 如果average_attn_weights=False,則返回每個頭的注意力權重,形狀為 。