評價此頁

torch.library#

創建於:2022年6月13日 | 最後更新於:2025年8月13日

torch.library 是用於擴充套件 PyTorch 核心運算元庫的 API 集合。它包含用於測試自定義運算元、建立新的自定義運算元以及擴充套件使用 PyTorch C++ 運算元註冊 API(例如 aten 運算元)定義的運算元的實用工具。

有關有效使用這些 API 的詳細指南,請參閱 PyTorch 自定義運算元著陸頁,瞭解有關如何有效使用這些 API 的更多詳細資訊。

測試自定義運算元#

使用 torch.library.opcheck() 測試自定義運算元是否正確使用了 Python torch.library 和/或 C++ TORCH_LIBRARY API。此外,如果您的運算元支援訓練,請使用 torch.autograd.gradcheck() 來測試梯度是否在數學上是正確的。

torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)[source]#

給定一個運算元和一些示例引數,測試該運算元是否已正確註冊。

也就是說,當您使用 torch.library/TORCH_LIBRARY API 建立自定義運算元時,您指定了自定義運算元的元資料(例如,可變性資訊),而這些 API 要求您傳遞的函式滿足特定屬性(例如,Fake/Meta/Abstract 核心中沒有資料指標訪問)`opcheck` 會測試這些元資料和屬性。

具體來說,我們測試以下內容:

  • test_schema: 如果模式與運算元的實現匹配。例如:如果模式指定了 Tensor 被修改,那麼我們檢查實現是否修改了 Tensor。如果模式指定我們返回一個新的 Tensor,那麼我們檢查實現是否返回了一個新的 Tensor(而不是一個現有的 Tensor 或現有 Tensor 的檢視)。

  • test_autograd_registration: 如果運算元支援訓練(autograd):我們檢查其 autograd 公式是否透過 torch.library.register_autograd 或手動註冊到一個或多個 DispatchKey::Autograd 鍵來註冊。任何其他基於 DispatchKey 的註冊都可能導致未定義行為。

  • test_faketensor: 如果運算元具有 FakeTensor 核心(並且是正確的)。FakeTensor 核心對於運算元與 PyTorch 編譯 API(torch.compile/export/FX)一起工作是必需的(但不是充分條件)。我們檢查是否為該運算元註冊了 FakeTensor 核心(也稱為 meta 核心)並且它是正確的。此測試會獲取在實際 Tensor 上執行運算元的結果以及在 FakeTensor 上執行運算元的結果,並檢查它們是否具有相同的 Tensor 元資料(大小/步幅/dtype/裝置/等)。

  • test_aot_dispatch_dynamic: 如果運算元在 PyTorch 編譯 API(torch.compile/export/FX)下表現正確。這會檢查在 eager 模式 PyTorch 和 torch.compile 下的輸出(以及適用的梯度)是否相同。此測試是 `test_faketensor` 的超集,並且是一個端到端測試;它還測試運算元是否支援功能化,以及反向傳播(如果存在)是否也支援 FakeTensor 和功能化。

為了獲得最佳結果,請使用代表性輸入集多次呼叫 `opcheck`。如果您的運算元支援 autograd,請使用 `requires_grad = True` 的輸入呼叫 `opcheck`;如果您的運算元支援多個裝置(例如 CPU 和 CUDA),請在所有支援的裝置上使用輸入呼叫 `opcheck`。

引數
  • op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 運算元。必須是使用 torch.library.custom_op() 裝飾的函式,或者是 torch.ops.* 中的 OpOverload/OpOverloadPacket(例如,torch.ops.aten.sin, torch.ops.mylib.foo)

  • args (tuple[Any, ...]) – 運算元的引數

  • kwargs (Optional[dict[str, Any]]) – 運算元的關鍵字引數

  • test_utils (Union[str, Sequence[str]]) – 我們應該執行的測試。預設值:全部。示例:("test_schema", "test_faketensor")

  • raise_exception (bool) – 我們是否應該在第一次出錯時丟擲異常。如果為 False,我們將返回一個包含每個測試是否透過的資訊的字典。

  • rtol (Optional[float]) – 浮點數比較的相對容差。如果指定,則 atol 也必須指定。如果省略,則選擇基於 `dtype` 的預設值(請參閱 torch.testing.assert_close() 中的表格)。

  • atol (Optional[float]) – 浮點數比較的絕對容差。如果指定,則 rtol 也必須指定。如果省略,則選擇基於 `dtype` 的預設值(請參閱 torch.testing.assert_close() 中的表格)。

