快捷方式

TanhModule

class torchrl.modules.tensordict_module.TanhModule(*args, **kwargs)[原始碼]

用於具有有界動作空間的確定性策略的 Tanh 模組。

此變換用作 TensorDictModule 層,將網路輸出對映到有界空間。

引數:
  • in_keys (list of strtuples of str) – 模組的輸入鍵。

  • out_keys (list of strtuples of str, optional) – 模組的輸出鍵。如果未提供,則假定與 in_keys 相同的鍵。

關鍵字引數:
  • spec (TensorSpec, optional) – 如果提供,則為輸出的 spec。如果提供 Composite,則其鍵必須與 out_keys 中的鍵匹配。否則,將假定 out_keys 的鍵,並對所有輸出使用相同的 spec。

  • low (float, np.ndarray or torch.Tensor) – 空間的下界。如果未提供且未提供 spec,則假定為 -1。如果提供了 spec,則將檢索 spec 的最小值。

  • high (float, np.ndarray or torch.Tensor) – 空間的上限。如果未提供且未提供 spec,則假定為 1。如果提供了 spec,則將檢索 spec 的最大值。

  • clamp (bool, optional) – 如果為 True,則輸出將被限制在邊界內,但與邊界至少有一個最小解析度。預設為 False

示例

>>> from tensordict import TensorDict
>>> # simplest use case: -1 - 1 boundaries
>>> torch.manual_seed(0)
>>> in_keys = ["action"]
>>> mod = TanhModule(
...     in_keys=in_keys,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([ 1.0000, -0.9944, -1.0000,  1.0000, -1.0000])
>>> # low and high can be customized
>>> low = -2
>>> high = 1
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
... )
>>> data = TensorDict({"action": torch.randn(5) * 10}, [])
>>> data = mod(data)
>>> data['action']
tensor([-2.0000,  0.9991,  1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
>>> from torchrl.data import Bounded
>>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
...     in_keys=in_keys,
...     low=low,
...     high=high,
...     spec=spec,
...     clamp=False,
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
>>> spec = Composite(
...     a=Bounded(-3, 0, shape=()),
...     b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
...     in_keys=in_keys,
...     spec=spec,
... )
>>> data = TensorDict(
...     {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
>>> data = mod(data)
>>> data['a']
tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
        -2.5686, -2.8602])
>>> data['b']
tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
        0.5760])
forward(tensordict=None)[原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源