評價此頁

訓練分類器#

建立日期:2017 年 3 月 24 日 | 最後更新:2025 年 9 月 30 日 | 最後驗證:未經驗證

就是這樣。您已經瞭解瞭如何定義神經網路、計算損失以及更新網路權重。

現在您可能會想,

資料呢?#

通常,當您需要處理影像、文字、音訊或影片資料時,可以使用標準的 Python 包將資料載入到 NumPy 陣列中。然後,您可以將此陣列轉換為 torch.*Tensor

  • 對於影像,Pillow、OpenCV 等包非常有用。

  • 對於音訊,scipy 和 librosa 等包非常有用。

  • 對於文字,可以使用純 Python 或基於 Cython 的載入,或者 NLTK 和 SpaCy。

特別是對於視覺領域,我們建立了一個名為 torchvision 的包,它提供了常見資料集(如 ImageNet、CIFAR10、MNIST 等)的資料載入器以及影像資料轉換器,即 torchvision.datasetstorch.utils.data.DataLoader

這提供了極大的便利,並避免了編寫樣板程式碼。

在本教程中,我們將使用 CIFAR10 資料集。它包含以下類別:“飛機”、“汽車”、“鳥”、“貓”、“鹿”、“狗”、“青蛙”、“馬”、“船”、“卡車”。CIFAR-10 中的影像尺寸為 3x32x32,即 3 通道的彩色影像,尺寸為 32x32 畫素。

cifar10

cifar10#

訓練影像分類器#

我們將按順序執行以下步驟:

  1. 使用 torchvision 載入和歸一化 CIFAR10 訓練集和測試集。

  2. 定義一個卷積神經網路。

  3. 定義一個損失函式。

  4. 在訓練資料上訓練網路。

  5. 在測試資料上測試網路。

1. 載入和歸一化 CIFAR10#

使用 torchvision 載入 CIFAR10 非常簡單。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision 資料集的輸出是 PILImage 影像,範圍為 [0, 1]。我們將其轉換為歸一化範圍 [-1, 1] 的 Tensor。

注意

