快捷方式

UpdateWeights

class torchrl.trainers.UpdateWeights(collector: DataCollectorBase, update_weights_interval: int, policy_weights_getter: Callable[[Any], Any] | None = None)[原始碼]

一個收集器權重更新鉤子類。

當收集器的策略權重位於與 Trainer 正在訓練的策略權重不同的裝置上時,必須使用此鉤子。在這種情況下,這些權重必須定期同步。如果裝置匹配,則此操作將不執行任何操作。

引數:
  • collector (DataCollectorBase) – 需要同步策略權重的收集器。

  • update_weights_interval (int) – 同步必須發生的間隔(以收集的批次數計)。

示例

>>> update_weights = UpdateWeights(trainer.collector, T)
>>> trainer.register_op("post_steps", update_weights)
register(trainer: Trainer, name: str = 'update_weights')[原始碼]

Registers the hook in the trainer at a default location.

引數:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

注意

To register the hook at another location than the default, use register_op().

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源