評價此頁

修剪教程#

創建於: 2019年7月22日 | 最後更新: 2023年11月02日 | 最後驗證: 2024年11月05日

作者: Michela Paganini

最先進的深度學習技術依賴於難以部署的過度引數化模型。相反,生物神經網路以使用高效的稀疏連線而聞名。識別透過減少模型中的引數數量來壓縮模型的最佳技術非常重要,以減少記憶體、電池和硬體消耗,而又不犧牲準確性。這反過來又允許您在裝置上部署輕量級模型,並透過私有的裝置內計算來保證隱私。在研究方面,修剪被用於研究過度引數化和欠引數化網路之間學習動態的差異,研究幸運稀疏子網路和初始化的作用(“彩票”),作為一種破壞性的神經架構搜尋技術,等等。

在本教程中,您將學習如何使用 torch.nn.utils.prune 來稀疏化您的神經網路,以及如何擴充套件它來實現您自己的自定義修剪技術。

要求#

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

建立模型#

在本教程中,我們使用了 LeCun 等人 (1998) 的 LeNet 架構。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

檢查模組#

讓我們檢查一下 LeNet 模型中(未修剪的)conv1 層。它將包含兩個引數 weightbias,目前沒有緩衝區。

[('weight', Parameter containing:
tensor([[[[ 0.1080,  0.1777, -0.0219,  0.1556, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0840],
          [ 0.1834,  0.0396,  0.1947, -0.0930, -0.1996],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.1610, -0.1277,  0.0102, -0.0191, -0.0627],
          [-0.1233, -0.0103,  0.0556,  0.0748,  0.1583],
          [ 0.0701, -0.1686,  0.0733, -0.1530, -0.0384],
          [-0.1136, -0.0863,  0.0755, -0.1585, -0.1921],
          [-0.0318,  0.1514,  0.1999,  0.0979,  0.0559]]],


        [[[ 0.0826, -0.1019, -0.1807, -0.0031,  0.1562],
          [ 0.0134,  0.0204, -0.0599, -0.0034,  0.0462],
          [ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
          [-0.1300, -0.1447,  0.0057, -0.0971, -0.1935],
          [-0.1217, -0.1738,  0.1224, -0.1521,  0.0138]]],


        [[[-0.0396, -0.1639,  0.1371, -0.1733, -0.0824],
          [-0.0278, -0.1693,  0.0440,  0.1116, -0.0702],
          [ 0.0930, -0.1650,  0.1249, -0.0173, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0841,  0.1897]]],


        [[[-0.0385,  0.1232,  0.1315,  0.1062, -0.0976],
          [ 0.1838, -0.1291,  0.1153,  0.1173,  0.0644],
          [-0.1098, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0470],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0974, -0.1663, -0.1839,  0.1924, -0.0193],
          [ 0.0538,  0.0496, -0.1254,  0.0740, -0.1996],
          [-0.0378,  0.0121,  0.1558, -0.1539, -0.1766],
          [-0.1681,  0.0488,  0.1711,  0.1994, -0.0155],
          [-0.1179,  0.0486,  0.1481, -0.0658, -0.0872]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993,  0.1291, -0.1555, -0.1203], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[]

修剪模組#

要修剪一個模組(在本例中是 LeNet 架構的 conv1 層),首先從 torch.nn.utils.prune 中提供的修剪技術中選擇一種(或透過繼承 BasePruningMethod 實現 您自己的)。然後,指定要修剪的模組和該模組內參數的名稱。最後,使用所選修剪技術所需的適當關鍵字引數,指定修剪引數。

在此示例中,我們將隨機修剪 conv1 層中名為 weight 的引數的 30% 的連線。模組作為函式的第一個引數傳遞;name 使用其字串識別符號標識該模組內的引數;amount 指示要修剪的連線的百分比(如果它是介於 0. 和 1. 之間的浮點數),或者要修剪的連線的絕對數量(如果它是一個非負整數)。

prune.random_unstructured(module, name="weight", amount=0.3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

修剪透過從引數中刪除 weight 並用名為 weight_orig 的新引數替換它(即在初始引數 name 後面追加 "_orig")來起作用。weight_orig 儲存了未修剪的張量版本。bias 未被修剪,因此它將保持不變。

print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993,  0.1291, -0.1555, -0.1203], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1080,  0.1777, -0.0219,  0.1556, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0840],
          [ 0.1834,  0.0396,  0.1947, -0.0930, -0.1996],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.1610, -0.1277,  0.0102, -0.0191, -0.0627],
          [-0.1233, -0.0103,  0.0556,  0.0748,  0.1583],
          [ 0.0701, -0.1686,  0.0733, -0.1530, -0.0384],
          [-0.1136, -0.0863,  0.0755, -0.1585, -0.1921],
          [-0.0318,  0.1514,  0.1999,  0.0979,  0.0559]]],


        [[[ 0.0826, -0.1019, -0.1807, -0.0031,  0.1562],
          [ 0.0134,  0.0204, -0.0599, -0.0034,  0.0462],
          [ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
          [-0.1300, -0.1447,  0.0057, -0.0971, -0.1935],
          [-0.1217, -0.1738,  0.1224, -0.1521,  0.0138]]],


        [[[-0.0396, -0.1639,  0.1371, -0.1733, -0.0824],
          [-0.0278, -0.1693,  0.0440,  0.1116, -0.0702],
          [ 0.0930, -0.1650,  0.1249, -0.0173, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0841,  0.1897]]],


        [[[-0.0385,  0.1232,  0.1315,  0.1062, -0.0976],
          [ 0.1838, -0.1291,  0.1153,  0.1173,  0.0644],
          [-0.1098, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0470],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0974, -0.1663, -0.1839,  0.1924, -0.0193],
          [ 0.0538,  0.0496, -0.1254,  0.0740, -0.1996],
          [-0.0378,  0.0121,  0.1558, -0.1539, -0.1766],
          [-0.1681,  0.0488,  0.1711,  0.1994, -0.0155],
          [-0.1179,  0.0486,  0.1481, -0.0658, -0.0872]]]], device='cuda:0',
       requires_grad=True))]

