快捷方式

TorchRL 簡介

此演示已在 ICML 2022 的行業演示日上展出。

它對 TorchRL 的功能進行了良好的概述。如果您有關於此演示的問題或評論,請隨時聯絡 vmoens@fb.com 或提交 issue。

TorchRL 是一個用於 PyTorch 的開源強化學習 (RL) 庫。

https://github.com/pytorch/rl

PyTorch 生態系統團隊 (Meta) 已決定投資此庫,以提供一個領先的平臺來開發研究環境中的 RL 解決方案。

它提供 PyTorch 和 **以 Python 為優先** 的低階和高階 **抽象** #,旨在高效、文件齊全且經過適當測試。程式碼旨在支援 RL 研究。其中大部分是用高度模組化的 Python 編寫的,以便研究人員可以輕鬆地交換元件、轉換它們或輕鬆編寫新元件。

此倉庫試圖與現有的 PyTorch 生態系統庫保持一致,因為它有一個數據集支柱 (torchrl/envs)、轉換、模型、資料實用程式 (例如收集器和容器) 等。TorchRL 旨在擁有儘可能少的依賴項 (Python 標準庫、numpy 和 pytorch)。常見的環境庫 (例如 OpenAI gym) 僅為可選。

內容:
../_images/aafig-7040c2adc5d51442b6d7aebd109ebfeb53d8fc90.svg

與許多其他領域不同,RL 更側重於 *演算法* 而非媒體。因此,很難建立真正獨立的元件。

TorchRL 不是什麼

  • 演算法集合:我們不打算提供 SOTA 的 RL 演算法實現,但我們僅提供這些演算法作為如何使用該庫的示例。

  • 研究框架:TorchRL 的模組化有兩種形式。首先,我們嘗試構建可重用元件,以便可以輕鬆地將它們相互交換。其次,我們盡最大努力確保元件可以獨立於庫的其他部分使用。

TorchRL 的核心依賴項非常少,主要是 PyTorch 和 numpy。所有其他依賴項 (gym、torchvision、wandb / tensorboard) 都是可選的。

資料

TensorDict

import torch
from tensordict import TensorDict

讓我們建立一個 TensorDict。建構函式接受許多不同的格式,例如透過字典或關鍵字引數傳遞

batch_size = 5
data = TensorDict(
    key1=torch.zeros(batch_size, 3),
    key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    batch_size=[batch_size],
)
print(data)

您可以沿其 batch_size 索引 TensorDict,還可以查詢鍵。

print(data[2])
print(data["key1"] is data.get("key1"))

以下展示瞭如何堆疊多個 TensorDict。在編寫 rollout 迴圈時,這尤其有用!

data1 = TensorDict(
    {
        "key1": torch.zeros(batch_size, 1),
        "key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data2 = TensorDict(
    {
        "key1": torch.ones(batch_size, 1),
        "key2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

data = torch.stack([data1, data2], 0)
data.batch_size, data["key1"]

這裡有一些 TensorDict 的其他功能:檢視、置換、共享記憶體或展開。

print(
    "view(-1): ",
    data.view(-1).batch_size,
    data.view(-1).get("key1").shape,
)

print("to device: ", data.to("cpu"))

# print("pin_memory: ", data.pin_memory())

print("share memory: ", data.share_memory_())

print(
    "permute(1, 0): ",
    data.permute(1, 0).batch_size,
    data.permute(1, 0).get("key1").shape,
)

print(
    "expand: ",
    data.expand(3, *data.batch_size).batch_size,
    data.expand(3, *data.batch_size).get("key1").shape,
)

您也可以建立 **巢狀資料**。

data = TensorDict(
    source={
        "key1": torch.zeros(batch_size, 3),
        "key2": TensorDict(
            source={"sub_key1": torch.zeros(batch_size, 2, 1)},
            batch_size=[batch_size, 2],
        ),
    },
    batch_size=[batch_size],
)
data

回放緩衝區

回放緩衝區 是許多 RL 演算法中的關鍵組成部分。TorchRL 提供了一系列回放緩衝區實現。大多數基本功能將適用於任何資料結構 (列表、元組、字典),但要充分利用回放緩衝區並實現快速的讀寫訪問,應優先使用 TensorDict API。

from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

rb = ReplayBuffer(collate_fn=lambda x: x)

新增可以使用 add() (n=1) 或 extend() (n>1) 完成。

rb.add(1)
rb.sample(1)
rb.extend([2, 3])
rb.sample(3)

也可以使用優先順序回放緩衝區

rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)

這裡有一些使用 replaybuffer 與 data_stack 的示例。使用它們可以輕鬆地為多種用例抽象回放緩衝區的行為。

collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)

rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())

torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer

rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
data_sample = rb.sample(2).contiguous()
print(data_sample)

print(data_sample["index"])

data_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(data_sample)

for i, val in enumerate(rb._sampler._sum_tree):
    print(i, val)
    if i == len(rb):
        break

環境

TorchRL 提供了一系列 環境 包裝器和實用程式。

Gym 環境

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym

from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend

gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")

data = env.reset()
env.rand_step(data)

更改環境配置

env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()

env.close()
del env

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

環境轉換

轉換類似於 Gym 包裝器,但 API 更接近 torchvision 的 torch.distributions 轉換。有多種 轉換 可供選擇。

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

env.reset()

print("env: ", env)
print("last transform parent: ", env.transform[2].parent)

向量化環境

向量化/並行環境可以提供顯著的速度提升。

from torchrl.envs import ParallelEnv


def make_env():
    # You can control whether to use gym or gymnasium for your env
    with set_gym_backend("gym"):
        return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)


