評價此頁

torch.Tensor.register_hook#

Tensor.register_hook(hook)[原始碼]#

註冊一個反向鉤子。

每次計算關於 Tensor 的梯度時都會呼叫該 hook。該 hook 應具有以下簽名:

hook(grad) -> Tensor or None

該 hook 不應修改其引數,但可以選擇性地返回一個新的梯度,該梯度將用於替換 grad

此函式返回一個控制代碼,其中包含一個方法 handle.remove(),用於從模組中移除該鉤子。

注意

有關此 hook 何時執行以及如何與其他 hook 排序執行的資訊,請參閱 反向傳播 hook 的執行

示例

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook