快捷方式

ValueEstimatorBase

class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[原始碼]

價值函式模組的抽象父類。

ValueFunctionBase.forward() 方法將計算值(由價值網路給出)和價值估計(由價值估計器給出)以及優勢,並將這些值寫入輸出 tensordict。

如果只需要價值估計,則應改用 ValueFunctionBase.value_estimate()

default_keys

別名:_AcceptedKeys

abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[原始碼]

給定 tensordict 中的資料,計算優勢估計。

If a functional module is provided, a nested TensorDict containing the parameters (and if relevant the target parameters) can be passed to the module.

引數:

tensordict (TensorDictBase) – 一個包含資料的 TensorDict(一個觀測鍵,"action"("next", "reward")("next", "done")("next", "terminated"),以及由環境返回的 "next" tensordict 狀態),這些資料用於計算價值估計和 TD 估計。傳遞給此模組的資料應結構化為 [*B, T, *F],其中 B 是批次大小,T 是時間維度,F 是特徵維度。tensordict 的形狀必須為 [*B, T]

關鍵字引數:
  • params (TensorDictBase, optional) – A nested TensorDict containing the params to be passed to the functional value network module.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • device (torch.device, optional) – 緩衝區將被例項化的裝置。預設為 torch.get_default_device()

返回:

An updated TensorDict with an advantage and a value_error keys as defined in the constructor.

set_keys(**kwargs) None[原始碼]

設定 tensordict 鍵名。

value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[原始碼]

Gets a value estimate, usually used as a target value for the value network.

如果狀態值鍵存在於 tensordict.get(("next", self.tensor_keys.value)) 下,則將使用此值,而無需呼叫值網路。

引數:
  • tensordict (TensorDictBase) – the tensordict containing the data to read.

  • target_params (TensorDictBase, optional) – A nested TensorDict containing the target params to be passed to the functional value network module.

  • next_value (torch.Tensor, optional) – 下一個狀態或狀態-動作對的值。與 target_params 互斥。

  • **kwargs – the keyword arguments to be passed to the value network.

Returns: a tensor corresponding to the state value.

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源