torch.Tensor.register_hook#
- Tensor.register_hook(hook)[原始碼]#
註冊一個反向鉤子。
每次計算關於 Tensor 的梯度時都會呼叫該 hook。該 hook 應具有以下簽名:
hook(grad) -> Tensor or None
該 hook 不應修改其引數,但可以選擇性地返回一個新的梯度,該梯度將用於替換
grad。此函式返回一個控制代碼,其中包含一個方法
handle.remove(),用於從模組中移除該鉤子。注意
有關此 hook 何時執行以及如何與其他 hook 排序執行的資訊,請參閱 反向傳播 hook 的執行。
示例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.tensor([1., 2., 3.])) >>> v.grad 2 4 6 [torch.FloatTensor of size (3,)] >>> h.remove() # removes the hook