評價此頁

BasePruningMethod#

class torch.nn.utils.prune.BasePruningMethod[source]#

用於建立新剪枝技術的抽象基類。

提供了一個用於自定義的骨架,需要重寫諸如 compute_mask()apply() 等方法。

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]#

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

引數
  • module (nn.Module) – module containing the tensor to prune

  • name (str) – 在 module 中執行剪枝操作的引數名稱。

  • args – 傳遞給 BasePruningMethod 子類的引數

  • importance_scores (torch.Tensor) – 重要性得分張量(形狀與模組引數相同),用於計算剪枝的掩碼。此張量中的值表示正在剪枝的引數中相應元素的重要性。如果未指定或為 None,則將使用引數本身。

  • kwargs – 傳遞給 BasePruningMethod 子類的關鍵字引數

apply_mask(module)[source]#

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

引數

module (nn.Module) – module containing the tensor to prune

返回

pruned version of the input tensor

返回型別

pruned_tensor (torch.Tensor)

abstract compute_mask(t, default_mask)[source]#

計算並返回輸入張量 t 的掩碼。

從一個基礎的 default_mask(如果張量尚未被剪枝,則應為全為 1 的掩碼)開始,根據特定的剪枝方法規則生成一個隨機掩碼,以應用於 default_mask 之上。

引數
  • t (torch.Tensor) – 表示待剪枝引數重要性得分的張量

  • prune. (parameter to) –

  • default_mask (torch.Tensor) – 上一次迭代剪枝的基礎掩碼

  • iterations

  • is (that need to be respected after the new mask) –

  • t. (applied. Same dims as) –

返回

應用於 t 的掩碼,維度與 t 相同

返回型別

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]#

Compute and returns a pruned version of input tensor t.

根據 compute_mask() 中指定的剪枝規則。

引數
  • t (torch.Tensor) – 要剪枝的張量(維度與 default_mask 相同)。

  • importance_scores (torch.Tensor) – 重要性分數張量(與 t 形狀相同),用於計算剪枝 t 的掩碼。此張量中的值指示正在剪枝的 t 中相應元素的 গুরুত্ব。如果未指定或為 None,則將使用張量 t 本身。

  • default_mask (torch.Tensor, optional) – 前一個剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時需要考慮。如果為 None,則預設為一個全為 1 的掩碼。

返回

張量 t 的修剪版本。

remove(module)[source]#

從模組中移除修剪重引數化。

名為 name 的已剪枝引數將永久保持剪枝狀態,名為 name+'_orig' 的引數將從引數列表中移除。類似地,名為 name+'_mask' 的緩衝區也將從緩衝區中移除。

注意

修剪本身**不會**被撤銷或恢復!