LnStructured#
- class torch.nn.utils.prune.LnStructured(amount, n, dim=-1)[source]#
根據 L
n範數修剪張量中的整個(當前未修剪的)通道。- 引數
- classmethod apply(module, name, amount, n, dim, importance_scores=None)[source]#
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- 引數
module (nn.Module) – module containing the tensor to prune
name (str) – 在
module中執行剪枝操作的引數名稱。amount (int 或 float) – 要剪枝的引數數量。如果是
float,則應介於 0.0 和 1.0 之間,表示要剪枝的引數的比例。如果是int,則表示要剪枝的引數的絕對數量。n (int, float, inf, -inf, 'fro', 'nuc') – 請參閱
torch.norm()中引數p的有效條目文件。dim (int) – 定義要修剪通道的維度的索引。
importance_scores (torch.Tensor) – 用於計算修剪掩碼的重要性分數張量(形狀與模組引數相同)。此張量中的值表示要修剪的引數中相應元素的重要性。如果未指定或為 None,則將使用模組引數本身。
- apply_mask(module)[source]#
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- 引數
module (nn.Module) – module containing the tensor to prune
- 返回
pruned version of the input tensor
- 返回型別
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source]#
計算並返回輸入張量
t的掩碼。從基礎
default_mask(如果張量尚未被修剪,則應為全 1 的掩碼)開始,生成一個掩碼,透過將具有最低 Ln-範數的通道沿指定維度置零,來應用於default_mask之上。- 引數
t (torch.Tensor) – 表示要修剪的引數的張量
default_mask (torch.Tensor) – 來自先前修剪迭代的基礎掩碼,在應用新掩碼後需要保留。與
t的維度相同。
- 返回
應用於
t的掩碼,維度與t相同- 返回型別
mask (torch.Tensor)
- 引發
IndexError – 如果
self.dim >= len(t.shape)
- prune(t, default_mask=None, importance_scores=None)[source]#
Compute and returns a pruned version of input tensor
t.根據
compute_mask()中指定的修剪規則進行操作。- 引數
t (torch.Tensor) – 要剪枝的張量(維度與
default_mask相同)。importance_scores (torch.Tensor) – 重要性分數張量(與
t形狀相同),用於計算剪枝t的掩碼。此張量中的值指示正在剪枝的t中相應元素的 গুরুত্ব。如果未指定或為 None,則將使用張量t本身。default_mask (torch.Tensor, optional) – 前一個剪枝迭代的掩碼(如果有)。在確定剪枝應作用於張量的哪個部分時需要考慮。如果為 None,則預設為一個全為 1 的掩碼。
- 返回
張量
t的修剪版本。