Torch 庫 API#
PyTorch C++ API 提供了擴充套件 PyTorch 核心運算元庫以支援使用者自定義運算元和資料型別的能力。使用 Torch Library API 實現的擴充套件可以同時用於 PyTorch 的 eager API 和 TorchScript。
要了解該庫 API 的教程風格介紹,請參閱 使用自定義 C++ 運算元擴充套件 TorchScript 教程。
宏#
-
TORCH_LIBRARY(ns, m) static void TORCH_LIBRARY_init_##ns(torch::Library
&); \
static const
torch::detail::TorchLibraryInitTORCH_LIBRARY_static_init_##ns(
torch::Library::DEF, \
&TORCH_LIBRARY_init_##ns, \
#ns, \
std::nullopt, \
__FILE__, \
__LINE__); \
void TORCH_LIBRARY_init_##ns(
torch::Library& m) 用於定義一個將在靜態初始化時執行的函式,以在名稱空間
ns(必須是有效的 C++ 識別符號,不帶引號)中定義運算元庫。當您想定義一組 PyTorch 中尚不存在的新的自定義運算元時,請使用此宏。
使用示例
TORCH_LIBRARY(myops, m) { // m is a torch::Library; methods on it will define // operators in the myops namespace m.def("add", add_impl); }
m引數繫結到一個 torch::Library 物件,該物件用於註冊運算元。對於任何給定的名稱空間,只能有一個 TORCH_LIBRARY()。
-
TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)
用於定義一個將在靜態初始化時執行的函式,以在名稱空間
ns(必須是有效的 C++ 識別符號,不帶引號)中為排程鍵k(必須是 c10::DispatchKey 的未限定列舉成員)定義運算元過載。當您想為預先存在的自定義運算元集提供新的排程鍵實現時(例如,您想為已存在的運算元提供 CUDA 實現),請使用此宏。一種常見的用法是使用 TORCH_LIBRARY() 來定義您想定義的所有新運算元的 schema,然後使用多個 TORCH_LIBRARY_IMPL() 塊來為 CPU、CUDA 和 Autograd 提供運算元實現。
在某些情況下,您需要定義適用於所有名稱空間(而不僅僅是單個名稱空間)的內容(通常是回退)。在這種情況下,請使用保留的名稱空間 _,例如:
TORCH_LIBRARY_IMPL(_, XLA, m) { m.fallback(xla_fallback); }
使用示例
TORCH_LIBRARY_IMPL(myops, CPU, m) { // m is a torch::Library; methods on it will define // CPU implementations of operators in the myops namespace. // It is NOT valid to call torch::Library::def() // in this context. m.impl("add", add_cpu_impl); }
如果
add_cpu_impl是一個過載函式,請使用static_cast來指定您想要的過載(透過提供完整型別)。
類#
-
class Library
此物件提供了用於定義運算元和在排程鍵處提供實現的 API。
通常,torch::Library 物件不會被直接分配;而是由 TORCH_LIBRARY() 或 TORCH_LIBRARY_IMPL() 宏建立。
torch::Library 的大多數方法都返回其自身的引用,支援方法鏈式呼叫。
// Examples: TORCH_LIBRARY(torchvision, m) { // m is a torch::Library m.def("roi_align", ...); ... } TORCH_LIBRARY_IMPL(aten, XLA, m) { // m is a torch::Library m.impl("add", ...); ... }
公共函式
-
Library(const Library&) = delete
-
Library(Library&&) = default
-
~Library() = default
-
inline Library &def(c10::FunctionSchema &&s, const std::vector<at::Tag> &tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &
宣告一個運算元,但不提供任何實現。
您需要稍後使用 impl() 方法提供實現。所有模板引數都將自動推斷。
// Example: TORCH_LIBRARY(myops, m) { m.def("add(Tensor self, Tensor other) -> Tensor"); }
- 引數
raw_schema – 要定義的運算元的 schema。通常是一個
const char*字串字面量,但 torch::schema() 接受的任何型別都可以在這裡使用。
-
inline Library &def(const char *raw_schema, const std::vector<at::Tag> &tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &
-
inline Library &set_python_module(const char *pymodule, const char *context = "")
宣告對於後續定義的所有運算元,其偽實現(fake impls)可以在給定的 Python 模組 (pymodule) 中找到。
這會註冊一些幫助文字,在找不到偽實現時會用到。
引數 (Args)
pymodule: python 模組
context: 我們可以將其包含在錯誤訊息中。
-
inline Library &impl_abstract_pystub(const char *pymodule, const char *context = "")
已棄用;請使用 set_python_module。
-
template<typename NameOrSchema, typename Func>
inline Library &def(NameOrSchema &&raw_name_or_schema, Func &&raw_f, const std::vector<at::Tag> &tags = {}) & 定義一個運算元,然後為其註冊一個實現。
這通常是您在不打算利用排程器來組織運算元實現時使用的。它大致等同於呼叫 def() 然後呼叫 impl(),但如果您省略了運算元的 schema,我們將從 C++ 函式的型別中推斷出來。所有模板引數都將自動推斷。
// Example: TORCH_LIBRARY(myops, m) { m.def("add", add_fn); }
- 引數
raw_name_or_schema – 要定義的運算元的 schema,或者只是運算元的名稱(如果 schema 是從
raw_f推斷的)。通常是一個const char*字面量。raw_f – 實現此運算元的 C++ 函式。這裡接受任何有效的 torch::CppFunction 建構函式;通常您提供一個函式指標或 lambda。
-
template<typename Name, typename Func>
inline Library &impl(Name name, Func &&raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & 註冊一個運算元的實現。
您可以為單個運算元在不同的排程鍵處註冊多個實現(請參閱 torch::dispatch())。實現必須有一個對應的宣告(來自 def()),否則它們是無效的。如果您計劃註冊多個實現,請不要在 def() 運算元時提供函式實現。
// Example: TORCH_LIBRARY_IMPL(myops, CUDA, m) { m.impl("add", add_cuda); }
- 引數
name – 要實現的運算元的名稱。請勿在此處提供 schema。
raw_f – 實現此運算元的 C++ 函式。這裡接受任何有效的 torch::CppFunction 建構函式;通常您提供一個函式指標或 lambda。
-
c10::OperatorName _resolve(const char *name) const
-
inline Library &def(detail::SelectiveStr<false>, const std::vector<at::Tag> &tags[
[maybe_unused]] = {}) &
-
inline Library &def(detail::SelectiveStr<true> raw_schema, const std::vector<at::Tag> &tags = {}) &
-
template<typename Func>
inline Library &def(detail::SelectiveStr<false>, Func&&, const std::vector<at::Tag> &tags[[maybe_unused]] = {}) &
-
template<typename Func>
inline Library &def(detail::SelectiveStr<true> raw_name_or_schema, Func &&raw_f, const std::vector<at::Tag> &tags = {}) &
-
template<typename Func>
inline Library &impl(detail::SelectiveStr<false>, Func&&) &
-
template<typename Dispatch, typename Func>
inline Library &impl(detail::SelectiveStr<false>, Dispatch&&key, Func&&raw_f) &
-
template<typename Func>
inline Library &impl_UNBOXED(detail::SelectiveStr<false>, Func*) &
-
template<typename Func>
inline Library &impl(detail::SelectiveStr<true> name, Func &&raw_f) &
-
template<typename Dispatch, typename Func>
inline Library &impl(detail::SelectiveStr<true> name, Dispatch &&key, Func &&raw_f) &
-
template<typename Func>
inline Library &impl_UNBOXED(detail::SelectiveStr<true>, Func*) &
-
template<typename Func>
inline Library &fallback(Func &&raw_f) & 註冊一個運算元的回退實現,當沒有特定運算元實現可用時將被使用。
回退必須關聯一個排程鍵;例如,只能從名稱空間為
_的 TORCH_LIBRARY_IMPL() 呼叫此函式;例如,只有當名稱空間為_的 TORCH_LIBRARY_IMPL() 呼叫此函式。// Example: TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { // If there is not a kernel explicitly registered // for AutogradXLA, fallthrough to the next // available kernel m.fallback(torch::CppFunction::makeFallthrough()); } // See aten/src/ATen/core/dispatch/backend_fallback_test.cpp // for a full example of boxed fallback
- 引數
raw_f – 實現回退的函式。未打包的函式(Unboxed functions)通常不能用作回退函式,因為回退函式必須適用於每個運算元(即使它們具有不同的型別簽名)。常見的引數是 CppFunction::makeFallthrough() 或 CppFunction::makeFromBoxedFunction()。
-
template<class CurClass>
inline torch::class_<CurClass> class_(detail::SelectiveStr<true> className)
-
template<class CurClass>
inline detail::ClassNotSelected class_(detail::SelectiveStr<false> className)
-
void reset()
-
template<class CurClass>
inline class_<CurClass> class_(detail::SelectiveStr<true> className)
友元
- friend class detail::TorchLibraryInit
-
Library(const Library&) = delete
-
class CppFunction
表示實現運算元的 C++ 函式。
大多數使用者不會直接與此類互動,除了在錯誤訊息中:此函式提供的建構函式定義了可以透過介面繫結的“類函式”的允許集合。
此類擦除了所傳遞函式的型別,但透過從函式推斷出的 schema 來持久記錄型別。
公共函式
-
template<typename Func>
inline explicit CppFunction(Func *f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr) 此過載接受函式指標,例如
CppFunction(&add_impl)。
-
template<typename FuncPtr>
inline explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr) 此過載接受編譯時函式指標,例如
CppFunction(TORCH_FN(add_impl))。
-
template<typename Lambda>
inline explicit CppFunction(Lambda &&f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr) 此過載接受 lambda,例如
CppFunction([](const Tensor& self) { ...。})
-
~CppFunction()
-
CppFunction(const CppFunction&) = delete
-
CppFunction &operator=(const CppFunction&) = delete
-
CppFunction(CppFunction&&) noexcept = default
-
CppFunction &operator=(CppFunction&&) = default
-
inline CppFunction &&debug(std::string d) &&
公共靜態函式
-
static inline CppFunction makeFallthrough()
建立回退函式。
回退函式會立即重新排程到下一個可用的排程鍵,但其實現比手動編寫的相同功能的函式更有效率。
-
template<c10::BoxedKernel::BoxedKernelFunction *func>
static inline CppFunction makeFromBoxedFunction() 從簽名
void(const OperatorHandle&, Stack*)的打包核心函式(boxed kernel function)建立函式;即,它們以打包的呼叫約定接收引數,而不是以原生的 C++ 呼叫約定。打包函式通常只用於透過 torch::Library::fallback() 註冊後端回退。
-
template<c10::BoxedKernel::BoxedKernelFunction_withDispatchKeys *func>
static inline CppFunction makeFromBoxedFunction()
-
template<class KernelFunctor>
static inline CppFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) 從打包的核心函式(boxed kernel functor)建立函式,該函式定義了
operator()(const OperatorHandle&, DispatchKeySet, Stack*)(以打包的呼叫約定接收引數),並繼承自c10::OperatorKernel。與makeFromBoxedFunction不同,以這種方式註冊的函式還可以攜帶由 functor 管理的附加狀態;如果您正在為其他實現(例如 Python 可呼叫物件)編寫介面卡,並且該介面卡已動態關聯到已註冊的核心,則這會很有用。
-
template<typename FuncPtr, std::enable_if_t<c10::guts::is_function_type<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr *f) 從非裝箱核心函式建立函式。
這通常用於註冊常用運算子。
-
template<typename FuncPtr, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr>
static inline CppFunction makeFromUnboxedFunction(FuncPtr f) 從編譯時非裝箱核心函式指標建立函式。
這通常用於註冊常用運算子。編譯時函式指標可用於讓編譯器最佳化(例如內聯)對其的呼叫。
-
template<typename Func>
函式#
-
template<typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func &&raw_f)# 建立一個與特定 dispatch key 關聯的 torch::CppFunction。
torch::CppFunctions 標記了 c10::DispatchKey,除非排程器確定應該排程此特定的 c10::DispatchKey,否則它們不會被呼叫。
此函式通常不直接使用,而是首選使用 TORCH_LIBRARY_IMPL(),它將隱式設定其正文中所有註冊呼叫的 c10::DispatchKey。
-
template<typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func &&raw_f)# 接受 c10::DeviceType 的 dispatch() 的便利過載。
-
inline c10::FunctionSchema schema(const char *str, c10::AliasAnalysisKind k, bool allow_typevars = false)#
從字串構造 c10::FunctionSchema,並顯式指定 c10::AliasAnalysisKind。
通常,schema 只需作為字串傳遞,但如果您需要指定自定義別名分析,則可以用對此函式的呼叫替換字串。
// Default alias analysis (FROM_SCHEMA) m.def("def3(Tensor self) -> Tensor"); // Pure function alias analysis m.def(torch::schema("def3(Tensor self) -> Tensor", c10::AliasAnalysisKind::PURE_FUNCTION));
-
inline c10::FunctionSchema schema(const char *s, bool allow_typevars = false)#
函式 schema 可以直接從字串字面量構造。