如果您在 Windows 或 MacOS 上執行此教程,並遇到與多程序相關的 BrokenPipeError 或 RuntimeError,請嘗試將 `torch.utils.data.DataLoader()` 的 `num_worker` 設定為 0。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 328k/170M [00:00<00:53, 3.18MB/s]
  1%|          | 950k/170M [00:00<00:34, 4.85MB/s]
  1%|          | 1.70M/170M [00:00<00:28, 6.01MB/s]
  2%|▏         | 2.65M/170M [00:00<00:22, 7.30MB/s]
  2%|▏         | 3.80M/170M [00:00<00:19, 8.74MB/s]
  3%|▎         | 5.05M/170M [00:00<00:16, 9.96MB/s]
  4%|▎         | 6.23M/170M [00:00<00:15, 10.4MB/s]
  4%|▍         | 7.27M/170M [00:00<00:16, 9.98MB/s]
  5%|▍         | 8.29M/170M [00:00<00:16, 9.69MB/s]
  5%|▌         | 9.27M/170M [00:01<00:17, 9.33MB/s]
  6%|▌         | 10.2M/170M [00:01<00:17, 9.01MB/s]
  7%|▋         | 11.1M/170M [00:01<00:17, 8.94MB/s]
  7%|▋         | 12.1M/170M [00:01<00:17, 8.94MB/s]
  8%|▊         | 13.0M/170M [00:01<00:17, 8.95MB/s]
  8%|▊         | 14.0M/170M [00:01<00:17, 9.18MB/s]
  9%|▊         | 14.9M/170M [00:01<00:16, 9.18MB/s]
  9%|▉         | 15.9M/170M [00:01<00:17, 9.09MB/s]
 10%|▉         | 16.8M/170M [00:01<00:16, 9.09MB/s]
 10%|█         | 17.7M/170M [00:01<00:16, 9.10MB/s]
 11%|█         | 18.6M/170M [00:02<00:16, 9.05MB/s]
 11%|█▏        | 19.5M/170M [00:02<00:16, 9.03MB/s]
 12%|█▏        | 20.4M/170M [00:02<00:16, 9.04MB/s]
 13%|█▎        | 21.4M/170M [00:02<00:16, 8.99MB/s]
 13%|█▎        | 22.3M/170M [00:02<00:16, 8.84MB/s]
 14%|█▎        | 23.2M/170M [00:02<00:16, 8.78MB/s]
 14%|█▍        | 24.1M/170M [00:02<00:16, 8.66MB/s]
 15%|█▍        | 25.0M/170M [00:02<00:17, 8.28MB/s]
 15%|█▌        | 25.8M/170M [00:02<00:18, 7.94MB/s]
 16%|█▌        | 26.6M/170M [00:03<00:18, 7.66MB/s]
 16%|█▌        | 27.4M/170M [00:03<00:19, 7.49MB/s]
 17%|█▋        | 28.2M/170M [00:03<00:19, 7.32MB/s]
 17%|█▋        | 28.9M/170M [00:03<00:19, 7.20MB/s]
 17%|█▋        | 29.7M/170M [00:03<00:19, 7.08MB/s]
 18%|█▊        | 30.4M/170M [00:03<00:19, 7.01MB/s]
 18%|█▊        | 31.1M/170M [00:03<00:19, 6.99MB/s]
 19%|█▊        | 31.8M/170M [00:03<00:19, 6.98MB/s]
 19%|█▉        | 32.5M/170M [00:03<00:20, 6.90MB/s]
 20%|█▉        | 33.3M/170M [00:04<00:20, 6.80MB/s]
 20%|█▉        | 33.9M/170M [00:04<00:20, 6.71MB/s]
 20%|██        | 34.6M/170M [00:04<00:20, 6.65MB/s]
 21%|██        | 35.3M/170M [00:04<00:20, 6.61MB/s]
 21%|██        | 36.0M/170M [00:04<00:20, 6.58MB/s]
 22%|██▏       | 36.7M/170M [00:04<00:20, 6.55MB/s]
 22%|██▏       | 37.4M/170M [00:04<00:20, 6.54MB/s]
 22%|██▏       | 38.0M/170M [00:04<00:20, 6.51MB/s]
 23%|██▎       | 38.7M/170M [00:04<00:20, 6.47MB/s]
 23%|██▎       | 39.3M/170M [00:04<00:20, 6.40MB/s]
 23%|██▎       | 40.0M/170M [00:05<00:20, 6.27MB/s]
 24%|██▍       | 40.6M/170M [00:05<00:21, 6.16MB/s]
 24%|██▍       | 41.3M/170M [00:05<00:21, 6.09MB/s]
 25%|██▍       | 41.9M/170M [00:05<00:21, 6.11MB/s]
 25%|██▍       | 42.5M/170M [00:05<00:21, 6.09MB/s]
 25%|██▌       | 43.2M/170M [00:05<00:20, 6.13MB/s]
 26%|██▌       | 43.8M/170M [00:05<00:20, 6.12MB/s]
 26%|██▌       | 44.4M/170M [00:05<00:20, 6.14MB/s]
 26%|██▋       | 45.1M/170M [00:05<00:20, 6.16MB/s]
 27%|██▋       | 45.7M/170M [00:06<00:20, 6.23MB/s]
 27%|██▋       | 46.4M/170M [00:06<00:20, 6.19MB/s]
 28%|██▊       | 47.0M/170M [00:06<00:19, 6.19MB/s]
 28%|██▊       | 47.6M/170M [00:06<00:19, 6.29MB/s]
 28%|██▊       | 48.3M/170M [00:06<00:19, 6.33MB/s]
 29%|██▊       | 49.0M/170M [00:06<00:19, 6.39MB/s]
 29%|██▉       | 49.6M/170M [00:06<00:18, 6.36MB/s]
 29%|██▉       | 50.3M/170M [00:06<00:19, 6.25MB/s]
 30%|██▉       | 50.9M/170M [00:06<00:19, 6.09MB/s]
 30%|███       | 51.5M/170M [00:06<00:19, 5.98MB/s]
 31%|███       | 52.2M/170M [00:07<00:19, 5.93MB/s]
 31%|███       | 52.8M/170M [00:07<00:20, 5.82MB/s]
 31%|███▏      | 53.4M/170M [00:07<00:20, 5.72MB/s]
 32%|███▏      | 54.0M/170M [00:07<00:20, 5.57MB/s]
 32%|███▏      | 54.6M/170M [00:07<00:21, 5.47MB/s]
 32%|███▏      | 55.1M/170M [00:07<00:21, 5.35MB/s]
 33%|███▎      | 55.7M/170M [00:07<00:21, 5.26MB/s]
 33%|███▎      | 56.2M/170M [00:07<00:22, 5.19MB/s]
 33%|███▎      | 56.8M/170M [00:07<00:22, 5.16MB/s]
 34%|███▎      | 57.3M/170M [00:08<00:22, 5.14MB/s]
 34%|███▍      | 57.8M/170M [00:08<00:22, 5.05MB/s]
 34%|███▍      | 58.3M/170M [00:08<00:22, 5.09MB/s]
 35%|███▍      | 58.9M/170M [00:08<00:22, 5.06MB/s]
 35%|███▍      | 59.4M/170M [00:08<00:21, 5.07MB/s]
 35%|███▌      | 59.9M/170M [00:08<00:21, 5.09MB/s]
 35%|███▌      | 60.4M/170M [00:08<00:21, 5.08MB/s]
 36%|███▌      | 60.9M/170M [00:08<00:21, 5.07MB/s]
 36%|███▌      | 61.5M/170M [00:08<00:21, 5.07MB/s]
 36%|███▋      | 62.0M/170M [00:08<00:21, 5.09MB/s]
 37%|███▋      | 62.5M/170M [00:09<00:21, 5.10MB/s]
 37%|███▋      | 63.0M/170M [00:09<00:21, 5.11MB/s]
 37%|███▋      | 63.6M/170M [00:09<00:21, 5.08MB/s]
 38%|███▊      | 64.1M/170M [00:09<00:20, 5.10MB/s]
 38%|███▊      | 64.6M/170M [00:09<00:20, 5.10MB/s]
 38%|███▊      | 65.1M/170M [00:09<00:20, 5.08MB/s]
 39%|███▊      | 65.7M/170M [00:09<00:20, 5.09MB/s]
 39%|███▉      | 66.2M/170M [00:09<00:20, 5.11MB/s]
 39%|███▉      | 66.7M/170M [00:09<00:20, 5.06MB/s]
 39%|███▉      | 67.2M/170M [00:10<00:20, 5.09MB/s]
 40%|███▉      | 67.8M/170M [00:10<00:20, 5.10MB/s]
 40%|████      | 68.3M/170M [00:10<00:20, 5.08MB/s]
 40%|████      | 68.8M/170M [00:10<00:19, 5.09MB/s]
 41%|████      | 69.3M/170M [00:10<00:19, 5.07MB/s]
 41%|████      | 69.9M/170M [00:10<00:19, 5.07MB/s]
 41%|████▏     | 70.4M/170M [00:10<00:19, 5.08MB/s]
 42%|████▏     | 70.9M/170M [00:10<00:19, 5.10MB/s]
 42%|████▏     | 71.4M/170M [00:10<00:19, 5.07MB/s]
 42%|████▏     | 72.0M/170M [00:10<00:19, 5.07MB/s]
 43%|████▎     | 72.5M/170M [00:11<00:19, 5.00MB/s]
 43%|████▎     | 73.0M/170M [00:11<00:19, 4.98MB/s]
 43%|████▎     | 73.5M/170M [00:11<00:19, 4.95MB/s]
 43%|████▎     | 74.1M/170M [00:11<00:19, 4.89MB/s]
 44%|████▎     | 74.5M/170M [00:11<00:19, 4.87MB/s]
 44%|████▍     | 75.1M/170M [00:11<00:19, 4.88MB/s]
 44%|████▍     | 75.6M/170M [00:11<00:19, 4.89MB/s]
 45%|████▍     | 76.1M/170M [00:11<00:19, 4.92MB/s]
 45%|████▍     | 76.6M/170M [00:11<00:19, 4.91MB/s]
 45%|████▌     | 77.1M/170M [00:12<00:18, 4.93MB/s]
 46%|████▌     | 77.7M/170M [00:12<00:18, 4.93MB/s]
 46%|████▌     | 78.2M/170M [00:12<00:19, 4.78MB/s]
 46%|████▌     | 78.7M/170M [00:12<00:18, 4.95MB/s]
 46%|████▋     | 79.3M/170M [00:12<00:18, 5.02MB/s]
 47%|████▋     | 79.8M/170M [00:12<00:18, 5.04MB/s]
 47%|████▋     | 80.3M/170M [00:12<00:17, 5.02MB/s]
 47%|████▋     | 80.8M/170M [00:12<00:17, 5.01MB/s]
 48%|████▊     | 81.4M/170M [00:12<00:17, 4.99MB/s]
 48%|████▊     | 81.9M/170M [00:12<00:17, 5.03MB/s]
 48%|████▊     | 82.4M/170M [00:13<00:17, 4.93MB/s]
 49%|████▊     | 82.9M/170M [00:13<00:17, 4.96MB/s]
 49%|████▉     | 83.5M/170M [00:13<00:17, 4.95MB/s]
 49%|████▉     | 84.0M/170M [00:13<00:17, 4.97MB/s]
 50%|████▉     | 84.5M/170M [00:13<00:17, 4.84MB/s]
 50%|████▉     | 85.0M/170M [00:13<00:17, 4.78MB/s]
 50%|█████     | 85.5M/170M [00:13<00:18, 4.70MB/s]
 50%|█████     | 86.0M/170M [00:13<00:18, 4.67MB/s]
 51%|█████     | 86.5M/170M [00:13<00:18, 4.64MB/s]
 51%|█████     | 87.0M/170M [00:14<00:18, 4.63MB/s]
 51%|█████▏    | 87.5M/170M [00:14<00:18, 4.56MB/s]
 52%|█████▏    | 87.9M/170M [00:14<00:18, 4.54MB/s]
 52%|█████▏    | 88.4M/170M [00:14<00:18, 4.54MB/s]
 52%|█████▏    | 88.8M/170M [00:14<00:18, 4.49MB/s]
 52%|█████▏    | 89.3M/170M [00:14<00:18, 4.49MB/s]
 53%|█████▎    | 89.8M/170M [00:14<00:18, 4.47MB/s]
 53%|█████▎    | 90.2M/170M [00:14<00:17, 4.46MB/s]
 53%|█████▎    | 90.7M/170M [00:14<00:17, 4.52MB/s]
 53%|█████▎    | 91.2M/170M [00:14<00:17, 4.50MB/s]
 54%|█████▎    | 91.6M/170M [00:15<00:17, 4.50MB/s]
 54%|█████▍    | 92.1M/170M [00:15<00:17, 4.52MB/s]
 54%|█████▍    | 92.5M/170M [00:15<00:17, 4.52MB/s]
 55%|█████▍    | 93.0M/170M [00:15<00:17, 4.43MB/s]
 55%|█████▍    | 93.5M/170M [00:15<00:17, 4.37MB/s]
 55%|█████▌    | 93.9M/170M [00:15<00:17, 4.31MB/s]
 55%|█████▌    | 94.4M/170M [00:15<00:17, 4.28MB/s]
 56%|█████▌    | 94.8M/170M [00:15<00:17, 4.32MB/s]
 56%|█████▌    | 95.3M/170M [00:15<00:17, 4.37MB/s]
 56%|█████▌    | 95.7M/170M [00:16<00:16, 4.40MB/s]
 56%|█████▋    | 96.2M/170M [00:16<00:16, 4.42MB/s]
 57%|█████▋    | 96.7M/170M [00:16<00:16, 4.38MB/s]
 57%|█████▋    | 97.1M/170M [00:16<00:16, 4.39MB/s]
 57%|█████▋    | 97.6M/170M [00:16<00:16, 4.39MB/s]
 58%|█████▊    | 98.0M/170M [00:16<00:16, 4.36MB/s]
 58%|█████▊    | 98.5M/170M [00:16<00:16, 4.37MB/s]
 58%|█████▊    | 99.0M/170M [00:16<00:16, 4.36MB/s]
 58%|█████▊    | 99.4M/170M [00:16<00:16, 4.34MB/s]
 59%|█████▊    | 99.9M/170M [00:16<00:16, 4.30MB/s]
 59%|█████▉    | 100M/170M [00:17<00:16, 4.28MB/s]
 59%|█████▉    | 101M/170M [00:17<00:16, 4.22MB/s]
 59%|█████▉    | 101M/170M [00:17<00:16, 4.22MB/s]
 60%|█████▉    | 102M/170M [00:17<00:16, 4.21MB/s]
 60%|█████▉    | 102M/170M [00:17<00:16, 4.17MB/s]
 60%|██████    | 102M/170M [00:17<00:16, 4.17MB/s]
 60%|██████    | 103M/170M [00:17<00:16, 4.18MB/s]
 61%|██████    | 103M/170M [00:17<00:16, 4.18MB/s]
 61%|██████    | 104M/170M [00:17<00:16, 4.13MB/s]
 61%|██████    | 104M/170M [00:18<00:16, 4.12MB/s]
 61%|██████▏   | 105M/170M [00:18<00:16, 4.09MB/s]
 62%|██████▏   | 105M/170M [00:18<00:16, 4.06MB/s]
 62%|██████▏   | 105M/170M [00:18<00:16, 4.01MB/s]
 62%|██████▏   | 106M/170M [00:18<00:16, 3.99MB/s]
 62%|██████▏   | 106M/170M [00:18<00:15, 4.03MB/s]
 63%|██████▎   | 107M/170M [00:18<00:15, 4.03MB/s]
 63%|██████▎   | 107M/170M [00:18<00:15, 4.00MB/s]
 63%|██████▎   | 108M/170M [00:18<00:15, 4.00MB/s]
 63%|██████▎   | 108M/170M [00:18<00:15, 3.99MB/s]
 64%|██████▎   | 108M/170M [00:19<00:15, 3.95MB/s]
 64%|██████▍   | 109M/170M [00:19<00:15, 3.91MB/s]
 64%|██████▍   | 109M/170M [00:19<00:15, 3.91MB/s]
 64%|██████▍   | 110M/170M [00:19<00:15, 3.93MB/s]
 65%|██████▍   | 110M/170M [00:19<00:15, 3.96MB/s]
 65%|██████▍   | 111M/170M [00:19<00:15, 3.99MB/s]
 65%|██████▌   | 111M/170M [00:19<00:14, 3.98MB/s]
 65%|██████▌   | 111M/170M [00:19<00:14, 3.98MB/s]
 66%|██████▌   | 112M/170M [00:19<00:14, 4.00MB/s]
 66%|██████▌   | 112M/170M [00:20<00:14, 4.01MB/s]
 66%|██████▌   | 113M/170M [00:20<00:14, 3.98MB/s]
 66%|██████▋   | 113M/170M [00:20<00:14, 3.96MB/s]
 67%|██████▋   | 114M/170M [00:20<00:14, 3.96MB/s]
 67%|██████▋   | 114M/170M [00:20<00:14, 3.96MB/s]
 67%|██████▋   | 114M/170M [00:20<00:14, 3.90MB/s]
 67%|██████▋   | 115M/170M [00:20<00:14, 3.86MB/s]
 68%|██████▊   | 115M/170M [00:20<00:14, 3.85MB/s]
 68%|██████▊   | 116M/170M [00:20<00:14, 3.84MB/s]
 68%|██████▊   | 116M/170M [00:21<00:14, 3.87MB/s]
 68%|██████▊   | 116M/170M [00:21<00:14, 3.86MB/s]
 68%|██████▊   | 117M/170M [00:21<00:13, 3.86MB/s]
 69%|██████▊   | 117M/170M [00:21<00:13, 3.88MB/s]
 69%|██████▉   | 118M/170M [00:21<00:13, 3.87MB/s]
 69%|██████▉   | 118M/170M [00:21<00:13, 3.88MB/s]
 69%|██████▉   | 118M/170M [00:21<00:13, 3.90MB/s]
 70%|██████▉   | 119M/170M [00:21<00:13, 3.89MB/s]
 70%|██████▉   | 119M/170M [00:21<00:13, 3.93MB/s]
 70%|███████   | 120M/170M [00:21<00:12, 3.94MB/s]
 70%|███████   | 120M/170M [00:22<00:12, 3.95MB/s]
 71%|███████   | 120M/170M [00:22<00:12, 3.98MB/s]
 71%|███████   | 121M/170M [00:22<00:12, 4.03MB/s]
 71%|███████   | 121M/170M [00:22<00:12, 4.06MB/s]
 71%|███████▏  | 122M/170M [00:22<00:12, 4.03MB/s]
 72%|███████▏  | 122M/170M [00:22<00:12, 4.03MB/s]
 72%|███████▏  | 123M/170M [00:22<00:12, 3.97MB/s]
 72%|███████▏  | 123M/170M [00:22<00:11, 3.96MB/s]
 72%|███████▏  | 123M/170M [00:22<00:11, 3.93MB/s]
 73%|███████▎  | 124M/170M [00:23<00:11, 3.90MB/s]
 73%|███████▎  | 124M/170M [00:23<00:12, 3.85MB/s]
 73%|███████▎  | 125M/170M [00:23<00:12, 3.78MB/s]
 73%|███████▎  | 125M/170M [00:23<00:12, 3.75MB/s]
 74%|███████▎  | 125M/170M [00:23<00:12, 3.73MB/s]
 74%|███████▍  | 126M/170M [00:23<00:12, 3.69MB/s]
 74%|███████▍  | 126M/170M [00:23<00:12, 3.68MB/s]
 74%|███████▍  | 127M/170M [00:23<00:11, 3.67MB/s]
 74%|███████▍  | 127M/170M [00:23<00:11, 3.63MB/s]
 75%|███████▍  | 127M/170M [00:23<00:11, 3.65MB/s]
 75%|███████▍  | 128M/170M [00:24<00:11, 3.64MB/s]
 75%|███████▌  | 128M/170M [00:24<00:11, 3.62MB/s]
 75%|███████▌  | 129M/170M [00:24<00:11, 3.58MB/s]
 76%|███████▌  | 129M/170M [00:24<00:11, 3.57MB/s]
 76%|███████▌  | 129M/170M [00:24<00:11, 3.56MB/s]
 76%|███████▌  | 130M/170M [00:24<00:11, 3.57MB/s]
 76%|███████▋  | 130M/170M [00:24<00:11, 3.56MB/s]
 76%|███████▋  | 130M/170M [00:24<00:11, 3.56MB/s]
 77%|███████▋  | 131M/170M [00:24<00:11, 3.59MB/s]
 77%|███████▋  | 131M/170M [00:25<00:10, 3.66MB/s]
 77%|███████▋  | 132M/170M [00:25<00:10, 3.76MB/s]
 77%|███████▋  | 132M/170M [00:25<00:09, 3.85MB/s]
 78%|███████▊  | 132M/170M [00:25<00:09, 3.87MB/s]
 78%|███████▊  | 133M/170M [00:25<00:09, 3.90MB/s]
 78%|███████▊  | 133M/170M [00:25<00:09, 3.93MB/s]
 78%|███████▊  | 134M/170M [00:25<00:09, 3.96MB/s]
 79%|███████▊  | 134M/170M [00:25<00:09, 3.95MB/s]
 79%|███████▉  | 135M/170M [00:25<00:09, 3.96MB/s]
 79%|███████▉  | 135M/170M [00:25<00:08, 3.96MB/s]
 79%|███████▉  | 135M/170M [00:26<00:08, 3.98MB/s]
 80%|███████▉  | 136M/170M [00:26<00:08, 3.97MB/s]
 80%|███████▉  | 136M/170M [00:26<00:08, 3.97MB/s]
 80%|████████  | 137M/170M [00:26<00:08, 3.97MB/s]
 80%|████████  | 137M/170M [00:26<00:08, 3.98MB/s]
 81%|████████  | 138M/170M [00:26<00:08, 3.95MB/s]
 81%|████████  | 138M/170M [00:26<00:08, 3.96MB/s]
 81%|████████  | 138M/170M [00:26<00:08, 3.97MB/s]
 81%|████████▏ | 139M/170M [00:26<00:07, 3.99MB/s]
 82%|████████▏ | 139M/170M [00:27<00:07, 3.96MB/s]
 82%|████████▏ | 140M/170M [00:27<00:07, 3.93MB/s]
 82%|████████▏ | 140M/170M [00:27<00:07, 3.90MB/s]
 82%|████████▏ | 140M/170M [00:27<00:07, 3.89MB/s]
 83%|████████▎ | 141M/170M [00:27<00:07, 3.87MB/s]
 83%|████████▎ | 141M/170M [00:27<00:07, 3.86MB/s]
 83%|████████▎ | 142M/170M [00:27<00:07, 3.83MB/s]
 83%|████████▎ | 142M/170M [00:27<00:07, 3.81MB/s]
 84%|████████▎ | 142M/170M [00:27<00:07, 3.77MB/s]
 84%|████████▍ | 143M/170M [00:28<00:07, 3.73MB/s]
 84%|████████▍ | 143M/170M [00:28<00:07, 3.70MB/s]
 84%|████████▍ | 144M/170M [00:28<00:07, 3.71MB/s]
 84%|████████▍ | 144M/170M [00:28<00:07, 3.70MB/s]
 85%|████████▍ | 144M/170M [00:28<00:07, 3.66MB/s]
 85%|████████▍ | 145M/170M [00:28<00:07, 3.66MB/s]
 85%|████████▌ | 145M/170M [00:28<00:06, 3.67MB/s]
 85%|████████▌ | 146M/170M [00:28<00:06, 3.66MB/s]
 86%|████████▌ | 146M/170M [00:28<00:06, 3.67MB/s]
 86%|████████▌ | 146M/170M [00:28<00:06, 3.71MB/s]
 86%|████████▌ | 147M/170M [00:29<00:06, 3.75MB/s]
 86%|████████▋ | 147M/170M [00:29<00:06, 3.74MB/s]
 87%|████████▋ | 148M/170M [00:29<00:06, 3.78MB/s]
 87%|████████▋ | 148M/170M [00:29<00:05, 3.84MB/s]
 87%|████████▋ | 148M/170M [00:29<00:05, 3.89MB/s]
 87%|████████▋ | 149M/170M [00:29<00:05, 3.90MB/s]
 88%|████████▊ | 149M/170M [00:29<00:05, 3.90MB/s]
 88%|████████▊ | 150M/170M [00:29<00:05, 3.94MB/s]
 88%|████████▊ | 150M/170M [00:29<00:05, 3.95MB/s]
 88%|████████▊ | 151M/170M [00:30<00:05, 3.96MB/s]
 89%|████████▊ | 151M/170M [00:30<00:04, 3.93MB/s]
 89%|████████▉ | 151M/170M [00:30<00:04, 3.95MB/s]
 89%|████████▉ | 152M/170M [00:30<00:04, 3.96MB/s]
 89%|████████▉ | 152M/170M [00:30<00:04, 3.94MB/s]
 90%|████████▉ | 153M/170M [00:30<00:04, 3.91MB/s]
 90%|████████▉ | 153M/170M [00:30<00:04, 3.91MB/s]
 90%|█████████ | 153M/170M [00:30<00:04, 3.92MB/s]
 90%|█████████ | 154M/170M [00:30<00:04, 3.94MB/s]
 91%|█████████ | 154M/170M [00:30<00:04, 3.93MB/s]
 91%|█████████ | 155M/170M [00:31<00:03, 3.95MB/s]
 91%|█████████ | 155M/170M [00:31<00:03, 3.95MB/s]
 91%|█████████▏| 156M/170M [00:31<00:03, 3.97MB/s]
 92%|█████████▏| 156M/170M [00:31<00:03, 3.95MB/s]
 92%|█████████▏| 156M/170M [00:31<00:03, 3.96MB/s]
 92%|█████████▏| 157M/170M [00:31<00:03, 3.96MB/s]
 92%|█████████▏| 157M/170M [00:31<00:03, 3.98MB/s]
 93%|█████████▎| 158M/170M [00:31<00:03, 3.96MB/s]
 93%|█████████▎| 158M/170M [00:31<00:03, 3.96MB/s]
 93%|█████████▎| 159M/170M [00:32<00:03, 3.97MB/s]
 93%|█████████▎| 159M/170M [00:32<00:02, 4.01MB/s]
 94%|█████████▎| 159M/170M [00:32<00:02, 4.04MB/s]
 94%|█████████▍| 160M/170M [00:32<00:02, 4.08MB/s]
 94%|█████████▍| 160M/170M [00:32<00:02, 4.11MB/s]
 94%|█████████▍| 161M/170M [00:32<00:02, 4.15MB/s]
 94%|█████████▍| 161M/170M [00:32<00:02, 4.13MB/s]
 95%|█████████▍| 162M/170M [00:32<00:02, 4.14MB/s]
 95%|█████████▍| 162M/170M [00:32<00:02, 4.16MB/s]
 95%|█████████▌| 162M/170M [00:32<00:01, 4.17MB/s]
 95%|█████████▌| 163M/170M [00:33<00:01, 4.14MB/s]
 96%|█████████▌| 163M/170M [00:33<00:01, 4.14MB/s]
 96%|█████████▌| 164M/170M [00:33<00:01, 4.13MB/s]
 96%|█████████▌| 164M/170M [00:33<00:01, 4.14MB/s]
 96%|█████████▋| 165M/170M [00:33<00:01, 4.11MB/s]
 97%|█████████▋| 165M/170M [00:33<00:01, 4.11MB/s]
 97%|█████████▋| 165M/170M [00:33<00:01, 4.13MB/s]
 97%|█████████▋| 166M/170M [00:33<00:01, 4.17MB/s]
 97%|█████████▋| 166M/170M [00:33<00:01, 4.13MB/s]
 98%|█████████▊| 167M/170M [00:34<00:00, 4.14MB/s]
 98%|█████████▊| 167M/170M [00:34<00:00, 4.14MB/s]
 98%|█████████▊| 168M/170M [00:34<00:00, 4.16MB/s]
 98%|█████████▊| 168M/170M [00:34<00:00, 4.13MB/s]
 99%|█████████▊| 168M/170M [00:34<00:00, 4.12MB/s]
 99%|█████████▉| 169M/170M [00:34<00:00, 4.10MB/s]
 99%|█████████▉| 169M/170M [00:34<00:00, 4.04MB/s]
 99%|█████████▉| 170M/170M [00:34<00:00, 3.95MB/s]
