快捷方式

ActorValueOperator

class torchrl.modules.tensordict_module.ActorValueOperator(*args, **kwargs)[原始碼]

Actor-value operator( actor-value 運算子)。

This class wraps together an actor and a value model that share a common observation embedding network(該類將共享公共觀測嵌入網路的 actor 和 value model 包裝在一起)。

../../_images/aafig-2229301c32d3e27b4cec9be5284f11e681ba0607.svg

注意

For a similar class that returns an action and a Quality value \(Q(s, a)\), see ActorCriticOperator. For a version without common embedding, refer to ActorCriticWrapper.(對於返回 action 和 Quality value \(Q(s, a)\) 的類似類,請參閱 ActorCriticOperator。對於沒有通用嵌入的版本,請參考 ActorCriticWrapper。)

To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which will both return a standalone TDModule with the dedicated functionality.(為了簡化工作流程,此類提供了 get_policy_operator() 和 get_value_operator() 方法,它們都將返回一個獨立的 TDModule,具有專門的功能。)

引數:
  • common_operator (TensorDictModule) – a common operator that reads observations and produces a hidden variable(一個讀取觀測並生成隱藏變數的通用運算子)。

  • policy_operator (TensorDictModule) – 一個策略運算子,讀取隱藏變數並返回一個動作

  • value_operator (TensorDictModule) – 一個值運算子,讀取隱藏變數並返回一個值

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.modules import ProbabilisticActor, SafeModule
>>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor
>>> module_hidden = torch.nn.Linear(4, 4)
>>> td_module_hidden = SafeModule(
...    module=module_hidden,
...    in_keys=["observation"],
...    out_keys=["hidden"],
...    )
>>> module_action = TensorDictModule(
...     nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()),
...     in_keys=["hidden"],
...     out_keys=["loc", "scale"],
...     )
>>> td_module_action = ProbabilisticActor(
...    module=module_action,
...    in_keys=["loc", "scale"],
...    out_keys=["action"],
...    distribution_class=TanhNormal,
...    return_log_prob=True,
...    )
>>> module_value = torch.nn.Linear(4, 1)
>>> td_module_value = ValueOperator(
...    module=module_value,
...    in_keys=["hidden"],
...    )
>>> td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> td_clone = td_module(td.clone())
>>> print(td_clone)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_policy_operator()(td.clone())
>>> print(td_clone)  # no value
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_value_operator()(td.clone())
>>> print(td_clone)  # no action
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
get_policy_head() SafeSequential[原始碼]

Returns the policy head.(返回策略頭。)

get_policy_operator() SafeSequential[原始碼]

Returns a standalone policy operator that maps an observation to an action.(返回一個獨立的策略運算子,該運算子將觀測對映到動作。)

get_value_head() SafeSequential[原始碼]

Returns the value head.(返回價值頭。)

get_value_operator() SafeSequential[原始碼]

返回一個獨立的價值網路運算子,將觀測對映到價值估計。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源