快捷方式

GAILLoss

class torchrl.objectives.GAILLoss(*args, **kwargs)[source]

TorchRL 實現的生成對抗模仿學習 (GAIL) 損失。

“Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476> 中提出

引數:

discriminator_network (TensorDictModule) – 隨機策略

關鍵字引數:
  • use_grad_penalty (bool, optional) – 是否使用梯度懲罰。預設值:False

  • gp_lambda (float, optional) – 梯度懲罰 lambda。預設值:10

  • reduction (str, optional) – 指定應用於輸出的約簡:"none" | "mean" | "sum""none":不應用約簡,"mean":輸出的總和將除以輸出中的元素數量,"sum":將對輸出進行求和。預設為 "mean"

default_keys

別名:_AcceptedKeys

forward(tensordict: TensorDictBase = None) TensorDictBase[source]

forward 方法。

計算判別器損失和梯度懲罰(如果 use_grad_penalty 設定為 True)。如果 use_grad_penalty 設定為 True,還將返回解耦的梯度懲罰損失以用於日誌記錄。要檢視輸入 tensordict 所期望的鍵以及輸出所期望的鍵,請檢視類的 “in_keys”“out_keys” 屬性。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源