快捷方式

LossModule

class torchrl.objectives.LossModule(*args, **kwargs)[原始碼]

RL 損失的父類。

LossModule 繼承自 nn.Module。它被設計用於讀取一個輸入的 TensorDict 並返回另一個 tensordict,其中包含名為 "loss_*" 的損失鍵。

將損失分解為其組成部分可以被訓練器用於在訓練過程中記錄各種損失值。輸出 tensordict 中存在的其他標量也將被記錄。

變數:

default_value_estimator – 類的預設值型別。需要值估計的損失會配備一個預設值指標。這個類屬性指示了將使用哪個值估計器,如果未指定其他值估計器的話。可以透過 make_value_estimator() 方法更改值估計器。

預設情況下,forward 方法始終使用 gh torchrl.envs.ExplorationType.MEAN 進行裝飾。

要利用透過 set_keys() 配置 tensordict 鍵的能力,子類必須定義一個 _AcceptedKeys dataclass。這個 dataclass 應包含所有打算可配置的鍵。此外,子類必須實現 :meth:._forward_value_estimator_keys() 方法。此函式對於將任何修改後的 tensordict 鍵轉發到底層 value_estimator 至關重要。

示例

>>> class MyLoss(LossModule):
>>>     @dataclass
>>>     class _AcceptedKeys:
>>>         action = "action"
>>>
>>>     def _forward_value_estimator_keys(self, **kwargs) -> None:
>>>         pass
>>>
>>> loss = MyLoss()
>>> loss.set_keys(action="action2")

注意

當將一個被包裝或增強了探索模組的策略傳遞給 loss 時,我們希望透過 set_exploration_type(<exploration>) 來停用探索,其中 <exploration> 可以是 ExplorationType.MEANExplorationType.MODEExplorationType.DETERMINISTIC。預設值是 DETERMINISTIC,它透過 deterministic_sampling_mode loss 屬性設定。如果需要其他探索模式(或者 DETERMINISTIC 不可用),可以更改此屬性的值,這將改變模式。

convert_to_functional(module: TensorDictModule, module_name: str, expand_dim: int | None = None, create_target_params: bool = False, compare_against: list[Parameter] | None = None, **kwargs) None[原始碼]

將模組轉換為函式式以在損失中使用。

引數:
  • module (TensorDictModule相容) – 一個有狀態的 tensordict 模組。來自此模組的引數將被隔離在 <module_name>_params 屬性中,而模組的無狀態版本將註冊在 module_name 屬性下。

  • module_name (str) – 模組將被找到的名稱。該模組的引數將在 loss_module.<module_name>_params 下找到,而模組本身將在 loss_module.<module_name> 下找到。

  • expand_dim (int, optional) –

    如果提供,模組的引數將沿第一個維度擴充套件 N 次,其中 N = expand_dim。當使用具有多個配置的目標網路時,應使用此選項。

    注意

    如果提供了 compare_against 值列表,則生成的引數將只是原始引數的解耦擴充套件。如果未提供 compare_against,則引數的值將在引數內容的最小值和最大值之間均勻重取樣。

  • create_target_params (bool, 可選) – 如果為 True,則引數的解耦副本將可用於為名稱為 loss_module.<module_name>_target_params 的目標網路提供輸入。如果為 False(預設),此屬性仍可用,但它將是引數的解耦例項,而不是副本。換句話說,引數值的任何修改將直接反映在目標引數中。

  • compare_against (引數的可迭代物件, 可選) – 如果提供,此引數列表將用作模組引數的比較集。如果引數被擴充套件(expand_dim > 0),則模組生成的引數將是原始引數的簡單擴充套件。否則,生成的引數將是原始引數的解耦版本。如果為 None,則生成的引數將按預期攜帶梯度。

forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

它旨在讀取一個輸入的 TensorDict 並返回另一個包含名為“loss*”的損失鍵的 tensordict。

將損失分解為其組成部分可以被訓練器用於在訓練過程中記錄各種損失值。輸出 tensordict 中存在的其他標量也將被記錄。

引數:

tensordict – 一個輸入的 tensordict,包含計算損失所需的值。

返回:

一個沒有批處理維度的新 tensordict,其中包含各種損失標量,這些標量將被命名為“loss*”。重要的是,損失必須以這個名稱返回,因為它們將在反向傳播之前被訓練器讀取。

from_stateful_net(network_name: str, stateful_net: Module)[原始碼]

根據有狀態的網路版本填充模型的引數。

有關如何收集網路的狀態化版本,請參閱 get_stateful_net()

引數:
  • network_name (str) – 要重置的網路名稱。

  • stateful_net (nn.Module) – 應從中收集引數的狀態化網路。

property functional

模組是否功能化。

除非經過專門設計使其不具有功能性,否則所有損失都具有功能性。

get_stateful_net(network_name: str, copy: bool | None = None)[原始碼]

返回網路的狀態化版本。

這可用於初始化引數。

這些網路通常開箱即用,無法呼叫,需要呼叫 vmap 才能執行。

引數:
  • network_name (str) – 要收集的網路名稱。

  • copy (bool, optional) –

    如果為 True,則會進行網路的深複製。預設為 True

    注意

    如果模組不是函式式的,則不會進行復制。

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[原始碼]

值函式建構函式。

如果需要非預設值函式,必須使用此方法構建。

引數:
  • value_type (ValueEstimators) – 一個 ValueEstimators 列舉型別,指示要使用的值函式。如果未提供,將使用儲存在 default_value_estimator 屬性中的預設值。生成的估值器類將註冊在 self.value_type 中,以便將來進行改進。

  • **hyperparams – 用於值函式的超引數。如果未提供,將使用 default_value_kwargs() 中指示的值。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
named_parameters(prefix: str = '', recurse: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]][原始碼]

返回模組引數的迭代器,同時生成引數的名稱和引數本身。

引數:
  • prefix (str) – 為所有引數名稱新增字首。

  • recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的引數。預設為 True。

產生:

(str, Parameter) – 包含名稱和引數的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter][原始碼]

返回模組引數的迭代器。

這通常傳遞給最佳化器。

引數:

recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

產生:

Parameter – 模組引數

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
reset_parameters_recursive()[原始碼]

重置模組的引數。

set_keys(**kwargs) None[原始碼]

設定 tensordict 鍵名。

示例

>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
property value_estimator: ValueEstimatorBase

價值函式將獎勵和即將到來的狀態/狀態-動作對的價值估計值融合到價值網路的目標價值估計中。

property vmap_randomness

Vmap 隨機模式。

vmap 的隨機性模式控制當處理具有隨機結果的函式(如 randn()rand())時,vmap() 應該如何執行。如果設定為 “error”,任何隨機函式都將引發異常,表明 vmap 不知道如何處理該隨機呼叫。

如果設定為 “different”,則 vmap 正在呼叫的批次中的每個元素將表現不同。如果設定為 “same”,則 vmap 會將相同的結果複製到所有元素。

vmap_randomness 預設情況下是 “error”(如果未檢測到任何隨機模組),而在其他情況下為 “different”。預設情況下,只有有限數量的模組被列為隨機模組,但可以使用 add_random_module() 函式來擴充套件此列表。

此屬性支援設定其值。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源