評價此頁

運算元註冊#

創建於:2025年8月27日 | 最後更新於:2025年9月2日

對於新的加速器,整合中最重要和最基本方面之一就是支援高效能運算元。為了方便使用者和加速器開發者進行運算元適配,PyTorch 提供了多種方法來在 PythonC++ 中開發和註冊運算元。以下各節詳細介紹了 PyTorch 在運算元註冊方面的一些基本功能。

注意

Dispatch Key 用於唯一標識 PyTorch 中的加速器,例如 CPUCUDAMPSPrivateUse1。理論上,所有後續的新加速器都將共享 PrivateUse1,利用其內建的全面腳手架能力來完成新加速器的整合。如果您對 dispatcher 感興趣,請參閱 Let’s talk about the PyTorch dispatcher

運算元集#

PyTorch 目前擁有超過 3500 個內建運算元(包括相關的運算元變體)。這無論從哪個角度來看都是一項巨大的工作量,而且在短時間內支援如此龐大的運算元數量絕非易事。因此,作為開發新後端運算元的第一步,我們的目標應該是專注於核心運算元。對於其他運算元,我們可以首先使用社群的 fallback 機制作為優先事項來支援該功能。之後,我們可以逐步完成其他運算元,以提高新後端的效能。

所需的運算元集列在下面,主要包括工廠函式所需的低階運算元和 fallback 運算元。

運算元名稱

Dispatch Key

描述

empty.memory_format

PrivateUse1

使用指定的形狀和記憶體佈局(步幅自動計算)建立未初始化的 Tensor。

empty_strided

PrivateUse1

建立具有指定形狀和步幅的未初始化 Tensor(具有更大的自由度)。

as_strided

PrivateUse1

使用新的形狀、步幅和偏移量建立輸入 Tensor 的共享檢視(無需分配新記憶體)。

view

PrivateUse1

建立具有新形狀的輸入 Tensor 的共享檢視,但原始 Tensor 必須是記憶體連續的。

_reshape_alias

PrivateUse1

建立無安全檢查的共享檢視(reshape 的內部版本)。

resize_

PrivateUse1

就地修改 Tensor 的形狀,並在容量不足時重新分配記憶體。

_copy_from

PrivateUse1

Tensor.copy_ 的底層核心函式,負責實際的跨裝置資料複製。

_copy_from_and_resize

PrivateUse1

結合 resize__copy_from,先調整大小再複製。

_local_scalar_dense

PrivateUse1

.item() 的底層實現,將 Tensor 中的值提取為 CPU 標量。

set_.source_Tensor

PrivateUse1

使用指定的 Tensor 設定當前 Tensor。

set_.source_Storage

PrivateUse1

使用指定的 Storage 設定當前 Tensor。

set_.source_Storage_storage_offset

PrivateUse1

使用指定的 Storage 和儲存偏移量設定當前 Tensor。

fallback

PrivateUse1

回退到 CPU。

基礎#

現在我們已經定義了運算元支援的初始範圍,我們可以開始開發運算元適配。本節將根據實際場景,在 PythonC++ 中解釋這些實現。

第一步#

上面提到的運算元 具有一個共同點:它們是內建的 PyTorch 運算元,具有定義的 名稱空間Schema,並且這些運算元的內建加速器(CPUCUDA 等)已經實現。我們接下來要做的是為新加速器實現這些運算元。

 1at::Tensor empty_memory_format(
 2    c10::IntArrayRef size,
 3    std::optional<c10::ScalarType> dtype_opt,
 4    std::optional<c10::Layout> layout_opt,
 5    std::optional<c10::Device> device_opt,
 6    std::optional<bool> pin_memory_opt,
 7    std::optional<c10::MemoryFormat> memory_format_opt) {
 8  const auto device = c10::device_or_default(device_opt);
 9  const auto dtype = c10::dtype_or_default(dtype_opt);
10  TORCH_CHECK(device.is_privateuseone());
11  TORCH_CHECK(
12      c10::layout_or_default(layout_opt) == c10::Layout::Strided,
13      "Non strided layout not supported");
14  TORCH_CHECK(
15      !c10::pinned_memory_or_default(pin_memory_opt),
16      "Pin memory can only be on CPU");
17  const c10::DeviceGuard device_guard(device);
18  constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
19  auto allocator = at::GetAllocator(at::kPrivateUse1);
20  return at::detail::empty_generic(
21      size, allocator, pu1_dks, dtype, memory_format_opt);
22}
 1at::Tensor wrapper_empty_memory_format(
 2    c10::IntArrayRef size,
 3    std::optional<c10::ScalarType> dtype_opt,
 4    std::optional<c10::Layout> layout_opt,
 5    std::optional<c10::Device> device_opt,
 6    std::optional<bool> pin_memory_opt,
 7    std::optional<c10::MemoryFormat> memory_format_opt) {
 8  return at::native::openreg::empty_memory_format(
 9      size,
10      dtype_opt,
11      layout_opt,
12      device_opt,
13      pin_memory_opt,
14      memory_format_opt);
15}