由上述修剪技術生成的修剪掩碼將作為名為 weight_mask 的模組緩衝區儲存(即在初始引數 name 後面追加 "_mask")。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 0., 0., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 1., 1., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [1., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [0., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]]], device='cuda:0'))]

為了使前向傳播能夠正常工作而無需修改,需要存在 weight 屬性。在 torch.nn.utils.prune 中實現的修剪技術會計算權重的修剪版本(透過將掩碼與原始引數組合),並將它們儲存在 weight 屬性中。請注意,這不再是 module 的引數,現在它只是一個屬性。

tensor([[[[ 0.0000,  0.1777, -0.0219,  0.0000, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0000],
          [ 0.1834,  0.0000,  0.0000, -0.0930, -0.0000],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.0000, -0.0000,  0.0102, -0.0191, -0.0627],
          [-0.0000, -0.0103,  0.0000,  0.0748,  0.1583],
          [ 0.0701, -0.1686,  0.0733, -0.1530, -0.0384],
          [-0.1136, -0.0000,  0.0755, -0.1585, -0.1921],
          [-0.0000,  0.0000,  0.1999,  0.0979,  0.0559]]],


        [[[ 0.0826, -0.0000, -0.1807, -0.0031,  0.1562],
          [ 0.0134,  0.0204, -0.0599, -0.0034,  0.0462],
          [ 0.1143, -0.0000, -0.0000, -0.0000, -0.0187],
          [-0.1300, -0.1447,  0.0057, -0.0000, -0.1935],
          [-0.0000, -0.0000,  0.1224, -0.1521,  0.0138]]],


        [[[-0.0396, -0.0000,  0.0000, -0.1733, -0.0000],
          [-0.0278, -0.1693,  0.0000,  0.0000, -0.0702],
          [ 0.0000, -0.0000,  0.1249, -0.0000, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0000,  0.1897]]],


        [[[-0.0385,  0.0000,  0.1315,  0.0000, -0.0976],
          [ 0.1838, -0.0000,  0.0000,  0.1173,  0.0000],
          [-0.0000, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0000],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.1924, -0.0193],
          [ 0.0000,  0.0000, -0.1254,  0.0740, -0.1996],
          [-0.0000,  0.0121,  0.0000, -0.0000, -0.1766],
          [-0.1681,  0.0488,  0.1711,  0.1994, -0.0155],
          [-0.1179,  0.0486,  0.0000, -0.0658, -0.0872]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最後,使用 PyTorch 的 forward_pre_hooks 在每次前向傳播之前應用修剪。具體來說,當 module 被修剪時,正如我們在這裡所做的那樣,它將為與之關聯的每個被修剪的引數獲取一個 forward_pre_hook。在這種情況下,由於我們到目前為止只修剪了名為 weight 的原始引數,因此只會存在一個鉤子。

print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>)])

