torch.nn.functional.scaled_dot_product_attention#
- torch.nn.functional.scaled_dot_product_attention()#
- scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> Tensor
計算 query、key 和 value 張量的縮放點積注意力,如果提供了 attention mask,則使用它,如果指定了大於 0.0 的機率,則應用 dropout。可選的 scale 引數只能作為關鍵字引數指定。
# Efficient implementation equivalent to the following: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias = attn_mask + attn_bias if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
警告
此函式處於 beta 階段,可能會發生更改。
警告
此函式始終根據指定的
dropout_p引數應用 dropout。要在評估期間停用 dropout,請確保在呼叫該函式的模組不在訓練模式下時將0.0傳遞給它。例如
class MyModel(nn.Module): def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, ...): return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
注意
目前支援三種縮放點積注意力實現:
C++ 實現的 PyTorch 版本,匹配上述公式
在使用 CUDA 後端時,該函式可能會呼叫最佳化核心以提高效能。對於所有其他後端,將使用 PyTorch 實現。
所有實現預設都已啟用。縮放點積注意力會嘗試根據輸入自動選擇最優的實現。為了對使用的實現進行更精細化的控制,提供了以下函式來啟用和停用實現。上下文管理器是首選機制。
torch.nn.attention.sdpa_kernel(): 一個上下文管理器,用於啟用或停用任何實現。torch.backends.cuda.enable_flash_sdp(): 全域性啟用或停用 FlashAttention。torch.backends.cuda.enable_mem_efficient_sdp(): 全域性啟用或停用記憶體高效注意力。torch.backends.cuda.enable_math_sdp(): 全域性啟用或停用 PyTorch C++ 實現。
每個融合核心都有特定的輸入限制。如果使用者需要使用特定的融合實現,請使用
torch.nn.attention.sdpa_kernel()停用 PyTorch C++ 實現。如果融合實現不可用,將引發警告,說明無法執行融合實現的原因。由於浮點運算融合的性質,此函式的輸出可能因選擇的後端核心而異。C++ 實現支援 torch.float64,在需要更高精度時可以使用。對於 math 後端,如果輸入為 torch.half 或 torch.bfloat16,則所有中間值都保留為 torch.float。
更多資訊請參閱 數值精度。
分組查詢注意力 (GQA) 是一項實驗性功能。目前它僅適用於 CUDA 張量上的 Flash_attention 和 math 核心,不支援 Nested tensor。GQA 的約束條件:
number_of_heads_query % number_of_heads_key_value == 0 且,
number_of_heads_key == number_of_heads_value
注意
在某些情況下,當在 CUDA 裝置上使用張量並利用 CuDNN 時,此運算元可能會選擇一個非確定性演算法來提高效能。如果這不可取,你可以嘗試將操作設定為確定性的(可能以效能為代價),方法是設定
torch.backends.cudnn.deterministic = True。有關更多資訊,請參閱 可復現性。- 引數
query (Tensor) – 查詢張量;形狀為 。
key (Tensor) – 鍵張量;形狀為 。
value (Tensor) – 值張量;形狀為 。
attn_mask (optional Tensor) – 注意力掩碼;形狀必須可廣播到注意力權重形狀,即 。支援兩種型別的掩碼。布林掩碼,其中 True 值表示該元素應該參與注意力。與 query、key、value 型別相同的浮點掩碼,該掩碼會加到注意力分數上。
dropout_p (float) – Dropout 機率;如果大於 0.0,則應用 dropout。
is_causal (bool) – 如果設定為 true,則在掩碼為方陣時,注意力掩碼為下三角矩陣。當掩碼為非方陣時,注意力掩碼的形式是由於對齊而產生的左上因果偏差(參見
torch.nn.attention.bias.CausalBias)。如果同時設定了 attn_mask 和 is_causal,則會引發錯誤。scale (optional python:float, keyword-only) – Softmax 之前應用的縮放因子。如果為 None,則預設值為 。
enable_gqa (bool) – 如果設定為 True,則啟用分組查詢注意力 (GQA),預設設定為 False。
- 返回
注意力輸出;形狀為 。
- 返回型別
output (Tensor)
- 形狀說明
示例
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): >>> F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3 >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with sdpa_kernel(backends=[SDPBackend.MATH]): >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)