評價此頁

torch.Tensor.register_post_accumulate_grad_hook#

Tensor.register_post_accumulate_grad_hook(hook)[原始碼]#

註冊一個在梯度累積後執行的反向鉤子。

該 hook 將在所有梯度累積到某個張量之後被呼叫,這意味著該張量的 `.grad` 欄位已被更新。post accumulate grad hook **僅**適用於葉子張量(即沒有 `.grad_fn` 欄位的張量)。在非葉子張量上註冊此 hook 會報錯!

鉤子應具有以下簽名

hook(param: Tensor) -> None

請注意,與其他 autograd hook 不同,此 hook 操作的是需要梯度的張量本身,而不是梯度。hook 可以原地修改和訪問其張量引數,包括其 `.grad` 欄位。

此函式返回一個控制代碼,其中包含一個方法 handle.remove(),用於從模組中移除該鉤子。

注意

有關此 hook 何時執行以及其執行順序相對於其他 hook 的更多資訊,請參閱 Backward Hooks execution。由於此 hook 在反向傳播過程中執行,它將在 `no_grad` 模式下執行(除非 `create_graph` 為 `True`)。如果您需要在 hook 中重新啟用 autograd,可以使用 `torch.enable_grad()`。

示例

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook