torch.autograd.function.FunctionCtx.save_for_backward#
- FunctionCtx.save_for_backward(*tensors)[source]#
為未來的
backward()呼叫儲存給定的張量。save_for_backward最多隻能呼叫一次,可以在setup_context()或forward()方法中呼叫,並且只能使用 tensors。所有打算在 backward 傳播中使用但不是 forward 函式的輸入或輸出的 tensor 都應使用
save_for_backward儲存(而不是直接儲存在ctx上),以防止梯度不正確和記憶體洩漏,並啟用已儲存 tensor hook 的應用。請參閱torch.autograd.graph.saved_tensors_hooks。有關更多詳細資訊,請參閱 擴充套件 torch.autograd。請注意,如果儲存了中間 tensors(即既不是
forward()的輸入也不是輸出的 tensors),那麼自定義 Function 可能不支援二階反向傳播。不支援二階反向傳播的自定義 Function 應該用@once_differentiable裝飾其backward()方法,這樣執行二階反向傳播時會報錯。如果您希望支援二階反向傳播,可以根據反向傳播過程中的輸入重新計算中間值,或者將中間值作為自定義 Function 的輸出返回。有關更多詳細資訊,請參閱 二階反向傳播教程。在
backward()中,可以透過saved_tensors屬性訪問已儲存的 tensors。在將它們返回給使用者之前,會進行檢查以確保它們沒有被用於任何就地修改其內容的運算。引數也可以是
None。這不會執行任何操作。有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd。
示例
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * z >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)