返回型別

dict[str, str]

警告

`opcheck` 和 `torch.autograd.gradcheck` 測試的內容不同;`opcheck` 測試您對 torch.library API 的使用是否正確,而 `torch.autograd.gradcheck` 測試您的 autograd 公式是否在數學上是正確的。對於支援梯度計算的自定義運算元,請同時使用兩者進行測試。

示例

>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: float) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     z_np = x_np * y
>>>     return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_mul.register_fake
>>> def _(x, y):
>>>     return torch.empty_like(x)
>>>
>>> def setup_context(ctx, inputs, output):
>>>     y, = inputs
>>>     ctx.y = y
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.y, None
>>>
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>>     (torch.randn(3), 3.14),
>>>     (torch.randn(2, 3, device='cuda'), 2.718),
>>>     (torch.randn(1, 10, requires_grad=True), 1.234),
>>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
>>> ]
>>>
>>> for args in sample_inputs:
>>>     torch.library.opcheck(numpy_mul, args)

在 Python 中建立新的自定義運算元#

使用 torch.library.custom_op() 來建立新的自定義運算元。

torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None, tags=None)[source]#

將一個函式包裝成自定義運算元。

您可能希望建立自定義運算元的一些原因包括:- 包裝第三方庫或自定義核心以使其與 Autograd 等 PyTorch 子系統協同工作。- 防止 torch.compile/export/FX 跟蹤窺探您的函式內部。

此 API 用作函式上方的裝飾器(請參閱示例)。提供的函式必須具有型別提示;這些對於與 PyTorch 的各種子系統進行介面是必需的。

引數
  • name (str) – 自定義運算元的名稱,格式為“{namespace}::{name}”,例如,“mylib::my_linear”。該名稱用作運算元在 PyTorch 子系統(例如,torch.export、FX 圖)中的穩定識別符號。為避免名稱衝突,請使用您的專案名稱作為名稱空間;例如,pytorch/fbgemm 中的所有自定義運算元都使用“fbgemm”作為名稱空間。

  • mutates_args (Iterable[str] or "unknown") – 函式修改的引數名稱。這必須是準確的,否則行為將是未定義的。如果為“unknown”,則悲觀地假定運算元的所有輸入都被修改。

  • device_types (None | str | Sequence[str]) – 函式有效的裝置型別。如果未提供裝置型別,則該函式用作所有裝置型別的預設實現。示例:“cpu”、“cuda”。當為接受無 Tensor 的運算元註冊特定裝置實現時,我們要求該運算元具有“device: torch.device 引數”。

  • schema (None | str) – 運算元的模式字串。如果為 None(推薦),我們將從其型別註解中推斷運算元的模式。我們建議讓 PyTorch 推斷模式,除非您有特定原因不這樣做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回型別

Union[Callable[[Callable[[…], object]], CustomOpDef], CustomOpDef]

注意

我們建議不要傳遞 `schema` 引數,而是讓 PyTorch 從型別註解中推斷它。編寫自己的模式容易出錯。如果您希望的模式與我們從型別註解推斷的模式不同,您可能希望提供自己的模式。有關如何編寫模式字串的更多資訊,請參閱 此處

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>>     x_np = x.numpy()
>>>     np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>>
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -> Tensor:
>>>     return torch.ones(3)
>>>
>>> bar("cpu")
torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)[source]#

建立一個實現由 1 個或多個 triton 核心支援的自定義運算元。

這是使用 triton 核心與 PyTorch 的一種更結構化的方法。優先使用沒有 `torch.library` 自定義運算元包裝器(如 `torch.library.custom_op()`、`torch.library.triton_op()`)的 triton 核心,因為那樣更簡單;僅當您想建立一個行為類似於 PyTorch 內建運算元的運算元時,才使用 `torch.library.custom_op()`/`torch.library.triton_op()`。例如,您可以使用 `torch.library` 包裝器 API 來定義 triton 核心在傳入張量子類或 TorchDispatchMode 下的行為。

