評價此頁

PyTorch 自定義運算元#

建立時間: 2024年6月18日 | 最後更新: 2025年7月31日 | 最後驗證: 2024年11月5日

PyTorch 提供了大量可在 Tensor 上操作的運算元(例如 torch.add, torch.sum 等)。然而,您可能希望將新的自定義操作引入 PyTorch,並使其能夠與 torch.compile、autograd 和 torch.vmap 等子系統協同工作。要做到這一點,您必須透過 Python 的 torch.library 文件 或 C++ 的 TORCH_LIBRARY API 將自定義操作註冊到 PyTorch。

從 Python 編寫自定義運算元#

請參閱 自定義 Python 運算元

如果您希望將一個 Python 函式作為不透明的可呼叫物件由 PyTorch 處理,特別是在 torch.compiletorch.export 方面,那麼您可能希望從 Python 編寫自定義運算元(而不是 C++)。

  • 您有一個希望 PyTorch 將其視為不透明的可呼叫物件的 Python 函式,特別是在 torch.compiletorch.export 方面。

  • 您有一些到 C++/CUDA 核心的 Python 繫結,並希望這些繫結能夠與 PyTorch 子系統(如 torch.compiletorch.autograd)組合。

  • 您正在使用 Python(而不是 AOTInductor 等純 C++ 環境)。

將自定義 C++ 和/或 CUDA 程式碼整合到 PyTorch#

請參閱 自定義 C++ 和 CUDA 運算元

注意

SYCL 是 Intel GPU 的後端程式語言。整合自定義 Sycl 程式碼請參考 自定義 SYCL 運算元

如果您希望從 C++ 編寫自定義運算元(而不是 Python),那麼您可能希望這樣做,如果

  • 您有自定義的 C++ 和/或 CUDA 程式碼。

  • 您計劃使用此程式碼與 AOTInductor 進行無 Python 推理。

自定義運算元手冊#

有關教程和本頁面未涵蓋的資訊,請參閱 自定義運算元手冊(我們正在努力將資訊遷移到我們的文件網站)。我們建議您首先閱讀以上任一教程,然後使用自定義運算元手冊作為參考;它不適合從頭到尾閱讀。

何時應該建立自定義運算元?#

如果您的操作可以表示為內建 PyTorch 操作的組合,那麼請將其寫成一個 Python 函式並呼叫它,而不是建立自定義運算元。如果您正在呼叫 PyTorch 無法識別的某個庫(例如,自定義 C/C++ 程式碼、自定義 CUDA 核心或 C/C++/CUDA 擴充套件的 Python 繫結),請使用運算元註冊 API 來建立自定義運算元。

為什麼應該建立自定義運算元?#

可以透過獲取 Tensor 的資料指標並將其傳遞給 pybind 核心來使用 C/C++/CUDA 核心。但是,這種方法無法與 autograd、torch.compile、vmap 等 PyTorch 子系統組合。為了使一個操作能夠與 PyTorch 子系統組合,它必須透過運算元註冊 API 進行註冊。