LSTMModule¶
- class torchrl.modules.LSTMModule(*args, **kwargs)[原始碼]¶
LSTM 模組的嵌入器。
此類為
torch.nn.LSTM添加了以下功能:與 TensorDict 的相容性:隱藏狀態將被重塑以匹配 tensordict 的批處理大小。
可選的多步執行:使用 torch.nn,必須在
torch.nn.LSTMCell和torch.nn.LSTM之間進行選擇,前者相容單步輸入,後者相容多步輸入。此類同時支援這兩種用法。
構造後,該模組不處於迴圈模式,即它將期望單步輸入。
如果處於迴圈模式,則 tensordict 的最後一個維度應標記為步數。tensordict 的維度沒有限制(除非對於時間輸入,維度必須大於一)。
注意
此類可以處理沿時間維度的多個連續軌跡,但是在這種情況下,不應信任最終隱藏值(即,不應將其重新用於連續軌跡)。原因是 LSTM 只返回最後一個隱藏值,對於我們提供的填充輸入,該值可能對應於一個 0 填充的輸入。
- 引數:
input_size – 輸入 x 中預期特徵的數量
hidden_size – 隱藏狀態 h 中的特徵數量
num_layers – 迴圈層數。例如,設定
num_layers=2意味著將兩個 LSTM 堆疊在一起形成一個“堆疊 LSTM”,第二個 LSTM 接收第一個 LSTM 的輸出並計算最終結果。預設為:1bias – 如果為
False,則該層不使用偏置權重 b_ih 和 b_hh。預設值:Truedropout – 如果非零,則在除最後一層外的每個 LSTM 層的輸出上引入“Dropout”層,其 dropout 機率等於
dropout。預設為:0python_based — 如果為
True,將使用完整的 Python 實現的 LSTM 單元。預設值:False
- 關鍵字引數:
in_key (str 或 tuple of str) – 模組的輸入鍵。與
in_keys互斥使用。如果提供,迴圈鍵假定為 [“recurrent_state_h”, “recurrent_state_c”],並且in_key將新增到它們之前。in_keys (list of str) – 一組三個字串,分別對應輸入值、第一個和第二個隱藏鍵。與
in_key互斥。out_key (str 或 tuple of str) – 模組的輸出鍵。與
out_keys互斥使用。如果提供,迴圈鍵假定為 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],並且out_key將新增到它們之前。out_keys (list of str) –
一組三個字串,分別對應輸出值、第一個和第二個隱藏鍵。.. 注意
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.device (torch.device 或 compatible) – 模組的裝置。
lstm (torch.nn.LSTM, optional) – 要包裝的 LSTM 例項。與其他 nn.LSTM 引數互斥。
default_recurrent_mode (bool, optional) – 如果提供,則為迴圈模式,如果尚未被
set_recurrent_mode上下文管理器/裝飾器覆蓋。預設為False。
- 變數:
recurrent_mode – 返回模組的迴圈模式。
注意
此模組依賴於輸入 TensorDict 中存在的特定
recurrent_state鍵。要生成TensorDictPrimer轉換,該轉換將自動向環境 TensorDict 新增隱藏狀態,請使用方法make_tensordict_primer()。如果此類是更大模組的子模組,則可以對父模組呼叫方法get_primers_from_module(),以自動生成此模組所需的所有 primer 轉換。示例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase = None)[原始碼]¶
定義每次呼叫時執行的計算。
所有子類都應重寫此方法。
注意
儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫
Module例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。
- make_cudnn_based() LSTMModule[原始碼]¶
將 LSTM 層轉換為其基於 CuDNN 的版本。
- 返回:
self
- make_python_based() LSTMModule[原始碼]¶
將 LSTM 層轉換為其基於 Python 的版本。
- 返回:
self
- make_tensordict_primer()[原始碼]¶
為環境建立一個 tensordict primer。
一個
TensorDictPrimer物件將確保策略在 Rollout 執行期間瞭解補充輸入和輸出(迴圈狀態)。這樣,資料就可以在程序之間共享並得到妥善處理。使用批處理環境(如
ParallelEnv)時,該轉換可以在單個環境例項級別(即,一組具有內部設定的 tensordict primers 的轉換後的環境)或在批處理環境例項級別(即,一組普通環境的轉換後的批處理)上使用。如果在環境中未包含
TensorDictPrimer,可能會導致行為不當,例如在並行設定中,一個步驟涉及將新的迴圈狀態從"next"複製到根 tensordict,而~torchrl.EnvBase.step_mdp方法將無法執行此操作,因為迴圈狀態未在環境規範中註冊。有關生成給定模組所有 primer 的方法,請參閱
torchrl.modules.utils.get_primers_from_module()。示例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(lstm_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break