當實現由 1 個或多個 triton 核心組成時,請使用 `torch.library.triton_op()` 而不是 `torch.library.custom_op()`。`torch.library.custom_op()` 將自定義運算元視為不透明(`torch.compile()` 和 `torch.export.export()` 永遠不會跟蹤它們),但 `triton_op` 使這些子系統可以看到實現,從而允許它們最佳化 triton 核心。

請注意,`fn` 只能由 PyTorch 可理解的運算元和 triton 核心組成。在 `fn` 中呼叫的任何 triton 核心都必須包裝在對 `torch.library.wrap_triton()` 的呼叫中。

引數
  • name (str) – 自定義運算元的名稱,格式為“{namespace}::{name}”,例如,“mylib::my_linear”。該名稱用作運算元在 PyTorch 子系統(例如,torch.export、FX 圖)中的穩定識別符號。為避免名稱衝突,請使用您的專案名稱作為名稱空間;例如,pytorch/fbgemm 中的所有自定義運算元都使用“fbgemm”作為名稱空間。

  • mutates_args (Iterable[str] or "unknown") – 函式修改的引數名稱。這必須是準確的,否則行為將是未定義的。如果為“unknown”,則悲觀地假定運算元的所有輸入都被修改。

  • schema (None | str) – 運算元的模式字串。如果為 None(推薦),我們將從其型別註解中推斷運算元的模式。我們建議讓 PyTorch 推斷模式,除非您有特定原因不這樣做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。

返回型別

Callable

示例

>>> import torch
>>> from torch.library import triton_op, wrap_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     # NB: we need to wrap the triton kernel in a call to wrap_triton
>>>     wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>>     return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
torch.library.wrap_triton(triton_kernel, /)[source]#

允許透過 make_fx 或非嚴格 `torch.export` 將 triton 核心捕獲到圖中。

這些技術執行基於 Dispatcher 的跟蹤(透過 `__torch_dispatch__`),無法看到對原始 triton 核心的呼叫。`wrap_triton` API 將 triton 核心包裝到一個可呼叫的物件中,該物件實際上可以被跟蹤到圖中。

請將此 API 與 `torch.library.triton_op()` 一起使用。

示例

>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.library import wrap_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>>     in_ptr0,
>>>     in_ptr1,
>>>     out_ptr,
>>>     n_elements,
>>>     BLOCK_SIZE: "tl.constexpr",
>>> ):
>>>     pid = tl.program_id(axis=0)
>>>     block_start = pid * BLOCK_SIZE
>>>     offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>>     mask = offsets < n_elements
>>>     x = tl.load(in_ptr0 + offsets, mask=mask)
>>>     y = tl.load(in_ptr1 + offsets, mask=mask)
>>>     output = x + y
>>>     tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>>     output = torch.empty_like(x)
>>>     n_elements = output.numel()
>>>
>>>     def grid_fn(meta):
>>>         return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>>     wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>>     return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> #     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> #     triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> #         kernel_idx = 0, constant_args_idx = 0,
>>> #         grid = [(1, 1, 1)], kwargs = {
>>> #             'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> #             'n_elements': 3, 'BLOCK_SIZE': 16
>>> #         })
>>> #     return empty_like
返回型別

任何

擴充套件自定義運算元(由 Python 或 C++ 建立)#

使用 `register.*` 方法,例如 `torch.library.register_kernel()` 和 `torch.library.register_fake()`,為任何運算元(它們可能是使用 `torch.library.custom_op()` 或透過 PyTorch 的 C++ 運算元註冊 API 建立的)新增實現。

torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source]#

為該運算元的特定裝置型別註冊一個實現。

一些有效的 device_types 是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。此 API 可用作裝飾器。

引數
  • op (str | OpOverload) – 要註冊實現的運算元。

  • device_types (None | str | Sequence[str]) – 要註冊實現的裝置型別。如果為 None,我們將註冊到所有裝置型別 - 請僅在您的實現真正與裝置型別無關時才使用此選項。

  • func (Callable) – 註冊為給定裝置型別實現的函式。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
>>> def _(x):
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)[source]#

為該自定義運算元註冊一個自動強制轉換分派規則。

有效的 `device_type` 包括:“cpu”和“cuda”。

