• 文件 >
  • 為資料集使用 tensorclass
快捷方式

使用 tensorclasses 載入資料集

在本教程中,我們將演示如何在訓練管道中高效且透明地載入和管理資料。本教程大量借鑑了 PyTorch 快速入門教程,但進行了修改以展示 tensorclass 的用法。請參閱使用 TensorDict 的相關教程。

import torch
import torch.nn as nn

from tensordict import MemoryMappedTensor, tensorclass
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu

torchvision.datasets 模組包含許多方便的預備資料集。在本教程中,我們將使用相對簡單的 FashionMNIST 資料集。每張圖片都是一件衣服,目標是對圖片中的衣服型別進行分類(例如,“包”、“運動鞋”等)。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:12, 363kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 682kB/s]
  3%|▎         | 852k/26.4M [00:00<00:13, 1.94MB/s]
 13%|█▎        | 3.47M/26.4M [00:00<00:03, 6.89MB/s]
 35%|███▍      | 9.24M/26.4M [00:00<00:01, 15.9MB/s]
 57%|█████▋    | 15.1M/26.4M [00:01<00:00, 21.5MB/s]
 79%|███████▉  | 20.8M/26.4M [00:01<00:00, 24.8MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.4MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 326kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:12, 361kB/s]
  4%|▍         | 197k/4.42M [00:00<00:07, 575kB/s]
 19%|█▉        | 852k/4.42M [00:00<00:01, 1.96MB/s]
 76%|███████▋  | 3.38M/4.42M [00:00<00:00, 6.68MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.08MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 59.8MB/s]

Tensorclasses 是 dataclasses,它們像 TensorDict 一樣,對其內容公開專用的 tensor 方法。當您要儲存的資料結構固定且可預測時,它們是很好的選擇。

除了指定內容外,我們還可以在定義類時將相關邏輯封裝為自定義方法。在這種情況下,我們將編寫一個 from_dataset 類方法,它接受一個數據集作為輸入,並建立一個包含資料集資料的 tensorclass。我們建立記憶體對映的 tensor 來儲存資料。這將使我們能夠高效地從磁碟載入轉換後資料的批次,而不是反覆載入和轉換單個影像。

@tensorclass
class FashionMNISTData:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_dataset(cls, dataset, device=None):
        data = cls(
            images=MemoryMappedTensor.empty(
                (len(dataset), *dataset[0][0].squeeze().shape), dtype=torch.float32
            ),
            targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
            batch_size=[len(dataset)],
            device=device,
        )
        for i, (image, target) in enumerate(dataset):
            data[i] = cls(images=image, targets=torch.tensor(target), batch_size=[])
        return data

我們將建立兩個 tensorclasses,分別用於訓練和測試資料。請注意,我們在這裡會產生一些開銷,因為我們要遍歷整個資料集,進行轉換並儲存到磁碟。

training_data_tc = FashionMNISTData.from_dataset(training_data, device=device)
test_data_tc = FashionMNISTData.from_dataset(test_data, device=device)

DataLoaders

我們將從 torchvision 提供的 Datasets 以及我們的記憶體對映 TensorDicts 建立 DataLoaders。

由於 TensorDict 實現了 __len____getitem__(以及 __getitems__),我們可以像使用 map-style Dataset 一樣使用它,並直接從中建立 DataLoader。請注意,由於 TensorDict 已經能夠處理批處理索引,因此無需進行 collate 操作,所以我們將 identity 函式作為 collate_fn 傳遞。

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)  # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size)  # noqa: TOR401

