• 文件 >
  • 使用 tensorclass 進行批次資料載入
快捷方式

使用 tensorclasses 進行批次資料載入

在本教程中,我們將演示如何將 tensorclasses 和記憶體對映張量結合使用,以便在模型訓練管道中高效且透明地從磁碟載入資料。

基本思路是我們將整個資料集預載入到記憶體對映張量中,在儲存到磁碟之前應用任何非隨機轉換。這意味著我們不僅避免了每次迭代資料時執行重複計算,而且還能夠有效地從記憶體對映張量中批次載入資料,而不是順序地從原始影像檔案中載入。

透過結合預處理、在連續物理記憶體儲存上載入以及裝置上的批次轉換,我們在資料載入速度上獲得了比常規 torch + torchvision 管道高 10 倍的提升。

我們將使用與此遷移學習教程中相同的 ImageNet 子集,但我們也提供了在 ImageNet 上執行相同程式碼的實驗結果。

注意

請從這裡下載資料並解壓。在本教程中,我們假設解壓後的資料儲存在子目錄 data/ 中。

import os
import time
from pathlib import Path

import torch
import torch.nn as nn
import tqdm

from tensordict import MemoryMappedTensor, tensorclass
from tensordict.utils import strtobool
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

if __name__ == "__main__":
    NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "4"))
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    ##############################################################################
    # Transforms
    # ----------
    # First we define train and val transforms that will be applied to train and
    # val examples respectively. Note that there are random components in the
    # train transform to prevent overfitting to training data over multiple
    # epochs.

    train_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    val_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    ##############################################################################
    # We use ``torchvision.datasets.ImageFolder`` to conveniently load and
    # transform the data from disk.

    data_dir = Path("data") / "hymenoptera_data/"

    train_data = datasets.ImageFolder(
        root=data_dir / "train", transform=train_transform
    )
    val_data = datasets.ImageFolder(root=data_dir / "val", transform=val_transform)

    ##############################################################################
    # We'll also create a dataset of the raw training data that simply resizes
    # the image to a common size and converts to tensor. We'll use this to
    # load the data into memory-mapped tensors. The random transformations
    # need to be different each time we iterate through the data, so they
    # cannot be pre-computed. We also do not scale the data yet so that we can set the
    # ``dtype`` of the memory-mapped array to ``uint8`` and save space.

    train_data_raw = datasets.ImageFolder(
        root=data_dir / "train",
        transform=transforms.Compose(
            [transforms.Resize((256, 256)), transforms.PILToTensor()]
        ),
    )

    ##############################################################################
    # Since we'll be loading our data in batches, we write a few custom transformations
    # that take advantage of this, and apply the transformations in a vectorized way.
    #
    # First a transformation that can be used for normalization.
    class InvAffine(nn.Module):
        """A custom normalization layer."""

        def __init__(self, loc, scale):
            super().__init__()
            self.loc = loc
            self.scale = scale

        def forward(self, x):
            return (x - self.loc) / self.scale

    ##############################################################################
    # Next two transformations that can be used to randomly crop and flip the images.

    class RandomHFlip(nn.Module):
        def forward(self, x: torch.Tensor):
            idx = (
                torch.zeros(*x.shape[:-3], 1, 1, 1, device=x.device, dtype=torch.bool)
                .bernoulli_()
                .expand_as(x)
            )
            return x.masked_fill(idx, 0.0) + x.masked_fill(~idx, 0.0).flip(-1)

    class RandomCrop(nn.Module):
        def __init__(self, w, h):
            super(RandomCrop, self).__init__()
            self.w = w
            self.h = h

        def forward(self, x):
            batch = x.shape[:-3]
            index0 = torch.randint(x.shape[-2] - self.h, (*batch, 1), device=x.device)
            index0 = index0 + torch.arange(self.h, device=x.device)
            index0 = (
                index0.unsqueeze(1)
                .unsqueeze(-1)
                .expand((*batch, 3, self.h, x.shape[-1]))
            )
            index1 = torch.randint(x.shape[-1] - self.w, (*batch, 1), device=x.device)
            index1 = index1 + torch.arange(self.w, device=x.device)
            index1 = (
                index1.unsqueeze(1).unsqueeze(-2).expand((*batch, 3, self.h, self.w))
            )
            return x.gather(-2, index0).gather(-1, index1)

    ##############################################################################
    # When each batch is loaded, we will scale it, then randomly crop and flip. The random
    # transformations cannot be pre-applied as they must differ each time we iterate over
    # the data. The scaling could be pre-applied in principle, but by waiting until we load
    # the data into RAM, we are able to set the dtype of the memory-mapped array to
    # ``uint8``, a significant space saving over ``float32``.

    collate_transform = nn.Sequential(
        InvAffine(
            loc=torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) * 255,
            scale=torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
            * 255,
        ),
        RandomCrop(224, 224),
        RandomHFlip(),
    )

    ##############################################################################
    # Representing data with a TensorClass
    # ------------------------------------
    # Tensorclasses are a good choice when the structure of your data is known
    # apriori. They are dataclasses that expose dedicated tensor methods over
    # their contents much like a ``TensorDict``.
    #
    # As well as specifying the contents (in this case ``images`` and
    # ``targets``) we can also encapsulate related logic as custom methods
    # when defining the class. Here we add a classmethod that takes a dataset
    # and creates a tensorclass containing the data by iterating over the
    # dataset. We create memory-mapped tensors to hold the data so that they
    # can be efficiently loaded in batches later.

    @tensorclass
    class ImageNetData:
        images: torch.Tensor
        targets: torch.Tensor

        @classmethod
        def from_dataset(cls, dataset):
            data = cls(
                images=MemoryMappedTensor.empty(
                    (
                        len(dataset),
                        *dataset[0][0].squeeze().shape,
                    ),
                    dtype=torch.uint8,
                ),
                targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
                batch_size=[len(dataset)],
            )
            # locks the tensorclass and ensures that is_memmap will return True.
            data.memmap_()

            batch = 64
            dl = DataLoader(dataset, batch_size=batch, num_workers=NUM_WORKERS)
            i = 0
            pbar = tqdm.tqdm(total=len(dataset))
            for image, target in dl:
                _batch = image.shape[0]
                pbar.update(_batch)
                print(data)
                print(cls(images=image, targets=target, batch_size=[_batch]))
                data[i : i + _batch] = cls(
                    images=image, targets=target, batch_size=[_batch]
                )
                i += _batch

            return data

    ##############################################################################
    # We create two tensorclasses, one for the training and on for the
    # validation data. Note that while this step can be slightly expensive, it
    # allows us to save repeated computation later during training.

    train_data_tc = ImageNetData.from_dataset(train_data_raw)
    val_data_tc = ImageNetData.from_dataset(val_data)

    ##############################################################################
    # DataLoaders
    # -----------
    #
    # We can create dataloaders both from the ``torchvision``-provided
    # Datasets, as well as from our memory-mapped tensorclasses.
    #
    # Since tensorclasses implement ``__len__`` and ``__getitem__`` (and also
    # ``__getitems__``) we can use them like a map-style Dataset and create a
    # ``DataLoader`` directly from them.
    #
    # Since the TensorClass data will be loaded in batches, we need to specify how these
    # batches should be collated. For this we write the following helper class

    class Collate(nn.Module):
        def __init__(self, transform=None, device=None):
            super().__init__()
            self.transform = transform
            self.device = torch.device(device)

        def __call__(self, x: ImageNetData):
            # move data to RAM
            if self.device.type == "cuda":
                out = x.pin_memory()
            else:
                out = x
            if self.device:
                # move data to gpu
                out = out.to(self.device)
            if self.transform:
                # apply transforms on gpu
                out.images = self.transform(out.images)
            return out

    ##############################################################################
    # ``DataLoader`` has support for multiple workers loading data in parallel. The
    # tensorclass dataloader will use just one worker, but load data in batches.
    #
    # Note that under this approach our ``collate_fn`` is essentially just an ``nn.Module``,
    # making it transparent and easy to implement. But this approach also offers
    # flexibility, for example, if needed we could move the collation step into the training
    # loop by considering the ``Collate`` module as part of the model.

    batch_size = 8
    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
    )
    val_dataloader = DataLoader(
        val_data,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
    )

    train_dataloader_tc = DataLoader(  # noqa: TOR401
        train_data_tc,
        batch_size=batch_size,
        collate_fn=Collate(collate_transform, device),
    )
    val_dataloader_tc = DataLoader(  # noqa: TOR401
        val_data_tc,
        batch_size=batch_size,
        collate_fn=Collate(device=device),
    )

    ##############################################################################
    # We can now compare how long it takes to iterate once over the data in
    # each case. The regular dataloader loads images one by one from disk,
    # applies the transform sequentially and then stacks the results
    # (note: we start measuring time a little after the first iteration, as
    # starting the dataloader can take some time).

    total = 0
    for i, (image, target) in enumerate(train_dataloader):
        if i == 3:
            t0 = time.time()
        if i >= 3:
            total += image.shape[0]
        image, target = image.to(device), target.to(device)
    t = time.time() - t0
    print(
        f"One iteration over dataloader done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
    )

    ##############################################################################
    # Our tensorclass-based dataloader instead loads data from the
    # memory-mapped tensor in batches. We then apply the batched random
    # transformations to the batched images.

    total = 0
    for i, batch in enumerate(train_dataloader_tc):
        if i == 3:
            t0 = time.time()
        if i >= 3:
            total += batch.numel()
        image, target = batch.images, batch.targets
    t = time.time() - t0
    print(
        f"One iteration over tensorclass dataloader done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
    )

    ##############################################################################
    # In the case of the validation set, we see an even bigger performance
    # improvement, because there are no random transformations, so we can save
    # the fully transformed data in the memory-mapped tensor, eliminating the
    # need for additional transformations as we load from disk.

    total = 0
    for i, (image, target) in enumerate(val_dataloader):
        if i == 3:
            t0 = time.time()
        if i >= 3:
            total += image.shape[0]
        image, target = image.to(device), target.to(device)
    t = time.time() - t0
    print(
        f"One iteration over val data done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
    )

    total = 0
    for i, batch in enumerate(val_dataloader_tc):
        if i == 3:
            t0 = time.time()
        if i >= 3:
            total += batch.shape[0]
        image, target = batch.images.contiguous().to(
            device
        ), batch.targets.contiguous().to(device)
    t = time.time() - t0
    print(
        f"One iteration over tensorclass val data done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
    )

    ##############################################################################
    # Results from ImageNet
    # ---------------------
    #
    # We repeated the above on full-size ImageNet data, running on an AWS EC2 instance with
    # 32 cores and 1 A100 GPU. We compare against the regular ``DataLoader`` with different
    # numbers of workers. We found that our single-threaded TensorClass approach
    # out-performed the ``DataLoader`` even when we used a large number of workers.
    #
    # .. image:: /reference/generated/tutorials/media/imagenet-benchmark-time.png
    #    :alt: Bar chart showing runtimes of dataloaders compared with TensorClass
    #
    # .. image:: /reference/generated/tutorials/media/imagenet-benchmark-speed.png
    #    :alt: Bar chart showing collection rate of dataloaders compared with TensorClass

    ##############################################################################
    # This shows that much of the overhead is coming from i/o operations rather than the
    # transforms, and hence explains how the memory-mapped array helps us load data more
    # efficiently. Check out the `distributed example <https://github.com/pytorch/tensordict/tree/main/benchmarks/distributed/dataloading.py>`__
    # for more context about the other results from these charts.
    #
    # We can get even better performance with the TensorClass approach by using multiple
    # workers to load batches from the memory-mapped array, though this comes with some
    # added complexity. See `this example in our benchmarks
    # <https://github.com/pytorch/tensordict/blob/main/benchmarks/distributed/dataloading.py>`__
    # for an example of how this could work.
