快捷方式

BurnInTransform

class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Sequence[NestedKey] | None = None)[source]

用於部分燒入資料序列的轉換。

此轉換器對於在無法獲得迴圈狀態時獲取最新的迴圈狀態很有用。它沿著時間維度燒入從取樣的順序資料切片中獲得的若干步,並返回具有燒入資料作為其初始時間步的剩餘資料序列。此轉換器旨在用作回放緩衝區轉換器,而不是環境轉換器。

引數:
  • modules (TensorDictModule 序列) – 用於燒入資料序列的模組列表。

  • burn_in (int) – 要燒入的時間步數。

  • out_keys (NestedKey 序列, 可選) – 目標鍵。預設為

  • ` (所有指向下一個時間步的模組的 out_keys例如“hidden”,如果) –

  • (“next”

  • module) (“hidden”是其中一個模組的 out_keys) –

注意

此轉換器期望輸入的 TensorDict 的最後一個維度是時間維度。它還假定所有提供的模組都可以處理順序資料。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.envs.transforms import BurnInTransform
>>> from torchrl.modules import GRUModule
>>> gru_module = GRUModule(
...     input_size=10,
...     hidden_size=10,
...     in_keys=["observation", "hidden"],
...     out_keys=["intermediate", ("next", "hidden")],
...     default_recurrent_mode=True,
... )
>>> burn_in_transform = BurnInTransform(
...     modules=[gru_module],
...     burn_in=5,
... )
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> td = burn_in_transform(td)
>>> td.shape
torch.Size([2, 5])
>>> td.get("hidden").abs().sum()
tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> buffer = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(2),
...     batch_size=1,
... )
>>> buffer.append_transform(burn_in_transform)
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> buffer.extend(td)
>>> td = buffer.sample(1)
>>> td.shape
torch.Size([1, 5])
>>> td.get("hidden").abs().sum()
tensor(37.0344)
forward(tensordict: TensorDictBase) TensorDictBase[source]

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源