快捷方式

GRUModule

class torchrl.modules.GRUModule(*args, **kwargs)[原始碼]

GRU 模組的嵌入器。

此類為 torch.nn.GRU 添加了以下功能:

  • 與 TensorDict 的相容性:隱藏狀態被重塑以匹配 tensordict 的批次大小。

  • 可選的多步執行:使用 torch.nn,必須在 torch.nn.GRUCelltorch.nn.GRU 之間進行選擇,前者相容單步輸入,後者相容多步。此類支援這兩種用法。

構造後,模組*不*處於迴圈模式,即它將期望單步輸入。

如果處於迴圈模式,預計 tensordict 的最後一個維度標記步數。tensordict 的維度沒有限制(除了對於時間輸入它必須大於一)。

引數:
  • input_size – 輸入 x 中預期特徵的數量

  • hidden_size – 隱藏狀態 h 中的特徵數量

  • num_layers – 迴圈層數。例如,設定 num_layers=2 將意味著堆疊兩個 GRU 來形成一個 堆疊 GRU,第二個 GRU 接收第一個 GRU 的輸出並計算最終結果。預設值:1

  • bias – 如果為 False,則該層不使用偏置權重。預設為 True

  • dropout – 如果非零,則在除最後一層外的每個 GRU 層的輸出上引入 Dropout 層,dropout 機率等於 dropout。預設值:0

  • python_based – 如果為 True,則將使用完整的 Python 實現 GRU 單元。預設為 False

關鍵字引數:
  • in_key (strtuple of str) – 模組的輸入鍵。與 in_keys 互斥使用。如果提供,則迴圈鍵假定為 [“recurrent_state”],並且 in_key 將在此之前追加。

  • in_keys (list of str) – 一對字串,對應於輸入值和迴圈條目。與 in_key 互斥。

  • out_key (strtuple of str) – 模組的輸出鍵。與 out_keys 互斥使用。如果提供,則迴圈鍵假定為 [(“recurrent_state”)],並且 out_key 將在此之前追加。

  • out_keys (list of str) –

    一對字串,對應於輸出值、第一個和第二個隱藏鍵。 .. note

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.devicecompatible) – 模組的裝置。

  • gru (torch.nn.GRU, optional) – 要包裝的 GRU 例項。與其他 nn.GRU 引數互斥。

  • default_recurrent_mode (bool, optional) – 如果提供,則為迴圈模式,如果尚未被 set_recurrent_mode 上下文管理器/裝飾器覆蓋。預設為 False

變數:

recurrent_mode – 返回模組的迴圈模式。

set_recurrent_mode()[原始碼]

控制模組是否應以迴圈模式執行。

make_tensordict_primer()[原始碼]

為環境建立 TensorDictPrimer 轉換,使其能夠感知 RNN 的迴圈狀態。

注意

此模組依賴於輸入 TensorDict 中存在的特定 recurrent_state 鍵。要生成一個 TensorDictPrimer 轉換,該轉換將自動向環境 TensorDict 新增隱藏狀態,請使用方法 make_tensordict_primer()。如果此類是更大模組的子模組,則可以透過父模組呼叫方法 get_primers_from_module() 來自動生成此模組(包括此模組)所需的所有 primer 轉換。

示例

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
>>> gru_module_training = gru_module.set_recurrent_mode()
>>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> print(traj_td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase = None)[原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

make_cudnn_based() GRUModule[原始碼]

將 GRU 層轉換為其 CuDNN 版本。

返回:

self

make_python_based() GRUModule[原始碼]

將 GRU 層轉換為其 Python 版本。

返回:

self

make_tensordict_primer()[原始碼]

為環境建立一個 tensordict primer。

一個 TensorDictPrimer 物件將確保策略在 rollouts 執行期間能夠感知輔助輸入和輸出(迴圈狀態)。這樣,資料就可以在程序之間共享並得到妥善處理。

如果不在環境中包含 TensorDictPrimer,可能會導致行為不當,例如在並行設定中,一個步驟涉及將新的迴圈狀態從 "next" 複製到根 tensordict,而 ~torchrl.EnvBase.step_mdp 方法將無法執行此操作,因為迴圈狀態未在環境規範中註冊。

在使用 ParallelEnv 等批處理環境時,該轉換可以在單 env 例項級別(即,一批具有 tensordict primers 的轉換 envs)或在批處理 env 例項級別(即,一批常規 envs 的轉換)使用。

有關生成給定模組所有 primers 的方法,請參閱 torchrl.modules.utils.get_primers_from_module()

示例

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(gru_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
...     env,
...     policy,
...     frames_per_batch=10
... )
>>> for data in data_collector:
...     print(data)
...     break

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源