torch.nn.utils.prune.global_unstructured#
- torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)[原始碼]#
透過應用指定的
pruning_method,全域性性地修剪parameters中所有引數對應的張量。透過以下方式就地修改模組:
新增一個名為
name+'_mask'的命名緩衝區,對應於剪枝方法應用於引數name的二值掩碼。用剪枝後的版本替換引數
name,同時將原始(未剪枝)引數儲存在一個名為name+'_orig'的新引數中。
- 引數
parameters (Iterable of (module, name) tuples) – 要以全域性方式進行剪枝的模型引數,即在決定剪枝哪些引數之前先聚合所有權重。module 必須是
nn.Module型別,name 必須是字串。pruning_method (function) – 來自此模組的有效剪枝函式,或者使用者實現的滿足實現指南並具有
PRUNING_TYPE='unstructured'的自定義函式。importance_scores (dict) – 一個字典,將 (module, name) 元組對映到對應的引數的重要性分數張量。該張量應與引數具有相同的形狀,並用於計算剪枝的掩碼。如果未指定或為 None,則將使用引數本身作為其重要性分數。
kwargs – 其他關鍵字引數,例如:amount (int 或 float):要在指定引數中進行剪枝的引數數量。如果為
float,應介於 0.0 和 1.0 之間,表示要剪枝的引數的比例。如果為int,表示要剪枝的引數的絕對數量。
- 引發
TypeError – 如果
PRUNING_TYPE != 'unstructured'
注意
由於全域性結構化剪枝在引數範數未歸一化之前沒有太大意義,因此我們現在將全域性剪枝的範圍限制在非結構化方法。
示例
>>> from torch.nn.utils import prune >>> from collections import OrderedDict >>> net = nn.Sequential( ... OrderedDict( ... [ ... ("first", nn.Linear(10, 4)), ... ("second", nn.Linear(4, 1)), ... ] ... ) ... ) >>> parameters_to_prune = ( ... (net.first, "weight"), ... (net.second, "weight"), ... ) >>> prune.global_unstructured( ... parameters_to_prune, ... pruning_method=prune.L1Unstructured, ... amount=10, ... ) >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) tensor(10)