快捷方式

EGreedyModule

class torchrl.modules.EGreedyModule(*args, **kwargs)[source]

Epsilon-Greedy 探索模組。

此模組根據 epsilon-greedy 探索策略隨機更新 tensordict 中的動作。每次呼叫時,都會根據一定的機率閾值執行隨機抽取(每個動作一個)。如果成功,則相應的動作將被替換為從提供的動作規範中抽取的隨機樣本。其他動作保持不變。

引數:
  • spec (TensorSpec) – 用於取樣動作的規範。

  • eps_init (scalar, optional) – 初始 epsilon 值。預設值:1.0

  • eps_end (scalar, optional) – 最終 epsilon 值。預設值:0.1

  • annealing_num_steps (int, optional) – epsilon 達到 eps_end 值所需的步數。預設為 1000

關鍵字引數:
  • action_key (NestedKey, optional) – 在輸入 tensordict 中找到動作的鍵。預設為 "action"

  • action_mask_key (NestedKey, optional) – 在輸入 tensordict 中找到動作掩碼的鍵。預設為 None(表示沒有掩碼)。

  • device (torch.device, optional) – 探索模組的裝置。

注意

至關重要的是在訓練迴圈中加入對 step() 的呼叫,以更新探索因子。由於很難捕獲這種遺漏,因此如果省略此項,將不會發出警告或異常!

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.modules import EGreedyModule, Actor
>>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
>>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = TensorDictSequential(policy,  EGreedyModule(eps_init=0.2))
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td).get("action"))
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9055, -0.9277, -0.6295, -0.2532],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<AddBackward0>)
forward(tensordict: TensorDictBase) TensorDictBase[source]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

step(frames: int = 1) None[source]

epsilon 衰減的一步。

在此方法呼叫 self.annealing_num_steps 次後,後續呼叫將無效。

引數:

frames (int, optional) – 自上次呼叫以來經過的幀數。預設為 1

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源