評價此頁

torch.autograd.function.FunctionCtx.save_for_backward#

FunctionCtx.save_for_backward(*tensors)[source]#

為未來的 backward() 呼叫儲存給定的張量。

save_for_backward 最多隻能呼叫一次,可以在 setup_context()forward() 方法中呼叫,並且只能使用 tensors。

所有打算在 backward 傳播中使用但不是 forward 函式的輸入或輸出的 tensor 都應使用 save_for_backward 儲存(而不是直接儲存在 ctx 上),以防止梯度不正確和記憶體洩漏,並啟用已儲存 tensor hook 的應用。請參閱 torch.autograd.graph.saved_tensors_hooks。有關更多詳細資訊,請參閱 擴充套件 torch.autograd

請注意,如果儲存了中間 tensors(即既不是 forward() 的輸入也不是輸出的 tensors),那麼自定義 Function 可能不支援二階反向傳播。不支援二階反向傳播的自定義 Function 應該用 @once_differentiable 裝飾其 backward() 方法,這樣執行二階反向傳播時會報錯。如果您希望支援二階反向傳播,可以根據反向傳播過程中的輸入重新計算中間值,或者將中間值作為自定義 Function 的輸出返回。有關更多詳細資訊,請參閱 二階反向傳播教程

backward() 中,可以透過 saved_tensors 屬性訪問已儲存的 tensors。在將它們返回給使用者之前,會進行檢查以確保它們沒有被用於任何就地修改其內容的運算。

引數也可以是 None。這不會執行任何操作。

有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd

示例

>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)