Hash¶
- class torchrl.envs.transforms.Hash(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] = None, out_keys_inv: Sequence[NestedKey] = None, *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, repertoire: tuple[tuple[int], Any] = None)[source]¶
將雜湊值新增到 tensordict 中。
- 引數:
in_keys (Sequence of NestedKey) – 要雜湊的值的鍵。
out_keys (Sequence of NestedKey) – 結果雜湊的鍵。
in_keys_inv (Sequence of NestedKey, optional) – 在 inv 呼叫期間要雜湊的值的鍵。
out_keys_inv (Sequence of NestedKey, optional) – 在 inv 呼叫期間結果雜湊的鍵。
- 關鍵字引數:
hash_fn (Callable, optional) – 要使用的雜湊函式。函式簽名必須是
(input: Any, seed: Any | None) -> torch.Tensor。如果此轉換使用seed引數初始化,則seed才會被使用。預設為Hash.reproducible_hash。seed (optional) – 如果需要,用於雜湊函式的種子。
use_raw_nontensor (bool, optional) – 如果為
False,則在呼叫fn之前,會從NonTensorData/NonTensorStack輸入中提取資料。如果為True,則直接將原始NonTensorData/NonTensorStack輸入提供給fn,此時fn必須支援這些輸入。預設為False。repertoire (Dict[Tuple[int], Any], optional) – 如果提供,此字典將儲存從雜湊到輸入的逆對映。此 repertoire 不會被複制,因此可以在轉換例項化後在同一工作空間中修改它,這些修改將反映在對映中。缺少雜湊將對映到
None。預設值:NoneHash (>>> from torchrl.envs import GymEnv, UnaryTransform,) –
GymEnv (>>> env =) –
output (>>> # process the string) –
env.append_transform( (>>> env =) –
UnaryTransform( (...) –
in_keys=["observation"], (...) –
out_keys=["observation_str"], (...) –
tensor (... fn=lambda) – str(tensor.numpy().tobytes())))
output –
env.append_transform( –
Hash( (...) –
in_keys=["observation_str"], (...) –
out_keys=["observation_hash"],) (...) –
) (...) –
env.observation_spec (>>>) –
Composite( –
- observation: BoundedContinuous(
shape=torch.Size([3]), space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu, dtype=torch.float32, domain=continuous),
- observation_str: NonTensor(
shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None),
- observation_hash: UnboundedDiscrete(
shape=torch.Size([32]), space=ContinuousBox(
low=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True), high=Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.uint8, contiguous=True)),
device=cpu, dtype=torch.uint8, domain=discrete),
device=None, shape=torch.Size([]))
env.rollout (>>>) –
TensorDict( –
- fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict(
- fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’g\x08\x8b\xbexav\xbf\x00\xee(>’”, “b’\x…, batch_size=torch.Size([3]), device=None),
reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_hash: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.uint8, is_shared=False), observation_str: NonTensorStack(
[“b’\xb5\x17\x8f\xbe\x88\xccu\xbf\xc0Vr?’”…, batch_size=torch.Size([3]), device=None),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]), device=None, is_shared=False)
env.check_env_specs() (>>>) –
succeeded! ([torchrl][INFO] check_env_specs) –
- get_input_from_hash(hash_tensor) Any[source]¶
查詢給定特定雜湊輸出的輸入。
此功能僅在初始化期間提供了 :arg:`repertoire` 引數,或者同時提供了 :arg:`in_keys_inv` 和 :arg:`out_keys_inv` 引數時才可用。
- 引數:
hash_tensor (Tensor) – 雜湊輸出。
- 返回:
生成雜湊的輸入。
- 返回型別:
任何
- classmethod reproducible_hash(string, seed=None)[source]¶
使用種子從字串建立可復現的 256 位雜湊。
- 引數:
string (str or None) – 輸入字串。如果為
None,則使用空字串""。seed (str, optional) – 種子值。預設為
None。
- 返回:
形狀為
(32,),dtype 為torch.uint8。- 返回型別:
張量
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]¶
返回一個字典,其中包含對模組整個狀態的引用。
引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為
None的引數和緩衝區不包含在內。注意
返回的物件是淺複製。它包含對模組引數和緩衝區的引用。
警告
當前
state_dict()還接受destination、prefix和keep_vars的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。警告
請避免使用引數
destination,因為它不是為終端使用者設計的。- 引數:
destination (dict, optional) – 如果提供,模組的狀態將更新到 dict 中,並返回相同的物件。否則,將建立一個
OrderedDict並返回。預設為None。prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''。keep_vars (bool, optional) – 預設情況下,state dict 中返回的
Tensors 會從 autograd 中分離。如果設定為True,則不會執行分離。預設為False。
- 返回:
包含模組整體狀態的字典
- 返回型別:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']