評價此頁

torch.nn.utils.prune.global_unstructured#

torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)[原始碼]#

透過應用指定的 pruning_method,全域性性地修剪 parameters 中所有引數對應的張量。

透過以下方式就地修改模組:

  1. 新增一個名為 name+'_mask' 的命名緩衝區,對應於剪枝方法應用於引數 name 的二值掩碼。

  2. 用剪枝後的版本替換引數 name,同時將原始(未剪枝)引數儲存在一個名為 name+'_orig' 的新引數中。

引數
  • parameters (Iterable of (module, name) tuples) – 要以全域性方式進行剪枝的模型引數,即在決定剪枝哪些引數之前先聚合所有權重。module 必須是 nn.Module 型別,name 必須是字串。

  • pruning_method (function) – 來自此模組的有效剪枝函式,或者使用者實現的滿足實現指南並具有 PRUNING_TYPE='unstructured' 的自定義函式。

  • importance_scores (dict) – 一個字典,將 (module, name) 元組對映到對應的引數的重要性分數張量。該張量應與引數具有相同的形狀,並用於計算剪枝的掩碼。如果未指定或為 None,則將使用引數本身作為其重要性分數。

  • kwargs – 其他關鍵字引數,例如:amount (int 或 float):要在指定引數中進行剪枝的引數數量。如果為 float,應介於 0.0 和 1.0 之間,表示要剪枝的引數的比例。如果為 int,表示要剪枝的引數的絕對數量。

引發

TypeError – 如果 PRUNING_TYPE != 'unstructured'

注意

由於全域性結構化剪枝在引數範數未歸一化之前沒有太大意義,因此我們現在將全域性剪枝的範圍限制在非結構化方法。

示例

>>> from torch.nn.utils import prune
>>> from collections import OrderedDict
>>> net = nn.Sequential(
...     OrderedDict(
...         [
...             ("first", nn.Linear(10, 4)),
...             ("second", nn.Linear(4, 1)),
...         ]
...     )
... )
>>> parameters_to_prune = (
...     (net.first, "weight"),
...     (net.second, "weight"),
... )
>>> prune.global_unstructured(
...     parameters_to_prune,
...     pruning_method=prune.L1Unstructured,
...     amount=10,
... )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10)