評價此頁

MultiheadAttention#

class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[原始碼]#
dequantize()[原始碼]#

將量化後的MHA轉換回浮點數。

這樣做的動機是,將量化版本中使用的格式的權重轉換回浮點數並非易事。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[原始碼]#
注意:

有關更多資訊,請參閱 forward()

引數
  • query (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。

  • key (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。

  • value (Tensor) – 將查詢和一組鍵值對對映到輸出。更多細節請參閱“Attention Is All You Need”。

  • key_padding_mask (Optional[Tensor]) – 如果提供,指定的鍵中的填充元素將被注意力忽略。當給定二進位制掩碼且值為 True 時,將忽略注意力層上的相應值。

  • need_weights (bool) – 輸出 attn_output_weights。

  • attn_mask (Optional[Tensor]) – 2D 或 3D 掩碼,可防止注意力指向特定位置。2D 掩碼將廣播到所有批次,而 3D 掩碼允許為每個批次的條目指定不同的掩碼。

返回型別

tuple[torch.Tensor, Optional[torch.Tensor]]

形狀
  • 輸入

  • query: (L,N,E)(L, N, E),其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果 batch_firstTrue,則為 (N,L,E)(N, L, E)

  • key: (S,N,E)(S, N, E),其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果 batch_firstTrue,則為 (N,S,E)(N, S, E)

  • value: (S,N,E)(S, N, E),其中 S 是源序列長度,N 是批次大小,E 是嵌入維度。 如果 batch_firstTrue,則為 (N,S,E)(N, S, E)

  • key_padding_mask: (N,S)(N, S),其中 N 是批次大小,S 是源序列長度。如果提供 BoolTensor,值為 True 的位置將被忽略,而值為 False 的位置將保持不變。

  • attn_mask: 2D 掩碼 (L,S)(L, S),其中 L 是目標序列長度,S 是源序列長度。3D 掩碼 (Nnumheads,L,S)(N*num_heads, L, S),其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。attn_mask 確保位置 i 可以注意力到未遮掩的位置。如果提供 BoolTensor,值為 True 的位置不允許注意力,而 False 值將保持不變。如果提供 FloatTensor,它將被新增到注意力權重中。

  • is_causal: 如果指定,則將因果掩碼用作注意力掩碼。與提供 attn_mask 互斥。預設值:False

  • average_attn_weights: 如果為 True,則表示返回的 attn_weights 應該跨頭平均。否則,attn_weights 將按頭單獨提供。請注意,此標誌僅在 need_weights=True 時有效。預設值:True(即平均跨頭權重)。

  • 輸出

  • attn_output: (L,N,E)(L, N, E),其中 L 是目標序列長度,N 是批次大小,E 是嵌入維度。 如果 batch_firstTrue,則為 (N,L,E)(N, L, E)

  • attn_output_weights: 如果 average_attn_weights=True,則返回跨頭平均的注意力權重,形狀為 (N,L,S)(N, L, S),其中 N 是批次大小,L 是目標序列長度,S 是源序列長度。 如果 average_attn_weights=False,則返回每個頭的注意力權重,形狀為 (N,numheads,L,S)(N, num_heads, L, S)