ValueEstimatorBase¶
- class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[原始碼]¶
價值函式模組的抽象父類。
其
ValueFunctionBase.forward()方法將計算值(由價值網路給出)和價值估計(由價值估計器給出)以及優勢,並將這些值寫入輸出 tensordict。如果只需要價值估計,則應改用
ValueFunctionBase.value_estimate()。- default_keys¶
別名:
_AcceptedKeys
- abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[原始碼]¶
給定 tensordict 中的資料,計算優勢估計。
If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.
- 引數:
tensordict (TensorDictBase) – 一個包含資料的 TensorDict(一個觀測鍵,
"action",("next", "reward"),("next", "done"),("next", "terminated"),以及由環境返回的"next"tensordict 狀態),這些資料用於計算價值估計和 TD 估計。傳遞給此模組的資料應結構化為[*B, T, *F],其中B是批次大小,T是時間維度,F是特徵維度。tensordict 的形狀必須為[*B, T]。- 關鍵字引數:
params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
device (torch.device, optional) – 緩衝區將被例項化的裝置。預設為
torch.get_default_device()。
- 返回:
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
- value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[原始碼]¶
Gets a value estimate, usually used as a target value for the value network.
如果狀態值鍵存在於
tensordict.get(("next", self.tensor_keys.value))下,則將使用此值,而無需呼叫值網路。- 引數:
tensordict (TensorDictBase) – the tensordict containing the data to read.
target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.
next_value (torch.Tensor, optional) – 下一個狀態或狀態-動作對的值。與
target_params互斥。**kwargs – the keyword arguments to be passed to the value network.
Returns: a tensor corresponding to the state value.