100%|█████████▉| 170M/170M [00:34<00:00, 3.90MB/s]
100%|█████████▉| 170M/170M [00:34<00:00, 3.88MB/s]
100%|██████████| 170M/170M [00:34<00:00, 4.87MB/s]

為了好玩,我們來展示一些訓練影像。

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
cifar10 tutorial
truck plane plane horse

2. 定義卷積神經網路#

從“神經網路”部分複製神經網路,並修改它以接受 3 通道影像(而不是之前定義的 1 通道影像)。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

3. 定義損失函式和最佳化器#

讓我們使用分類交叉熵損失和帶動量的 SGD。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 訓練網路#

這時事情開始變得有趣起來。我們只需遍歷資料迭代器,將輸入饋送到網路並進行最佳化。

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')
[1,  2000] loss: 2.171
[1,  4000] loss: 1.850
[1,  6000] loss: 1.662
[1,  8000] loss: 1.594
[1, 10000] loss: 1.540
[1, 12000] loss: 1.500
[2,  2000] loss: 1.422
[2,  4000] loss: 1.386
[2,  6000] loss: 1.369
[2,  8000] loss: 1.329
[2, 10000] loss: 1.310
[2, 12000] loss: 1.307
Finished Training

讓我們快速儲存訓練好的模型。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

