TransformedEnv¶
- class torchrl.envs.transforms.TransformedEnv(*args, **kwargs)[原始碼]¶
一個轉換後的環境。
- 引數:
base_env (EnvBase) – 要轉換的原始環境。
transform (Transform 或 callable, 可選) –
應用於
base_env.step(td)產生的 tensordict 的轉換。如果未提供,則使用空 Compose 佔位符進行評估模式。注意
如果
transform是一個可呼叫物件,它必須接收一個 tensordict 作為輸入並輸出一個 tensordict。該可呼叫物件將在step和reset時被呼叫:如果它作用於獎勵(在重置時不存在),則需要實現一個檢查以確保轉換能夠順利執行。>>> def add_1(data): ... if "reward" in data.keys(): ... return data.set("reward", data.get("reward") + 1) ... return data >>> env = TransformedEnv(base_env, add_1)
cache_specs (bool, 可選) – 如果為
True,則在第一次呼叫後將快取規範(即,規範只轉換一次)。如果在訓練過程中轉換髮生變化,原始規範轉換可能不再有效,在這種情況下,此值應設定為 False。預設為 True。
- 關鍵字引數:
auto_unwrap (bool, 可選) –
如果為
True,則將一個轉換後的環境包裝到另一個轉換後的環境時,會將在內部 TransformedEnv 的轉換解包到外部的(新例項)。預設為True。注意
此行為將在 v0.9 中切換為
False。
示例
>>> env = GymEnv("Pendulum-v0") >>> transform = RewardScaling(0.0, 1.0) >>> transformed_env = TransformedEnv(env, transform) >>> # check auto-unwrap >>> transformed_env = TransformedEnv(transformed_env, StepCounter()) >>> # The inner env has been unwrapped >>> assert isinstance(transformed_env.base_env, GymEnv)
注意
第一個引數已從
env重新命名為base_env以便更清晰。為了向後相容,仍然支援舊的env引數,但將在 v0.12 中刪除。使用舊引數名稱時會顯示棄用警告。- add_truncated_keys() TransformedEnv[原始碼]¶
將截斷鍵新增到環境中。
- append_transform(transform: Transform | Callable[[TensorDictBase], TensorDictBase]) TransformedEnv[原始碼]¶
向環境追加一個轉換。
接受
Transform或可呼叫物件。
- property batch_locked: bool¶
環境是否可以用於與初始化時不同的批次大小。
如果為 True,則需要在與環境相同批次大小的 tensordict 上使用該環境。batch_locked 是一個不可變屬性。
- property batch_size: Size¶
此環境例項中批次化環境的數量,組織為 torch.Size() 物件。
環境可能相似或不同,但假定它們之間幾乎沒有(如果有的話)互動(例如,多工或並行批處理執行)。
- eval() TransformedEnv[原始碼]¶
將模組設定為評估模式。
這僅對某些模組有影響。有關模組在訓練/評估模式下的行為,例如它們是否受影響(如
Dropout、BatchNorm等),請參閱具體模組的文件。這等同於
self.train(False)。有關 .eval() 和幾種可能與之混淆的類似機制之間的比較,請參閱 區域性停用梯度計算。
- 返回:
self
- 返回型別:
模組
- property input_spec: TensorSpec¶
轉換環境的觀測規格。
- insert_transform(index: int, transform: Transform) TransformedEnv[原始碼]¶
將轉換插入到指定索引的環境中。
接受
Transform或可呼叫物件。
- load_state_dict(state_dict: OrderedDict, **kwargs) None[原始碼]¶
將狀態字典中的引數和緩衝區複製到此模組及其子模組中。
如果
strict為True,則state_dict的鍵必須與此模組的state_dict()函式返回的鍵完全匹配。警告
如果
assign為True,則最佳化器必須在呼叫load_state_dict之後建立,除非get_swap_module_params_on_conversion()為True。- 引數:
state_dict (dict) – 包含引數和持久 buffer 的字典。
strict (bool, 可選) – 是否嚴格執行
state_dict中的鍵是否與此模組的state_dict()函式返回的鍵匹配。預設值:Trueassign (bool, optional) – 當設定為
False時,將保留當前模組中張量的屬性;當設定為True時,將保留 state_dict 中張量的屬性。唯一的例外是Parameter的requires_grad欄位,此時將保留模組的值。預設值:False
- 返回:
missing_keys是一個包含此模組期望但在提供的
state_dict中缺失的任何鍵的字串列表。
unexpected_keys是一個字串列表,包含此模組不期望但在提供的
state_dict中存在的鍵。
- 返回型別:
NamedTuple,包含missing_keys和unexpected_keys欄位。
注意
如果引數或緩衝區被註冊為
None且其對應的鍵存在於state_dict中,load_state_dict()將引發RuntimeError。
- property output_spec: TensorSpec¶
轉換環境的觀測規格。
- rand_action(tensordict: TensorDictBase | None = None) TensorDict[原始碼]¶
根據 action_spec 屬性執行隨機動作。
- 引數:
tensordict (TensorDictBase, optional) – 要將生成的動作寫入的 tensordict。
- 返回:
一個 tensordict 物件,其“action”條目已用從 action-spec 中隨機抽取的樣本更新。
- state_dict(*args, **kwargs) OrderedDict[原始碼]¶
返回一個字典,其中包含對模組整個狀態的引用。
引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為
None的引數和緩衝區不包含在內。注意
返回的物件是淺複製。它包含對模組引數和緩衝區的引用。
警告
當前
state_dict()還接受destination、prefix和keep_vars的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。警告
請避免使用引數
destination,因為它不是為終端使用者設計的。- 引數:
destination (dict, optional) – 如果提供,模組的狀態將更新到 dict 中,並返回相同的物件。否則,將建立一個
OrderedDict並返回。預設為None。prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''。keep_vars (bool, optional) – 預設情況下,state dict 中返回的
Tensors 會從 autograd 中分離。如果設定為True,則不會執行分離。預設為False。
- 返回:
包含模組整體狀態的字典
- 返回型別:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- to(*args, **kwargs) TransformedEnv[原始碼]¶
移動和/或轉換引數和緩衝區。
這可以這樣呼叫
- to(device=None, dtype=None, non_blocking=False)[原始碼]
- to(dtype, non_blocking=False)[原始碼]
- to(tensor, non_blocking=False)[原始碼]
- to(memory_format=torch.channels_last)[原始碼]
其簽名類似於
torch.Tensor.to(),但僅接受浮點或複數dtype。此外,此方法只會將浮點或複數引數和緩衝區轉換為dtype(如果給出)。整數引數和緩衝區將被移動到device(如果給出),但dtype保持不變。當設定non_blocking時,它會嘗試與主機非同步地進行轉換/移動(如果可能),例如,將具有已固定記憶體的 CPU Tensor 移動到 CUDA 裝置。有關示例,請參閱下文。
注意
此方法就地修改模組。
- 引數:
device (
torch.device) – the desired device of the parameters and buffers in this module – 此模組中引數和緩衝區的目標裝置。dtype (
torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module – 此模組中引數和緩衝區的目標浮點數或複數 dtype。tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module – 其 dtype 和 device 是此模組中所有引數和緩衝區的目標 dtype 和 device 的 Tensor。
memory_format (
torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument) – 此模組中 4D 引數和緩衝區的目標記憶體格式(僅關鍵字引數)。
- 返回:
self
- 返回型別:
模組
示例
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- train(mode: bool = True) TransformedEnv[原始碼]¶
將模組設定為訓練模式。
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g.
Dropout,BatchNorm, etc. – 這隻對某些模組有影響。有關其在訓練/評估模式下的行為的詳細資訊,例如它們是否受影響,請參閱特定模組的文件,例如Dropout、BatchNorm等。- 引數:
mode (bool) – whether to set training mode (
True) or evaluation mode (False). Default:True. – 設定訓練模式(True)或評估模式(False)。預設值:True。- 返回:
self
- 返回型別:
模組