快捷方式

ValueOperator

class torchrl.modules.tensordict_module.ValueOperator(*args, **kwargs)[原始碼]

RL 中價值函式的一般類。

ValueOperator 類附帶 in_keys 和 out_keys 引數的預設值(分別為 ["observation"] 和 ["state_value"] 或 ["state_action_value"],具體取決於“action”鍵是否包含在 in_keys 列表中)。

引數:
  • module (nn.Module) – 一個 torch.nn.Module,用於將輸入對映到輸出引數空間。

  • in_keys (iterable of str, optional) – 要從輸入 tensordict 中讀取並傳遞給 module 的鍵。如果包含多個元素,則值將按 in_keys 可迭代物件給出的順序傳遞。預設為 ["observation"]

  • out_keys (iterable of str) – 要寫入輸入 tensordict 的鍵。out_keys 的長度必須與嵌入式模組返回的張量數量匹配。“_”作為鍵可以避免將張量寫入輸出。預設為 ["state_value"]["state_action_value"](如果 "action"in_keys 的一部分)。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import Unbounded
>>> from torchrl.modules import ValueOperator
>>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,])
>>> class CustomModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = torch.nn.Linear(6, 1)
...     def forward(self, obs, action):
...         return self.linear(torch.cat([obs, action], -1))
>>> module = CustomModule()
>>> td_module = ValueOperator(
...    in_keys=["observation", "action"], module=module
... )
>>> td = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源