有關儲存 PyTorch 模型的更多詳細資訊,請參見 此處

5. 在測試資料上測試網路#

我們對訓練資料集進行了 2 個週期的訓練。但我們需要檢查網路是否真的學到了一些東西。

我們將透過預測神經網路輸出的類別標籤,並將其與真實標籤進行比較來檢查這一點。如果預測正確,我們就將該樣本新增到正確預測列表。

好的,第一步。讓我們顯示一張測試集影像,以便熟悉一下。

dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
cifar10 tutorial
GroundTruth:  cat   ship  ship  plane

接下來,讓我們重新載入儲存的模型(注意:這裡不一定需要儲存和重新載入模型,我們只是為了演示如何操作)。

net = Net()
net.load_state_dict(torch.load(PATH, weights_only=True))
<All keys matched successfully>

好的,現在讓我們看看神經網路如何看待上述示例。

outputs = net(images)

輸出是 10 個類別的能量值。一個類別的能量值越高,網路就越認為該影像屬於該特定類別。因此,讓我們獲取最高能量的索引。

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))
Predicted:  cat   car   car   plane

結果看起來相當不錯。

讓我們看看網路在整個資料集上的表現。

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 56 %

這看起來比隨機猜測(隨機從 10 個類別中選擇一個,準確率為 10%)好多了。看來網路學到了一些東西。

