快捷方式

ReplayBufferTrainer

class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: int | None = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, max_dims: Sequence[int] | None = None, iterate: bool = False)[原始碼]

回放緩衝區鉤子提供程式。

引數:
  • replay_buffer (TensorDictReplayBuffer) – 要使用的回放緩衝區。

  • batch_size (int, optional) – 從最新收集或從回放緩衝區取樣資料時的批次大小。如果未提供,則將使用回放緩衝區的批次大小(對於未更改的批次大小,這是首選選項)。

  • memmap (bool, optional) – 如果為 True,則建立 memmap tensordict。預設為 False

  • device (device, optional) – 必須放置樣本的裝置。預設為 None

  • flatten_tensordicts (bool, optional) – 如果為 True,則 tensordicts 將被展平(或等效地使用從收集器獲得的有效掩碼進行掩碼),然後傳遞給回放緩衝區。否則,除了填充外,不會進行其他轉換(請參閱下面的 max_dims 引數)。預設為 False

  • max_dims (sequence of int, optional) – 如果 flatten_tensordicts 設定為 False,這將是一個列表,其長度為提供的 tensordicts 的 batch_size,表示每個 tensordict 的最大大小。如果提供,此大小列表將用於填充 tensordict 並使其形狀匹配,然後將它們傳遞給回放緩衝區。如果沒有最大值,應提供 -1 值。

  • iterate (bool, optional) – 如果為 True,則回放緩衝區將迴圈迭代。預設為 False(將呼叫 sample())。

示例

>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N)
>>> trainer.register_op("batch_process", rb_trainer.extend)
>>> trainer.register_op("process_optim_batch", rb_trainer.sample)
>>> trainer.register_op("post_loss", rb_trainer.update_priority)
register(trainer: Trainer, name: str = 'replay_buffer')[原始碼]

Registers the hook in the trainer at a default location.

引數:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

注意

To register the hook at another location than the default, use register_op().

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源