ActionDiscretizer¶
- class torchrl.envs.transforms.ActionDiscretizer(num_intervals: int | torch.Tensor, action_key: NestedKey = 'action', out_action_key: NestedKey = None, sampling=None, categorical: bool = True)[source]¶
一個將連續動作空間離散化的轉換。
此轉換使得可以使用為離散動作空間設計的演算法(如 DQN)處理具有連續動作空間的環境。
- 引數:
num_intervals (int 或 torch.Tensor) – 動作空間每個元素的離散值數量。如果提供單個整數,所有動作項都以相同數量的元素進行切片。如果提供張量,它必須具有與動作空間相同數量的元素(即,
num_intervals張量的長度必須與動作空間的最後一個維度匹配)。action_key (NestedKey, optional) – 要使用的動作鍵。指向父環境的動作(浮點動作)。預設為
"action"。out_action_key (NestedKey, optional) – 離散動作應寫入的鍵。如果提供
None,則預設為action_key的值。如果兩個鍵不匹配,則連續的 action_spec 會從full_action_spec環境屬性移動到full_state_spec容器,因為只有離散動作才應該被取樣以採取動作。提供out_action_key可以確保浮點動作可用於記錄。sampling (ActionDiscretizer.SamplingStrategy, optinoal) –
ActionDiscretizer.SamplingStrategyIntEnum物件(MEDIAN、LOW、HIGH或RANDOM)的一個元素。指示如何對提供的區間內的連續動作進行取樣。categorical (bool, optional) – 如果為
False,則使用獨熱編碼。預設為True。
示例
>>> from torchrl.envs import GymEnv, check_env_specs >>> import torch >>> base_env = GymEnv("HalfCheetah-v4") >>> num_intervals = torch.arange(5, 11) >>> categorical = True >>> sampling = ActionDiscretizer.SamplingStrategy.MEDIAN >>> t = ActionDiscretizer( ... num_intervals=num_intervals, ... categorical=categorical, ... sampling=sampling, ... out_action_key="action_disc", ... ) >>> env = base_env.append_transform(t) TransformedEnv( env=GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu), transform=ActionDiscretizer( num_intervals=tensor([ 5, 6, 7, 8, 9, 10]), action_key=action, out_action_key=action_disc,, sampling=0, categorical=True)) >>> check_env_specs(env) >>> # Produce a rollout >>> r = env.rollout(4) >>> print(r) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.float32, is_shared=False), action_disc: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False) >>> assert r["action"].dtype == torch.float >>> assert r["action_disc"].dtype == torch.int64 >>> assert (r["action"] < base_env.action_spec.high).all() >>> assert (r["action"] > base_env.action_spec.low).all()
- inv(tensordict)[source]¶
讀取輸入 tensordict,並對選定的鍵應用逆變換。
預設情況下,此方法
直接呼叫
_inv_apply_transform()。不呼叫
_inv_call()。
注意
inv也透過使用dispatch將引數名稱強制轉換為鍵來處理常規關鍵字引數。注意
inv由extend()呼叫。
- transform_input_spec(input_spec)[source]¶
轉換輸入規範,使結果規範與轉換對映匹配。
- 引數:
input_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範