嗯,哪些類別表現好,哪些類別表現不好?

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 64.7 %
Accuracy for class: car   is 71.7 %
Accuracy for class: bird  is 42.9 %
Accuracy for class: cat   is 40.5 %
Accuracy for class: deer  is 52.4 %
Accuracy for class: dog   is 53.2 %
Accuracy for class: frog  is 61.4 %
Accuracy for class: horse is 65.0 %
Accuracy for class: ship  is 57.4 %
Accuracy for class: truck is 55.5 %

好的,那麼接下來呢?

如何在 GPU 上執行這些神經網路?

在 GPU 上訓練#

就像將 Tensor 傳輸到 GPU 一樣,您也可以將神經網路傳輸到 GPU。

如果我們有 CUDA 可用,讓我們首先定義我們的裝置為第一個可見的 CUDA 裝置。

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)
cuda:0

本節的其餘部分假設 `device` 是一個 CUDA 裝置。

然後,這些方法將遞迴地遍歷所有模組,並將它們的引數和緩衝區轉換為 CUDA Tensor。

net.to(device)

請記住,您需要在每一步將輸入和目標也傳送到 GPU。

inputs, labels = data[0].to(device), data[1].to(device)

為什麼我沒有注意到與 CPU 相比有巨大的速度提升?因為您的網路非常小。

練習: 嘗試增加網路的寬度(第一個 `nn.Conv2d` 的第二個引數,以及第二個 `nn.Conv2d` 的第一個引數——它們需要相同),看看能獲得多大的速度提升。

已達成目標:

  • 高層次地理解 PyTorch 的 Tensor 庫和神經網路。

  • 訓練一個小神經網路來對影像進行分類。

在多個 GPU 上訓練#

如果您想使用所有 GPU 獲得更大的速度提升,請參閱 可選:資料並行性

接下來去哪裡?#

del dataiter

指令碼總執行時間: (1 分鐘 57.226 秒)