雙重反向傳播與自定義函式#
建立日期:2021 年 8 月 13 日 | 最後更新:2021 年 8 月 13 日 | 最後驗證:2024 年 11 月 5 日
有時需要對反向傳播圖進行兩次反向傳播,例如計算高階梯度。這需要對 autograd 的理解和一些注意事項來支援雙重反向傳播。並非所有能支援單次反向傳播的函式都一定能支援雙重反向傳播。在本教程中,我們將展示如何編寫支援雙重反向傳播的自定義 autograd 函式,並指出一些需要注意的地方。
在編寫用於雙重反向傳播的自定義 autograd 函式時,重要的是瞭解自定義函式中的操作何時會被 autograd 記錄,何時不會,以及最重要的是,save_for_backward 如何與所有這些協同工作。
自定義函式透過兩種方式隱式影響 grad 模式
在前向傳播過程中,autograd 不會記錄在 forward 函式內執行的任何操作的圖。當 forward 完成後,自定義函式的 backward 函式將成為 forward 輸出的 grad_fn。
在反向傳播過程中,如果指定了 create_graph,autograd 會記錄用於計算反向傳播的計算圖。
接下來,為了理解 save_for_backward 如何與上述機制互動,我們可以通過幾個例子來探討。
儲存輸入#
考慮這個簡單的平方函式。它會儲存一個輸入張量用於反向傳播。當 autograd 能夠記錄反向傳播中的操作時,雙重反向傳播會自動工作,因此當我們為反向傳播儲存輸入時,通常無需擔心,因為如果輸入是任何需要 grad 的張量的函式,它應該有一個 grad_fn。這使得梯度能夠正確傳播。
import torch
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Because we are saving one of the inputs use `save_for_backward`
# Save non-tensors and non-inputs/non-outputs directly on ctx
ctx.save_for_backward(x)
return x**2
@staticmethod
def backward(ctx, grad_out):
# A function support double backward automatically if autograd
# is able to record the computations performed in backward
x, = ctx.saved_tensors
return grad_out * 2 * x
# Use double precision because finite differencing method magnifies errors
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Square.apply, x)
# Use gradcheck to verify second-order derivatives
torch.autograd.gradgradcheck(Square.apply, x)
我們可以使用 torchviz 來視覺化圖,以瞭解為什麼這會起作用。
import torchviz
x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
我們可以看到,關於 x 的梯度本身就是 x 的函式(dout/dx = 2x),並且這個函式的圖已經正確構建。
儲存輸出#
與前一個示例略有不同的是,它儲存輸出而不是輸入。其機制是相似的,因為輸出也與 grad_fn 相關聯。
class Exp(torch.autograd.Function):
# Simple case where everything goes well
@staticmethod
def forward(ctx, x):
# This time we save the output
result = torch.exp(x)
# Note that we should use `save_for_backward` here when
# the tensor saved is an ouptut (or an input).
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_out):
result, = ctx.saved_tensors
return result * grad_out
x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
# Validate our gradients using gradcheck
torch.autograd.gradcheck(Exp.apply, x)
torch.autograd.gradgradcheck(Exp.apply, x)
使用 torchviz 視覺化圖。
out = Exp.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
儲存中間結果#
一個更棘手的情況是我們有時需要儲存一箇中間結果。我們透過實現以下函式來演示這種情況:
由於 sinh 的導數是 cosh,因此在反向傳播計算中重用 forward 中的兩個中間結果 exp(x) 和 exp(-x) 可能很有用。
然而,中間結果不應直接儲存並在反向傳播中使用。因為 forward 是在 no-grad 模式下執行的,如果 forward 過程的中間結果被用於計算 backward 中的梯度,那麼梯度的 backward 圖將不包含計算中間結果的操作。這會導致梯度不正確。
class Sinh(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
expx = torch.exp(x)
expnegx = torch.exp(-x)
ctx.save_for_backward(expx, expnegx)
# In order to be able to save the intermediate results, a trick is to
# include them as our outputs, so that the backward graph is constructed
return (expx - expnegx) / 2, expx, expnegx
@staticmethod
def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
expx, expnegx = ctx.saved_tensors
grad_input = grad_out * (expx + expnegx) / 2
# We cannot skip accumulating these even though we won't use the outputs
# directly. They will be used later in the second backward.
grad_input += _grad_out_exp * expx
grad_input -= _grad_out_negexp * expnegx
return grad_input
def sinh(x):
# Create a wrapper that only returns the first output
return Sinh.apply(x)[0]
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(sinh, x)
torch.autograd.gradgradcheck(sinh, x)
使用 torchviz 視覺化圖。
out = sinh(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
儲存中間結果:不該怎麼做#
現在我們展示當也沒有將中間結果作為輸出返回時會發生什麼:grad_x 甚至不會有 backward 圖,因為它僅僅是 exp 和 expnegx 的函式,而它們不需要 grad。
class SinhBad(torch.autograd.Function):
# This is an example of what NOT to do!
@staticmethod
def forward(ctx, x):
expx = torch.exp(x)
expnegx = torch.exp(-x)
ctx.expx = expx
ctx.expnegx = expnegx
return (expx - expnegx) / 2
@staticmethod
def backward(ctx, grad_out):
expx = ctx.expx
expnegx = ctx.expnegx
grad_input = grad_out * (expx + expnegx) / 2
return grad_input
使用 torchviz 視覺化圖。請注意,grad_x 不在圖的範圍內!
out = SinhBad.apply(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
當反向傳播未被跟蹤時#
最後,我們來考慮一個 autograd 完全無法跟蹤函式反向傳播梯度的示例。我們可以設想 cube_backward 是一個可能需要非 PyTorch 庫(如 SciPy 或 NumPy)的函式,或者它被寫成 C++ 擴充套件。這裡演示的解決方法是建立另一個自定義函式 CubeBackward,您也需要手動指定 cube_backward 的 backward!
def cube_forward(x):
return x**3
def cube_backward(grad_out, x):
return grad_out * 3 * x**2
def cube_backward_backward(grad_out, sav_grad_out, x):
return grad_out * sav_grad_out * 6 * x
def cube_backward_backward_grad_out(grad_out, x):
return grad_out * 3 * x**2
class Cube(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return cube_forward(x)
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return CubeBackward.apply(grad_out, x)
class CubeBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_out, x):
ctx.save_for_backward(x, grad_out)
return cube_backward(grad_out, x)
@staticmethod
def backward(ctx, grad_out):
x, sav_grad_out = ctx.saved_tensors
dx = cube_backward_backward(grad_out, sav_grad_out, x)
dgrad_out = cube_backward_backward_grad_out(grad_out, x)
return dgrad_out, dx
x = torch.tensor(2., requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x)
使用 torchviz 視覺化圖。
out = Cube.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
總而言之,您的自定義函式是否支援雙重反向傳播,僅取決於反向傳播是否能被 autograd 跟蹤。透過前兩個示例,我們展示了雙重反向傳播開箱即用的情況。透過第三個和第四個示例,我們演示了在通常情況下反向傳播函式無法被跟蹤時,啟用其跟蹤的技術。