引數
  • op (str | OpOverload) – 要註冊自動強制轉換分派規則的運算元。

  • device_type (str) – 要使用的裝置型別。“cuda”或“cpu”。該型別與 `torch.device` 的 `type` 屬性相同。因此,您可以使用 `Tensor.device.type` 來獲取張量的裝置型別。

  • cast_inputs (torch.dtype) – 當自定義運算元在啟用了自動強制轉換的區域執行時,將輸入的浮點 Tensor 強制轉換為目標 dtype(非浮點 Tensor 不受影響),然後停用自動強制轉換來執行自定義運算元。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期

示例:
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>>     return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>>     y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source]#

為該自定義運算元註冊一個反向傳播公式。

為了讓運算元能夠與 autograd 一起工作,您需要註冊一個反向傳播公式:1.您必須透過提供一個“backward”函式來告訴 PyTorch 如何在反向傳播過程中計算梯度。2.如果您在計算梯度時需要正向傳播的任何值,可以使用 `setup_context` 來為反向傳播儲存值。

`backward` 在反向傳播過程中執行。它接受 `(ctx, *grads)`:- `grads` 是一個或多個梯度。梯度的數量與運算元的輸出數量相匹配。`ctx` 物件與 `torch.autograd.Function` 使用的 `ctx` 物件相同。`backward_fn` 的語義與 `torch.autograd.Function.backward()` 相同。

`setup_context(ctx, inputs, output)` 在正向傳播過程中執行。請透過 `torch.autograd.function.FunctionCtx.save_for_backward()` 或將它們作為 `ctx` 的屬性來將反向傳播所需的值儲存到 `ctx` 物件上。如果您的自定義運算元有僅關鍵字引數,我們期望 `setup_context` 的簽名是 `setup_context(ctx, inputs, keyword_only_inputs, output)`。

`setup_context_fn` 和 `backward_fn` 都必須是可跟蹤的。也就是說,它們不能直接訪問 `torch.Tensor.data_ptr()`,並且它們不能依賴或修改全域性狀態。如果您需要一個不可跟蹤的反向傳播,您可以將其製作成一個單獨的自定義運算元,並在 `backward_fn` 中呼叫它。

如果您需要在不同裝置上具有不同的 autograd 行為,那麼我們建議建立兩個不同的自定義運算元,一個用於需要不同行為的裝置,並在執行時在它們之間切換。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = np.sin(x_np)
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>>     x, = inputs
>>>     ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>>     x, = ctx.saved_tensors
>>>     return grad * x.cos()
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_sin", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>>     x_np = x.cpu().numpy()
>>>     y_np = x_np * val
>>>     return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>>     ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>>     return grad * ctx.val
>>>
>>> torch.library.register_autograd(
...     "mylib::numpy_mul", backward, setup_context=setup_context
... )
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1, allow_override=False)[source]#

為該運算元註冊一個 FakeTensor 實現(“fake impl”)。

也稱為“meta kernel”、“abstract impl”。

“FakeTensor 實現”指定了該運算元在不包含資料的 Tensor(“FakeTensor”)上的行為。給定具有特定屬性(大小/步幅/儲存偏移量/裝置)的輸入 Tensor,它指定輸出 Tensor 的屬性。

FakeTensor 實現具有與運算元相同的簽名。它同時用於 FakeTensor 和 meta Tensor。要編寫 FakeTensor 實現,請假定運算元的所有 Tensor 輸入都是常規的 CPU/CUDA/Meta Tensor,但它們沒有儲存,並且您試圖返回常規的 CPU/CUDA/Meta Tensor 作為輸出。FakeTensor 實現只能由 PyTorch 操作組成(並且不能直接訪問任何輸入或中間 Tensor 的儲存或資料)。

此 API 可用作裝飾器(參見示例)。

有關自定義運算元的詳細指南,請參閱 https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html

引數
  • op_name – 運算元名稱(連同過載)或 OpOverload 物件。

  • func (Optional[Callable]) – Fake Tensor 實現。

  • lib (Optional[Library]) – 要將 fake tensor 註冊到的庫。

  • allow_override (bool) – 控制是否覆蓋已註冊的 fake impl 的標誌。預設情況下,此標誌處於關閉狀態,並且如果您嘗試向已具有 fake impl 的運算元註冊 fake impl,則會報錯。這也僅適用於未使用 torch.library.custom_op 建立的自定義運算元,因為覆蓋已有的 fake impl 已經允許。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>>     raise NotImplementedError("Implementation goes here")
