評價此頁

(Beta) 使用縮放點積注意力(SDPA)實現高效能 Transformer#

建立日期: 2023年3月15日 | 最後更新: 2024年10月09日 | 最後驗證: 2024年11月05日

作者: Driss Guessous

摘要#

在本教程中,我們想重點介紹一個有助於實現 Transformer 架構的新 torch.nn.functional 函式。該函式名為 torch.nn.functional.scaled_dot_product_attention。有關該函式的詳細描述,請參閱 PyTorch 文件。該函式已整合到 torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer 中。

概述#

總的來說,此 PyTorch 函式根據論文 Attention is all you need 中的定義,計算查詢(query)、鍵(key)和值(value)之間的縮放點積注意力(SDPA)。雖然可以使用現有的 PyTorch 函式來實現此功能,但融合實現可以比樸素實現帶來顯著的效能優勢。

融合實現#

對於 CUDA 張量輸入,該函式將分派到以下實現之一:

注意

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-0.1471,  0.0784, -0.0581, -0.5448,  0.0610, -0.4824,  0.0488,
          -0.4969],
         [-0.3822, -0.5073, -0.2710, -0.7289,  0.1801, -0.2160, -0.0845,
          -0.1191],
         [-0.3038, -0.4056, -0.3013, -0.4887,  0.2677, -0.2204, -0.0220,
          -0.0686]],

        [[ 0.1158, -0.4914,  1.2867, -0.2343,  0.2195, -0.3615,  0.2703,
          -1.0827],
         [ 0.1269, -0.4285,  1.2088, -0.3356,  0.1363, -0.2540,  0.3196,
          -0.9992],
         [ 0.1568, -0.4041,  1.2502, -0.3362,  0.1297, -0.2964,  0.3251,
          -0.9878]]], device='cuda:0')

顯式分派控制#

雖然函式會自動分派到三種實現之一,但使用者也可以透過使用上下文管理器來顯式控制分派。此上下文管理器允許使用者顯式停用某些實現。如果使用者想確保函式確實使用了最快的實現來處理其特定輸入,可以使用上下文管理器來遍歷並測量效能。

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2273.824 microseconds
The math implementation runs in 87451.818 microseconds
The flash attention implementation runs in 2281.270 microseconds
The memory efficient implementation runs in 4357.890 microseconds

硬體依賴性#

根據您執行上述單元格的機器以及可用的硬體,您的結果可能會有所不同。 - 如果您沒有 GPU 並且在 CPU 上執行,那麼對於 FP32,上下文管理器將不起作用,所有三次執行都應返回相似的時間。 - 根據您的顯示卡支援的計算能力,Flash Attention 或記憶體高效實現可能已失敗。

因果自注意力#

下面是一個多頭因果自注意力塊的示例實現,靈感來自 Andrej Karpathy 的 NanoGPT 倉庫。

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和密集張量支援#

SDPA 支援 NestedTensor 和密集張量輸入。 NestedTensors 可以處理輸入是可變長度序列批次的情況,而無需將每個序列填充到批次中的最大長度。有關 NestedTensors 的更多資訊,請參閱 torch.nestedNestedTensors 教程

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:250: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)

Random NT runs in 606.955 microseconds
Random Dense runs in 952.497 microseconds

將 SDPA 與 torch.compile 一起使用#

隨著 PyTorch 2.0 的釋出,引入了一個名為 torch.compile() 的新功能,它可以提供比即時模式顯著更好的效能。縮放點積注意力與 torch.compile() 完全可組合。為了演示這一點,讓我們使用 torch.compile() 編譯 CausalSelfAttention 模組,並觀察由此產生的效能改進。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  425.073 microseconds
The compiled module runs in  544.064 microseconds