base_env = ParallelEnv(
    4,
    make_env,
    mp_start_method="fork",  # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
    base_env, Compose(StepCounter(), ToTensorImage())
)  # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()

print(env.action_spec)

env.close()
del env

模組

庫中可以找到多個 模組 (實用程式、模型和包裝器)。

模型

MLP 模型示例

from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims

net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)

CNN 模型示例

cnn = ConvNet(
    num_cells=[32, 64],
    kernel_sizes=[8, 4],
    strides=[2, 1],
    aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape)  # last tensor is squashed

TensorDictModules

一些模組 專門設計用於處理 tensordict 輸入。

from tensordict.nn import TensorDictModule

data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"])
td_module(data)
print(data)

模組序列

透過 TensorDictSequential,可以輕鬆建立模組序列。

from tensordict.nn import TensorDictSequential

backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
    backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])

sequence = TensorDictSequential(backbone, actor, value)
print(sequence)

print(sequence.in_keys, sequence.out_keys)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
backbone(data)
actor(data)
value(data)

data = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
sequence(data)
print(data)

函數語言程式設計 (整合 / Meta-RL)

函式式呼叫從未如此簡單。使用 from_module() 提取引數,並使用 to_module() 替換它們。

from tensordict import from_module

params = from_module(sequence)
print("extracted params", params)

使用 tensordict 進行函式式呼叫

with params.to_module(sequence):
    data = sequence(data)

VMAP

快速執行相似架構的多個副本對於快速訓練模型至關重要。vmap() 正是為了實現這一點而量身定製的。

from torch import vmap

params_expand = params.expand(4)


def exec_sequence(params, data):
    with params.to_module(sequence):
        return sequence(data)


tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, data)
print(tensordict_exp)

專用類

TorchRL 還提供了一些對輸出值進行檢查的專用模組。

torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule

spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
    module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(data)["action"]

data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(data)["action"]  # safe=True projects the result within the set

Actor 類具有預定義的輸出鍵 ("action")。

from torchrl.modules import Actor

base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
data = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(data)  # action is the default value

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)

藉助 tensordict.nn API,使用機率模型也變得更加容易。

from torchrl.modules import NormalParamExtractor, TanhNormal

td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
    nn.Linear(5, 4), NormalParamExtractor()
)  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=False,
    ),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=True,
    ),
)
td_module(td)
print(td)

透過上下文管理器 set_exploration_type 可以實現對隨機性和取樣策略的控制。

from torchrl.envs.utils import ExplorationType, set_exploration_type

td = TensorDict({"input": torch.randn(3, 5)}, [3])

torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
    td_module(td)
    print("random:", td["action"])

with set_exploration_type(ExplorationType.DETERMINISTIC):
    td_module(td)
    print("mode:", td["action"])

使用環境和模組

讓我們看看如何結合使用環境和模組。

from torchrl.envs.utils import step_mdp

env = GymEnv("Pendulum-v1")

action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
    actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)

torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = TensorDict(batch_size=[max_steps])
for i in range(max_steps):
    actor(data)
    data_stack[i] = env.step(data)
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs

tensordicts_prealloc = data_stack.clone()
print("total steps:", i)
print(data_stack)
# equivalent
torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
data = env.reset()
data_stack = []
for _ in range(max_steps):
    actor(data)
    data_stack.append(env.step(data))
    if data["done"].any():
        break
    data = step_mdp(data)  # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(data_stack, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout


(tensordict_rollout == tensordicts_prealloc).all()

from tensordict.nn import TensorDictModule

收集器

我們還提供了一套 資料收集器,它們可以自動收集每批所需數量的幀。它們適用於從單節點、單工作程序到多節點、多工作程序的各種設定。

from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector

from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv

EnvCreator 確保我們可以從一個程序將 lambda 函式傳送到另一個程序。我們使用 SerialEnv 以簡化 (單工作程序),但對於較大的任務,ParallelEnv (多工作程序) 會更合適。

注意

多程序環境和多程序收集器可以結合使用!

parallel_env = SerialEnv(
    3,
    EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])

同步多程序資料收集器

devices = ["cpu", "cpu"]

collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector

非同步多程序資料收集器

此類允許您在模型訓練時收集資料。這在離策略設定中尤其有用,因為它將推理與模型訓練解耦。資料以先到先得的方式交付 (工作程序將排隊等待其結果)。

collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)

for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env

目標

目標 是編寫新演算法時的主要入口點。

from torchrl.objectives import DDPGLoss

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])


class ConcatModule(nn.Linear):
    def forward(self, obs, action):
        return super().forward(torch.cat([obs, action], -1))


value_module = ConcatModule(4, 1)
value = TensorDictModule(
    value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)

loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
data = TensorDict(
    {
        "observation": torch.randn(10, 3),
        "next": {
            "observation": torch.randn(10, 3),
            "reward": torch.randn(10, 1),
            "done": torch.zeros(10, 1, dtype=torch.bool),
        },
        "action": torch.randn(10, 1),
    },
    batch_size=[10],
    device="cpu",
)
loss_td = loss_fn(data)

print(loss_td)

print(data)

安裝庫

該庫已在 PyPI 上釋出:pip install torchrl 更多資訊請參閱 README

貢獻

我們正在積極尋找貢獻者和早期使用者。如果您正在從事 RL 工作 (或者只是好奇),請嘗試一下!給我們反饋:TorchRL 的成功取決於它在多大程度上滿足研究人員的需求。為此,我們需要他們的投入!由於該庫尚處於起步階段,現在是塑造它的絕佳時機!

更多資訊請參閱 貢獻指南

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源