GAILLoss¶
- class torchrl.objectives.GAILLoss(*args, **kwargs)[source]¶
TorchRL 實現的生成對抗模仿學習 (GAIL) 損失。
在 “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476> 中提出
- 引數:
discriminator_network (TensorDictModule) – 隨機策略
- 關鍵字引數:
use_grad_penalty (bool, optional) – 是否使用梯度懲罰。預設值:
False。gp_lambda (
float, optional) – 梯度懲罰 lambda。預設值:10。reduction (str, optional) – 指定應用於輸出的約簡:
"none"|"mean"|"sum"。"none":不應用約簡,"mean":輸出的總和將除以輸出中的元素數量,"sum":將對輸出進行求和。預設為"mean"。
- default_keys¶
別名:
_AcceptedKeys