確切的執行時間取決於機器,但我的結果是:未編譯的模組執行時間為 166.616 微秒,編譯後的模組執行時間為 166.726 微秒。這並非我們所期望的。讓我們深入研究一下。PyTorch 配備了一個出色的內建分析器,您可以使用它來檢查程式碼的效能特徵。

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention        16.90%       2.143ms        76.94%       9.754ms       9.754ms       0.000us         0.00%      10.836ms      10.836ms             1
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.734ms       101.14%      10.734ms      10.734ms             1
                                           aten::linear         1.05%     133.531us        35.29%       4.474ms      89.487us       0.000us         0.00%       8.012ms     160.232us            50
                                           aten::matmul         2.01%     254.573us        31.58%       4.003ms      80.066us       0.000us         0.00%       8.012ms     160.232us            50
                                               aten::mm         9.80%       1.242ms        27.24%       3.453ms      69.057us       7.789ms        73.39%       8.012ms     160.232us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.572ms        52.51%       5.572ms     222.894us            25
                     aten::scaled_dot_product_attention         1.63%     207.093us        15.32%       1.943ms      77.701us       0.000us         0.00%       2.824ms     112.966us            25
              aten::_scaled_dot_product_flash_attention         2.35%     298.135us        13.69%       1.735ms      69.417us       0.000us         0.00%       2.824ms     112.966us            25
                         aten::_flash_attention_forward         2.38%     301.874us         9.51%       1.206ms      48.225us       2.824ms        26.61%       2.824ms     112.966us            25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.824ms        26.61%       2.824ms     112.966us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 12.678ms
Self CUDA time total: 10.613ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
## Call CompiledFxGraph fi6oafvta3xbcesbp2mjppjhzijz...         0.00%       0.000us         0.00%       0.000us       0.000us      10.648ms       100.38%      10.648ms     425.919us            25
                              Compiled Causal Attention         7.05%     917.762us        86.03%      11.205ms      11.205ms       0.000us         0.00%      10.608ms      10.608ms             1
                             Torch-Compiled Region: 0/0         7.21%     938.613us        75.92%       9.888ms     395.521us       0.000us         0.00%      10.608ms     424.325us            25
                                       CompiledFunction         8.70%       1.133ms        66.38%       8.645ms     345.804us       0.000us         0.00%      10.608ms     424.325us            25
## Call CompiledFxGraph fi6oafvta3xbcesbp2mjppjhzijz...        19.62%       2.555ms        57.68%       7.513ms     300.503us       0.000us         0.00%      10.608ms     424.325us            25
                                               aten::mm         7.55%     982.923us        11.56%       1.506ms      30.115us       7.786ms        73.39%       7.786ms     155.715us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.570ms        52.51%       5.570ms     222.808us            25
              aten::_scaled_dot_product_flash_attention         1.79%     232.953us        12.95%       1.686ms      67.447us       0.000us         0.00%       2.822ms     112.896us            25
                         aten::_flash_attention_forward         2.40%     313.225us         8.96%       1.166ms      46.660us       2.822ms        26.61%       2.822ms     112.896us            25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.822ms        26.61%       2.822ms     112.896us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 13.024ms
Self CUDA time total: 10.608ms

前面的程式碼片段生成了一個報告,顯示了編譯和未編譯模組在 GPU 上消耗最多執行時間的 Top 10 PyTorch 函式。分析表明,對於兩個模組,大部分 GPU 時間都集中在相同的函式集上。此處的原因是 torch.compile 非常擅長消除與 PyTorch 相關的框架開銷。如果您的模型啟動了大型、高效的 CUDA 核心(在本例中 CausalSelfAttention 就是如此),那麼 PyTorch 的開銷可能就會被隱藏。

實際上,您的模組通常不只包含一個 CausalSelfAttention 塊。在試驗 Andrej Karpathy 的 NanoGPT 倉庫時,將模組編譯後,每次訓練步驟的時間從 6090.49ms 減少到 3273.17ms!這是在 NanoGPT 使用 Shakespeare 資料集進行訓練的提交 ae3a8d5 上完成的。

將 SDPA 與 attn_bias 子類一起使用#

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

結論#

在本教程中,我們演示了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我們展示瞭如何使用 sdpa_kernel 上下文管理器來斷言在 GPU 上使用了某種實現。此外,我們構建了一個簡單的 CausalSelfAttention 模組,它可以與 NestedTensor 一起工作並且可以進行 torch 編譯。在此過程中,我們展示瞭如何使用分析工具來探索使用者定義的模組的效能特徵。

指令碼總執行時間: (0 分 7.096 秒)