recipes/torch_compile_torch_function_modes
在 Google Colab 中執行
Colab
下載 Notebook
Notebook
在 GitHub 上檢視
GitHub
注意
轉到底部 下載完整的示例程式碼。
(beta) 將 Torch Function 模式與 torch.compile 結合使用#
作者: Michael Lazos
- 本教程介紹瞭如何使用一個關鍵的 Torch 擴充套件點,
Torch Function 模式,與
torch.compile結合使用,在跟蹤時覆蓋 Torch 運算元(也稱為 **ops**)的行為,且沒有執行時開銷。
注意
本教程要求 PyTorch 2.7.0 或更高版本。
重寫一個 Torch op (torch.add -> torch.mul)#
在這個示例中,我們將使用 Torch Function 模式將加法運算替換為乘法運算。這種覆蓋在某個後端為給定 op 提供了自定義實現時很常見。
import torch
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
from torch.overrides import BaseTorchFunctionMode
# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul
return super().__torch_function__(func, types, args, kwargs)
@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
with AddToMultiplyMode():
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
# The mode can also be used within the compiled region as well like this:
@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
結論#
在本教程中,我們演示瞭如何從 torch.compile 內部使用 Torch Function 模式來覆蓋 torch.* 運算元的行為。這使得使用者可以在不產生每次呼叫 op 時呼叫 Torch Function 的執行時開銷的情況下,利用 Torch Function 模式的擴充套件性優勢。
有關 Torch Function 模式的其他示例和背景資訊,請參閱 使用模式擴充套件 Torch API。
指令碼總執行時間: (0 分鐘 10.322 秒)