next_state_value¶
- class torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: TensorDictModule | None = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Tensor | None = None, **kwargs)[原始碼]¶
計算下一個狀態的值(不帶梯度),用於計算目標值。
- 目標值通常用於計算距離損失(例如 MSE)。
L = Sum[ (q_value - target_value)^2 ]
- 目標值計算如下:
r + gamma ** n_steps_to_next * value_next_state
如果獎勵是即時獎勵,n_steps_to_next=1。如果使用 N 步獎勵,n_steps_to_next 將從輸入 tensordict 中收集。
- 引數:
tensordict (TensorDictBase) – 包含獎勵和完成鍵(以及 N 步獎勵的 n_steps_to_next 鍵)的 Tensordict。
operator (ProbabilisticTDModule, optional) – 值函式運算元。呼叫時應在輸入 tensordict 中寫入 ‘next_val_key’ 鍵值。如果提供了 pred_next_val,則無需提供此引數。
next_val_key (str, optional) – 將寫入下一個值的鍵。預設為 ‘state_action_value’。
gamma (
float, optional) – 回報折扣率。預設為 0.99。pred_next_val (Tensor, optional) – 如果下一個狀態值不是透過運算元計算的,則可以提供。
- 返回:
一個與輸入 tensordict 大小相同的張量,包含預測的值狀態。