為了完整起見,我們現在也可以修剪 bias,以瞭解模組的引數、緩衝區、鉤子和屬性如何變化。僅為了嘗試另一種修剪技術,這裡我們根據 L1 範數修剪 bias 中的 3 個最小的條目,這在 l1_unstructured 修剪函式中實現。

prune.l1_unstructured(module, name="bias", amount=3)
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

我們現在預計命名引數將包括 weight_orig(來自之前)和 bias_orig。緩衝區將包括 weight_maskbias_mask。兩個張量的修剪版本將作為模組屬性存在,並且該模組現在將具有兩個 forward_pre_hooks

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1080,  0.1777, -0.0219,  0.1556, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0840],
          [ 0.1834,  0.0396,  0.1947, -0.0930, -0.1996],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.1610, -0.1277,  0.0102, -0.0191, -0.0627],
          [-0.1233, -0.0103,  0.0556,  0.0748,  0.1583],
          [ 0.0701, -0.1686,  0.0733, -0.1530, -0.0384],
          [-0.1136, -0.0863,  0.0755, -0.1585, -0.1921],
          [-0.0318,  0.1514,  0.1999,  0.0979,  0.0559]]],


        [[[ 0.0826, -0.1019, -0.1807, -0.0031,  0.1562],
          [ 0.0134,  0.0204, -0.0599, -0.0034,  0.0462],
          [ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
          [-0.1300, -0.1447,  0.0057, -0.0971, -0.1935],
          [-0.1217, -0.1738,  0.1224, -0.1521,  0.0138]]],


        [[[-0.0396, -0.1639,  0.1371, -0.1733, -0.0824],
          [-0.0278, -0.1693,  0.0440,  0.1116, -0.0702],
          [ 0.0930, -0.1650,  0.1249, -0.0173, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0841,  0.1897]]],


        [[[-0.0385,  0.1232,  0.1315,  0.1062, -0.0976],
          [ 0.1838, -0.1291,  0.1153,  0.1173,  0.0644],
          [-0.1098, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0470],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0974, -0.1663, -0.1839,  0.1924, -0.0193],
          [ 0.0538,  0.0496, -0.1254,  0.0740, -0.1996],
          [-0.0378,  0.0121,  0.1558, -0.1539, -0.1766],
          [-0.1681,  0.0488,  0.1711,  0.1994, -0.0155],
          [-0.1179,  0.0486,  0.1481, -0.0658, -0.0872]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993,  0.1291, -0.1555, -0.1203], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 0., 0., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 1., 1., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [1., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [0., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]
print(module.bias)
tensor([ 0.1364, -0.0000, -0.1993,  0.0000, -0.1555, -0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f5e194cb0d0>)])

迭代修剪#

模組中的同一引數可以被多次修剪,各種修剪呼叫的效果等同於一系列應用各種掩碼的組合。新掩碼與舊掩碼的組合由 PruningContainercompute_mask 方法處理。

例如,假設我們現在想進一步修剪 module.weight,這次使用沿張量第 0 軸的結構化修剪(第 0 軸對應卷積層的輸出通道,對於 conv1 維度為 6),基於通道的 L2 範數。這可以透過 ln_structured 函式實現,其中 n=2dim=0

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
tensor([[[[ 0.0000,  0.1777, -0.0219,  0.0000, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0000],
          [ 0.1834,  0.0000,  0.0000, -0.0930, -0.0000],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[-0.0396, -0.0000,  0.0000, -0.1733, -0.0000],
          [-0.0278, -0.1693,  0.0000,  0.0000, -0.0702],
          [ 0.0000, -0.0000,  0.1249, -0.0000, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0000,  0.1897]]],


        [[[-0.0385,  0.0000,  0.1315,  0.0000, -0.0976],
          [ 0.1838, -0.0000,  0.0000,  0.1173,  0.0000],
          [-0.0000, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0000],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

相應的鉤子現在將是 torch.nn.utils.prune.PruningContainer 型別,並將儲存應用於 weight 引數的修剪歷史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f5e194cb040>, <torch.nn.utils.prune.LnStructured object at 0x7f5e194c88e0>]

序列化修剪後的模型#

所有相關的張量,包括掩碼緩衝區和用於計算修剪後的張量的原始引數,都儲存在模型的 state_dict 中,因此如果需要,可以輕鬆地進行序列化和儲存。

print(model.state_dict().keys())
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

移除修剪重引數化#

為了使修剪永久化,移除關於 weight_origweight_mask 的重引數化,並移除 forward_pre_hook,我們可以使用 torch.nn.utils.prune 中的 remove 功能。請注意,這並不會撤消修剪,好像它從未發生過一樣。相反,它透過將 weight 引數重新分配給模型引數(在其修剪後的版本中)來使其永久化。

在移除重引數化之前

print(list(module.named_parameters()))
[('weight_orig', Parameter containing:
tensor([[[[ 0.1080,  0.1777, -0.0219,  0.1556, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0840],
          [ 0.1834,  0.0396,  0.1947, -0.0930, -0.1996],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.1610, -0.1277,  0.0102, -0.0191, -0.0627],
          [-0.1233, -0.0103,  0.0556,  0.0748,  0.1583],
          [ 0.0701, -0.1686,  0.0733, -0.1530, -0.0384],
          [-0.1136, -0.0863,  0.0755, -0.1585, -0.1921],
          [-0.0318,  0.1514,  0.1999,  0.0979,  0.0559]]],


        [[[ 0.0826, -0.1019, -0.1807, -0.0031,  0.1562],
          [ 0.0134,  0.0204, -0.0599, -0.0034,  0.0462],
          [ 0.1143, -0.0257, -0.0628, -0.1107, -0.0187],
          [-0.1300, -0.1447,  0.0057, -0.0971, -0.1935],
          [-0.1217, -0.1738,  0.1224, -0.1521,  0.0138]]],


        [[[-0.0396, -0.1639,  0.1371, -0.1733, -0.0824],
          [-0.0278, -0.1693,  0.0440,  0.1116, -0.0702],
          [ 0.0930, -0.1650,  0.1249, -0.0173, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0841,  0.1897]]],


        [[[-0.0385,  0.1232,  0.1315,  0.1062, -0.0976],
          [ 0.1838, -0.1291,  0.1153,  0.1173,  0.0644],
          [-0.1098, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0470],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0974, -0.1663, -0.1839,  0.1924, -0.0193],
          [ 0.0538,  0.0496, -0.1254,  0.0740, -0.1996],
          [-0.0378,  0.0121,  0.1558, -0.1539, -0.1766],
          [-0.1681,  0.0488,  0.1711,  0.1994, -0.0155],
          [-0.1179,  0.0486,  0.1481, -0.0658, -0.0872]]]], device='cuda:0',
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993,  0.1291, -0.1555, -0.1203], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 0., 0., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 0., 0., 1., 0.],
          [1., 1., 0., 0., 1.],
          [0., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [1., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]
tensor([[[[ 0.0000,  0.1777, -0.0219,  0.0000, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0000],
          [ 0.1834,  0.0000,  0.0000, -0.0930, -0.0000],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[-0.0396, -0.0000,  0.0000, -0.1733, -0.0000],
          [-0.0278, -0.1693,  0.0000,  0.0000, -0.0702],
          [ 0.0000, -0.0000,  0.1249, -0.0000, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0000,  0.1897]]],


        [[[-0.0385,  0.0000,  0.1315,  0.0000, -0.0976],
          [ 0.1838, -0.0000,  0.0000,  0.1173,  0.0000],
          [-0.0000, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0000],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

移除重引數化之後

prune.remove(module, 'weight')
print(list(module.named_parameters()))
[('bias_orig', Parameter containing:
tensor([ 0.1364, -0.0281, -0.1993,  0.1291, -0.1555, -0.1203], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.1777, -0.0219,  0.0000, -0.1139],
          [ 0.1726, -0.0548, -0.0320,  0.0731, -0.1833],
          [-0.1037,  0.1749,  0.0621,  0.1258, -0.0000],
          [ 0.1834,  0.0000,  0.0000, -0.0930, -0.0000],
          [-0.1061, -0.0661,  0.0420, -0.1807,  0.1205]]],


        [[[-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]],


        [[[-0.0396, -0.0000,  0.0000, -0.1733, -0.0000],
          [-0.0278, -0.1693,  0.0000,  0.0000, -0.0702],
          [ 0.0000, -0.0000,  0.1249, -0.0000, -0.0074],
          [ 0.1675,  0.0054, -0.1918, -0.0846, -0.0560],
          [-0.1026,  0.1980, -0.1918,  0.0000,  0.1897]]],


        [[[-0.0385,  0.0000,  0.1315,  0.0000, -0.0976],
          [ 0.1838, -0.0000,  0.0000,  0.1173,  0.0000],
          [-0.0000, -0.1352,  0.1762,  0.0470,  0.1758],
          [ 0.1444, -0.1419,  0.1106,  0.0789,  0.0000],
          [ 0.0996,  0.0549,  0.0470,  0.1610,  0.1657]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
[('bias_mask', tensor([1., 0., 1., 0., 1., 0.], device='cuda:0'))]

修剪模型中的多個引數#

透過指定所需的修剪技術和引數,我們可以輕鬆地修剪網路中的多個張量,也許根據它們的型別,正如我們將在本例中看到的。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全域性修剪#

到目前為止,我們只看了通常被稱為“區域性”修剪的內容,即逐個修剪模型中的張量,僅將每個條目的統計資料(權重幅度、啟用、梯度等)與其在張量中的其他條目進行比較。然而,一種常見且可能更強大的技術是同時修剪模型,例如,透過移除整個模型中最低的 20% 的連線,而不是移除每個層中最低的 20% 的連線。這可能會導致每層的修剪百分比不同。讓我們看看如何使用 torch.nn.utils.prune 中的 global_unstructured 來實現這一點。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

現在我們可以檢查每個修剪引數中引起的稀疏性,它不會等於每層的 20%。然而,全域性稀疏性將是(大約)20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 6.67%
Sparsity in conv2.weight: 13.96%
Sparsity in fc1.weight: 22.14%
Sparsity in fc2.weight: 12.31%
Sparsity in fc3.weight: 9.64%
Global sparsity: 20.00%

使用自定義修剪函式擴充套件 torch.nn.utils.prune#

要實現自己的修剪函式,您可以像其他所有修剪方法一樣,透過繼承 BasePruningMethod 基類來擴充套件 nn.utils.prune 模組。基類為您實現了以下方法:__call__apply_maskapplypruneremove。除了某些特殊情況,您無需為新的修剪技術重新實現這些方法。但是,您將需要實現 __init__(建構函式)和 compute_mask(根據您的修剪技術的邏輯計算給定張量掩碼的說明)。此外,您必須指定此技術實現哪種型別的修剪(支援的選項是 globalstructuredunstructured)。這對於確定如何在迭代應用修剪時組合掩碼是必需的。換句話說,當修剪一個預先修剪過的引數時,當前修剪技術應該作用於引數中未修剪的部分。指定 PRUNING_TYPE 將使 PruningContainer(負責迭代應用修剪掩碼)能夠正確識別要修剪的引數切片。

例如,假設您想實現一種修剪技術,該技術修剪張量中的每隔一個條目(或者——如果張量先前已被修剪——則修剪張量中剩餘未修剪部分中的每隔一個條目)。這將是 PRUNING_TYPE='unstructured',因為它作用於層中的單個連線,而不是作用於整個單元/通道('structured')或跨不同引數('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

現在,要將其應用於 nn.Module 中的引數,您還應該提供一個簡單的函式來例項化該方法並應用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

讓我們試試看!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

指令碼總執行時間: (0 分鐘 0.559 秒)