擴充套件 PyTorch#
建立時間: 2017 年 1 月 16 日 | 最後更新時間: 2025 年 5 月 7 日
在本指南中,我們將介紹擴充套件 torch.nn、torch.autograd、torch 以及編寫自定義 C++ 擴充套件的方法。
新增新運算子#
PyTorch 提供了一個大型的運算子庫,這些運算子可用於張量(例如 torch.add()、torch.sum() 等)。然而,您可能希望為 PyTorch 新增自定義操作,並使其行為類似於 PyTorch 的內建運算子。要實現這一點,您必須透過 Python torch.library 或 C++ TORCH_LIBRARY API 將自定義操作註冊到 PyTorch。
有關更多詳細資訊,請參閱 PyTorch 自定義運算子著陸頁。
擴充套件 torch.autograd#
向 autograd 新增操作需要為每個操作實現一個新的 Function 子類。請注意,Functions 是 autograd 用來編碼操作歷史和計算梯度的內容。
本指南的第一個部分側重於後向模式 AD,因為它是最廣泛使用的功能。結尾部分討論了前向模式 AD 的擴充套件。
何時使用#
通常,如果您想在模型中執行不可微分的計算,或者依賴非 PyTorch 庫(例如 NumPy),但仍希望您的操作能與其他操作連結並與 autograd 引擎協同工作,則應實現自定義函式。
在某些情況下,自定義函式還可以用於提高效能和記憶體使用量:如果您使用 C++ 擴充套件 實現前向和後向傳遞,您可以將其包裝在 Function 中以與 autograd 引擎進行互動。如果您想減少後向傳遞所需的緩衝區數量,可以使用自定義函式將操作組合在一起。
何時不使用#
如果您已經能夠用 PyTorch 的內建操作編寫函式,那麼 autograd (很可能)已經能夠記錄其後向圖。在這種情況下,您無需自己實現後向函式。請考慮使用普通的 Python 函式。
如果您需要維護狀態(即可訓練引數),則應(也)使用自定義模組。有關擴充套件 torch.nn 的更多資訊,請參閱下面的部分。
如何使用#
請執行以下步驟:1. 繼承 Function 並實現 forward()、(可選)setup_context() 和 backward() 方法。2. 呼叫 ctx 引數上的適當方法。3. 宣告您的函式是否支援 雙後向。4. 使用 gradcheck 驗證您的梯度是否正確。
步驟 1: 繼承 Function 後,您需要定義 3 個方法
forward()是執行操作的程式碼。它可以接受任意數量的引數,其中一些引數是可選的(如果指定了預設值)。這裡接受所有種類的 Python 物件。跟蹤歷史的Tensor引數(即requires_grad=True)將在呼叫前被轉換為不跟蹤歷史的張量,並且它們的使用將被記錄在圖中。請注意,此邏輯不會遍歷列表/字典/任何其他資料結構,只會考慮直接作為呼叫引數的張量。您可以返回單個Tensor輸出,或者返回張量tuple(如果有多個輸出)。此外,請參考Function的文件,查詢可以僅從forward()呼叫的一些有用的方法。setup_context()(可選)。您可以編寫一個“合併”的forward(),它接受一個ctx物件,或者(從 PyTorch 2.0 開始)一個不接受ctx的獨立forward()和一個setup_context()方法,其中ctx的修改會發生。forward()應該包含計算邏輯,而setup_context()應該只負責ctx的修改(不包含任何計算)。通常,獨立的forward()和setup_context()更接近 PyTorch 原生操作的工作方式,因此更易於與各種 PyTorch 子系統組合。有關更多詳細資訊,請參閱 合併或分離 forward() 和 setup_context()。backward()(或vjp())定義了梯度公式。它將接收與輸出數量相同的Tensor引數,每個引數代表相對於該輸出的梯度。絕對不要就地修改這些引數,這一點很重要。它應該返回與輸入數量相同的張量,每個張量包含相對於其對應輸入的梯度。如果您的輸入不需要梯度(needs_input_grad是一個布林元組,指示每個輸入是否需要梯度計算),或者是非Tensor物件,您可以返回python:None。此外,如果forward()有可選引數,您可以返回比輸入更多的梯度,只要它們都是None。
步驟 2: 您有責任正確使用 ctx 中的函式,以確保新的 Function 能與 autograd 引擎正常工作。
save_for_backward()應用於儲存後向傳遞所需的任何張量(而不是直接儲存在ctx上)。您不能對非張量使用save_for_backward;您應該直接將它們儲存在ctx上。透過
save_for_backward儲存張量:1. 允許 autograd 引擎在autograd.Function的後向計算完成後立即將其清除。(如果張量直接儲存在ctx上,它將不必要地保留到 autograd 圖的生命週期結束——通常是到迭代結束。)2. 有助於避免某些引用迴圈(例如,由於autograd.Function的張量輸出本身會保留對 ctx 的引用)。3. 對於像啟用檢查點和解除安裝這樣的功能很重要,這些功能依賴於torch.autograd.graph.saved_tensors_hooks。如果儲存的張量既不是輸入也不是輸出,那麼您的
Function可能不支援雙後向(請參見步驟 3)。mark_dirty()必須用於標記前向函式就地修改的任何輸入。mark_non_differentiable()必須用於告知引擎輸出是否可微分。預設情況下,所有可微分型別的輸出張量都將被設定為要求梯度。不可微分型別的張量(即整數型別)永遠不會被標記為需要梯度。set_materialize_grads()可用於告訴 autograd 引擎在輸出不依賴於輸入的情況下最佳化梯度計算,方法是不具體化傳遞給後向函式的 grad 張量。也就是說,如果設定為 False,Python 中的 None 物件或 C++ 中的“未定義張量”(x.defined() 為 False 的張量 x)將不會在呼叫後向之前轉換為填充零的張量,因此您的程式碼需要像處理填充零的張量一樣處理這些物件。此設定的預設值為 True。
步驟 3: 如果您的 Function 不支援雙後向,您應透過使用 once_differentiable() 裝飾器顯式宣告。使用此裝飾器,嘗試透過您的函式進行雙後向操作將產生錯誤。有關雙後向的更多資訊,請參閱我們的雙後向教程。
步驟 4: 建議您使用 torch.autograd.gradcheck() 來檢查您的後向函式是否透過使用您的後向函式計算雅可比矩陣並將值與使用有限差分法數值計算的雅可比矩陣進行逐元素比較來正確計算前向函式的梯度。
示例#
下面是 Linear 函式的程式碼,附帶了額外的註釋
# Inherit from Function
class LinearFunction(Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
現在,為了更輕鬆地使用這些自定義操作,我們建議將其別名化或將其包裝在函式中。包裝在函式中使我們能夠支援預設引數和關鍵字引數
# Option 1: alias
linear = LinearFunction.apply
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
在這裡,我們提供了另一個由非張量引數引數化的函式的示例
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
# ctx is a context object that can be used to stash information
# for backward computation
tensor, constant = inputs
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
在這裡,我們透過呼叫 set_materialize_grads(False) 來最佳化上面的示例
class MulConstant(Function):
@staticmethod
def forward(tensor, constant):
return tensor * constant
@staticmethod
def setup_context(ctx, inputs, output):
tensor, constant = inputs
ctx.set_materialize_grads(False)
ctx.constant = constant
@staticmethod
def backward(ctx, grad_output):
# Here we must handle None grad_output tensor. In this case we
# can skip unnecessary computations and just return None.
if grad_output is None:
return None, None
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
return grad_output * ctx.constant, None
如果您需要儲存前向傳遞中計算的任何“中間”張量,它們要麼必須作為輸出返回,要麼需要合併 forward 和 setup_context()(參見 合併或分離 forward() 和 setup_context())。請注意,這意味著如果您希望梯度流過這些中間值,您需要為它們定義梯度公式(另請參見 雙後向教程)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
# We wish to save dx for backward. In order to do so, it must
# be returned as an output.
dx = 3 * x ** 2
result = x ** 3
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`,
# which is grad_dx * 6 * x.
result = grad_output * dx + grad_dx * 6 * x
return result
# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
result, dx = MyCube.apply(x)
return result
注意
傳遞給 backward 的輸入,即 grad_output,也可以是跟蹤歷史的張量。因此,如果 backward 是使用可微分操作實現的(例如,呼叫另一個自定義 Function),則高階導數將正常工作。在這種情況下,使用 save_for_backward 儲存的張量也可以在後向中使用並具有反向傳播的梯度,但儲存在 ctx 中的張量將沒有反向傳播的梯度。如果您需要反向傳播張量在 ctx 中儲存,您應該將其設為自定義 Function 的輸出並使用 save_for_backward 儲存。
您可能需要檢查實現的後向方法是否實際計算了函式的導數。這可以透過將導數與使用小有限差分的數值近似進行比較來實現。
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)
有關有限差分梯度比較的更多詳細資訊,請參閱 數值梯度檢查。如果您的函式用於高階導數(對後向傳遞進行微分),您可以使用同一包中的 gradgradcheck 函式來檢查高階導數。
合併或分離 forward() 和 setup_context()#
定義 Function 有兩種主要方式:
我們推薦第二種方法(獨立的 forward() 和 setup_context()),因為這更接近 PyTorch 原生操作的實現方式,並且它能與 torch.func 變換組合。然而,我們計劃在未來同時支援這兩種方法;合併 forward() 和 setup_context():允許更大的靈活性,因為您可以在不將中間結果作為輸出返回的情況下儲存它們。
請參閱上一節,瞭解如何使用獨立的 forward() 和 setup_context() 定義 Function。
下面是一個如何定義包含合併的 forward() 和 setup_context() 的 Function 的示例
class LinearFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(ctx, input, weight, bias=None):
# The forward pass can use ctx.
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
前向模式 AD#
重寫前向模式 AD 公式具有非常相似的 API,但存在一些細微差別。您可以實現 jvp() 函式。
它將接收與輸入數量相同的 Tensor 引數,每個引數代表相對於該輸入的梯度。它應該返回與輸出數量相同的張量,每個張量包含相對於其對應輸出的梯度。 jvp() 將在 forward() 方法之後、apply() 返回之前呼叫。
jvp() 與 backward() 函式有一些細微的差別
您可以使用 ctx 將資料從
forward()傳遞給jvp()函式。如果該狀態對於backward()不是必需的,您可以透過在jvp()函式末尾執行del ctx.foo來顯式釋放它。jvp()的實現必須是後向可微分的,或者顯式檢查前向模式梯度中沒有任何一個具有requires_grad設定。jvp()函式必須匹配forward()的檢視/就地行為。例如,如果第i個輸入被就地修改,那麼第i個梯度必須被就地修改。類似地,如果第j個輸出是第k個輸入的檢視。那麼返回的第j個輸出梯度必須是給定第k個輸入梯度的檢視。由於使用者無法指定需要計算哪個梯度,
jvp()函式應始終計算所有輸出的梯度。前向模式梯度會遵守
set_materialize_grads()設定的標誌,並且當該標誌停用時,您可能會收到 None 輸入梯度。
torch.func 變換和/或 torch.vmap()#
有關詳細資訊,請參閱 使用 autograd.Function 擴充套件 torch.func。
擴充套件 torch.nn#
nn 匯出了兩種介面:模組及其函式版本。您可以以兩種方式擴充套件它,但我們建議對所有持有引數或緩衝區的層使用模組,並建議對無引數操作(如啟用函式、池化等)使用函式形式。
上面部分已完全涵蓋了新增操作的函式版本。
新增 Module#
由於 nn 大量使用了 autograd,新增一個新的 Module 需要實現一個執行操作並能計算梯度的 Function。從現在開始,我們假設我們要實現一個 Linear 模組,並且我們已經按照上面的列表實現了該函式。所需程式碼很少。現在,需要實現兩個函式:
以下是如何實現 Linear 模組
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
擴充套件 torch Python API#
您可以建立自定義型別來模擬 Tensor,方法是定義一個具有與 Tensor 匹配的方法的自定義類。但是,如果您希望能夠將這些型別傳遞給 torch 頂級名稱空間中接受 Tensor 運算元的函式(如 torch.add())呢?
如果您的自定義 Python 型別定義了一個名為 __torch_function__ 的方法,當您的自定義類的例項傳遞給 torch 名稱空間中的函式時,PyTorch 將呼叫您的 __torch_function__ 實現。這使得您可以為 torch 名稱空間中的任何函式定義自定義實現,您的 __torch_function__ 實現可以呼叫它們,從而允許您的使用者使用他們已為 Tensor 編寫的現有 PyTorch 工作流來利用您的自定義型別。這適用於與 Tensor 無關的“鴨子”型別以及 Tensor 的使用者定義的子類。
使用 Tensor 型別擴充套件 torch#
為了具體化,讓我們從一個簡單的示例開始,該示例說明了 API 分派機制。我們將建立一個自定義型別來表示一個二維標量張量,該張量由階數 N 和對角線項上的值 value 引數化
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
該設計的第一個迭代版本不是非常有用的。 ScalarTensor 的主要功能是提供比基張量類更緊湊的標量張量字串表示形式
>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.],
[0., 0., 0., 0., 2.]])
如果我們嘗試使用此物件與 torch API,我們將遇到問題
>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
向 ScalarTensor 新增 __torch_function__ 實現使其能夠使上述操作成功。讓我們重新進行實現,這次新增一個 __torch_function__ 實現
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
def __init__(self, N, value):
self._N = N
self._value = value
def __repr__(self):
return "ScalarTensor(N={}, value={})".format(self._N, self._value)
def tensor(self):
return self._value * torch.eye(self._N)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
__torch_function__ 方法接受四個引數:func,它是被重寫的 torch API 函式的引用;types,實現 __torch_function__ 的 Tensor-like 型別列表;args,傳遞給函式的引數元組;以及 kwargs,傳遞給函式的關鍵字引數字典。它使用一個名為 HANDLED_FUNCTIONS 的全域性分派表來儲存自定義實現。此字典的鍵是 torch 名稱空間中的函式,值是 ScalarTensor 的實現。
注意
使用全域性分派表不是 __torch_function__ API 的強制要求,它只是用於組織重寫實現的一種有用的設計模式。
此類的定義不足以讓 torch.mean 在我們傳遞 ScalarTensor 時執行正確操作——我們還需要為 ScalarTensor 運算元定義 torch.mean 的實現,並將實現新增到 HANDLED_FUNCTIONS 分派表字典中。一種方法是定義一個裝飾器
import functools
def implements(torch_function):
"""Register a torch function override for ScalarTensor"""
def decorator(func):
functools.update_wrapper(func, torch_function)
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
該裝飾器可以應用於我們重寫的實現
@implements(torch.mean)
def mean(input):
return float(input._value) / input._N
透過此更改,我們現在可以在 ScalarTensor 中使用 torch.mean
>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4
當然,torch.mean 是最簡單的重寫函式型別的一個示例,因為它只接受一個運算元。我們可以使用相同的機制來重寫接受多個運算元的函式,其中任何一個都可能是定義了 __torch_function__ 的張量或類張量,例如對於 torch.add()
def ensure_tensor(data):
if isinstance(data, ScalarTensor):
return data.tensor()
return torch.as_tensor(data)
@implements(torch.add)
def add(input, other):
try:
if input._N == other._N:
return ScalarTensor(input._N, input._value + other._value)
else:
raise ValueError("Shape mismatch!")
except AttributeError:
return torch.add(ensure_tensor(input), ensure_tensor(other))
此版本在兩個運算元都是 ScalarTensor 例項時有一個快速路徑,還有一個較慢的路徑,當任一運算元不是 ScalarTensor 時,它會退化為將資料轉換為張量。這使得重寫函式在任一運算元是 ScalarTensor 或常規 Tensor 時都能正確工作。
>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
[1., 3.]])
請注意,我們對 add 的實現不像 torch.add() 那樣接受 alpha 或 out 作為關鍵字引數
>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'
為了速度和靈活性,__torch_function__ 分派機制不會檢查重寫函式的簽名是否與 torch API 中被重寫的函式的簽名是否匹配。對於某些應用程式,忽略可選引數是可以的,但為了確保與 Tensor 的完全相容性,torch API 函式的使用者實現應仔細地精確模擬被重寫的函式的 API。
torch API 中沒有顯式重寫的函式將從 __torch_function__ 返回 NotImplemented。如果所有具有 __torch_function__ 定義的運算元都返回 NotImplemented,PyTorch 將引發 TypeError。這意味著在大多數情況下,當一個型別的例項被傳遞時,沒有為該型別重寫的操作將引發 TypeError。
>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]
在實踐中,這意味著如果您希望以這種方式實現 __torch_function__ 實現,您需要顯式實現完整的 torch API 或您關心的 API 的完整子集。這可能是一項艱鉅的任務,因為完整的 torch API 非常龐大。
另一種選擇是,對於未處理的操作,不返回 NotImplemented,而是當沒有重寫可用時,將 Tensor 傳遞給原始 torch 函式。例如,如果我們修改 ScalarTensor 的 __torch_function__ 實現如下
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, ScalarTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return HANDLED_FUNCTIONS[func](*args, **kwargs)
那麼 torch.mul() 將正常工作,儘管即使兩個運算元都是 ScalarTensor 例項,返回型別也將始終是 Tensor 而不是 ScalarTensor。
>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
[0., 4.]])
另請參閱下面的 MetadataTensor 示例,瞭解此模式的另一種變體,但它始終返回 MetadataTensor 以便透過 torch API 中的操作傳播元資料。
__torch_function__ 協議旨在覆蓋整個 API,部分覆蓋可能會導致不良結果,特別是某些函式引發 TypeError。這對於子類尤其如此,其中 torch.add、torch.Tensor.__add__ 和 torch.Tensor.add 三者都必須被覆蓋,即使它們返回完全相同的結果。未能這樣做也可能導致無限遞迴。如果一個人需要實現 torch.Tensor 子類中的函式,他們必須在實現中使用 super().__torch_function__。
繼承 torch.Tensor#
從 1.7.0 版本開始,應用於 torch.Tensor 子類上的 torch.Tensor 方法和公共 torch.* 名稱空間中的函式將返回子類例項而不是 torch.Tensor 例項。
>>> class SubTensor(torch.Tensor):
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'
如果存在多個子類,預設情況下將選擇層次結構中最底層的一個。如果沒有唯一確定的方式來確定這種情況,則會引發 TypeError。
>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]
如果希望對所有張量方法進行全域性重寫,可以使用 __torch_function__。下面是一個記錄所有函式/方法呼叫的示例
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
if func is not torch.Tensor.__repr__:
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
然而,如果希望重寫 Tensor 子類上的方法,則可以透過直接重寫該方法(透過為子類定義它)來實現,或者使用 __torch_function__ 並與 func 匹配來實現。
在子類的 __torch_function__ 中,應該始終呼叫 super().__torch_function__(func, ...) 而不是直接呼叫 func,正如 1.7.0 版本之前的程式碼一樣。未能這樣做可能會導致 func 遞歸回 __torch_function__,從而導致無限遞迴。
使用 Tensor 包裝器型別擴充套件 torch#
另一種有用的情況是包裝 Tensor 的型別,無論是作為屬性還是透過繼承。下面我們實現這種型別的特例,即 MetadataTensor,它將元資料字典附加到 Tensor 上,並透過 torch 操作傳播。由於這是針對整個 torch API 的通用包裝,我們不需要單獨實現每個重寫,因此我們可以使 __torch_function__ 實現對允許的操作更加寬容。
class MetadataTensor(object):
def __init__(self, data, metadata=None, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self._metadata = metadata
def __repr__(self):
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [getattr(a, '_t', a) for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])
這個簡單的實現不一定適用於 torch API 中的每個函式,但它足以處理大多數常見操作。
>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[2, 4],
[4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}
data:
tensor([[1, 4],
[3, 8]])
定義了 __torch_function__ 的多個型別上的操作#
使用具有 __torch_function__ 實現的多個不同型別的 torch API 是可能的,但需要特別小心。在這種情況下,規則是:
分派操作會收集每個運算元的所有不同的
__torch_function__實現,並按順序呼叫它們:子類在超類之前,否則按操作表示式中的從左到右的順序。如果返回的值不是
NotImplemented,則該值將作為結果返回。實現可以透過返回NotImplemented來註冊它們不實現操作。如果所有
__torch_function__實現都返回NotImplemented,PyTorch 將引發TypeError。
PyTorch API 重寫覆蓋率測試#
實現 __torch_function__ 的一個令人頭疼的方面是,如果某些操作有重寫而其他操作沒有,使用者最多隻會看到不一致的體驗,或者最壞的情況下,在使用沒有重寫的函式時會在執行時引發錯誤。為了簡化這個過程,PyTorch 提供了一個面向開發者的 API,用於確保對 __torch_function__ 重寫的全面支援。此 API 是私有的,並可能在未來進行更改,恕不另行通知。
首先,要獲取所有可重寫函式的列表,請使用 torch.overrides._get_overridable_functions。這將返回一個字典,其鍵是 PyTorch Python API 中的名稱空間,其值是該名稱空間中可以重寫的函式列表。例如,讓我們列印 torch.nn.functional 中前 5 個可重寫函式的名稱。
>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
此函式列表使得可以遍歷所有可重寫函式,但實際上,僅憑此不足以編寫所有這些函式的測試,因為需要費力且手動地為每個測試複製每個函式的簽名。為了簡化這個過程,torch.overrides._get_testing_overrides 函式返回一個字典,該字典將 PyTorch API 中可重寫的函式對映到虛擬 lambda 函式,這些函式具有與原始函式相同的簽名,但無條件地返回 -1。這些函式最適合與 inspect 一起使用,以分析原始 PyTorch 函式的函式簽名。
>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>
最後,torch.overrides.get_ignored_functions 返回一個函式元組,這些函式明確不能被 __torch_function__ 重寫。此列表有助於確認 get_overridable_functions 返回的字典中不存在的函式不能被重寫。
擴充套件 torch 原生 API#
雖然 __torch_function__ 允許有效地擴充套件 PyTorch 的純 Python 元件的行為,但它不允許擴充套件用 C++ 實現的 PyTorch 部分。為此,Tensor 子類還可以定義 __torch_dispatch__,它能夠在 C++ 級別重寫行為。
為了有效地使用此功能,瞭解 PyTorch 的原生部分是如何實現的非常重要。其中最重要的組成部分是我們稱之為“排程器”的東西(您可以在這篇 部落格文章 中找到最佳描述,儘管它已略微過時)。正如其名稱所暗示的那樣,排程器負責為一個特定函式呼叫呼叫正確的後端函式。例如,當呼叫 torch.add(a, b) 時,排程器將檢查兩個引數,找出應該為該特定呼叫使用哪個“功能”(autograd、autocast、functionalization 等)和哪個“後端”(CPU、CUDA、MPS 等),最後呼叫所有正確的核心。核心的一個非常常見的操作是“重新排程”。例如,當在 GPU 上使用 autocast 執行神經網路時,第一次呼叫將是 autocast 核心,它將處理任何潛在的 autocast 邏輯並向下重新排程。下一項功能將是 autograd,它將正確建立 autograd 圖,然後向下重新排程。最後,我們到達 CUDA 的後端核心,它將啟動正確的 CUDA 核心並返回最終結果。在退出時,autograd 會將圖附加到輸出,最後,autocast 將有機會在退出時進行任何必要的更新。
排程器的一個配置是所有這些功能和後端鍵被呼叫的順序。最新的列表及其順序可以在 DispatchKey.h 中的 DispatchKey 列舉中找到。就擴充套件 torch 而言,用於此討論的重要子集是:
vmap -> Autocast -> Autograd -> ZeroTensor -> Neg/Conj -> Functionalize -> Python -> Backends
對於此討論最重要的鍵是 Python,因為每個具有 __torch_dispatch__ 方法定義的張量子類都將呼叫此功能。使用者定義的方法就是從那裡呼叫的,行為可以在那裡任意重寫。從那裡開始,再次呼叫提供的 func 將執行“重新排程”。
此實現的一些重要含義是:
此程式碼在“所有功能之下”執行。因此,它僅負責像常規後端一樣生成每個張量的輸出值(並且可以,也應該忽略所有高階功能,如 autograd、autocast 等)。
如果任何高階功能在沒有重新排程的情況下實現了給定函式,它將永遠不會到達
Python鍵,因此__torch_dispatch__回撥將永遠不會被觸發。這尤其發生在 CompositeImplicitAutograd 函式中,這些函式在 Autograd 級別進行評估而無需重新排程。這是因為 CompositeImplicitAutograd 函式透過隱式呼叫其他原生操作來指定其 autograd 公式,因此在 Autograd 級別,該函式被分解為其原生操作,然後對這些操作進行評估。在回撥到 Python 和包裝結果時,使用的轉換與常規 PyTorch Python/C++ 繫結相同。特別是,某些物件無法在 Python 中表示,需要特殊處理(例如,未定義張量變成 None)。
我們的原生函式被惰性填充為
torch.ops.{namespace}.{func_name}.{overload_name},作為可呼叫的 Python 物件,以便於從 Python 進行互動。傳遞給__torch_dispatch__的func物件始終是此名稱空間中的一個條目。此名稱空間可用於直接呼叫原生操作並繞過常規 Python API 和繫結程式碼。
類似於 __torch_function__ 能夠攔截 torch 的所有 Python API 和 Tensor 方法,__torch_dispatch__ 能夠攔截所有進入 aten 原生 API 的呼叫。請注意,Tensor 上的所有方法在進入排程器之前都會被轉換為函式呼叫,因此它們將在此處顯示為函式呼叫:torch.add(a, 2) 和 a + 2 將導致完全相同的 aten 呼叫。其中許多函式定義在 native_functions.yaml 中,該檔案指定了這些函式的屬性以及它們的後端實現。它們的實現以及指定的特性隨後透過 codegen 自動註冊。一些更奇特的函式或特性也在 C++ 程式碼庫中的其他地方或使用者定義的 C++ 擴充套件中註冊。
還可以使用 torch.library 新增 新 的原生函式。此 Python 功能允許定義和/或向原生函式新增新實現。這可用於新增丟失的核心、替換現有的核心或定義全新的原生函式。
您可以在 subclass zoo 儲存庫中找到許多基於 __torch_dispatch__ 的子類的示例。
__torch_dispatch__ 呼叫約定#
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
當用戶呼叫具有定義了 __torch_dispatch__ 的輸入的運算子時,該呼叫可能會轉發到 __torch_dispatch__。在呼叫 __torch_dispatch__ 之前,args 和 kwargs 會被規範化,即:
kwargs由運算子模式中的關鍵字引數組成。如果一個關鍵字引數等於其預設值(在模式中),它將不會被傳遞。args由所有其他引數組成,無論它們是如何傳遞給運算子的(位置引數 vs 關鍵字引數)。如果一個引數等於其預設值,並且它是最右邊的位置引數,或者其右邊的所有引數都未傳遞,則它不會被傳遞。
使用模式擴充套件所有 torch API#
不幸的是,有些函式不接受 Tensor 輸入。這意味著上述子類方法不能用於重寫 PyTorch 所有函式行為。此外,如果用例需要攔截每個函式呼叫,更改每個 Tensor 以成為子類可能會過於侵入。
為了解決這個用例,我們引入了“模式”的概念。這些模式用於 __torch_function__ 和 __torch_dispatch__ 重寫,透過分別繼承 torch.overrides.TorchFunctionMode 和 torch.utils._python_dispatch.TorchDispatchMode 來建立,並用作上下文管理器。
為了簡化其與子類和其他模式互動的描述,每當進入模式的上下文管理器時,所有函式都會表現得好像引數列表的開頭有一個額外的 Tensor 引數,該引數具有該模式作為子類。這意味著尤其所有模式處理程式都將在任何子類處理程式之前被呼叫,並且對應於內部上下文管理器的模式將始終首先執行。
同樣重要的是要注意,在給定的模式處理程式內,此特定模式將被停用,並且可以透過執行 with self: 來手動重新啟用它。
下面是一個示例,展示了每種型別的日誌模式
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
b = a * 2
b.sum().backward()
print("TorchFunctionMode logging:")
with FunctionLog():
f()
print("TorchDispatchMode logging:")
with DispatchLog():
f()
這會列印以下內容,並附帶額外註釋
TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})
TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})