將 torch.func 與 autograd.Function 擴充套件#
創建於: 2023年01月03日 | 最後更新於: 2023年09月14日
您想將 torch.autograd.Function 與 torch.func 變換(如 torch.vmap()、torch.func.grad() 等)一起使用。
有兩種主要用例:
您希望呼叫不包含 PyTorch 操作的程式碼,並使其能夠與函式變換一起工作。也就是說,
torch.autograd.Function的 forward/backward/etc 呼叫會指向其他系統(如 C++、CUDA、NumPy)的函式。您希望指定自定義梯度規則,類似於 JAX 的 custom_vjp/custom_jvp。
PyTorch 將這兩個概念結合到了 torch.autograd.Function 中。
基本用法#
本指南假設您已熟悉 擴充套件 torch.autograd,其中解釋瞭如何使用 torch.autograd.Function。
torch.autograd.Function 可以有一個接受 ctx 物件的 forward() 方法,或者可以有單獨的 forward() 方法(不接受 ctx)和一個 setup_context() 靜態方法,後者會修改 ctx 物件。
只有後者才支援函式變換。
forward()是執行操作的程式碼,它不應接受ctx物件。setup_context(ctx, inputs, output)是您可以呼叫ctx方法的程式碼。在這裡,您應該儲存用於反向傳播的 Tensor(透過呼叫ctx.save_for_backward(*tensors)),或者儲存非 Tensor 物件(透過將它們賦值給ctx物件)。
由於 setup_context() 只接受 inputs 和 output,因此只能儲存輸入或輸出中的物件(如 Tensor)或從它們派生的數量(如 Tensor.shape)。如果您希望為反向傳播儲存 Function.forward() 的非輸入中間啟用,則需要將其作為 forward() 的輸出返回,以便傳遞給 setup_context()。
根據變換的不同:
為了支援反向模式 AD(
torch.func.grad()、torch.func.vjp()),torch.autograd.Function需要一個backward()靜態方法。為了支援
torch.vmap(),torch.autograd.Function需要一個vmap()靜態方法。為了支援
torch.func.jvp(),torch.autograd.Function需要一個jvp()靜態方法。為了支援變換的組合(如
torch.func.jacrev()、torch.func.jacfwd()、torch.func.hessian())——您可能需要以上多種方法。
為了使 torch.autograd.Function 能夠與函式變換任意組合,我們建議除 forward() 和 setup_context() 之外的所有其他靜態方法都必須是可變換的:也就是說,它們必須只由 PyTorch 操作組成,或者呼叫其他 torch.autograd.Function(這些 torch.autograd.Function 可能呼叫 C++/CUDA/etc)。
下面我們來看一些常見用例的示例。
示例 1:autograd.Function 呼叫另一個系統#
一種常見情況是 torch.autograd.Function 同時在 forward() 和 backward() 中呼叫另一個系統(如 C++、CUDA、NumPy、Triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
現在,為了更方便地使用 NumpySort(隱藏我們作為輸出返回的中間變數,並允許預設的 args 和 kwargs),我們建立了一個新函式來呼叫它。
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
這是一個健全性檢查。
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定義梯度規則#
另一種常見情況是使用 PyTorch 操作實現的 torch.autograd.Function。PyTorch 能夠自動計算 PyTorch 操作的梯度,但我們可能希望自定義梯度的計算方式。我們可能希望自定義 backward 的原因包括:
提高數值穩定性
改變 backward 的效能特徵
改變邊緣情況的處理方式(例如,NaN、Inf)
修改梯度(例如,梯度裁剪)
下面是一個函式 y = x ** 3 的 torch.autograd.Function 示例,其中我們改變了效能特徵(一些通常在 backward 傳遞中進行的計算,計算 dx,現在在 forward 傳遞中完成)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
現在,為了更方便地使用 NumpySort(並隱藏我們作為輸出返回的中間變數),我們建立了一個新函式來呼叫它。
def my_cube(x):
result, _ = MyCube.apply(x)
return result
這是一個計算二階梯度的健全性檢查。
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事項#
警告
請仔細閱讀 torch.autograd.Function 與 torch.func 變換結合使用的限制。我們無法優雅地捕獲許多這種情況,它們會導致未定義的行為。
請不要將正在被變換的 Tensor、requires_grad=True 的 Tensor 或雙重 Tensor 捕獲到 torch.autograd.Function 的方法中。完全安全的方法是確保 torch.autograd.Function 的任何方法內部使用的唯一 Tensor 必須直接作為輸入(或透過 ctx 物件)傳遞,而不是來自 torch.autograd.Function 外部。
torch.autograd.Function 不處理 PyTree 中的 Tensor(可能包含 Tensor 的任意巢狀 Python 資料結構)。為了讓這些 Tensor 被 autograd 跟蹤,它們必須直接作為引數傳遞給 torch.autograd.Function。這與 jax.{custom_vjp, custom_jvp} 不同,後者接受 PyTree。
請僅使用 save_for_backward() 或 save_for_forward() 來儲存 Tensor。請不要直接將 Tensor 或 Tensor 集合賦值給 ctx 物件——這些 Tensor 將不會被跟蹤。
torch.vmap() 支援#
要將 torch.autograd.Function 與 torch.vmap() 一起使用,您必須執行以下操作之一:
提供一個
vmap()靜態方法,告訴我們torch.autograd.Function在torch.vmap()下的行為。透過設定
generate_vmap_rule=True來請求我們自動生成它。
自動生成 vmap 規則#
如果您的 torch.autograd.Function 滿足以下附加約束,我們就能為其生成 vmap 規則。如果它不滿足約束,或者您希望在 vmap 下具有自定義行為,請手動定義 vmap 靜態方法(請參閱下一節)。
警告
我們無法輕鬆檢查以下約束並優雅地報錯。違反約束可能導致未定義的行為。
torch.autograd.Function的forward()、backward()(如果存在)和jvp()(如果存在)靜態方法必須可以透過torch.vmap()進行變換。也就是說,它們必須只由 PyTorch 操作組成(而不是例如 NumPy 或自定義 CUDA 核心)。
示例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定義 vmap 靜態方法#
如果您的 torch.autograd.Function 呼叫了另一個系統(如 NumPy、C++、CUDA、Triton),那麼為了使其能夠與 torch.vmap() 或使用它的變換一起工作,您需要手動定義一個 vmap() 靜態方法。
根據您想要使用的變換以及您的用例,您可能不需要在所有 torch.autograd.Function 中新增 vmap() 靜態方法。
例如,
torch.func.jacrev()在 backward 傳遞上執行vmap()。因此,如果您只對使用torch.func.jacrev()感興趣,那麼只有backward()靜態方法需要是可 vmap 的。
我們建議確保您的所有 torch.autograd.Function 都支援 torch.vmap(),特別是如果您正在編寫第三方庫,並希望您的 torch.autograd.Function 能夠與所有 torch.func() 變換的組合一起工作。
概念上,vmap 靜態方法負責定義 forward() 在 torch.vmap() 下的行為。也就是說,它定義瞭如何變換 forward() 以在具有附加維度(被 vmap 的維度)的輸入上執行。這類似於 torch.vmap() 如何在 PyTorch 操作上實現:對於每個操作,我們定義一個 vmap 規則(有時也稱為“批處理規則”)。
以下是如何定義 vmap() 靜態方法:
簽名是
vmap(info, in_dims: Tuple[Optional[int]], *args),其中*args與forward()的 args 相同。vmap 靜態方法負責定義
forward()在torch.vmap()下的行為。也就是說,給定帶有附加維度(由in_dims指定)的輸入,我們如何計算forward()的批處理版本?對於
args中的每個引數,in_dims都有一個對應的Optional[int]。如果引數不是 Tensor,或者引數沒有被 vmap,則為None;否則,它是一個整數,指定 Tensor 被 vmap 的哪個維度。info是一個包含附加元資料的集合,這些元資料可能很有用:info.batch_size指定了被 vmap 的維度的大小,而info.randomness是傳遞給torch.vmap()的randomness選項。vmap 靜態方法的返回值是一個元組
(output, out_dims)。與in_dims類似,out_dims的結構應與output相同,並且每個輸出都包含一個out_dim,指定輸出是否具有 vmap 的維度以及在該維度中的索引。
示例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 靜態方法應旨在保留整個 Function 的語義。也就是說,(虛擬碼)grad(vmap(MyFunc)) 應該可以被 grad(map(MyFunc)) 替換。
如果您的 autograd.Function 在 backward 傳遞中具有任何自定義行為,請牢記這一點。
torch.func.jvp() 支援#
為了支援前向模式 AD,torch.autograd.Function 必須有一個 jvp() 靜態方法。詳情請參閱 前向 AD autograd.Function。