empty.memory_format 運算元為例,我們首先需要在 native_functions.yaml 中查詢運算元的 schema 資訊,其中包含詳細的簽名信息。然後,我們可以根據新加速器的能力來實現該運算元。

- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
    CPU: empty_cpu
    CUDA: empty_cuda
    ...
 1TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
 2  m.impl("empty.memory_format", wrapper_empty_memory_format);
 3  m.impl("empty_strided", wrapper_empty_strided);
 4  m.impl("as_strided", wrapper_as_strided);
 5  m.impl("resize_", wrapper_resize_);
 6  m.impl("_reshape_alias", wrapper__reshape_alias);
 7  m.impl("_copy_from", wrapper__copy_from);
 8  m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize);
 9  m.impl("_local_scalar_dense", wrapper__local_scalar_densor);
10  m.impl("set_.source_Tensor", wrapper_set_source_Tensor_);
11  m.impl("set_.source_Storage", wrapper_set_source_Storage_);
12  m.impl(
13      "set_.source_Storage_storage_offset",
14      wrapper_set_source_Storage_storage_offsetset_);
15  m.impl("view", wrapper_view);
16}

完成 wrapper_empty_memory_format 後,我們可以透過 TORCH_LIBRARY_IMPLPrivateUse1 註冊 aten::empty.memory_format

第二步#

按照 第一步,我們可以完成除 fallback 之外所有運算元的開發和註冊。接下來,為了支援與運算相關的運算元(例如數學運算和卷積運算),我們需要實現 fallback 語義的註冊。這是 PyTorch 框架提供的內建功能,可以將新加速器不支援的某些運算回退到 CPU 執行。對於正在開發的新後端,這是確保功能性的極其有效的方法,但會犧牲效能。

 1void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
 2  static const std::unordered_set<c10::OperatorName> cpu_fallback_blocklist = {
 3      c10::OperatorName("aten::abs", ""),
 4      c10::OperatorName("aten::abs", "out"),
 5  };
 6
 7  const auto& op_name = op.schema().operator_name();
 8  if (cpu_fallback_blocklist.count(op_name)) {
 9    TORCH_CHECK(
10        false,
11        "Operator '",
12        op_name,
13        "' is not implemented for device openreg.");
14  } else {
15    at::native::cpu_fallback(op, stack);
16  }
17}
1void wrapper_cpu_fallback(
2    const c10::OperatorHandle& op,
3    torch::jit::Stack* stack) {
4  at::native::openreg::cpu_fallback(op, stack);
5}
1TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
2  m.fallback(
3      torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>());
4}

wrapper_cpu_fallback 封裝了 PyTorch 提供的 at::native::cpu_fallback 方法,並透過 TORCH_LIBRARY_IMPL 在 PyTorch 中註冊到 PrivateUse1。後續新後端不支援的操作將自動回退到 CPU 執行,執行完成後結果將傳回新後端。

高階#

選擇性回退#

僅為某些運算元啟用回退機制,而其他運算元則遵循 PyTorch 的預設行為(如果加速器沒有相應的運算元實現,則會報錯),這是一種非常合理的場景。

1void wrapper_cpu_fallback(
2    const c10::OperatorHandle& op,
3    torch::jit::Stack* stack) {
4  at::native::openreg::cpu_fallback(op, stack);
5}
1TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
2  m.impl(
3      "sub.Tensor",
4      torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>());
5}

每個運算元的回退與全域性回退非常相似,唯一的區別在於註冊方法:呼叫 m.impl 為特定運算元註冊實現,而 m.fallback 為所有運算元註冊預設實現。

 1void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
 2  static const std::unordered_set<c10::OperatorName> cpu_fallback_blocklist = {
 3      c10::OperatorName("aten::abs", ""),
 4      c10::OperatorName("aten::abs", "out"),
 5  };
 6
 7  const auto& op_name = op.schema().operator_name();
 8  if (cpu_fallback_blocklist.count(op_name)) {
 9    TORCH_CHECK(
10        false,
11        "Operator '",
12        op_name,
13        "' is not implemented for device openreg.");
14  } else {
15    at::native::cpu_fallback(op, stack);
16  }
17}

當然,全域性回退也可以與回退黑名單結合使用,這是一種常見的方法,尤其是在只有少數運算元不支援回退的情況下。

PyTorch STUB#

PyTorch 還為內建運算元提供了另一種方法:STUB。此方法本質上基於 第一步<step-one> 方法,但增加了二次排程功能(例如,基於 CPU 特性的排程)。

