FlattenObservation¶
- class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, allow_positive_dim: bool = False)[原始碼]¶
展平張量的相鄰維度。
- 引數:
first_dim (int) – 要展平的維度中的第一個維度。
last_dim (int) – 要展平的維度中的最後一個維度。
in_keys (Sequence[NestedKey], optional) – 要展平的條目。如果未提供,則假定為
["pixels"]。out_keys (Sequence[NestedKey], optional) – 展平後的觀察鍵。如果未提供,則假定為
in_keys。allow_positive_dim (bool, optional) – 如果為
True,則接受正維度。FlattenObservation會將這些維度對映到輸入張量的 n 次特徵維度(即父環境批處理大小之後的第 n 個維度)。預設為 False,即不允許使用非負維度。
- forward(next_tensordict: TensorDictBase) TensorDictBase¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
_call可以在每次修改 env.step 的輸出時被重寫,而無需考慮前一步收集的資料(包括動作和狀態)。對於任何僅與父環境相關的操作(例如
FrameSkip),請改用修改_step()方法。只有當需要修改輸入 tensordict 時,才應重寫_call()。_call()將被step()和reset()呼叫,但不會在forward()期間呼叫。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]¶
轉換觀察規範,使結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範