• 文件 >
  • 使用預訓練模型
快捷方式

使用預訓練模型

本教程將解釋如何在 TorchRL 中使用預訓練模型。

import tempfile

在本教程結束時,您將能夠使用預訓練模型進行高效的影像表示,並對其進行微調。

TorchRL 提供預訓練模型,這些模型可用作變換(transforms)或策略(policy)的元件。由於語義相同,它們可以互換地用於其中一種或另一種上下文。在本教程中,我們將使用 R3M(https://arxiv.org/abs/2203.12601),但其他模型(例如 VIP)同樣有效。

import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import Compose, R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

讓我們先建立一個環境。為了簡單起見,我們將使用一個通用的 gym 環境。實際上,這在更具挑戰性的具身 AI 環境中也能正常工作(例如,看看我們的 Habitat 包裝器)。

base_env = GymEnv("Ant-v4", from_pixels=True, device=device)

讓我們獲取預訓練模型。我們透過 `download=True` 標誌來請求模型的預訓練版本。預設情況下,此選項是關閉的。接下來,我們將變換新增到環境中。實際上,發生的情況是,收集到的每個資料批次都將透過該變換,並在輸出 tensordict 的“r3m_vec”條目中對映。然後,我們的策略(由單個 MLP 層組成)將讀取此向量並計算相應的動作。

r3m = R3MTransform(
    "resnet50",
    in_keys=["pixels"],
    download=False,  # Turn to true for real-life testing
)
env_transformed = TransformedEnv(base_env, r3m)
net = nn.Sequential(
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.Linear(128, base_env.action_spec.shape[-1], device=device),
)
policy = Actor(net, in_keys=["r3m_vec"])

讓我們檢查策略的引數數量

print("number of params:", len(list(policy.parameters())))

我們收集 32 步的 rollout 並列印其輸出

rollout = env_transformed.rollout(32, policy)
print("rollout with transform:", rollout)

對於微調,我們將變換整合到策略中,並將引數設定為可訓練。實際上,限制在引數的子集(例如 MLP 的最後一層)上可能會更明智。

r3m.train()
policy = TensorDictSequential(r3m, policy)
print("number of params after r3m is integrated:", len(list(policy.parameters())))

再次,我們使用 R3M 收集一個 rollout。輸出的結構略有變化,因為現在環境返回的是畫素(而不是嵌入)。嵌入“r3m_vec”是策略的中間結果。

rollout = base_env.rollout(32, policy)
print("rollout, fine tuning:", rollout)

我們將變換從環境切換到策略的輕鬆程度,得益於兩者都表現得像 `TensorDictModule`:它們都有一個 `“in_keys”` 和 `“out_keys”` 集合,使得在不同上下文中輕鬆讀寫輸出成為可能。

為了結束本教程,讓我們看看如何使用 R3M 讀取儲存在回放緩衝區中的影像(例如,在離線 RL 上下文中)。首先,讓我們構建我們的資料集

from torchrl.data import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name
storage = LazyMemmapStorage(1000, scratch_dir=buffer_scratch_dir)
rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m))

我們現在可以收集資料(為了我們的目的,是隨機 rollout)並用它來填充回放緩衝區

total = 0
while total < 1000:
    tensordict = base_env.rollout(1000)
    rb.extend(tensordict)
    total += tensordict.numel()

讓我們檢查一下我們的回放緩衝區儲存是什麼樣的。它不應該包含“r3m_vec”條目,因為我們還沒有使用它

print("stored data:", storage._storage)

在取樣時,資料將透過 R3M 變換,為我們提供所需的處理後的資料。透過這種方式,我們可以使用由影像組成的資料集對演算法進行離線訓練

batch = rb.sample(32)
print("data after sampling:", batch)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源