注意

STUB 方法目前僅支援有限的運算元集。對於新的加速器裝置,STUB 方法的優勢在於它以少量的效能開銷為代價,大大降低了開發成本。PyTorch 目前沒有明確列出可以透過 STUB 註冊的運算元集。由於相關運算元數量龐大,此處僅提供支援運算元列表的查詢方法。

pushd ${TORCH_ROOT}

find aten -type f -a -name "*.h" | xargs -I {} grep -wl "^DECLARE_DISPATCH" {}

popd

DECLARE_DISPATCH 是一個宏,用於顯式宣告 STUB。它目前分佈在 aten 目錄中。基於此宏,您可以找到所有可以使用 STUB 方法整合的運算元。

...
aten/src/ATen/native/Activation.h
aten/src/ATen/native/FusedSGD.h
aten/src/ATen/native/nested/NestedTensorBinaryOps.h
aten/src/ATen/native/TensorCompare.h
aten/src/ATen/native/Sorting.h
...
using unary_fn = void(*)(TensorIteratorBase&);

DECLARE_DISPATCH(unary_fn, abs_stub)

上面的列表包含宣告 STUB 運算元的檔案,您可以在其中清楚地看到 STUB 名稱和相關的函式簽名。接下來,我們將以 abs_stub 為例,簡要介紹透過 STUB 支援運算元的路徑。

 1void abs_kernel(at::TensorIteratorBase& iter) {
 2  TORCH_CHECK(iter.ntensors() == 2, "Abs kernel expects 2 tensors");
 3  TORCH_CHECK(
 4      iter.common_dtype() == at::ScalarType::Float,
 5      "Abs kernel only supports float type");
 6
 7  auto& output_tensor = iter.tensor(0);
 8  auto& input_tensor = iter.tensor(1);
 9
10  TORCH_CHECK(
11      input_tensor.sizes() == output_tensor.sizes(),
12      "Input and output tensor sizes must match.");
13
14  auto abs_loop = [](float* out_ptr, const float* in_ptr, int64_t n) {
15    for (int64_t i = 0; i < n; ++i) {
16      out_ptr[i] = std::abs(in_ptr[i]);
17    }
18  };
19
20  MemoryGuard guard(input_tensor, output_tensor);
21
22  if (iter.is_contiguous()) {
23    abs_loop(
24        static_cast<float*>(iter.data_ptr(0)),
25        static_cast<float*>(iter.data_ptr(1)),
26        iter.numel());
27  } else {
28    TORCH_CHECK(
29        input_tensor.is_contiguous(), "Input tensor must be contiguous.")
30
31    auto output = at::empty(
32        input_tensor.sizes(),
33        input_tensor.options().memory_format(
34            input_tensor.suggest_memory_format()));
35
36    MemoryGuard guard(output);
37
38    abs_loop(
39        static_cast<float*>(output.data_ptr()),
40        static_cast<float*>(iter.data_ptr(1)),
41        iter.numel());
42
43    output_tensor.copy_(output);
44  }
45}
1REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &wrapper_abs_stub);
2REGISTER_PRIVATEUSE1_DISPATCH(
3    quantize_tensor_per_tensor_affine_stub,
4    &wrapper_quantize_tensor_per_tensor_affine_stub);
5REGISTER_PRIVATEUSE1_DISPATCH(
6    _fused_sdp_choice_stub,
7    &wrapper__fused_sdp_choice);

從簽名可以看出,abs_stub 的輸入是 TensorIteratorBase,這是 PyTorch 提供的一個強大的輔助類,包含所有輸入和輸出運算元以及一些其他輔助方法。基於此,我們可以開發 abs_kernel 運算元,然後呼叫 REGISTER_PRIVATEUSE1_DISPATCH 來指定 abs_stub 以完成註冊。

自定義運算元#

除了 PyTorch 的內建運算元,自定義加速器運算元在提高特定場景下的效能方面也非常常見。這些可以分為三種主要方法:

  • 僅前向傳播

  • 前向和後向傳播:分開註冊

  • 前向和後向傳播:使用 torch.autograd.Function 實現

注意

PyTorch 教程中有更多細節,如果您有興趣,請參考 PyTorch Custom Operators

僅前向傳播#

