快捷方式

ConsistentDropoutModule

class torchrl.modules.ConsistentDropoutModule(*args, **kwargs)[原始碼]

用於 ConsistentDropout 的 TensorDictModule 包裝器。

引數:
  • p (float, optional) – Dropout 機率。預設為 0.5

  • in_keys (NestedKeylist of NestedKeys) – 將從輸入 tensordict 讀取並傳遞給此模組的鍵。

  • out_keys (NestedKeyiterable of NestedKeys) – 將寫入輸入 tensordict 的鍵。預設為 in_keys 值。

關鍵字引數:
  • input_shape (tuple, optional) – 輸入(非批處理)的形狀,用於透過 make_tensordict_primer() 生成 tensordict primer。

  • input_dtype (torch.dtype, optional) – primer 的輸入資料型別。如果未提供,則假定為 torch.get_default_dtype

注意

要在策略中使用此類,需要在重置時重置掩碼。這可以透過 TensorDictPrimer 轉換來實現,該轉換可以透過 make_tensordict_primer() 獲取。有關更多資訊,請參閱該方法。

示例

>>> from tensordict import TensorDict
>>> module = ConsistentDropoutModule(p = 0.1)
>>> td = TensorDict({"x": torch.randn(3, 4)}, [3])
>>> module(td)
TensorDict(
    fields={
        mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False),
        x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict)[原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

make_tensordict_primer()[原始碼]

建立一個 tensordict primer,供環境在重置呼叫期間生成隨機掩碼。

另請參閱

torchrl.modules.utils.get_primers_from_module() 用於生成給定模組的所有 primer 的方法。

模組。

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> from torchrl.envs import GymEnv, StepCounter, SerialEnv
>>> m = Seq(
...     Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]),
...     ConsistentDropoutModule(
...         p=0.5,
...         input_shape=(2, 4),
...         in_keys="intermediate",
...     ),
...     Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
... )
>>> primer = get_primers_from_module(m)
>>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
>>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
>>> env = env.append_transform(primer)
>>> r = env.rollout(10, m, break_when_any_done=False)
>>> mask = [k for k in r.keys() if k.startswith("mask")][0]
>>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
>>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源