注意
跳轉到末尾 下載完整的示例程式碼。
使用 TensorDict 處理資料集¶
在本教程中,我們將演示如何使用 TensorDict 來高效且透明地載入和管理訓練管道中的資料。本教程在很大程度上借鑑了 PyTorch 快速入門教程,但進行了修改以展示 TensorDict 的用法。
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
torchvision.datasets 模組包含許多方便的預準備資料集。在本教程中,我們將使用相對簡單的 FashionMNIST 資料集。每張圖片都是一件服裝,目標是根據圖片對服裝的型別進行分類(例如,“包”、“運動鞋”等)。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
我們將建立兩個 tensordict,分別用於訓練資料和測試資料。我們建立記憶體對映張量來儲存資料。這將使我們能夠有效地從磁碟載入轉換後資料的批次,而不是反覆載入和轉換單個影像。
首先,我們建立 MemoryMappedTensor 容器。
training_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(training_data), *training_data[0][0].squeeze().shape),
dtype=torch.float32,
),
"targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
},
batch_size=[len(training_data)],
device=device,
)
test_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
),
"targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
},
batch_size=[len(test_data)],
device=device,
)
然後,我們可以遍歷資料來填充記憶體對映張量。這需要一些時間,但提前執行轉換將在後續的訓練過程中節省重複的工作。
DataLoaders¶
我們將從 torchvision 提供的 Datasets 建立 DataLoaders,以及從我們的記憶體對映 TensorDicts 建立 DataLoaders。
TensorDict 實現 __len__ 和 __getitem__(以及 __getitems__),因此我們可以像使用 map-style Dataset 一樣使用它,並直接從中建立 DataLoader。請注意,由於 TensorDict 已經可以處理批處理索引,因此無需進行 collate 操作,所以我們將恆等函式作為 collate_fn 傳遞。
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
Model¶
我們使用了與 快速入門教程 中相同的模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
最佳化引數¶
我們將使用隨機梯度下降和交叉熵損失來最佳化模型的引數。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
基於 TensorDict 的 DataLoader 的訓練迴圈非常相似,我們只需調整資料解包的方式,以使用 TensorDict 提供的更明確的基於鍵的檢索。`.contiguous()` 方法會載入儲存在 memmap 張量中的資料。
def train_td(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data["images"].contiguous(), data["targets"].contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_td(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch["images"].contiguous(), batch["targets"].contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_td:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
fields={
images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.295151 [ 0/60000]
loss: 2.287965 [ 6400/60000]
loss: 2.263263 [12800/60000]
loss: 2.266189 [19200/60000]
loss: 2.250629 [25600/60000]
loss: 2.211262 [32000/60000]
loss: 2.238146 [38400/60000]
loss: 2.191725 [44800/60000]
loss: 2.188123 [51200/60000]
loss: 2.170130 [57600/60000]
Test Error:
Accuracy: 39.0%, Avg loss: 2.150614
Epoch 2
-------------------------
loss: 2.159945 [ 0/60000]
loss: 2.154417 [ 6400/60000]
loss: 2.091824 [12800/60000]
loss: 2.112408 [19200/60000]
loss: 2.068090 [25600/60000]
loss: 1.999538 [32000/60000]
loss: 2.045922 [38400/60000]
loss: 1.959862 [44800/60000]
loss: 1.966230 [51200/60000]
loss: 1.905551 [57600/60000]
Test Error:
Accuracy: 56.3%, Avg loss: 1.887337
Epoch 3
-------------------------
loss: 1.927120 [ 0/60000]
loss: 1.896457 [ 6400/60000]
loss: 1.773539 [12800/60000]
loss: 1.814664 [19200/60000]
loss: 1.712839 [25600/60000]
loss: 1.662633 [32000/60000]
loss: 1.699288 [38400/60000]
loss: 1.595371 [44800/60000]
loss: 1.624663 [51200/60000]
loss: 1.525723 [57600/60000]
Test Error:
Accuracy: 60.1%, Avg loss: 1.524303
Epoch 4
-------------------------
loss: 1.603213 [ 0/60000]
loss: 1.561387 [ 6400/60000]
loss: 1.406400 [12800/60000]
loss: 1.476314 [19200/60000]
loss: 1.357771 [25600/60000]
loss: 1.358757 [32000/60000]
loss: 1.380468 [38400/60000]
loss: 1.304820 [44800/60000]
loss: 1.337895 [51200/60000]
loss: 1.244839 [57600/60000]
Test Error:
Accuracy: 63.3%, Avg loss: 1.256509
Epoch 5
-------------------------
loss: 1.343973 [ 0/60000]
loss: 1.319552 [ 6400/60000]
loss: 1.153267 [12800/60000]
loss: 1.254595 [19200/60000]
loss: 1.123201 [25600/60000]
loss: 1.161435 [32000/60000]
loss: 1.184613 [38400/60000]
loss: 1.125289 [44800/60000]
loss: 1.156805 [51200/60000]
loss: 1.084467 [57600/60000]
Test Error:
Accuracy: 64.9%, Avg loss: 1.091612
TensorDict training done! time: 8.4945 s
Epoch 1
-------------------------
loss: 2.299966 [ 0/60000]
loss: 2.291062 [ 6400/60000]
loss: 2.265493 [12800/60000]
loss: 2.273356 [19200/60000]
loss: 2.247992 [25600/60000]
loss: 2.214662 [32000/60000]
loss: 2.228931 [38400/60000]
loss: 2.185137 [44800/60000]
loss: 2.188732 [51200/60000]
loss: 2.170628 [57600/60000]
Test Error:
Accuracy: 42.6%, Avg loss: 2.149621
Epoch 2
-------------------------
loss: 2.152856 [ 0/60000]
loss: 2.150230 [ 6400/60000]
loss: 2.082802 [12800/60000]
loss: 2.113469 [19200/60000]
loss: 2.062010 [25600/60000]
loss: 1.995835 [32000/60000]
loss: 2.027980 [38400/60000]
loss: 1.938653 [44800/60000]
loss: 1.948907 [51200/60000]
loss: 1.899682 [57600/60000]
Test Error:
Accuracy: 54.9%, Avg loss: 1.875343
Epoch 3
-------------------------
loss: 1.899561 [ 0/60000]
loss: 1.883063 [ 6400/60000]
loss: 1.748965 [12800/60000]
loss: 1.804443 [19200/60000]
loss: 1.698108 [25600/60000]
loss: 1.639669 [32000/60000]
loss: 1.662723 [38400/60000]
loss: 1.552907 [44800/60000]
loss: 1.583105 [51200/60000]
loss: 1.501382 [57600/60000]
Test Error:
Accuracy: 59.5%, Avg loss: 1.500250
Epoch 4
-------------------------
loss: 1.557939 [ 0/60000]
loss: 1.537942 [ 6400/60000]
loss: 1.371437 [12800/60000]
loss: 1.465200 [19200/60000]
loss: 1.346894 [25600/60000]
loss: 1.331310 [32000/60000]
loss: 1.351684 [38400/60000]
loss: 1.263354 [44800/60000]
loss: 1.307001 [51200/60000]
loss: 1.231721 [57600/60000]
Test Error:
Accuracy: 62.6%, Avg loss: 1.241147
Epoch 5
-------------------------
loss: 1.311746 [ 0/60000]
loss: 1.304771 [ 6400/60000]
loss: 1.124054 [12800/60000]
loss: 1.252720 [19200/60000]
loss: 1.126270 [25600/60000]
loss: 1.140746 [32000/60000]
loss: 1.170960 [38400/60000]
loss: 1.091258 [44800/60000]
loss: 1.138590 [51200/60000]
loss: 1.081223 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.083943
Training done! time: 35.4081 s
指令碼總執行時間: (0 分鐘 57.057 秒)