這裡,我們將簡要介紹自定義運算元的實現過程,重點關注僅前向傳播的方法。實現可以概括為以下三點:

  1. 定義 Schema

    1TORCH_LIBRARY(openreg, m) {
    2  m.def("custom_abs(Tensor input)-> Tensor");
    3}
    
    • 名稱空間名稱:openreg

    • 函式名稱:custom_abs

    • 輸入引數

      • 型別:Tensor

      • 名稱:input

    • 輸出型別:Tensor

  2. 註冊運算元與自動求導回退

    1TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) {
    2  m.impl("custom_abs", &wrapper_custom_abs);
    3}
    
    1TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
    2  m.fallback(torch::autograd::autogradNotImplementedFallback());
    3}
    

    使用 TORCH_LIBRARY_IMPLPrivateUse1 中的 custom_abs 運算元註冊 wrapper_custom_abs 實現。但是,由於 PyTorch 中始終啟用 Autograd,即使只需要前向計算,PyTorch 也會預設查詢並執行相應的後向實現(將在後向實現中 fallthrough)。因此,我們還需要為 custom_abs 運算元的 AutogradPrivateUse1 註冊相應的實現。幸運的是,PyTorch 還提供了一個通用的 Autograd Fallback 機制,名為 torch::autograd::autogradNotImplementedFallback,如果僅涉及前向計算,它相當於一個 fallthrough 操作,選擇下一個 DispatchKey 進行計算;如果涉及後向計算,則會丟擲錯誤。

  3. 註冊元資料(可選,但圖模式等需要)

    1lib = torch.library.Library("openreg", "IMPL", "Meta")  # noqa: TOR901
    2
    3
    4@torch.library.impl(lib, "custom_abs")
    5def custom_abs(self):
    6    return torch.empty_like(self)
    7
    8
    

    PyTorch 支援在 C++ 和 Python 中註冊 Meta。由於 Python 註冊更簡單,因此此處以 Python 為例。與 C++ 中的 TORCH_LIBRARY_IMPL 函式類似,Python 提供了更友好的 torch.library.impl 裝飾器。

工具#

PyTorch 中的運算元註冊很複雜,註冊方法多樣,場景眾多。因此,PyTorch 社群提供了一些工具來幫助開發者快速理解底層原理並協助故障排除。在此我們簡要介紹幾種常用工具:

命令#

PyTorch 提供了圍繞其 Dispatch 功能的一系列以 torch._C._dispatch_ 開頭的命令。您可以使用以下命令查詢所有相關的介面。

python -c 'import torch; print("\n".join([x for x in dir(torch._C) if x.startswith("_dispatch_")]))'

...
_dispatch_dump
_dispatch_dump_table
_dispatch_has_kernel
_dispatch_has_kernel_for_any_dispatch_key
_dispatch_has_kernel_for_dispatch_key
_dispatch_isTensorSubclassLike
_dispatch_is_alias_key
_dispatch_is_included_in_alias
_dispatch_is_main_interpreter
_dispatch_kernel_for_dispatch_key_is_fallthrough
_dispatch_key_for_device
_dispatch_key_name
_dispatch_key_parse
_dispatch_key_set
...

以下是幾個常用命令的解釋:

  • torch._C._dispatch_key_set:

    顯示當前 Tensor 的 DispatchKey,優先順序從左到右遞增。

    >>> import torch
    >>> a = torch.randn(3,3,device="cuda")
    >>> torch._C._dispatch_key_set(a)
    'DispatchKeySet(CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA)'
    
  • torch._C._dispatch_dump_table:

    查詢給定運算元在不同 Dispatch Keys 下的支援狀態,方便定位對應的實現程式碼。

    >>> import torch
    >>> print(torch._C._dispatch_dump_table("aten::add.Tensor"))
    >>> ...
        CPU: registered at ./build/aten/src/ATen/RegisterCPU_0.cpp:1309 [kernel]
        CUDA: registered at ./build/aten/src/ATen/RegisterCUDA_0.cpp:2420 [kernel]
        HIP: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        MPS: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        IPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        XPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        HPU: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        VE: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        MTIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        MAIA: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        PrivateUse1: registered at ./build/aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp:1373 [default backend kernel]
        ...
    

    您可以輕鬆地查詢其他平臺上 aten::add.Tensor 運算元的對應實現,從而可以從原始碼級別跟蹤整個運算元呼叫過程。

環境變數#

PyTorch 還提供了一些與 dispatcher 相關的環境變數,有助於學習和快速定位問題。

  • TORCH_SHOW_DISPATCH_TRACE

    顯示 PyTorch 執行過程中的詳細內部 dispatch key 排程。

    export TORCH_SHOW_DISPATCH_TRACE=1
    
    >>> import torch
    >>> a = torch.randn(3,3)
     [call] op=[aten::randn], key=[BackendSelect]
       [redispatch] op=[aten::randn], key=[CPU]
         [call] op=[aten::empty.memory_format], key=[BackendSelect]
           [redispatch] op=[aten::empty.memory_format], key=[CPU]
         [call] op=[aten::normal_], key=[CPU]
    

    您可以清楚地看到 PyTorch 中 Python 級運算元呼叫的所有底層運算元:包括運算元名稱、呼叫層級以及對應的 Dispatch Key