Using device: cpu

  0%|          | 0/244 [00:00<?, ?it/s]
 26%|██▌       | 64/244 [00:00<00:00, 238.76it/s]ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([244]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([64]),
    device=None,
    is_shared=False)
ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([244]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([64]),
    device=None,
    is_shared=False)
ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([244]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([64]),
    device=None,
    is_shared=False)
ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([244]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([52, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
    targets=Tensor(shape=torch.Size([52]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([52]),
    device=None,
    is_shared=False)

100%|██████████| 244/244 [00:00<00:00, 720.53it/s]

  0%|          | 0/153 [00:00<?, ?it/s]
 42%|████▏     | 64/153 [00:00<00:00, 179.83it/s]ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([153]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([64, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([64]),
    device=None,
    is_shared=False)
ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([153]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([64, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
    targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([64]),
    device=None,
    is_shared=False)
ImageNetData(
    images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
    targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([153]),
    device=cpu,
    is_shared=False)
ImageNetData(
    images=Tensor(shape=torch.Size([25, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
    targets=Tensor(shape=torch.Size([25]), device=cpu, dtype=torch.int64, is_shared=True),
    batch_size=torch.Size([25]),
    device=None,
    is_shared=False)

100%|██████████| 153/153 [00:00<00:00, 319.30it/s]
One iteration over dataloader done! Rate: 864.4558 fps, time:  0.2545s
One iteration over tensorclass dataloader done! Rate: 1843.0131 fps, time:  0.1194s
One iteration over val data done! Rate: 526.8308 fps, time:  0.2449s
One iteration over tensorclass val data done! Rate: 21398.6639 fps, time:  0.0060s

指令碼總執行時間: (0 分鐘 1.672 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源