UpdateWeights¶
- class torchrl.trainers.UpdateWeights(collector: DataCollectorBase, update_weights_interval: int, policy_weights_getter: Callable[[Any], Any] | None = None)[原始碼]¶
一個收集器權重更新鉤子類。
當收集器的策略權重位於與 Trainer 正在訓練的策略權重不同的裝置上時,必須使用此鉤子。在這種情況下,這些權重必須定期同步。如果裝置匹配,則此操作將不執行任何操作。
- 引數:
collector (DataCollectorBase) – 需要同步策略權重的收集器。
update_weights_interval (int) – 同步必須發生的間隔(以收集的批次數計)。
示例
>>> update_weights = UpdateWeights(trainer.collector, T) >>> trainer.register_op("post_steps", update_weights)
- register(trainer: Trainer, name: str = 'update_weights')[原始碼]¶
Registers the hook in the trainer at a default location.
- 引數:
trainer (Trainer) – the trainer where the hook must be registered.
name (str) – the name of the hook.
注意
To register the hook at another location than the default, use
register_op().