train_dataloader_tc = DataLoader(  # noqa: TOR401
    training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(  # noqa: TOR401
    test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

Model

我們使用了與 快速入門教程 中相同的模型。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = Net().to(device)
model_tc = Net().to(device)
model, model_tc
(Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
), Net(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
))

最佳化引數

我們將使用隨機梯度下降和交叉熵損失來最佳化模型的引數。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_tc = torch.optim.SGD(model_tc.parameters(), lr=1e-3)


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

我們基於 tensorclass 的 DataLoader 的訓練迴圈非常相似,我們只需要調整如何解包資料,以使用 tensorclass 提供的更明確的基於屬性的檢索。.contiguous() 方法會載入 memmap tensor 中儲存的資料。

def train_tc(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, data in enumerate(dataloader):
        X, y = data.images.contiguous(), data.targets.contiguous()

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


def test_tc(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch.images.contiguous(), batch.targets.contiguous()

            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )


for d in train_dataloader_tc:
    print(d)
    break

import time

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train_tc(train_dataloader_tc, model_tc, loss_fn, optimizer_tc)
    test_tc(test_dataloader_tc, model_tc, loss_fn)
print(f"Tensorclass training done! time: {time.time() - t0: 4.4f} s")

t0 = time.time()
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
FashionMNISTData(
    images=Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([64]),
    device=cpu,
    is_shared=False)
Epoch 1
-------------------------
loss: 2.319664 [    0/60000]
loss: 2.303494 [ 6400/60000]
loss: 2.282967 [12800/60000]
loss: 2.270802 [19200/60000]
loss: 2.254053 [25600/60000]
loss: 2.228572 [32000/60000]
loss: 2.237088 [38400/60000]
loss: 2.206554 [44800/60000]
loss: 2.196055 [51200/60000]
loss: 2.173318 [57600/60000]
Test Error:
 Accuracy: 44.4%, Avg loss: 2.165061

Epoch 2
-------------------------
loss: 2.176356 [    0/60000]
loss: 2.166144 [ 6400/60000]
loss: 2.111832 [12800/60000]
loss: 2.126604 [19200/60000]
loss: 2.076978 [25600/60000]
loss: 2.019964 [32000/60000]
loss: 2.044961 [38400/60000]
loss: 1.974381 [44800/60000]
loss: 1.967282 [51200/60000]
loss: 1.902837 [57600/60000]
Test Error:
 Accuracy: 57.9%, Avg loss: 1.903539

Epoch 3
-------------------------
loss: 1.934818 [    0/60000]
loss: 1.904432 [ 6400/60000]
loss: 1.797030 [12800/60000]
loss: 1.834745 [19200/60000]
loss: 1.721624 [25600/60000]
loss: 1.675833 [32000/60000]
loss: 1.687455 [38400/60000]
loss: 1.597601 [44800/60000]
loss: 1.613010 [51200/60000]
loss: 1.502748 [57600/60000]
Test Error:
 Accuracy: 59.6%, Avg loss: 1.531052

Epoch 4
-------------------------
loss: 1.599049 [    0/60000]
loss: 1.560516 [ 6400/60000]
loss: 1.420832 [12800/60000]
loss: 1.487164 [19200/60000]
loss: 1.363405 [25600/60000]
loss: 1.365580 [32000/60000]
loss: 1.366614 [38400/60000]
loss: 1.299678 [44800/60000]
loss: 1.330469 [51200/60000]
loss: 1.224487 [57600/60000]
Test Error:
 Accuracy: 61.7%, Avg loss: 1.259897

Epoch 5
-------------------------
loss: 1.340469 [    0/60000]
loss: 1.316975 [ 6400/60000]
loss: 1.159445 [12800/60000]
loss: 1.261730 [19200/60000]
loss: 1.131677 [25600/60000]
loss: 1.167898 [32000/60000]
loss: 1.174967 [38400/60000]
loss: 1.118451 [44800/60000]
loss: 1.154713 [51200/60000]
loss: 1.065566 [57600/60000]
Test Error:
 Accuracy: 63.9%, Avg loss: 1.093563

Tensorclass training done! time:  8.5377 s
Epoch 1
-------------------------
loss: 2.299644 [    0/60000]
loss: 2.293140 [ 6400/60000]
loss: 2.271977 [12800/60000]
loss: 2.273217 [19200/60000]
loss: 2.250980 [25600/60000]
loss: 2.225316 [32000/60000]
loss: 2.230843 [38400/60000]
loss: 2.195421 [44800/60000]
loss: 2.187299 [51200/60000]
loss: 2.160407 [57600/60000]
Test Error:
 Accuracy: 44.1%, Avg loss: 2.152778

Epoch 2
-------------------------
loss: 2.156122 [    0/60000]
loss: 2.149637 [ 6400/60000]
loss: 2.088117 [12800/60000]
loss: 2.110354 [19200/60000]
loss: 2.059621 [25600/60000]
loss: 2.002847 [32000/60000]
loss: 2.023556 [38400/60000]
loss: 1.943443 [44800/60000]
loss: 1.946714 [51200/60000]
loss: 1.874386 [57600/60000]
Test Error:
 Accuracy: 54.3%, Avg loss: 1.872510

Epoch 3
-------------------------
loss: 1.902230 [    0/60000]
loss: 1.873018 [ 6400/60000]
loss: 1.749329 [12800/60000]
loss: 1.796782 [19200/60000]
loss: 1.693456 [25600/60000]
loss: 1.648093 [32000/60000]
loss: 1.665045 [38400/60000]
loss: 1.566518 [44800/60000]
loss: 1.593781 [51200/60000]
loss: 1.490483 [57600/60000]
Test Error:
 Accuracy: 60.5%, Avg loss: 1.506298

Epoch 4
-------------------------
loss: 1.570804 [    0/60000]
loss: 1.537479 [ 6400/60000]
loss: 1.380056 [12800/60000]
loss: 1.460477 [19200/60000]
loss: 1.355020 [25600/60000]
loss: 1.349743 [32000/60000]
loss: 1.361715 [38400/60000]
loss: 1.282459 [44800/60000]
loss: 1.317743 [51200/60000]
loss: 1.225052 [57600/60000]
Test Error:
 Accuracy: 63.6%, Avg loss: 1.246684

Epoch 5
-------------------------
loss: 1.317353 [    0/60000]
loss: 1.304559 [ 6400/60000]
loss: 1.129192 [12800/60000]
loss: 1.247038 [19200/60000]
loss: 1.133444 [25600/60000]
loss: 1.153610 [32000/60000]
loss: 1.174721 [38400/60000]
loss: 1.104254 [44800/60000]
loss: 1.141913 [51200/60000]
loss: 1.068175 [57600/60000]
Test Error:
 Accuracy: 64.9%, Avg loss: 1.084284

Training done! time:  35.1946 s

指令碼總執行時間: (1 分鐘 2.380 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源