>>>
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>>     assert x.dim() == 2
>>>     assert weight.dim() == 2
>>>     assert bias.dim() == 1
>>>     assert x.shape[1] == weight.shape[1]
>>>     assert weight.shape[0] == bias.shape[0]
>>>     assert x.device == weight.device
>>>
>>>     return (x @ weight.t()) + bias
>>>
>>> with torch._subclasses.fake_tensor.FakeTensorMode():
>>>     x = torch.randn(2, 3)
>>>     w = torch.randn(3, 3)
>>>     b = torch.randn(3)
>>>     y = torch.ops.mylib.custom_linear(x, w, b)
>>>
>>> assert y.shape == (2, 3)
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>>     x_np = x.numpy(force=True)
>>>     res = np.stack(np.nonzero(x_np), axis=1)
>>>     return torch.tensor(res, device=x.device)
>>>
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>>     ctx = torch.library.get_ctx()
>>>     nnz = ctx.new_dynamic_size()
>>>     shape = [nnz, x.dim()]
>>>     result = x.new_empty(shape, dtype=torch.int64)
>>>     return result
>>>
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>>
>>> x = torch.tensor([0, 1, 2, 3, 4, 0])
>>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
>>> trace.print_readable()
>>>
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
torch.library.register_vmap(op, func=None, /, *, lib=None)[source]#

註冊一個 vmap 實現以支援該自定義運算元的 `torch.vmap()`。

此 API 可用作裝飾器(參見示例)。

為了讓運算元能夠與 `torch.vmap()` 一起工作,您可能需要註冊一個具有以下簽名的 vmap 實現:

vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs),

其中 `*args` 和 `**kwargs` 是 `op` 的引數和關鍵字引數。我們不支援僅關鍵字引數的 Tensor 引數。

它指定了如何計算 `op` 的批次版本,給定輸入和額外的維度(由 `in_dims` 指定)。

對於 `args` 中的每個引數,`in_dims` 都有一個對應的 `Optional[int]`。如果引數不是 Tensor 或未對引數進行 vmap,則為 `None`,否則它是一個指定 Tensor 的哪個維度正在被 vmap 的整數。

`info` 是可能有助於此的附加元資料的集合:`info.batch_size` 指定被 vmap 的維度的尺寸,而 `info.randomness` 是傳遞給 `torch.vmap()` 的 `randomness` 選項。

函式 `func` 的返回值為 `(output, out_dims)` 元組。類似於 `in_dims`,`out_dims` 的結構應與 `output` 相同,並且每個輸出都包含一個 `out_dim`,指定輸出是否具有 vmap 維度以及它在其中的索引。

示例

>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>>     return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>>     x_np = to_numpy(x)
>>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>>     return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>>     result = numpy_cube(x)
>>>     return result, (in_dims[0], in_dims[0])
>>>
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>>     x_bdim, y_bdim = in_dims
>>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>>     result = x * y
>>>     result = result.movedim(-1, 0)
>>>     return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)

注意

vmap 函式應旨在保留整個自定義運算元的語義。也就是說,`grad(vmap(op))` 應可替換為 `grad(map(op))`。

如果您的自定義運算元在反向傳播過程中有任何自定義行為,請牢記這一點。

torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]#

此 API 在 PyTorch 2.4 中重新命名為 `torch.library.register_fake()`。請改用該 API。

torch.library.get_ctx()[source]#

`get_ctx()` 返回當前的 AbstractImplCtx 物件。

呼叫 `get_ctx()` 僅在 fake impl 內部有效(有關更多用法詳細資訊,請參見 `torch.library.register_fake()`)。

返回型別

FakeImplCtx

torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source]#

為給定的運算元和 `torch_dispatch_class` 註冊一個 torch_dispatch 規則。

這允許開放式註冊來指定運算元與 `torch_dispatch_class` 之間的行為,而無需直接修改 `torch_dispatch_class` 或運算元。

`torch_dispatch_class` 要麼是具有 `__torch_dispatch__` 的 Tensor 子類,要麼是 TorchDispatchMode。

