torch.library#
創建於:2022年6月13日 | 最後更新於:2025年8月13日
torch.library 是用於擴充套件 PyTorch 核心運算元庫的 API 集合。它包含用於測試自定義運算元、建立新的自定義運算元以及擴充套件使用 PyTorch C++ 運算元註冊 API(例如 aten 運算元)定義的運算元的實用工具。
有關有效使用這些 API 的詳細指南,請參閱 PyTorch 自定義運算元著陸頁,瞭解有關如何有效使用這些 API 的更多詳細資訊。
測試自定義運算元#
使用 torch.library.opcheck() 測試自定義運算元是否正確使用了 Python torch.library 和/或 C++ TORCH_LIBRARY API。此外,如果您的運算元支援訓練,請使用 torch.autograd.gradcheck() 來測試梯度是否在數學上是正確的。
- torch.library.opcheck(op, args, kwargs=None, *, test_utils=('test_schema', 'test_autograd_registration', 'test_faketensor', 'test_aot_dispatch_dynamic'), raise_exception=True, atol=None, rtol=None)[source]#
給定一個運算元和一些示例引數,測試該運算元是否已正確註冊。
也就是說,當您使用 torch.library/TORCH_LIBRARY API 建立自定義運算元時,您指定了自定義運算元的元資料(例如,可變性資訊),而這些 API 要求您傳遞的函式滿足特定屬性(例如,Fake/Meta/Abstract 核心中沒有資料指標訪問)`opcheck` 會測試這些元資料和屬性。
具體來說,我們測試以下內容:
test_schema: 如果模式與運算元的實現匹配。例如:如果模式指定了 Tensor 被修改,那麼我們檢查實現是否修改了 Tensor。如果模式指定我們返回一個新的 Tensor,那麼我們檢查實現是否返回了一個新的 Tensor(而不是一個現有的 Tensor 或現有 Tensor 的檢視)。
test_autograd_registration: 如果運算元支援訓練(autograd):我們檢查其 autograd 公式是否透過 torch.library.register_autograd 或手動註冊到一個或多個 DispatchKey::Autograd 鍵來註冊。任何其他基於 DispatchKey 的註冊都可能導致未定義行為。
test_faketensor: 如果運算元具有 FakeTensor 核心(並且是正確的)。FakeTensor 核心對於運算元與 PyTorch 編譯 API(torch.compile/export/FX)一起工作是必需的(但不是充分條件)。我們檢查是否為該運算元註冊了 FakeTensor 核心(也稱為 meta 核心)並且它是正確的。此測試會獲取在實際 Tensor 上執行運算元的結果以及在 FakeTensor 上執行運算元的結果,並檢查它們是否具有相同的 Tensor 元資料(大小/步幅/dtype/裝置/等)。
test_aot_dispatch_dynamic: 如果運算元在 PyTorch 編譯 API(torch.compile/export/FX)下表現正確。這會檢查在 eager 模式 PyTorch 和 torch.compile 下的輸出(以及適用的梯度)是否相同。此測試是 `test_faketensor` 的超集,並且是一個端到端測試;它還測試運算元是否支援功能化,以及反向傳播(如果存在)是否也支援 FakeTensor 和功能化。
為了獲得最佳結果,請使用代表性輸入集多次呼叫 `opcheck`。如果您的運算元支援 autograd,請使用 `requires_grad = True` 的輸入呼叫 `opcheck`;如果您的運算元支援多個裝置(例如 CPU 和 CUDA),請在所有支援的裝置上使用輸入呼叫 `opcheck`。
- 引數
op (Union[OpOverload, OpOverloadPacket, CustomOpDef]) – 運算元。必須是使用
torch.library.custom_op()裝飾的函式,或者是 torch.ops.* 中的 OpOverload/OpOverloadPacket(例如,torch.ops.aten.sin, torch.ops.mylib.foo)test_utils (Union[str, Sequence[str]]) – 我們應該執行的測試。預設值:全部。示例:("test_schema", "test_faketensor")
raise_exception (bool) – 我們是否應該在第一次出錯時丟擲異常。如果為 False,我們將返回一個包含每個測試是否透過的資訊的字典。
rtol (Optional[float]) – 浮點數比較的相對容差。如果指定,則
atol也必須指定。如果省略,則選擇基於 `dtype` 的預設值(請參閱torch.testing.assert_close()中的表格)。atol (Optional[float]) – 浮點數比較的絕對容差。如果指定,則
rtol也必須指定。如果省略,則選擇基於 `dtype` 的預設值(請參閱torch.testing.assert_close()中的表格)。
- 返回型別
警告
`opcheck` 和 `torch.autograd.gradcheck` 測試的內容不同;`opcheck` 測試您對 torch.library API 的使用是否正確,而 `torch.autograd.gradcheck` 測試您的 autograd 公式是否在數學上是正確的。對於支援梯度計算的自定義運算元,請同時使用兩者進行測試。
示例
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: float) -> Tensor: >>> x_np = x.numpy(force=True) >>> z_np = x_np * y >>> return torch.from_numpy(z_np).to(x.device) >>> >>> @numpy_mul.register_fake >>> def _(x, y): >>> return torch.empty_like(x) >>> >>> def setup_context(ctx, inputs, output): >>> y, = inputs >>> ctx.y = y >>> >>> def backward(ctx, grad): >>> return grad * ctx.y, None >>> >>> numpy_mul.register_autograd(backward, setup_context=setup_context) >>> >>> sample_inputs = [ >>> (torch.randn(3), 3.14), >>> (torch.randn(2, 3, device='cuda'), 2.718), >>> (torch.randn(1, 10, requires_grad=True), 1.234), >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), >>> ] >>> >>> for args in sample_inputs: >>> torch.library.opcheck(numpy_mul, args)
在 Python 中建立新的自定義運算元#
使用 torch.library.custom_op() 來建立新的自定義運算元。
- torch.library.custom_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None, tags=None)[source]#
將一個函式包裝成自定義運算元。
您可能希望建立自定義運算元的一些原因包括:- 包裝第三方庫或自定義核心以使其與 Autograd 等 PyTorch 子系統協同工作。- 防止 torch.compile/export/FX 跟蹤窺探您的函式內部。
此 API 用作函式上方的裝飾器(請參閱示例)。提供的函式必須具有型別提示;這些對於與 PyTorch 的各種子系統進行介面是必需的。
- 引數
name (str) – 自定義運算元的名稱,格式為“{namespace}::{name}”,例如,“mylib::my_linear”。該名稱用作運算元在 PyTorch 子系統(例如,torch.export、FX 圖)中的穩定識別符號。為避免名稱衝突,請使用您的專案名稱作為名稱空間;例如,pytorch/fbgemm 中的所有自定義運算元都使用“fbgemm”作為名稱空間。
mutates_args (Iterable[str] or "unknown") – 函式修改的引數名稱。這必須是準確的,否則行為將是未定義的。如果為“unknown”,則悲觀地假定運算元的所有輸入都被修改。
device_types (None | str | Sequence[str]) – 函式有效的裝置型別。如果未提供裝置型別,則該函式用作所有裝置型別的預設實現。示例:“cpu”、“cuda”。當為接受無 Tensor 的運算元註冊特定裝置實現時,我們要求該運算元具有“device: torch.device 引數”。
schema (None | str) – 運算元的模式字串。如果為 None(推薦),我們將從其型別註解中推斷運算元的模式。我們建議讓 PyTorch 推斷模式,除非您有特定原因不這樣做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。
- 返回型別
Union[Callable[[Callable[[…], object]], CustomOpDef], CustomOpDef]
注意
我們建議不要傳遞 `schema` 引數,而是讓 PyTorch 從型別註解中推斷它。編寫自己的模式容易出錯。如果您希望的模式與我們從型別註解推斷的模式不同,您可能希望提供自己的模式。有關如何編寫模式字串的更多資訊,請參閱 此處。
- 示例:
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> @custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x = torch.randn(3) >>> y = numpy_sin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that only works for one device type. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") >>> def numpy_sin_cpu(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> x = torch.randn(3) >>> y = numpy_sin_cpu(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example of a custom op that mutates an input >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") >>> def numpy_sin_inplace(x: Tensor) -> None: >>> x_np = x.numpy() >>> np.sin(x_np, out=x_np) >>> >>> x = torch.randn(3) >>> expected = x.sin() >>> numpy_sin_inplace(x) >>> assert torch.allclose(x, expected) >>> >>> # Example of a factory function >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") >>> def bar(device: torch.device) -> Tensor: >>> return torch.ones(3) >>> >>> bar("cpu")
- torch.library.triton_op(name, fn=None, /, *, mutates_args, schema=None)[source]#
建立一個實現由 1 個或多個 triton 核心支援的自定義運算元。
這是使用 triton 核心與 PyTorch 的一種更結構化的方法。優先使用沒有 `torch.library` 自定義運算元包裝器(如 `torch.library.custom_op()`、`torch.library.triton_op()`)的 triton 核心,因為那樣更簡單;僅當您想建立一個行為類似於 PyTorch 內建運算元的運算元時,才使用 `torch.library.custom_op()`/`torch.library.triton_op()`。例如,您可以使用 `torch.library` 包裝器 API 來定義 triton 核心在傳入張量子類或 TorchDispatchMode 下的行為。
當實現由 1 個或多個 triton 核心組成時,請使用 `torch.library.triton_op()` 而不是 `torch.library.custom_op()`。`torch.library.custom_op()` 將自定義運算元視為不透明(`torch.compile()` 和 `torch.export.export()` 永遠不會跟蹤它們),但 `triton_op` 使這些子系統可以看到實現,從而允許它們最佳化 triton 核心。
請注意,`fn` 只能由 PyTorch 可理解的運算元和 triton 核心組成。在 `fn` 中呼叫的任何 triton 核心都必須包裝在對 `torch.library.wrap_triton()` 的呼叫中。
- 引數
name (str) – 自定義運算元的名稱,格式為“{namespace}::{name}”,例如,“mylib::my_linear”。該名稱用作運算元在 PyTorch 子系統(例如,torch.export、FX 圖)中的穩定識別符號。為避免名稱衝突,請使用您的專案名稱作為名稱空間;例如,pytorch/fbgemm 中的所有自定義運算元都使用“fbgemm”作為名稱空間。
mutates_args (Iterable[str] or "unknown") – 函式修改的引數名稱。這必須是準確的,否則行為將是未定義的。如果為“unknown”,則悲觀地假定運算元的所有輸入都被修改。
schema (None | str) – 運算元的模式字串。如果為 None(推薦),我們將從其型別註解中推斷運算元的模式。我們建議讓 PyTorch 推斷模式,除非您有特定原因不這樣做。示例:“(Tensor x, int y) -> (Tensor, Tensor)”。
- 返回型別
示例
>>> import torch >>> from torch.library import triton_op, wrap_triton >>> >>> import triton >>> from triton import language as tl >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> @triton_op("mylib::add", mutates_args={}) >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> # NB: we need to wrap the triton kernel in a call to wrap_triton >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) >>> return output >>> >>> @torch.compile >>> def f(x, y): >>> return add(x, y) >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> >>> z = f(x, y) >>> assert torch.allclose(z, x + y)
- torch.library.wrap_triton(triton_kernel, /)[source]#
允許透過 make_fx 或非嚴格 `torch.export` 將 triton 核心捕獲到圖中。
這些技術執行基於 Dispatcher 的跟蹤(透過 `__torch_dispatch__`),無法看到對原始 triton 核心的呼叫。`wrap_triton` API 將 triton 核心包裝到一個可呼叫的物件中,該物件實際上可以被跟蹤到圖中。
請將此 API 與 `torch.library.triton_op()` 一起使用。
示例
>>> import torch >>> import triton >>> from triton import language as tl >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.library import wrap_triton >>> >>> @triton.jit >>> def add_kernel( >>> in_ptr0, >>> in_ptr1, >>> out_ptr, >>> n_elements, >>> BLOCK_SIZE: "tl.constexpr", >>> ): >>> pid = tl.program_id(axis=0) >>> block_start = pid * BLOCK_SIZE >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> x = tl.load(in_ptr0 + offsets, mask=mask) >>> y = tl.load(in_ptr1 + offsets, mask=mask) >>> output = x + y >>> tl.store(out_ptr + offsets, output, mask=mask) >>> >>> def add(x, y): >>> output = torch.empty_like(x) >>> n_elements = output.numel() >>> >>> def grid_fn(meta): >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) >>> >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) >>> return output >>> >>> x = torch.randn(3, device="cuda") >>> y = torch.randn(3, device="cuda") >>> gm = make_fx(add)(x, y) >>> print(gm.code) >>> # def forward(self, x_1, y_1): >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( >>> # kernel_idx = 0, constant_args_idx = 0, >>> # grid = [(1, 1, 1)], kwargs = { >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 >>> # }) >>> # return empty_like
- 返回型別
擴充套件自定義運算元(由 Python 或 C++ 建立)#
使用 `register.*` 方法,例如 `torch.library.register_kernel()` 和 `torch.library.register_fake()`,為任何運算元(它們可能是使用 `torch.library.custom_op()` 或透過 PyTorch 的 C++ 運算元註冊 API 建立的)新增實現。
- torch.library.register_kernel(op, device_types, func=None, /, *, lib=None)[source]#
為該運算元的特定裝置型別註冊一個實現。
一些有效的 device_types 是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。此 API 可用作裝飾器。
- 引數
- 示例:
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> import numpy as np >>> >>> # Create a custom op that works on cpu >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np) >>> >>> # Add implementations for the cuda device >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") >>> def _(x): >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> x_cpu = torch.randn(3) >>> x_cuda = x_cpu.cuda() >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- torch.library.register_autocast(op, device_type, cast_inputs, /, *, lib=None)[source]#
為該自定義運算元註冊一個自動強制轉換分派規則。
有效的 `device_type` 包括:“cpu”和“cuda”。
- 引數
op (str | OpOverload) – 要註冊自動強制轉換分派規則的運算元。
device_type (str) – 要使用的裝置型別。“cuda”或“cpu”。該型別與 `torch.device` 的 `type` 屬性相同。因此,您可以使用 `Tensor.device.type` 來獲取張量的裝置型別。
cast_inputs (
torch.dtype) – 當自定義運算元在啟用了自動強制轉換的區域執行時,將輸入的浮點 Tensor 強制轉換為目標 dtype(非浮點 Tensor 不受影響),然後停用自動強制轉換來執行自定義運算元。lib (Optional[Library]) – 如果提供,此註冊的生命週期
- 示例:
>>> import torch >>> from torch import Tensor >>> from torch.library import custom_op >>> >>> # Create a custom op that works on cuda >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) >>> def my_sin(x: Tensor) -> Tensor: >>> return torch.sin(x) >>> >>> # Register autocast dispatch rule for the cuda device >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) >>> >>> x = torch.randn(3, dtype=torch.float32, device="cuda") >>> with torch.autocast("cuda", dtype=torch.float16): >>> y = torch.ops.mylib.my_sin(x) >>> assert y.dtype == torch.float16
- torch.library.register_autograd(op, backward, /, *, setup_context=None, lib=None)[source]#
為該自定義運算元註冊一個反向傳播公式。
為了讓運算元能夠與 autograd 一起工作,您需要註冊一個反向傳播公式:1.您必須透過提供一個“backward”函式來告訴 PyTorch 如何在反向傳播過程中計算梯度。2.如果您在計算梯度時需要正向傳播的任何值,可以使用 `setup_context` 來為反向傳播儲存值。
`backward` 在反向傳播過程中執行。它接受 `(ctx, *grads)`:- `grads` 是一個或多個梯度。梯度的數量與運算元的輸出數量相匹配。`ctx` 物件與 `torch.autograd.Function` 使用的 `ctx` 物件相同。`backward_fn` 的語義與 `torch.autograd.Function.backward()` 相同。
`setup_context(ctx, inputs, output)` 在正向傳播過程中執行。請透過 `torch.autograd.function.FunctionCtx.save_for_backward()` 或將它們作為 `ctx` 的屬性來將反向傳播所需的值儲存到 `ctx` 物件上。如果您的自定義運算元有僅關鍵字引數,我們期望 `setup_context` 的簽名是 `setup_context(ctx, inputs, keyword_only_inputs, output)`。
`setup_context_fn` 和 `backward_fn` 都必須是可跟蹤的。也就是說,它們不能直接訪問 `torch.Tensor.data_ptr()`,並且它們不能依賴或修改全域性狀態。如果您需要一個不可跟蹤的反向傳播,您可以將其製作成一個單獨的自定義運算元,並在 `backward_fn` 中呼叫它。
如果您需要在不同裝置上具有不同的 autograd 行為,那麼我們建議建立兩個不同的自定義運算元,一個用於需要不同行為的裝置,並在執行時在它們之間切換。
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) >>> def numpy_sin(x: Tensor) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = np.sin(x_np) >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, output) -> Tensor: >>> x, = inputs >>> ctx.save_for_backward(x) >>> >>> def backward(ctx, grad): >>> x, = ctx.saved_tensors >>> return grad * x.cos() >>> >>> torch.library.register_autograd( ... "mylib::numpy_sin", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: >>> x_np = x.cpu().numpy() >>> y_np = x_np * val >>> return torch.from_numpy(y_np).to(device=x.device) >>> >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: >>> ctx.val = keyword_only_inputs["val"] >>> >>> def backward(ctx, grad): >>> return grad * ctx.val >>> >>> torch.library.register_autograd( ... "mylib::numpy_mul", backward, setup_context=setup_context ... ) >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- torch.library.register_fake(op, func=None, /, *, lib=None, _stacklevel=1, allow_override=False)[source]#
為該運算元註冊一個 FakeTensor 實現(“fake impl”)。
也稱為“meta kernel”、“abstract impl”。
“FakeTensor 實現”指定了該運算元在不包含資料的 Tensor(“FakeTensor”)上的行為。給定具有特定屬性(大小/步幅/儲存偏移量/裝置)的輸入 Tensor,它指定輸出 Tensor 的屬性。
FakeTensor 實現具有與運算元相同的簽名。它同時用於 FakeTensor 和 meta Tensor。要編寫 FakeTensor 實現,請假定運算元的所有 Tensor 輸入都是常規的 CPU/CUDA/Meta Tensor,但它們沒有儲存,並且您試圖返回常規的 CPU/CUDA/Meta Tensor 作為輸出。FakeTensor 實現只能由 PyTorch 操作組成(並且不能直接訪問任何輸入或中間 Tensor 的儲存或資料)。
此 API 可用作裝飾器(參見示例)。
有關自定義運算元的詳細指南,請參閱 https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html
- 引數
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: >>> raise NotImplementedError("Implementation goes here") >>> >>> @torch.library.register_fake("mylib::custom_linear") >>> def _(x, weight, bias): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> with torch._subclasses.fake_tensor.FakeTensorMode(): >>> x = torch.randn(2, 3) >>> w = torch.randn(3, 3) >>> b = torch.randn(3) >>> y = torch.ops.mylib.custom_linear(x, w, b) >>> >>> assert y.shape == (2, 3) >>> >>> # Example 2: an operator with data-dependent output shape >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) >>> def custom_nonzero(x: Tensor) -> Tensor: >>> x_np = x.numpy(force=True) >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) >>> >>> @torch.library.register_fake("mylib::custom_nonzero") >>> def _(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an fake impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> from torch.fx.experimental.proxy_tensor import make_fx >>> >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) >>> trace.print_readable() >>> >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
- torch.library.register_vmap(op, func=None, /, *, lib=None)[source]#
註冊一個 vmap 實現以支援該自定義運算元的 `torch.vmap()`。
此 API 可用作裝飾器(參見示例)。
為了讓運算元能夠與 `torch.vmap()` 一起工作,您可能需要註冊一個具有以下簽名的 vmap 實現:
vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs),其中 `*args` 和 `**kwargs` 是 `op` 的引數和關鍵字引數。我們不支援僅關鍵字引數的 Tensor 引數。
它指定了如何計算 `op` 的批次版本,給定輸入和額外的維度(由 `in_dims` 指定)。
對於 `args` 中的每個引數,`in_dims` 都有一個對應的 `Optional[int]`。如果引數不是 Tensor 或未對引數進行 vmap,則為 `None`,否則它是一個指定 Tensor 的哪個維度正在被 vmap 的整數。
`info` 是可能有助於此的附加元資料的集合:`info.batch_size` 指定被 vmap 的維度的尺寸,而 `info.randomness` 是傳遞給 `torch.vmap()` 的 `randomness` 選項。
函式 `func` 的返回值為 `(output, out_dims)` 元組。類似於 `in_dims`,`out_dims` 的結構應與 `output` 相同,並且每個輸出都包含一個 `out_dim`,指定輸出是否具有 vmap 維度以及它在其中的索引。
示例
>>> import torch >>> import numpy as np >>> from torch import Tensor >>> from typing import Tuple >>> >>> def to_numpy(tensor): >>> return tensor.cpu().numpy() >>> >>> lib = torch.library.Library("mylib", "FRAGMENT") >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: >>> x_np = to_numpy(x) >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) >>> return torch.tensor(x_np ** 3, device=x.device), dx >>> >>> def numpy_cube_vmap(info, in_dims, x): >>> result = numpy_cube(x) >>> return result, (in_dims[0], in_dims[0]) >>> >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) >>> >>> x = torch.randn(3) >>> torch.vmap(numpy_cube)(x) >>> >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) >>> >>> @torch.library.register_vmap("mylib::numpy_mul") >>> def numpy_mul_vmap(info, in_dims, x, y): >>> x_bdim, y_bdim = in_dims >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) >>> result = x * y >>> result = result.movedim(-1, 0) >>> return result, 0 >>> >>> >>> x = torch.randn(3) >>> y = torch.randn(3) >>> torch.vmap(numpy_mul)(x, y)
注意
vmap 函式應旨在保留整個自定義運算元的語義。也就是說,`grad(vmap(op))` 應可替換為 `grad(map(op))`。
如果您的自定義運算元在反向傳播過程中有任何自定義行為,請牢記這一點。
- torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1)[source]#
此 API 在 PyTorch 2.4 中重新命名為 `torch.library.register_fake()`。請改用該 API。
- torch.library.get_ctx()[source]#
`get_ctx()` 返回當前的 AbstractImplCtx 物件。
呼叫 `get_ctx()` 僅在 fake impl 內部有效(有關更多用法詳細資訊,請參見 `torch.library.register_fake()`)。
- 返回型別
FakeImplCtx
- torch.library.register_torch_dispatch(op, torch_dispatch_class, func=None, /, *, lib=None)[source]#
為給定的運算元和 `torch_dispatch_class` 註冊一個 torch_dispatch 規則。
這允許開放式註冊來指定運算元與 `torch_dispatch_class` 之間的行為,而無需直接修改 `torch_dispatch_class` 或運算元。
`torch_dispatch_class` 要麼是具有 `__torch_dispatch__` 的 Tensor 子類,要麼是 TorchDispatchMode。
如果是 Tensor 子類,我們期望 `func` 具有以下簽名:` (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`
如果是 TorchDispatchMode,我們期望 `func` 具有以下簽名:` (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`
`args` 和 `kwargs` 將已在 `__torch_dispatch__` 中以相同方式進行規範化(請參閱 __torch_dispatch__ 呼叫約定)。
示例
>>> import torch >>> >>> @torch.library.custom_op("mylib::foo", mutates_args={}) >>> def foo(x: torch.Tensor) -> torch.Tensor: >>> return x.clone() >>> >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): >>> return func(*args, **kwargs) >>> >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) >>> def _(mode, func, types, args, kwargs): >>> x, = args >>> return x + 1 >>> >>> x = torch.randn(3) >>> y = foo(x) >>> assert torch.allclose(y, x) >>> >>> with MyMode(): >>> y = foo(x) >>> assert torch.allclose(y, x + 1)
- torch.library.infer_schema(prototype_function, /, *, mutates_args, op_name=None)[source]#
解析給定函式(帶型別提示)的模式。模式從函式的型別提示中推斷出來,並可用於定義新運算元。
我們做出以下假設:
無輸出會別名任何輸入或彼此。
- 字串型別註解“device, dtype, Tensor, types”,沒有庫規範,將被假定為 torch.*。類似地,字串型別註解“Optional, List, Sequence, Union”沒有庫規範,將被假定為 typing.*。
- 只有 `mutates_args` 中列出的引數才會被修改。如果 `mutates_args` 為“unknown”,則假定運算元的所有輸入都被修改。
呼叫者(例如,自定義運算元 API)負責檢查這些假設。
- 引數
- 返回
推斷出的模式。
- 返回型別
示例
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor: >>> return x.sin() >>> >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) foo(Tensor x) -> Tensor >>> >>> infer_schema(foo_impl, mutates_args={}) (Tensor x) -> Tensor
- class torch._library.custom_ops.CustomOpDef(namespace, name, schema, fn, tags=None)[source]#
`CustomOpDef` 是一個函式包裝器,它將函式轉換為自定義運算元。
它具有各種方法來為此自定義運算元註冊附加行為。
您不應直接例項化 `CustomOpDef`;而是使用 `torch.library.custom_op()` API。
- set_kernel_enabled(device_type, enabled=True)[source]#
停用或重新啟用此自定義運算元已註冊的核心。
如果核心已停用/啟用,則此操作無效。
注意
如果核心先被停用然後註冊,則它一直處於停用狀態,直到再次啟用。
示例
>>> inp = torch.randn(1) >>> >>> # define custom op `f`. >>> @custom_op("mylib::f", mutates_args=()) >>> def f(x: Tensor) -> Tensor: >>> return torch.zeros(1) >>> >>> print(f(inp)) # tensor([0.]), default kernel >>> >>> @f.register_kernel("cpu") >>> def _(x): >>> return torch.ones(1) >>> >>> print(f(inp)) # tensor([1.]), CPU kernel >>> >>> # temporarily disable the CPU kernel >>> with f.set_kernel_enabled("cpu", enabled = False): >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
- torch.library.get_kernel(op, dispatch_key)[source]#
返回給定運算元和分派鍵的計算出的核心。
此函式檢索將為給定的運算元和分派鍵組合執行的核心。返回的 SafeKernelFunction 可用於以打包方式呼叫核心。此函式 intended 用例是檢索給定分派鍵的原始核心,然後向同一分派鍵註冊另一個核心,該核心在某些情況下會呼叫原始核心。
- 引數
op (Union[str, OpOverload, CustomOpDef]) – 運算元名稱(連同過載)或 OpOverload 物件。可以是字串(例如,“aten::add.Tensor”)、OpOverload 或 CustomOpDef。
dispatch_key (str | torch.DispatchKey) – 要獲取核心的分派鍵。可以是字串(例如,“CPU”、“CUDA”)或 DispatchKey 列舉值。
- 返回
- 一個安全的核心函式,可用於
呼叫核心。
- 返回型別
torch._C._SafeKernelFunction
- 引發
RuntimeError – 如果運算元不存在。
示例
>>> # Get the CPU kernel for torch.add >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU") >>> >>> # You can also use DispatchKey enum >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU) >>> >>> # Or use an OpOverload directly >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") >>> >>> # Example: Using get_kernel in a custom op with conditional dispatch >>> # Get the original kernel for torch.sin >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU") >>> >>> # If input has negative values, use original sin, otherwise return zeros >>> def conditional_sin_impl(dispatch_keys, x): >>> if (x < 0).any(): >>> return original_sin_kernel.call_boxed(dispatch_keys, x) >>> else: >>> return torch.zeros_like(x) >>> >>> lib = torch.library.Library("aten", "IMPL") >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet >>> which needs to be the first argument to ``kernel.call_boxed`` >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True) >>> >>> # Test the conditional behavior >>> x_positive = torch.tensor([1.0, 2.0]) >>> x_mixed = torch.tensor([-1.0, 2.0]) >>> torch.sin(x_positive) tensor([0., 0.]) >>> torch.sin(x_mixed) tensor([-0.8415, 0.9093])
底層 API#
以下 API 是 PyTorch 底層運算元註冊 API 的直接繫結。
警告
底層運算元註冊 API 和 PyTorch Dispatcher 是一個複雜的 PyTorch 概念。我們建議您在可能的情況下使用上面的更高級別的 API(不需要 torch.library.Library 物件)。這篇博文是瞭解 PyTorch Dispatcher 的一個好起點。
本教程在 Google Colab 上提供了關於如何使用此 API 的一些示例。
- class torch.library.Library(ns, kind, dispatch_key='')[source]#
一個類,用於建立可用於在 Python 中註冊新運算元或覆蓋現有庫中的運算元的庫。使用者可以選擇傳入一個分派鍵名稱,如果他們只想註冊對應於單個特定分派鍵的核心。
要建立一個用於覆蓋現有庫(名稱為 ns)中的運算元的庫,請將 kind 設定為“IMPL”。要建立一個新庫(名稱為 ns)來註冊新運算元,請將 kind 設定為“DEF”。要建立一個可能存在的庫的片段來註冊運算元(並繞過一個名稱空間只有一個庫的限制),請將 kind 設定為“FRAGMENT”。
- 引數
ns – 庫名稱
kind – “DEF”、“IMPL”、“FRAGMENT”
dispatch_key – PyTorch 分派鍵(預設為“”)
- define(schema, alias_analysis='', *, tags=())[source]#
在 ns 名稱空間中定義新運算元及其語義。
- 引數
- 返回
從模式推斷出的運算元名稱。
示例
>>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- fallback(fn, dispatch_key='', *, with_keyset=False)[source]#
將函式實現註冊為給定鍵的後備。
此函式僅適用於具有全域性名稱空間(“_”)的庫。
- 引數
fn – 用作給定分派鍵的後備的函式,或 `torch.library.fallthrough_kernel()` 以註冊一個後落。
dispatch_key – 應為輸入函式註冊的分派鍵。預設情況下,它使用建立庫時使用的分派鍵。
with_keyset – 控制在呼叫 `fn` 時當前分派器呼叫鍵集是否應作為第一個引數傳遞的標誌。這應該用於建立鍵集的適當方法以進行重分派呼叫。
示例
>>> my_lib = Library("_", "IMPL") >>> def fallback_kernel(op, *args, **kwargs): >>> # Handle all autocast ops generically >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast")
- impl(op_name, fn, dispatch_key='', *, with_keyset=False, allow_override=False)[source]#
為在庫中定義的運算元註冊函式實現。
- 引數
op_name – 運算元名稱(連同過載)或 OpOverload 物件。
fn – 作為輸入分派鍵的運算元實現,或 `fallthrough_kernel()` 以註冊後落的函式。
dispatch_key – 應為輸入函式註冊的分派鍵。預設情況下,它使用建立庫時使用的分派鍵。
with_keyset – 控制在呼叫 `fn` 時當前分派器呼叫鍵集是否應作為第一個引數傳遞的標誌。這應該用於建立鍵集的適當方法以進行重分派呼叫。
allow_override – 控制是否覆蓋已註冊核心實現的標誌。預設情況下,此標誌處於關閉狀態,並且如果您嘗試向已註冊核心的排程鍵註冊核心,則會報錯。
示例
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- torch.library.define(qualname, schema, *, lib=None, tags=())[source]#
- torch.library.define(lib, schema, alias_analysis='')
定義一個新運算元。
在 PyTorch 中,定義一個 op(“operator”的縮寫)是一個兩步過程:- 我們需要定義 op(透過提供運算元名稱和模式)- 我們需要實現運算元如何與各種 PyTorch 子系統(如 CPU/CUDA Tensor、Autograd 等)互動的行為。
此入口點定義自定義運算元(第一步),您必須透過呼叫各種 `impl_*` API(如 `torch.library.impl()` 或 `torch.library.register_fake()`)來執行第二步。
- 引數
qualname (str) – 運算元的限定名稱。應為格式為“namespace::name”的字串,例如“aten::sin”。PyTorch 中的運算元需要一個名稱空間來避免名稱衝突;一個給定的運算元只能建立一次。如果您正在編寫 Python 庫,我們建議名稱空間為您的頂級模組的名稱。
schema (str) – 運算元的模式。例如,“(Tensor x) -> Tensor”表示接受一個 Tensor 並返回一個 Tensor 的 op。它不包含運算元名稱(該名稱在 `qualname` 中傳遞)。
lib (Optional[Library]) – 如果提供,此運算元的生命週期將與 Library 物件的生命週期繫結。
tags (Tag | Sequence[Tag]) – 應用於此運算元的一個或多個 torch.Tag。標記運算元會改變運算元在各種 PyTorch 子系統下的行為;請在應用之前仔細閱讀 torch.Tag 的文件。
- 示例:
>>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x.sin())
- torch.library.impl(lib, name, dispatch_key='')[source]#
- torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Literal[None] = None, *, lib: Optional[Library] = None) Callable[[Callable[..., object]], None]
- torch.library.impl(qualname: str, types: Union[str, Sequence[str]], func: Callable[..., object], *, lib: Optional[Library] = None) None
- torch.library.impl(lib: Library, name: str, dispatch_key: str = '') Callable[[Callable[_P, _T]], Callable[_P, _T]]
為該運算元的特定裝置型別註冊一個實現。
您可以為 `types` 傳遞“default”以將此實現註冊為所有裝置型別的預設實現。請僅在實現真正支援所有裝置型別時才使用此選項;例如,當它是內建 PyTorch 運算元的組合時,這是真的。
此 API 可用作裝飾器。您可以使用巢狀裝飾器與此 API,前提是它們返回一個函式並放置在此 API 內部(參見示例 2)。
一些有效的型別是:“cpu”、“cuda”、“xla”、“mps”、“ipu”、“xpu”。
- 引數
示例
>>> import torch >>> import numpy as np >>> # Example 1: Register function. >>> # Define the operator >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylib::mysin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylib.mysin(x) >>> assert torch.allclose(y, x.sin()) >>> >>> # Example 2: Register function with decorator. >>> def custom_decorator(func): >>> def wrapper(*args, **kwargs): >>> return func(*args, **kwargs) + 1 >>> return wrapper >>> >>> # Define the operator >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylib::sin_plus_one", "cpu") >>> @custom_decorator >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> >>> y1 = torch.ops.mylib.sin_plus_one(x) >>> y2 = torch.sin(x) + 1 >>> assert torch.allclose(y1, y2)