注意
請 跳轉到末尾 下載完整的示例程式碼。
訓練分類器#
建立日期:2017 年 3 月 24 日 | 最後更新:2025 年 9 月 30 日 | 最後驗證:未經驗證
就是這樣。您已經瞭解瞭如何定義神經網路、計算損失以及更新網路權重。
現在您可能會想,
資料呢?#
通常,當您需要處理影像、文字、音訊或影片資料時,可以使用標準的 Python 包將資料載入到 NumPy 陣列中。然後,您可以將此陣列轉換為 torch.*Tensor。
對於影像,Pillow、OpenCV 等包非常有用。
對於音訊,scipy 和 librosa 等包非常有用。
對於文字,可以使用純 Python 或基於 Cython 的載入,或者 NLTK 和 SpaCy。
特別是對於視覺領域,我們建立了一個名為 torchvision 的包,它提供了常見資料集(如 ImageNet、CIFAR10、MNIST 等)的資料載入器以及影像資料轉換器,即 torchvision.datasets 和 torch.utils.data.DataLoader。
這提供了極大的便利,並避免了編寫樣板程式碼。
在本教程中,我們將使用 CIFAR10 資料集。它包含以下類別:“飛機”、“汽車”、“鳥”、“貓”、“鹿”、“狗”、“青蛙”、“馬”、“船”、“卡車”。CIFAR-10 中的影像尺寸為 3x32x32,即 3 通道的彩色影像,尺寸為 32x32 畫素。
cifar10#
訓練影像分類器#
我們將按順序執行以下步驟:
使用
torchvision載入和歸一化 CIFAR10 訓練集和測試集。定義一個卷積神經網路。
定義一個損失函式。
在訓練資料上訓練網路。
在測試資料上測試網路。
1. 載入和歸一化 CIFAR10#
使用 torchvision 載入 CIFAR10 非常簡單。
import torch
import torchvision
import torchvision.transforms as transforms
torchvision 資料集的輸出是 PILImage 影像,範圍為 [0, 1]。我們將其轉換為歸一化範圍 [-1, 1] 的 Tensor。
注意
如果您在 Windows 或 MacOS 上執行此教程,並遇到與多程序相關的 BrokenPipeError 或 RuntimeError,請嘗試將 `torch.utils.data.DataLoader()` 的 `num_worker` 設定為 0。
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
0%| | 0.00/170M [00:00<?, ?B/s]
0%| | 328k/170M [00:00<00:53, 3.18MB/s]
1%| | 950k/170M [00:00<00:34, 4.85MB/s]
1%| | 1.70M/170M [00:00<00:28, 6.01MB/s]
2%|▏ | 2.65M/170M [00:00<00:22, 7.30MB/s]
2%|▏ | 3.80M/170M [00:00<00:19, 8.74MB/s]
3%|▎ | 5.05M/170M [00:00<00:16, 9.96MB/s]
4%|▎ | 6.23M/170M [00:00<00:15, 10.4MB/s]
4%|▍ | 7.27M/170M [00:00<00:16, 9.98MB/s]
5%|▍ | 8.29M/170M [00:00<00:16, 9.69MB/s]
5%|▌ | 9.27M/170M [00:01<00:17, 9.33MB/s]
6%|▌ | 10.2M/170M [00:01<00:17, 9.01MB/s]
7%|▋ | 11.1M/170M [00:01<00:17, 8.94MB/s]
7%|▋ | 12.1M/170M [00:01<00:17, 8.94MB/s]
8%|▊ | 13.0M/170M [00:01<00:17, 8.95MB/s]
8%|▊ | 14.0M/170M [00:01<00:17, 9.18MB/s]
9%|▊ | 14.9M/170M [00:01<00:16, 9.18MB/s]
9%|▉ | 15.9M/170M [00:01<00:17, 9.09MB/s]
10%|▉ | 16.8M/170M [00:01<00:16, 9.09MB/s]
10%|█ | 17.7M/170M [00:01<00:16, 9.10MB/s]
11%|█ | 18.6M/170M [00:02<00:16, 9.05MB/s]
11%|█▏ | 19.5M/170M [00:02<00:16, 9.03MB/s]
12%|█▏ | 20.4M/170M [00:02<00:16, 9.04MB/s]
13%|█▎ | 21.4M/170M [00:02<00:16, 8.99MB/s]
13%|█▎ | 22.3M/170M [00:02<00:16, 8.84MB/s]
14%|█▎ | 23.2M/170M [00:02<00:16, 8.78MB/s]
14%|█▍ | 24.1M/170M [00:02<00:16, 8.66MB/s]
15%|█▍ | 25.0M/170M [00:02<00:17, 8.28MB/s]
15%|█▌ | 25.8M/170M [00:02<00:18, 7.94MB/s]
16%|█▌ | 26.6M/170M [00:03<00:18, 7.66MB/s]
16%|█▌ | 27.4M/170M [00:03<00:19, 7.49MB/s]
17%|█▋ | 28.2M/170M [00:03<00:19, 7.32MB/s]
17%|█▋ | 28.9M/170M [00:03<00:19, 7.20MB/s]
17%|█▋ | 29.7M/170M [00:03<00:19, 7.08MB/s]
18%|█▊ | 30.4M/170M [00:03<00:19, 7.01MB/s]
18%|█▊ | 31.1M/170M [00:03<00:19, 6.99MB/s]
19%|█▊ | 31.8M/170M [00:03<00:19, 6.98MB/s]
19%|█▉ | 32.5M/170M [00:03<00:20, 6.90MB/s]
20%|█▉ | 33.3M/170M [00:04<00:20, 6.80MB/s]
20%|█▉ | 33.9M/170M [00:04<00:20, 6.71MB/s]
20%|██ | 34.6M/170M [00:04<00:20, 6.65MB/s]
21%|██ | 35.3M/170M [00:04<00:20, 6.61MB/s]
21%|██ | 36.0M/170M [00:04<00:20, 6.58MB/s]
22%|██▏ | 36.7M/170M [00:04<00:20, 6.55MB/s]
22%|██▏ | 37.4M/170M [00:04<00:20, 6.54MB/s]
22%|██▏ | 38.0M/170M [00:04<00:20, 6.51MB/s]
23%|██▎ | 38.7M/170M [00:04<00:20, 6.47MB/s]
23%|██▎ | 39.3M/170M [00:04<00:20, 6.40MB/s]
23%|██▎ | 40.0M/170M [00:05<00:20, 6.27MB/s]
24%|██▍ | 40.6M/170M [00:05<00:21, 6.16MB/s]
24%|██▍ | 41.3M/170M [00:05<00:21, 6.09MB/s]
25%|██▍ | 41.9M/170M [00:05<00:21, 6.11MB/s]
25%|██▍ | 42.5M/170M [00:05<00:21, 6.09MB/s]
25%|██▌ | 43.2M/170M [00:05<00:20, 6.13MB/s]
26%|██▌ | 43.8M/170M [00:05<00:20, 6.12MB/s]
26%|██▌ | 44.4M/170M [00:05<00:20, 6.14MB/s]
26%|██▋ | 45.1M/170M [00:05<00:20, 6.16MB/s]
27%|██▋ | 45.7M/170M [00:06<00:20, 6.23MB/s]
27%|██▋ | 46.4M/170M [00:06<00:20, 6.19MB/s]
28%|██▊ | 47.0M/170M [00:06<00:19, 6.19MB/s]
28%|██▊ | 47.6M/170M [00:06<00:19, 6.29MB/s]
28%|██▊ | 48.3M/170M [00:06<00:19, 6.33MB/s]
29%|██▊ | 49.0M/170M [00:06<00:19, 6.39MB/s]
29%|██▉ | 49.6M/170M [00:06<00:18, 6.36MB/s]
29%|██▉ | 50.3M/170M [00:06<00:19, 6.25MB/s]
30%|██▉ | 50.9M/170M [00:06<00:19, 6.09MB/s]
30%|███ | 51.5M/170M [00:06<00:19, 5.98MB/s]
31%|███ | 52.2M/170M [00:07<00:19, 5.93MB/s]
31%|███ | 52.8M/170M [00:07<00:20, 5.82MB/s]
31%|███▏ | 53.4M/170M [00:07<00:20, 5.72MB/s]
32%|███▏ | 54.0M/170M [00:07<00:20, 5.57MB/s]
32%|███▏ | 54.6M/170M [00:07<00:21, 5.47MB/s]
32%|███▏ | 55.1M/170M [00:07<00:21, 5.35MB/s]
33%|███▎ | 55.7M/170M [00:07<00:21, 5.26MB/s]
33%|███▎ | 56.2M/170M [00:07<00:22, 5.19MB/s]
33%|███▎ | 56.8M/170M [00:07<00:22, 5.16MB/s]
34%|███▎ | 57.3M/170M [00:08<00:22, 5.14MB/s]
34%|███▍ | 57.8M/170M [00:08<00:22, 5.05MB/s]
34%|███▍ | 58.3M/170M [00:08<00:22, 5.09MB/s]
35%|███▍ | 58.9M/170M [00:08<00:22, 5.06MB/s]
35%|███▍ | 59.4M/170M [00:08<00:21, 5.07MB/s]
35%|███▌ | 59.9M/170M [00:08<00:21, 5.09MB/s]
35%|███▌ | 60.4M/170M [00:08<00:21, 5.08MB/s]
36%|███▌ | 60.9M/170M [00:08<00:21, 5.07MB/s]
36%|███▌ | 61.5M/170M [00:08<00:21, 5.07MB/s]
36%|███▋ | 62.0M/170M [00:08<00:21, 5.09MB/s]
37%|███▋ | 62.5M/170M [00:09<00:21, 5.10MB/s]
37%|███▋ | 63.0M/170M [00:09<00:21, 5.11MB/s]
37%|███▋ | 63.6M/170M [00:09<00:21, 5.08MB/s]
38%|███▊ | 64.1M/170M [00:09<00:20, 5.10MB/s]
38%|███▊ | 64.6M/170M [00:09<00:20, 5.10MB/s]
38%|███▊ | 65.1M/170M [00:09<00:20, 5.08MB/s]
39%|███▊ | 65.7M/170M [00:09<00:20, 5.09MB/s]
39%|███▉ | 66.2M/170M [00:09<00:20, 5.11MB/s]
39%|███▉ | 66.7M/170M [00:09<00:20, 5.06MB/s]
39%|███▉ | 67.2M/170M [00:10<00:20, 5.09MB/s]
40%|███▉ | 67.8M/170M [00:10<00:20, 5.10MB/s]
40%|████ | 68.3M/170M [00:10<00:20, 5.08MB/s]
40%|████ | 68.8M/170M [00:10<00:19, 5.09MB/s]
41%|████ | 69.3M/170M [00:10<00:19, 5.07MB/s]
41%|████ | 69.9M/170M [00:10<00:19, 5.07MB/s]
41%|████▏ | 70.4M/170M [00:10<00:19, 5.08MB/s]
42%|████▏ | 70.9M/170M [00:10<00:19, 5.10MB/s]
42%|████▏ | 71.4M/170M [00:10<00:19, 5.07MB/s]
42%|████▏ | 72.0M/170M [00:10<00:19, 5.07MB/s]
43%|████▎ | 72.5M/170M [00:11<00:19, 5.00MB/s]
43%|████▎ | 73.0M/170M [00:11<00:19, 4.98MB/s]
43%|████▎ | 73.5M/170M [00:11<00:19, 4.95MB/s]
43%|████▎ | 74.1M/170M [00:11<00:19, 4.89MB/s]
44%|████▎ | 74.5M/170M [00:11<00:19, 4.87MB/s]
44%|████▍ | 75.1M/170M [00:11<00:19, 4.88MB/s]
44%|████▍ | 75.6M/170M [00:11<00:19, 4.89MB/s]
45%|████▍ | 76.1M/170M [00:11<00:19, 4.92MB/s]
45%|████▍ | 76.6M/170M [00:11<00:19, 4.91MB/s]
45%|████▌ | 77.1M/170M [00:12<00:18, 4.93MB/s]
46%|████▌ | 77.7M/170M [00:12<00:18, 4.93MB/s]
46%|████▌ | 78.2M/170M [00:12<00:19, 4.78MB/s]
46%|████▌ | 78.7M/170M [00:12<00:18, 4.95MB/s]
46%|████▋ | 79.3M/170M [00:12<00:18, 5.02MB/s]
47%|████▋ | 79.8M/170M [00:12<00:18, 5.04MB/s]
47%|████▋ | 80.3M/170M [00:12<00:17, 5.02MB/s]
47%|████▋ | 80.8M/170M [00:12<00:17, 5.01MB/s]
48%|████▊ | 81.4M/170M [00:12<00:17, 4.99MB/s]
48%|████▊ | 81.9M/170M [00:12<00:17, 5.03MB/s]
48%|████▊ | 82.4M/170M [00:13<00:17, 4.93MB/s]
49%|████▊ | 82.9M/170M [00:13<00:17, 4.96MB/s]
49%|████▉ | 83.5M/170M [00:13<00:17, 4.95MB/s]
49%|████▉ | 84.0M/170M [00:13<00:17, 4.97MB/s]
50%|████▉ | 84.5M/170M [00:13<00:17, 4.84MB/s]
50%|████▉ | 85.0M/170M [00:13<00:17, 4.78MB/s]
50%|█████ | 85.5M/170M [00:13<00:18, 4.70MB/s]
50%|█████ | 86.0M/170M [00:13<00:18, 4.67MB/s]
51%|█████ | 86.5M/170M [00:13<00:18, 4.64MB/s]
51%|█████ | 87.0M/170M [00:14<00:18, 4.63MB/s]
51%|█████▏ | 87.5M/170M [00:14<00:18, 4.56MB/s]
52%|█████▏ | 87.9M/170M [00:14<00:18, 4.54MB/s]
52%|█████▏ | 88.4M/170M [00:14<00:18, 4.54MB/s]
52%|█████▏ | 88.8M/170M [00:14<00:18, 4.49MB/s]
52%|█████▏ | 89.3M/170M [00:14<00:18, 4.49MB/s]
53%|█████▎ | 89.8M/170M [00:14<00:18, 4.47MB/s]
53%|█████▎ | 90.2M/170M [00:14<00:17, 4.46MB/s]
53%|█████▎ | 90.7M/170M [00:14<00:17, 4.52MB/s]
53%|█████▎ | 91.2M/170M [00:14<00:17, 4.50MB/s]
54%|█████▎ | 91.6M/170M [00:15<00:17, 4.50MB/s]
54%|█████▍ | 92.1M/170M [00:15<00:17, 4.52MB/s]
54%|█████▍ | 92.5M/170M [00:15<00:17, 4.52MB/s]
55%|█████▍ | 93.0M/170M [00:15<00:17, 4.43MB/s]
55%|█████▍ | 93.5M/170M [00:15<00:17, 4.37MB/s]
55%|█████▌ | 93.9M/170M [00:15<00:17, 4.31MB/s]
55%|█████▌ | 94.4M/170M [00:15<00:17, 4.28MB/s]
56%|█████▌ | 94.8M/170M [00:15<00:17, 4.32MB/s]
56%|█████▌ | 95.3M/170M [00:15<00:17, 4.37MB/s]
56%|█████▌ | 95.7M/170M [00:16<00:16, 4.40MB/s]
56%|█████▋ | 96.2M/170M [00:16<00:16, 4.42MB/s]
57%|█████▋ | 96.7M/170M [00:16<00:16, 4.38MB/s]
57%|█████▋ | 97.1M/170M [00:16<00:16, 4.39MB/s]
57%|█████▋ | 97.6M/170M [00:16<00:16, 4.39MB/s]
58%|█████▊ | 98.0M/170M [00:16<00:16, 4.36MB/s]
58%|█████▊ | 98.5M/170M [00:16<00:16, 4.37MB/s]
58%|█████▊ | 99.0M/170M [00:16<00:16, 4.36MB/s]
58%|█████▊ | 99.4M/170M [00:16<00:16, 4.34MB/s]
59%|█████▊ | 99.9M/170M [00:16<00:16, 4.30MB/s]
59%|█████▉ | 100M/170M [00:17<00:16, 4.28MB/s]
59%|█████▉ | 101M/170M [00:17<00:16, 4.22MB/s]
59%|█████▉ | 101M/170M [00:17<00:16, 4.22MB/s]
60%|█████▉ | 102M/170M [00:17<00:16, 4.21MB/s]
60%|█████▉ | 102M/170M [00:17<00:16, 4.17MB/s]
60%|██████ | 102M/170M [00:17<00:16, 4.17MB/s]
60%|██████ | 103M/170M [00:17<00:16, 4.18MB/s]
61%|██████ | 103M/170M [00:17<00:16, 4.18MB/s]
61%|██████ | 104M/170M [00:17<00:16, 4.13MB/s]
61%|██████ | 104M/170M [00:18<00:16, 4.12MB/s]
61%|██████▏ | 105M/170M [00:18<00:16, 4.09MB/s]
62%|██████▏ | 105M/170M [00:18<00:16, 4.06MB/s]
62%|██████▏ | 105M/170M [00:18<00:16, 4.01MB/s]
62%|██████▏ | 106M/170M [00:18<00:16, 3.99MB/s]
62%|██████▏ | 106M/170M [00:18<00:15, 4.03MB/s]
63%|██████▎ | 107M/170M [00:18<00:15, 4.03MB/s]
63%|██████▎ | 107M/170M [00:18<00:15, 4.00MB/s]
63%|██████▎ | 108M/170M [00:18<00:15, 4.00MB/s]
63%|██████▎ | 108M/170M [00:18<00:15, 3.99MB/s]
64%|██████▎ | 108M/170M [00:19<00:15, 3.95MB/s]
64%|██████▍ | 109M/170M [00:19<00:15, 3.91MB/s]
64%|██████▍ | 109M/170M [00:19<00:15, 3.91MB/s]
64%|██████▍ | 110M/170M [00:19<00:15, 3.93MB/s]
65%|██████▍ | 110M/170M [00:19<00:15, 3.96MB/s]
65%|██████▍ | 111M/170M [00:19<00:15, 3.99MB/s]
65%|██████▌ | 111M/170M [00:19<00:14, 3.98MB/s]
65%|██████▌ | 111M/170M [00:19<00:14, 3.98MB/s]
66%|██████▌ | 112M/170M [00:19<00:14, 4.00MB/s]
66%|██████▌ | 112M/170M [00:20<00:14, 4.01MB/s]
66%|██████▌ | 113M/170M [00:20<00:14, 3.98MB/s]
66%|██████▋ | 113M/170M [00:20<00:14, 3.96MB/s]
67%|██████▋ | 114M/170M [00:20<00:14, 3.96MB/s]
67%|██████▋ | 114M/170M [00:20<00:14, 3.96MB/s]
67%|██████▋ | 114M/170M [00:20<00:14, 3.90MB/s]
67%|██████▋ | 115M/170M [00:20<00:14, 3.86MB/s]
68%|██████▊ | 115M/170M [00:20<00:14, 3.85MB/s]
68%|██████▊ | 116M/170M [00:20<00:14, 3.84MB/s]
68%|██████▊ | 116M/170M [00:21<00:14, 3.87MB/s]
68%|██████▊ | 116M/170M [00:21<00:14, 3.86MB/s]
68%|██████▊ | 117M/170M [00:21<00:13, 3.86MB/s]
69%|██████▊ | 117M/170M [00:21<00:13, 3.88MB/s]
69%|██████▉ | 118M/170M [00:21<00:13, 3.87MB/s]
69%|██████▉ | 118M/170M [00:21<00:13, 3.88MB/s]
69%|██████▉ | 118M/170M [00:21<00:13, 3.90MB/s]
70%|██████▉ | 119M/170M [00:21<00:13, 3.89MB/s]
70%|██████▉ | 119M/170M [00:21<00:13, 3.93MB/s]
70%|███████ | 120M/170M [00:21<00:12, 3.94MB/s]
70%|███████ | 120M/170M [00:22<00:12, 3.95MB/s]
71%|███████ | 120M/170M [00:22<00:12, 3.98MB/s]
71%|███████ | 121M/170M [00:22<00:12, 4.03MB/s]
71%|███████ | 121M/170M [00:22<00:12, 4.06MB/s]
71%|███████▏ | 122M/170M [00:22<00:12, 4.03MB/s]
72%|███████▏ | 122M/170M [00:22<00:12, 4.03MB/s]
72%|███████▏ | 123M/170M [00:22<00:12, 3.97MB/s]
72%|███████▏ | 123M/170M [00:22<00:11, 3.96MB/s]
72%|███████▏ | 123M/170M [00:22<00:11, 3.93MB/s]
73%|███████▎ | 124M/170M [00:23<00:11, 3.90MB/s]
73%|███████▎ | 124M/170M [00:23<00:12, 3.85MB/s]
73%|███████▎ | 125M/170M [00:23<00:12, 3.78MB/s]
73%|███████▎ | 125M/170M [00:23<00:12, 3.75MB/s]
74%|███████▎ | 125M/170M [00:23<00:12, 3.73MB/s]
74%|███████▍ | 126M/170M [00:23<00:12, 3.69MB/s]
74%|███████▍ | 126M/170M [00:23<00:12, 3.68MB/s]
74%|███████▍ | 127M/170M [00:23<00:11, 3.67MB/s]
74%|███████▍ | 127M/170M [00:23<00:11, 3.63MB/s]
75%|███████▍ | 127M/170M [00:23<00:11, 3.65MB/s]
75%|███████▍ | 128M/170M [00:24<00:11, 3.64MB/s]
75%|███████▌ | 128M/170M [00:24<00:11, 3.62MB/s]
75%|███████▌ | 129M/170M [00:24<00:11, 3.58MB/s]
76%|███████▌ | 129M/170M [00:24<00:11, 3.57MB/s]
76%|███████▌ | 129M/170M [00:24<00:11, 3.56MB/s]
76%|███████▌ | 130M/170M [00:24<00:11, 3.57MB/s]
76%|███████▋ | 130M/170M [00:24<00:11, 3.56MB/s]
76%|███████▋ | 130M/170M [00:24<00:11, 3.56MB/s]
77%|███████▋ | 131M/170M [00:24<00:11, 3.59MB/s]
77%|███████▋ | 131M/170M [00:25<00:10, 3.66MB/s]
77%|███████▋ | 132M/170M [00:25<00:10, 3.76MB/s]
77%|███████▋ | 132M/170M [00:25<00:09, 3.85MB/s]
78%|███████▊ | 132M/170M [00:25<00:09, 3.87MB/s]
78%|███████▊ | 133M/170M [00:25<00:09, 3.90MB/s]
78%|███████▊ | 133M/170M [00:25<00:09, 3.93MB/s]
78%|███████▊ | 134M/170M [00:25<00:09, 3.96MB/s]
79%|███████▊ | 134M/170M [00:25<00:09, 3.95MB/s]
79%|███████▉ | 135M/170M [00:25<00:09, 3.96MB/s]
79%|███████▉ | 135M/170M [00:25<00:08, 3.96MB/s]
79%|███████▉ | 135M/170M [00:26<00:08, 3.98MB/s]
80%|███████▉ | 136M/170M [00:26<00:08, 3.97MB/s]
80%|███████▉ | 136M/170M [00:26<00:08, 3.97MB/s]
80%|████████ | 137M/170M [00:26<00:08, 3.97MB/s]
80%|████████ | 137M/170M [00:26<00:08, 3.98MB/s]
81%|████████ | 138M/170M [00:26<00:08, 3.95MB/s]
81%|████████ | 138M/170M [00:26<00:08, 3.96MB/s]
81%|████████ | 138M/170M [00:26<00:08, 3.97MB/s]
81%|████████▏ | 139M/170M [00:26<00:07, 3.99MB/s]
82%|████████▏ | 139M/170M [00:27<00:07, 3.96MB/s]
82%|████████▏ | 140M/170M [00:27<00:07, 3.93MB/s]
82%|████████▏ | 140M/170M [00:27<00:07, 3.90MB/s]
82%|████████▏ | 140M/170M [00:27<00:07, 3.89MB/s]
83%|████████▎ | 141M/170M [00:27<00:07, 3.87MB/s]
83%|████████▎ | 141M/170M [00:27<00:07, 3.86MB/s]
83%|████████▎ | 142M/170M [00:27<00:07, 3.83MB/s]
83%|████████▎ | 142M/170M [00:27<00:07, 3.81MB/s]
84%|████████▎ | 142M/170M [00:27<00:07, 3.77MB/s]
84%|████████▍ | 143M/170M [00:28<00:07, 3.73MB/s]
84%|████████▍ | 143M/170M [00:28<00:07, 3.70MB/s]
84%|████████▍ | 144M/170M [00:28<00:07, 3.71MB/s]
84%|████████▍ | 144M/170M [00:28<00:07, 3.70MB/s]
85%|████████▍ | 144M/170M [00:28<00:07, 3.66MB/s]
85%|████████▍ | 145M/170M [00:28<00:07, 3.66MB/s]
85%|████████▌ | 145M/170M [00:28<00:06, 3.67MB/s]
85%|████████▌ | 146M/170M [00:28<00:06, 3.66MB/s]
86%|████████▌ | 146M/170M [00:28<00:06, 3.67MB/s]
86%|████████▌ | 146M/170M [00:28<00:06, 3.71MB/s]
86%|████████▌ | 147M/170M [00:29<00:06, 3.75MB/s]
86%|████████▋ | 147M/170M [00:29<00:06, 3.74MB/s]
87%|████████▋ | 148M/170M [00:29<00:06, 3.78MB/s]
87%|████████▋ | 148M/170M [00:29<00:05, 3.84MB/s]
87%|████████▋ | 148M/170M [00:29<00:05, 3.89MB/s]
87%|████████▋ | 149M/170M [00:29<00:05, 3.90MB/s]
88%|████████▊ | 149M/170M [00:29<00:05, 3.90MB/s]
88%|████████▊ | 150M/170M [00:29<00:05, 3.94MB/s]
88%|████████▊ | 150M/170M [00:29<00:05, 3.95MB/s]
88%|████████▊ | 151M/170M [00:30<00:05, 3.96MB/s]
89%|████████▊ | 151M/170M [00:30<00:04, 3.93MB/s]
89%|████████▉ | 151M/170M [00:30<00:04, 3.95MB/s]
89%|████████▉ | 152M/170M [00:30<00:04, 3.96MB/s]
89%|████████▉ | 152M/170M [00:30<00:04, 3.94MB/s]
90%|████████▉ | 153M/170M [00:30<00:04, 3.91MB/s]
90%|████████▉ | 153M/170M [00:30<00:04, 3.91MB/s]
90%|█████████ | 153M/170M [00:30<00:04, 3.92MB/s]
90%|█████████ | 154M/170M [00:30<00:04, 3.94MB/s]
91%|█████████ | 154M/170M [00:30<00:04, 3.93MB/s]
91%|█████████ | 155M/170M [00:31<00:03, 3.95MB/s]
91%|█████████ | 155M/170M [00:31<00:03, 3.95MB/s]
91%|█████████▏| 156M/170M [00:31<00:03, 3.97MB/s]
92%|█████████▏| 156M/170M [00:31<00:03, 3.95MB/s]
92%|█████████▏| 156M/170M [00:31<00:03, 3.96MB/s]
92%|█████████▏| 157M/170M [00:31<00:03, 3.96MB/s]
92%|█████████▏| 157M/170M [00:31<00:03, 3.98MB/s]
93%|█████████▎| 158M/170M [00:31<00:03, 3.96MB/s]
93%|█████████▎| 158M/170M [00:31<00:03, 3.96MB/s]
93%|█████████▎| 159M/170M [00:32<00:03, 3.97MB/s]
93%|█████████▎| 159M/170M [00:32<00:02, 4.01MB/s]
94%|█████████▎| 159M/170M [00:32<00:02, 4.04MB/s]
94%|█████████▍| 160M/170M [00:32<00:02, 4.08MB/s]
94%|█████████▍| 160M/170M [00:32<00:02, 4.11MB/s]
94%|█████████▍| 161M/170M [00:32<00:02, 4.15MB/s]
94%|█████████▍| 161M/170M [00:32<00:02, 4.13MB/s]
95%|█████████▍| 162M/170M [00:32<00:02, 4.14MB/s]
95%|█████████▍| 162M/170M [00:32<00:02, 4.16MB/s]
95%|█████████▌| 162M/170M [00:32<00:01, 4.17MB/s]
95%|█████████▌| 163M/170M [00:33<00:01, 4.14MB/s]
96%|█████████▌| 163M/170M [00:33<00:01, 4.14MB/s]
96%|█████████▌| 164M/170M [00:33<00:01, 4.13MB/s]
96%|█████████▌| 164M/170M [00:33<00:01, 4.14MB/s]
96%|█████████▋| 165M/170M [00:33<00:01, 4.11MB/s]
97%|█████████▋| 165M/170M [00:33<00:01, 4.11MB/s]
97%|█████████▋| 165M/170M [00:33<00:01, 4.13MB/s]
97%|█████████▋| 166M/170M [00:33<00:01, 4.17MB/s]
97%|█████████▋| 166M/170M [00:33<00:01, 4.13MB/s]
98%|█████████▊| 167M/170M [00:34<00:00, 4.14MB/s]
98%|█████████▊| 167M/170M [00:34<00:00, 4.14MB/s]
98%|█████████▊| 168M/170M [00:34<00:00, 4.16MB/s]
98%|█████████▊| 168M/170M [00:34<00:00, 4.13MB/s]
99%|█████████▊| 168M/170M [00:34<00:00, 4.12MB/s]
99%|█████████▉| 169M/170M [00:34<00:00, 4.10MB/s]
99%|█████████▉| 169M/170M [00:34<00:00, 4.04MB/s]
99%|█████████▉| 170M/170M [00:34<00:00, 3.95MB/s]
100%|█████████▉| 170M/170M [00:34<00:00, 3.90MB/s]
100%|█████████▉| 170M/170M [00:34<00:00, 3.88MB/s]
100%|██████████| 170M/170M [00:34<00:00, 4.87MB/s]
為了好玩,我們來展示一些訓練影像。
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

