PruningContainer#
- class torch.nn.utils.prune.PruningContainer(*args)[source]#
用於迭代剪枝的剪枝方法序列的容器。
跟蹤應用剪枝方法的順序,並處理連續剪枝呼叫的合併。
接受 BasePruningMethod 的例項或其可迭代物件作為引數。
- add_pruning_method(method)[source]#
向容器新增一個子剪枝方法
method。- 引數
method (BasePruningMethod 的子類) – 要新增到容器的子剪枝方法。
- classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- 引數
module (nn.Module) – module containing the tensor to prune
- 返回
pruned version of the input tensor
- 返回型別
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source]#
透過計算新的部分掩碼來應用最新的
method,並將其與default_mask組合返回。新的部分掩碼應在未被
default_mask清零的條目或通道上計算。新掩碼將根據PRUNING_TYPE(由型別處理程式處理)從張量t的哪些部分計算,這取決於。對於“非結構化”,掩碼將從非零條目的展平列表中計算;
對於“結構化”,掩碼將從張量中的非零通道計算;
對於“全域性”,掩碼將在所有條目中計算。
- 引數
t (torch.Tensor) – 表示要剪枝的引數的張量(與
default_mask的維度相同)。default_mask (torch.Tensor) – 前一個剪枝迭代的掩碼。
- 返回
組合了
default_mask和當前剪枝method的新掩碼的效果(與default_mask和t的維度相同)。- 返回型別
mask (torch.Tensor)
- prune(t, default_mask=None, importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t.根據
compute_mask()中指定的剪枝規則。- 引數
t (torch.Tensor) – 要剪枝的張量(維度與
default_mask相同)。importance_scores (torch.Tensor) – 重要性分數張量(與
t形狀相同),用於計算剪枝t的掩碼。此張量中的值指示正在剪枝的t中相應元素的 গুরুত্ব。如果未指定或為 None,則將使用張量t本身。default_mask (torch.Tensor, optional) – 前一個剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時需要考慮。如果為 None,則預設為一個全為 1 的掩碼。
- 返回
張量
t的修剪版本。