評價此頁

torch.func.grad_and_value#

torch.func.grad_and_value(func, argnums=0, has_aux=False)[原始碼]#

返回一個用於計算梯度和原始值(或前向計算)的元組的函式。

引數
  • func (Callable) – 一個接受一個或多個引數的 Python 函式。必須返回一個單元素 Tensor。如果指定了 has_aux 等於 True,則函式可以返回一個單元素 Tensor 和其他輔助物件的元組:(output, aux)

  • argnums (intTuple[int]) – 指定需要計算梯度的引數。 argnums 可以是單個整數或整數元組。預設為:0。

  • has_aux (bool) – 標誌,指示 func 返回一個張量和其他輔助物件: (output, aux)。預設為:False。

返回

用於計算其輸入和前向計算的梯度元組的函式。預設情況下,函式輸出是相對於第一個引數的梯度張量和原始計算的元組。如果指定 has_auxTrue,則返回梯度元組和帶有輸出輔助物件的前向計算元組。如果 argnums 是整數元組,則返回一個元組,其中包含相對於每個 argnums 值的輸出梯度元組以及前向計算。

返回型別

Callable

請參閱 grad() 檢視示例。