如果是 Tensor 子類,我們期望 `func` 具有以下簽名:` (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`

如果是 TorchDispatchMode,我們期望 `func` 具有以下簽名:` (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`

`args` 和 `kwargs` 將已在 `__torch_dispatch__` 中以相同方式進行規範化(請參閱 __torch_dispatch__ 呼叫約定)。

示例

>>> import torch
>>>
>>> @torch.library.custom_op("mylib::foo", mutates_args={})
>>> def foo(x: torch.Tensor) -> torch.Tensor:
>>>     return x.clone()
>>>
>>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
>>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
>>>         return func(*args, **kwargs)
>>>
>>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
>>> def _(mode, func, types, args, kwargs):
>>>     x, = args
>>>     return x + 1
>>>
>>> x = torch.randn(3)
>>> y = foo(x)
>>> assert torch.allclose(y, x)
>>>
>>> with MyMode():
>>>     y = foo(x)
>>> assert torch.allclose(y, x + 1)
torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)[source]#

解析給定函式(帶型別提示)的模式。模式從函式的型別提示中推斷出來,並可用於定義新運算元。

我們做出以下假設:

  • 無輸出會別名任何輸入或彼此。

  • 字串型別註解“device, dtype, Tensor, types”,沒有庫規範,將被
    假定為 torch.*。類似地,字串型別註解“Optional, List, Sequence, Union”
    沒有庫規範,將被假定為 typing.*。
  • 只有 `mutates_args` 中列出的引數才會被修改。如果 `mutates_args` 為“unknown”,
    則假定運算元的所有輸入都被修改。

呼叫者(例如,自定義運算元 API)負責檢查這些假設。

引數
  • prototype_function (Callable) – 用於從其型別註解推斷模式的函式。

  • op_name (Optional[str]) – 模式中運算元的名稱。如果 `name` 為 None,則名稱不包含在推斷的模式中。請注意,`torch.library.Library.define` 的輸入模式需要一個運算元名稱。

  • mutates_args ("unknown" | Iterable[str]) – 函式中修改的引數。

返回

推斷出的模式。

返回型別

str

示例

>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>>     return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn, tags=None)[source]#

`CustomOpDef` 是一個函式包裝器,它將函式轉換為自定義運算元。

它具有各種方法來為此自定義運算元註冊附加行為。

您不應直接例項化 `CustomOpDef`;而是使用 `torch.library.custom_op()` API。

set_kernel_enabled(device_type, enabled=True)[source]#

停用或重新啟用此自定義運算元已註冊的核心。

如果核心已停用/啟用,則此操作無效。

注意

如果核心先被停用然後註冊,則它一直處於停用狀態,直到再次啟用。

引數
  • device_type (str) – 要為此停用/啟用核心的裝置型別。

  • disable (bool) – 是停用還是啟用核心。

示例

>>> inp = torch.randn(1)
>>>
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -> Tensor:
>>>     return torch.zeros(1)
>>>
>>> print(f(inp))  # tensor([0.]), default kernel
>>>
>>> @f.register_kernel("cpu")
>>> def _(x):
>>>     return torch.ones(1)
>>>
>>> print(f(inp))  # tensor([1.]), CPU kernel
>>>
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled
torch.library.get_kernel(op, dispatch_key)[source]#

返回給定運算元和分派鍵的計算出的核心。

此函式檢索將為給定的運算元和分派鍵組合執行的核心。返回的 SafeKernelFunction 可用於以打包方式呼叫核心。此函式 intended 用例是檢索給定分派鍵的原始核心,然後向同一分派鍵註冊另一個核心,該核心在某些情況下會呼叫原始核心。

引數
  • op (Union[str, OpOverload, CustomOpDef]) – 運算元名稱(連同過載)或 OpOverload 物件。可以是字串(例如,“aten::add.Tensor”)、OpOverload 或 CustomOpDef。

  • dispatch_key (str | torch.DispatchKey) – 要獲取核心的分派鍵。可以是字串(例如,“CPU”、“CUDA”)或 DispatchKey 列舉值。

返回

一個安全的核心函式,可用於

呼叫核心。

返回型別

torch._C._SafeKernelFunction

引發

RuntimeError – 如果運算元不存在。

示例

