評價此頁

使用區域編譯減少 AoT 冷啟動編譯時間#

作者: Sayak Paul, Charles Bensimon, Angela Yi

區域編譯教程 中,我們展示瞭如何在保留(幾乎)全部編譯優勢的同時減少冷啟動編譯時間。這已針對即時 (JIT) 編譯進行了演示。

本教程展示瞭如何在提前 (AoT) 編譯模型時應用類似的原理。如果您不熟悉 AOTInductor 和 torch.export,我們建議您檢視 本教程

先決條件#

  • Pytorch 2.6 或更高版本

  • 熟悉區域編譯

  • 熟悉 AOTInductor 和 torch.export

設定#

在開始之前,我們需要安裝 torch,如果尚未安裝。

pip install torch

步驟#

在本教程中,我們將遵循上述區域編譯教程中的相同步驟

  1. 匯入所有必要的庫。

  2. 定義並初始化具有重複區域的神經網路。

  3. 測量完整模型和使用 AoT 進行區域編譯的編譯時間。

首先,讓我們匯入載入資料所需的庫

import torch
torch.set_grad_enabled(False)

from time import perf_counter

定義神經網路#

我們將使用與區域編譯教程相同的神經網路結構。

我們將使用一個由重複層組成的網路。這模擬了一個大型語言模型,它通常由許多 Transformer 塊組成。在本教程中,我們將使用 nn.Module 類建立一個 Layer 作為重複區域的代理。然後,我們將建立一個由 64 個此類 Layer 例項組成的 Model

class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        a = self.linear1(x)
        a = self.relu1(a)
        a = torch.sigmoid(a)
        b = self.linear2(a)
        b = self.relu2(b)
        return b


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])

    def forward(self, x):
        # In regional compilation, the self.linear is outside of the scope of ``torch.compile``.
        x = self.linear(x)
        for layer in self.layers:
            x = layer(x)
        return x

提前編譯模型#

由於我們是提前編譯模型,因此需要準備代表性的輸入示例,我們期望模型在實際部署期間會看到這些示例。

讓我們建立一個 Model 例項,併為其提供一些樣本輸入資料。

model = Model().cuda()
input = torch.randn(10, 10, device="cuda")
output = model(input)
print(f"{output.shape=}")
output.shape=torch.Size([10, 10])

現在,讓我們提前編譯我們的模型。我們將使用上面建立的 input 傳遞給 torch.export。這將產生一個 torch.export.ExportedProgram,我們可以對其進行編譯。

/usr/local/lib/python3.10/dist-packages/torch/backends/cuda/__init__.py:131: UserWarning:

Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.com.tw/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)

/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:312: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

我們可以從該 path 載入並使用它來執行推理。

compiled_binary = torch._inductor.aoti_load_package(path)
output_compiled = compiled_binary(input)
print(f"{output_compiled.shape=}")
output_compiled.shape=torch.Size([10, 10])

提前編譯模型的 _區域_#

另一方面,提前編譯模型區域需要一些關鍵的更改。

由於計算模式被模型中所有重複的塊(在本例中為 Layer 例項)共享,因此我們可以僅編譯一個塊,然後讓 inductor 重用它。

model = Model().cuda()
path = torch._inductor.aoti_compile_and_package(
    torch.export.export(model.layers[0], args=(input,)),
    inductor_configs={
        # compile artifact w/o saving params in the artifact
        "aot_inductor.package_constants_in_so": False,
    }
)

匯出的程式(torch.export.ExportedProgram)包含張量計算、一個 state_dict,其中包含所有提升的引數和緩衝區的張量值以及其他元資料。我們將 aot_inductor.package_constants_in_so 設定為 False,以避免在生成的工件中序列化模型引數。

現在,在載入編譯後的二進位制檔案時,我們可以重用每個塊的現有引數。這使我們能夠利用上面獲得的編譯後的二進位制檔案。

for layer in model.layers:
    compiled_layer = torch._inductor.aoti_load_package(path)
    compiled_layer.load_constants(
        layer.state_dict(), check_full_update=True, user_managed=True
    )
    layer.forward = compiled_layer

output_regional_compiled = model(input)
print(f"{output_regional_compiled.shape=}")
output_regional_compiled.shape=torch.Size([10, 10])

與 JIT 區域編譯一樣,在模型內部提前編譯區域可以顯著減少冷啟動時間。實際數字會因模型而異。

儘管完整模型編譯提供了最廣泛的最佳化範圍,但出於實際目的,並且取決於模型型別,我們已經看到區域編譯(JIT 和 AoT)提供了相似的速度優勢,同時極大地減少了冷啟動時間。

測量編譯時間#

接下來,讓我們測量完整模型和區域編譯的編譯時間。

def measure_compile_time(input, regional=False):
    start = perf_counter()
    model = aot_compile_load_model(regional=regional)
    torch.cuda.synchronize()
    end = perf_counter()
    # make sure the model works.
    _ = model(input)
    return end - start

def aot_compile_load_model(regional=False) -> torch.nn.Module:
    input = torch.randn(10, 10, device="cuda")
    model = Model().cuda()

    inductor_configs = {}
    if regional:
        inductor_configs = {"aot_inductor.package_constants_in_so": False}

    # Reset the compiler caches to ensure no reuse between different runs
    torch.compiler.reset()
    with torch._inductor.utils.fresh_inductor_cache():
        path = torch._inductor.aoti_compile_and_package(
            torch.export.export(
                model.layers[0] if regional else model,
                args=(input,)
            ),
            inductor_configs=inductor_configs,
        )

        if regional:
            for layer in model.layers:
                compiled_layer = torch._inductor.aoti_load_package(path)
                compiled_layer.load_constants(
                    layer.state_dict(), check_full_update=True, user_managed=True
                )
                layer.forward = compiled_layer
        else:
            model = torch._inductor.aoti_load_package(path)
    return model

input = torch.randn(10, 10, device="cuda")
full_model_compilation_latency = measure_compile_time(input, regional=False)
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")

regional_compilation_latency = measure_compile_time(input, regional=True)
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")

assert regional_compilation_latency < full_model_compilation_latency
Full model compilation time = 11.46 seconds
Regional compilation time = 4.76 seconds

模型中也可能存在與編譯不相容的層。因此,完整編譯將導致計算圖碎片化,從而可能導致延遲下降。在這種情況下,區域編譯可能是有益的。

結論#

本教程展示瞭如何在提前編譯模型時控制冷啟動時間。當模型具有重複塊時,這會變得有效,這在大型生成模型中通常會看到。我們在各種模型上使用了此教程來加速即時效能。在此處瞭解更多資訊:here

指令碼總執行時間: (0 分 42.500 秒)