DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[source]¶
Dreamer 模型損失。
計算 Dreamer 世界模型的損失。該損失由 RSSM 的先驗和後驗之間的 KL 散度、重構觀察值的重構損失以及預測獎勵的獎勵損失組成。
參考: https://arxiv.org/abs/1912.01603。
- 引數:
world_model (TensorDictModule) – 世界模型。
lambda_kl (
float, optional) – KL 散度損失的權重。預設為:1.0。lambda_reco (
float, optional) – 重構損失的權重。預設為:1.0。lambda_reward (
float, optional) – 獎勵損失的權重。預設為:1.0。reco_loss (str, optional) – 重構損失。預設為:“l2”。
reward_loss (str, optional) – 獎勵損失。預設為:“l2”。
free_nats (int, optional) – free nats。預設為:3。
delayed_clamp (bool, optional) – 如果為
True,則 KL 鉗位在平均之後進行。如果為 False(預設值),則 KL 散度首先鉗位到 free nats 值,然後進行平均。global_average (bool, optional) – 如果為
True,則損失將針對所有維度進行平均。否則,將對所有非批處理/時間維度進行求和,然後對批處理和時間進行平均。預設為:False。
- default_keys¶
別名:
_AcceptedKeys
- forward(tensordict: TensorDict) Tensor[source]¶
它旨在讀取一個輸入的 TensorDict 並返回另一個包含名為“loss*”的損失鍵的 tensordict。
將損失分解為其組成部分可以被訓練器用於在訓練過程中記錄各種損失值。輸出 tensordict 中存在的其他標量也將被記錄。
- 引數:
tensordict – 一個輸入的 tensordict,包含計算損失所需的值。
- 返回:
一個沒有批處理維度的新 tensordict,其中包含各種損失標量,這些標量將被命名為“loss*”。重要的是,損失必須以這個名稱返回,因為它們將在反向傳播之前被訓練器讀取。