>>> # Get the CPU kernel for torch.add
>>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
>>>
>>> # You can also use DispatchKey enum
>>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
>>>
>>> # Or use an OpOverload directly
>>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
>>>
>>> # Example: Using get_kernel in a custom op with conditional dispatch
>>> # Get the original kernel for torch.sin
>>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
>>>
>>> # If input has negative values, use original sin, otherwise return zeros
>>> def conditional_sin_impl(dispatch_keys, x):
>>>     if (x < 0).any():
>>>         return original_sin_kernel.call_boxed(dispatch_keys, x)
>>>     else:
>>>         return torch.zeros_like(x)
>>>
>>> lib = torch.library.Library("aten", "IMPL")
>>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
>>> which needs to be the first argument to ``kernel.call_boxed``
>>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
>>>
>>> # Test the conditional behavior
>>> x_positive = torch.tensor([1.0, 2.0])
>>> x_mixed = torch.tensor([-1.0, 2.0])
>>> torch.sin(x_positive)
tensor([0., 0.])
>>> torch.sin(x_mixed)
tensor([-0.8415, 0.9093])

底層 API#

以下 API 是 PyTorch 底層運算元註冊 API 的直接繫結。

警告

底層運算元註冊 API 和 PyTorch Dispatcher 是一個複雜的 PyTorch 概念。我們建議您在可能的情況下使用上面的更高級別的 API(不需要 torch.library.Library 物件)。這篇博文是瞭解 PyTorch Dispatcher 的一個好起點。

本教程在 Google Colab 上提供了關於如何使用此 API 的一些示例。

class torch.library.Library(ns, kind, dispatch_key='')[source]#

一個類,用於建立可用於在 Python 中註冊新運算元或覆蓋現有庫中的運算元的庫。使用者可以選擇傳入一個分派鍵名稱,如果他們只想註冊對應於單個特定分派鍵的核心。

要建立一個用於覆蓋現有庫(名稱為 ns)中的運算元的庫,請將 kind 設定為“IMPL”。要建立一個新庫(名稱為 ns)來註冊新運算元,請將 kind 設定為“DEF”。要建立一個可能存在的庫的片段來註冊運算元(並繞過一個名稱空間只有一個庫的限制),請將 kind 設定為“FRAGMENT”。

引數
  • ns – 庫名稱

  • kind – “DEF”、“IMPL”、“FRAGMENT”

  • dispatch_key – PyTorch 分派鍵(預設為“”)

define(schema, alias_analysis='', *, tags=())[source]#

在 ns 名稱空間中定義新運算元及其語義。

引數
  • schema – 用於定義新運算元的函式模式。

  • alias_analysis (optional) – 指示是否可以從模式(預設行為)推斷運算元引數的別名屬性,或者“CONSERVATIVE”。

  • tags (Tag | Sequence[Tag]) – 應用於此運算元的一個或多個 torch.Tag。標記運算元會改變運算元在各種 PyTorch 子系統下的行為;請在應用之前仔細閱讀 torch.Tag 的文件。

返回

從模式推斷出的運算元名稱。

示例

>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
fallback(fn, dispatch_key='', *, with_keyset=False)[source]#

將函式實現註冊為給定鍵的後備。

此函式僅適用於具有全域性名稱空間(“_”)的庫。

引數
  • fn – 用作給定分派鍵的後備的函式,或 `torch.library.fallthrough_kernel()` 以註冊一個後落。

  • dispatch_key – 應為輸入函式註冊的分派鍵。預設情況下,它使用建立庫時使用的分派鍵。

  • with_keyset – 控制在呼叫 `fn` 時當前分派器呼叫鍵集是否應作為第一個引數傳遞的標誌。這應該用於建立鍵集的適當方法以進行重分派呼叫。

示例

>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>>     # Handle all autocast ops generically
>>>     # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
impl(op_name, fn, dispatch_key='', *, with_keyset=False, allow_override=False)[source]#

為在庫中定義的運算元註冊函式實現。

