BurnInTransform¶
- class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Sequence[NestedKey] | None = None)[source]¶
用於部分燒入資料序列的轉換。
此轉換器對於在無法獲得迴圈狀態時獲取最新的迴圈狀態很有用。它沿著時間維度燒入從取樣的順序資料切片中獲得的若干步,並返回具有燒入資料作為其初始時間步的剩餘資料序列。此轉換器旨在用作回放緩衝區轉換器,而不是環境轉換器。
- 引數:
modules (TensorDictModule 序列) – 用於燒入資料序列的模組列表。
burn_in (int) – 要燒入的時間步數。
out_keys (NestedKey 序列, 可選) – 目標鍵。預設為
` (所有指向下一個時間步的模組的 out_keys(例如“hidden”,如果) –
(“next” –
module)。 (“hidden”)是其中一個模組的 out_keys) –
注意
此轉換器期望輸入的 TensorDict 的最後一個維度是時間維度。它還假定所有提供的模組都可以處理順序資料。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.envs.transforms import BurnInTransform >>> from torchrl.modules import GRUModule >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], ... default_recurrent_mode=True, ... ) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, ... ) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> td = burn_in_transform(td) >>> td.shape torch.Size([2, 5]) >>> td.get("hidden").abs().sum() tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer >>> buffer = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(2), ... batch_size=1, ... ) >>> buffer.append_transform(burn_in_transform) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> buffer.extend(td) >>> td = buffer.sample(1) >>> td.shape torch.Size([1, 5]) >>> td.get("hidden").abs().sum() tensor(37.0344)
- forward(tensordict: TensorDictBase) TensorDictBase[source]¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.