torch.nn.attention.flex_attention#
創建於: 2024年7月16日 | 最後更新於: 2025年9月8日
- torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None, *, return_aux=None)[source]#
該函式實現了具有任意注意力分數修改函式的縮放點積注意力。
該函式在查詢、鍵和值張量之間計算縮放點積注意力,並使用使用者定義的注意力分數修改函式。注意力分數修改函式將在查詢和鍵張量之間的注意力分數計算完成後應用。注意力分數的計算方式如下:
score_mod函式應具有以下簽名:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- 其中
score:一個標量張量,表示注意力分數,其資料型別和裝置與查詢、鍵和值張量相同。batch、head、q_idx、k_idx:標量張量,分別指示批次索引、查詢頭索引、查詢索引和鍵/值索引。這些應具有torch.int資料型別,並位於與分數張量相同的裝置上。
- 引數
query (Tensor) – 查詢張量;形狀為 。對於 FP8 資料型別,應採用行主記憶體佈局以獲得最佳效能。
key (Tensor) – 鍵張量;形狀為 。對於 FP8 資料型別,應採用行主記憶體佈局以獲得最佳效能。
value (Tensor) – 值張量;形狀為 。對於 FP8 資料型別,應採用列主記憶體佈局以獲得最佳效能。
score_mod (Optional[Callable]) – 用於修改注意力分數的函式。預設情況下,不應用 score_mod。
block_mask (Optional[BlockMask]) – BlockMask 物件,用於控制注意力的塊稀疏性模式。
scale (Optional[float]) – 在 softmax 之前應用的縮放因子。如果為 None,則預設值為 。
enable_gqa (bool) – 如果設定為 True,則啟用分組查詢注意力(GQA)並向查詢頭廣播鍵/值頭。
return_lse (bool) – 是否返回注意力分數的對數和(logsumexp)。預設為 False。已棄用:請改用
return_aux=AuxRequest(lse=True)。kernel_options (Optional[FlexKernelOptions]) – 用於控制底層 Triton 核心行為的選項。有關可用選項和用法示例,請參閱
FlexKernelOptions。return_aux (Optional[AuxRequest]) – 指定要計算和返回的輔助輸出。如果為 None,則只返回注意力輸出。使用
AuxRequest(lse=True, max_scores=True)來請求兩個輔助輸出。
- 返回
注意力輸出;形狀為 。
- 當
return_aux不為 None 時 aux (AuxOutput): 包含已請求欄位的輔助輸出。
- 當
return_aux為 None 時(已棄用路徑) lse (Tensor): 注意力分數的對數和;形狀為 。僅當
return_lse=True時返回。
- 當
- 返回型別
output (Tensor)
- 形狀說明
警告
torch.nn.attention.flex_attention 是 PyTorch 中的一個原型功能。請期待 PyTorch 未來版本中更穩定的實現。有關功能分類的更多資訊,請訪問:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
- class torch.nn.attention.flex_attention.AuxOutput(lse=None, max_scores=None)[source]#
flex_attention 操作的輔助輸出。
如果未請求,欄位將為 None;如果已請求,則包含張量。
- class torch.nn.attention.flex_attention.AuxRequest(lse=False, max_scores=False)[source]#
請求從 flex_attention 計算哪些輔助輸出。
每個欄位都是一個布林值,指示是否應計算該輔助輸出。
BlockMask 工具#
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[source]#
此函式從 mask_mod 函式建立塊掩碼元組。
- 引數
mask_mod (Callable) – mask_mod 函式。這是一個可呼叫物件,用於定義注意力機制的掩碼模式。它接受四個引數:b(批次大小)、h(頭數)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。它應返回一個布林張量,指示哪些注意力連線是允許的(True)或被掩碼掉的(False)。
B (int) – 批次大小。
H (int) – 查詢頭數。
Q_LEN (int) – 查詢的序列長度。
KV_LEN (int) – 鍵/值的序列長度。
device (str) – 用於執行掩碼建立的裝置。
BLOCK_SIZE (int 或 tuple[int, int]) – 塊掩碼的塊大小。如果提供單個整數,則同時用於查詢和鍵/值。
- 返回
一個 BlockMask 物件,其中包含塊掩碼資訊。
- 返回型別
- 示例用法
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[source]#
此函式從 mod_fn 函式建立掩碼張量。
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[source]#
此函式從 mask_mod 函式建立與巢狀張量相容的塊掩碼元組。返回的 BlockMask 將位於輸入巢狀張量指定的裝置上。
- 引數
mask_mod (Callable) – mask_mod 函式。這是一個可呼叫物件,用於定義注意力機制的掩碼模式。它接受四個引數:b(批次大小)、h(頭數)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。它應返回一個布林張量,指示哪些注意力連線是允許的(True)或被掩碼掉的(False)。
B (int) – 批次大小。
H (int) – 查詢頭數。
q_nt (torch.Tensor) – 鋸齒狀佈局巢狀張量(NJT),用於定義查詢的序列長度結構。塊掩碼將構造為作用於 NJT 中序列長度
S的“堆疊序列”的長度sum(S)。kv_nt (torch.Tensor) – 鋸齒狀佈局巢狀張量(NJT),用於定義鍵/值的序列長度結構,允許交叉注意力。塊掩碼將構造為作用於 NJT 中序列長度
S的“堆疊序列”的長度sum(S)。如果此引數為 None,則q_nt也將用於定義鍵/值的結構。預設為 NoneBLOCK_SIZE (int 或 tuple[int, int]) – 塊掩碼的塊大小。如果提供單個整數,則同時用於查詢和鍵/值。
- 返回
一個 BlockMask 物件,其中包含塊掩碼資訊。
- 返回型別
- 示例用法
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask( causal_mask, 1, 1, query, _compile=True ) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask( causal_mask, 1, 1, query, key, _compile=True ) output = flex_attention(query, key, value, block_mask=block_mask)
FlexKernelOptions#
- class torch.nn.attention.flex_attention.FlexKernelOptions[source]#
FlexAttention 核心的行為控制選項。
這些選項將傳遞給底層 Triton 核心,以控制性能和數值行為。大多數使用者不需要指定這些選項,因為預設的自動調整提供了良好的效能。
選項可以加上
fwd_或bwd_字首,以便分別僅應用於前向或後向傳遞。例如:fwd_BLOCK_M和bwd_BLOCK_M1。注意
目前我們不為這些選項提供任何向後相容性保證。儘管如此,自引入以來,其中大部分選項都相當穩定。但我們暫時不認為這是公共 API 的一部分。我們認為文件比隱藏的秘密標誌更好,但我們將來可能會更改這些選項。
- 示例用法
# Using dictionary (backward compatible) kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True} output = flex_attention(q, k, v, kernel_options=kernel_opts) # Using TypedDict (recommended for type safety) from torch.nn.attention.flex_attention import FlexKernelOptions kernel_opts: FlexKernelOptions = { "BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True, } output = flex_attention(q, k, v, kernel_options=kernel_opts) # Forward/backward specific options kernel_opts: FlexKernelOptions = { "fwd_BLOCK_M": 64, "bwd_BLOCK_M1": 32, "PRESCALE_QK": False, } output = flex_attention(q, k, v, kernel_options=kernel_opts)
- BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]#
如果為 True,則保證掩碼中的所有塊都是連續的。允許最佳化塊遍歷。例如,因果掩碼會滿足此條件,但字首 LM + 滑動視窗則不會。預設為 False。
- FORCE_USE_FLEX_ATTENTION: NotRequired[bool]#
如果為 True,則強制使用 flex attention 核心,而不是可能為短序列使用更最佳化的 flex-decoding 核心。這對於除錯來說是一個有用的選項。預設為 False。
- ROWS_GUARANTEED_SAFE: NotRequired[bool]#
如果為 True,則保證每行至少有一個值未被掩碼掉。允許跳過安全檢查以獲得更好的效能。只有當您確定掩碼保證此屬性時才設定此項。例如,因果注意力被保證是安全的,因為每個查詢至少有 1 個鍵-值可以關注。預設為 False。
- USE_TMA: NotRequired[bool]#
是否在支援的硬體上使用 Tensor Memory Accelerator (TMA)。這處於實驗階段,可能無法在所有硬體上執行,目前僅限於 NVIDIA GPU Hopper+。預設為 False。
BlockMask#
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source]#
BlockMask 是我們用於表示塊稀疏注意力掩碼的格式。它在某種程度上介於 BCSR 和非稀疏格式之間。
基礎知識
塊稀疏掩碼意味著,與其表示掩碼中單個元素的稀疏性,不如將 KV_BLOCK_SIZE x Q_BLOCK_SIZE 塊視為稀疏,僅當該塊內的每個元素都稀疏時。這與硬體的期望非常吻合,硬體通常期望進行連續的載入和計算。
此格式主要針對 1. 簡單性;2. 核心效率進行了最佳化。值得注意的是,它 *不* 針對大小進行最佳化,因為此掩碼的大小總是除以 KV_BLOCK_SIZE * Q_BLOCK_SIZE。如果大小是問題,可以透過增加塊大小來減小張量的大小。
我們格式的關鍵點是:
num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的塊數。
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:
col_indices[i]是第 i 行的塊位置序列。此行中col_indices[i][num_blocks_in_row[i]]之後的值未定義。例如,要從該格式中恢復原始張量:
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
值得注意的是,此格式使得沿著掩碼的*行*進行歸約操作更容易。
詳細資訊
我們格式的基本要求是僅 kv_num_blocks 和 kv_indices。但是,我們在此物件上有多達 8 個張量。這代表 4 對:
1. (kv_num_blocks, kv_indices): 用於注意力的前向傳遞,因為我們沿著 KV 維度進行歸約。
2. [可選] (full_kv_num_blocks, full_kv_indices): 這是可選的,純粹是為了最佳化。事實證明,對每個塊應用掩碼成本很高!如果我們特別知道哪些塊是“完整的”並且不需要應用掩碼,那麼我們可以跳過將 mask_mod 應用於這些塊。這要求使用者將 mask_mod 分開,而不是從 score_mod 中分離。對於因果掩碼,這可以帶來約 15% 的速度提升。
3. [生成] (q_num_blocks, q_indices): 後向傳遞需要,因為計算 dKV 需要沿著 Q 維度沿掩碼進行迭代。這些是根據 1 自動生成的。
4. [生成] (full_q_num_blocks, full_q_indices): 與上面相同,但用於後向傳遞。這些是根據 2 自動生成的。
- as_tuple(flatten=True)[source]#
返回 BlockMask 屬性的元組。
- 引數
flatten (bool) – 如果為 True,則將 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 的元組展平。
- classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None, compute_q_blocks=True)[source]#
從鍵值塊資訊建立 BlockMask 例項。
- 引數
kv_num_blocks (Tensor) – 每個 Q_BLOCK_SIZE 行塊的 kv_blocks 數量。
kv_indices (Tensor) – 每個 Q_BLOCK_SIZE 行塊的鍵值塊索引。
full_kv_num_blocks (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行塊中的完整 kv_blocks 數量。
full_kv_indices (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行塊中的完整鍵值塊索引。
BLOCK_SIZE (Union[int, tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 塊的大小。
mask_mod (Optional[Callable]) – 用於修改掩碼的函式。
- 返回
透過 _transposed_ordered 生成完整 Q 資訊的例項。
- 返回型別
- 引發
RuntimeError – 如果 kv_indices 的維度小於 2。
AssertionError – 如果只提供 full_kv_* 引數中的一個。
- property shape#
- to(device)[source]#
將 BlockMask 移動到指定的裝置。
- 引數
device (torch.device 或 str) – 要將 BlockMask 移動到的目標裝置。可以是 torch.device 物件或字串(例如,‘cpu’、‘cuda:0’)。
- 返回
一個將所有張量元件移動到指定裝置的新 BlockMask 例項。
- 返回型別
注意
此方法不會就地修改原始 BlockMask。相反,它返回一個新的 BlockMask 例項,其中各個張量屬性可能會或可能不會移動到指定裝置,具體取決於它們當前的裝置放置。