引數
  • op_name – 運算元名稱(連同過載)或 OpOverload 物件。

  • fn – 作為輸入分派鍵的運算元實現,或 `fallthrough_kernel()` 以註冊後落的函式。

  • dispatch_key – 應為輸入函式註冊的分派鍵。預設情況下,它使用建立庫時使用的分派鍵。

  • with_keyset – 控制在呼叫 `fn` 時當前分派器呼叫鍵集是否應作為第一個引數傳遞的標誌。這應該用於建立鍵集的適當方法以進行重分派呼叫。

  • allow_override – 控制是否覆蓋已註冊核心實現的標誌。預設情況下,此標誌處於關閉狀態,並且如果您嘗試向已註冊核心的排程鍵註冊核心,則會報錯。

示例

>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source]#

一個虛擬函式,可以傳遞給 `Library.impl` 以註冊一個後落。

torch.library.define(qualname, schema, *, lib=None, tags=())[source]#
torch.library.define(lib, schema, alias_analysis='')

定義一個新運算元。

在 PyTorch 中,定義一個 op(“operator”的縮寫)是一個兩步過程:- 我們需要定義 op(透過提供運算元名稱和模式)- 我們需要實現運算元如何與各種 PyTorch 子系統(如 CPU/CUDA Tensor、Autograd 等)互動的行為。

此入口點定義自定義運算元(第一步),您必須透過呼叫各種 `impl_*` API(如 `torch.library.impl()` 或 `torch.library.register_fake()`)來執行第二步。

引數
  • qualname (str) – 運算元的限定名稱。應為格式為“namespace::name”的字串,例如“aten::sin”。PyTorch 中的運算元需要一個名稱空間來避免名稱衝突;一個給定的運算元只能建立一次。如果您正在編寫 Python 庫,我們建議名稱空間為您的頂級模組的名稱。

  • schema (str) – 運算元的模式。例如,“(Tensor x) -> Tensor”表示接受一個 Tensor 並返回一個 Tensor 的 op。它不包含運算元名稱(該名稱在 `qualname` 中傳遞)。

  • lib (Optional[Library]) – 如果提供,此運算元的生命週期將與 Library 物件的生命週期繫結。

  • tags (Tag | Sequence[Tag]) – 應用於此運算元的一個或多個 torch.Tag。標記運算元會改變運算元在各種 PyTorch 子系統下的行為;請在應用之前仔細閱讀 torch.Tag 的文件。

示例:
>>> import torch
>>> import numpy as np
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.sin(x)
>>> assert torch.allclose(y, x.sin())
torch.library.impl(lib, name, dispatch_key='')[source]#
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Literal[None] = None, *, lib: Optional[Library] = None) Callable[[Callable[..., object]], None]
torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Callable[..., object], *, lib: Optional[Library] = None) None
torch.library.impl(lib: Library, name: str, dispatch_key: str = '') Callable[[Callable[_P, _T]], Callable[_P, _T]]

為該運算元的特定裝置型別註冊一個實現。

您可以為 `types` 傳遞“default”以將此實現註冊為所有裝置型別的預設實現。請僅在實現真正支援所有裝置型別時才使用此選項;例如,當它是內建 PyTorch 運算元的組合時,這是真的。

此 API 可用作裝飾器。您可以使用巢狀裝飾器與此 API,前提是它們返回一個函式並放置在此 API 內部(參見示例 2)。

一些有效的型別是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。

引數
  • qualname (str) – 應為格式為“namespace::operator_name”的字串。

  • types (str | Sequence[str]) – 要註冊實現的裝置型別。

  • lib (Optional[Library]) – 如果提供,此註冊的生命週期將與 Library 物件的生命週期繫結。

示例

>>> import torch
>>> import numpy as np
>>> # Example 1: Register function.
>>> # Define the operator
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the cpu device
>>> @torch.library.impl("mylib::mysin", "cpu")
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> x = torch.randn(3)
>>> y = torch.ops.mylib.mysin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example 2: Register function with decorator.
>>> def custom_decorator(func):
>>>     def wrapper(*args, **kwargs):
>>>         return func(*args, **kwargs) + 1
>>>     return wrapper
>>>
>>> # Define the operator
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
>>>
>>> # Add implementations for the operator
>>> @torch.library.impl("mylib::sin_plus_one", "cpu")
>>> @custom_decorator
>>> def f(x):
>>>     return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Call the new operator from torch.ops.
>>> x = torch.randn(3)
>>>
>>> y1 = torch.ops.mylib.sin_plus_one(x)
>>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2)