truck plane plane horse
2. 定義卷積神經網路#
從“神經網路”部分複製神經網路,並修改它以接受 3 通道影像(而不是之前定義的 1 通道影像)。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
3. 定義損失函式和最佳化器#
讓我們使用分類交叉熵損失和帶動量的 SGD。
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. 訓練網路#
這時事情開始變得有趣起來。我們只需遍歷資料迭代器,將輸入饋送到網路並進行最佳化。
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.171
[1, 4000] loss: 1.850
[1, 6000] loss: 1.662
[1, 8000] loss: 1.594
[1, 10000] loss: 1.540
[1, 12000] loss: 1.500
[2, 2000] loss: 1.422
[2, 4000] loss: 1.386
[2, 6000] loss: 1.369
[2, 8000] loss: 1.329
[2, 10000] loss: 1.310
[2, 12000] loss: 1.307
Finished Training
讓我們快速儲存訓練好的模型。
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
有關儲存 PyTorch 模型的更多詳細資訊,請參見 此處。
5. 在測試資料上測試網路#
我們對訓練資料集進行了 2 個週期的訓練。但我們需要檢查網路是否真的學到了一些東西。
我們將透過預測神經網路輸出的類別標籤,並將其與真實標籤進行比較來檢查這一點。如果預測正確,我們就將該樣本新增到正確預測列表。
好的,第一步。讓我們顯示一張測試集影像,以便熟悉一下。
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

GroundTruth: cat ship ship plane
接下來,讓我們重新載入儲存的模型(注意:這裡不一定需要儲存和重新載入模型,我們只是為了演示如何操作)。
net = Net()
net.load_state_dict(torch.load(PATH, weights_only=True))
<All keys matched successfully>
好的,現在讓我們看看神經網路如何看待上述示例。
輸出是 10 個類別的能量值。一個類別的能量值越高,網路就越認為該影像屬於該特定類別。因此,讓我們獲取最高能量的索引。
Predicted: cat car car plane
結果看起來相當不錯。
讓我們看看網路在整個資料集上的表現。
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 56 %
這看起來比隨機猜測(隨機從 10 個類別中選擇一個,準確率為 10%)好多了。看來網路學到了一些東西。
嗯,哪些類別表現好,哪些類別表現不好?
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 64.7 %
Accuracy for class: car is 71.7 %
Accuracy for class: bird is 42.9 %
Accuracy for class: cat is 40.5 %
Accuracy for class: deer is 52.4 %
Accuracy for class: dog is 53.2 %
Accuracy for class: frog is 61.4 %
Accuracy for class: horse is 65.0 %
Accuracy for class: ship is 57.4 %
Accuracy for class: truck is 55.5 %
好的,那麼接下來呢?
如何在 GPU 上執行這些神經網路?
在 GPU 上訓練#
就像將 Tensor 傳輸到 GPU 一樣,您也可以將神經網路傳輸到 GPU。
如果我們有 CUDA 可用,讓我們首先定義我們的裝置為第一個可見的 CUDA 裝置。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
cuda:0
本節的其餘部分假設 `device` 是一個 CUDA 裝置。
然後,這些方法將遞迴地遍歷所有模組,並將它們的引數和緩衝區轉換為 CUDA Tensor。
net.to(device)
請記住,您需要在每一步將輸入和目標也傳送到 GPU。
為什麼我沒有注意到與 CPU 相比有巨大的速度提升?因為您的網路非常小。
練習: 嘗試增加網路的寬度(第一個 `nn.Conv2d` 的第二個引數,以及第二個 `nn.Conv2d` 的第一個引數——它們需要相同),看看能獲得多大的速度提升。
已達成目標:
高層次地理解 PyTorch 的 Tensor 庫和神經網路。
訓練一個小神經網路來對影像進行分類。
在多個 GPU 上訓練#
如果您想使用所有 GPU 獲得更大的速度提升,請參閱 可選:資料並行性。
接下來去哪裡?#
del dataiter
指令碼總執行時間: (1 分鐘 57.226 秒)