WandaSparsifier¶
- class torchao.sparsity.WandaSparsifier(sparsity_level: float = 0.5, semi_structured_block_size: Optional[int] = None)[source]¶
Wanda 稀疏化器
Wanda (Pruning by Weights and activations),發表於 https://arxiv.org/abs/2306.11695,是一種感知啟用的剪枝方法。該稀疏化器根據輸入啟用範數與權重大小的乘積來移除權重。
此稀疏化器由三個變數控制:1. sparsity_level 定義了要歸零的稀疏塊的數量;
- 引數:
sparsity_level – 目標的稀疏度級別;
model – 要進行稀疏化的模型;
- prepare(model: Module, config: List[Dict]) None[source]¶
準備模型,透過新增引數化。
注意
The model is modified inplace. If you need to preserve the original model, use copy.deepcopy.
- squash_mask(params_to_keep: Optional[Tuple[str, ...]] = None, params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, *args, **kwargs)[source]¶
將稀疏掩碼壓縮到相應的張量中。
如果設定了 params_to_keep 或 params_to_keep_per_layer,則模組將附加一個 sparse_params 字典。
- 引數:
params_to_keep – 要在模組中儲存的鍵的列表,或表示將儲存稀疏引數的模組和鍵的字典
params_to_keep_per_layer – 用於指定要為特定層儲存的引數的字典。字典中的鍵應為模組的 fqn,而值應為字串列表,包含要在 sparse_params 中儲存的變數名稱
示例
>>> # xdoctest: +SKIP("locals are undefined") >>> # Don't save any sparse params >>> sparsifier.squash_mask() >>> hasattr(model.submodule1, "sparse_params") False
>>> # Keep sparse params per layer >>> sparsifier.squash_mask( ... params_to_keep_per_layer={ ... "submodule1.linear1": ("foo", "bar"), ... "submodule2.linear42": ("baz",), ... } ... ) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'baz': 0.1}
>>> # Keep sparse params for all layers >>> sparsifier.squash_mask(params_to_keep=("foo", "bar")) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'foo': 42, 'bar': 24}
>>> # Keep some sparse params for all layers, and specific ones for >>> # some other layers >>> sparsifier.squash_mask( ... params_to_keep=("foo", "bar"), ... params_to_keep_per_layer={"submodule2.linear42": ("baz",)}, ... ) >>> print(model.submodule1.linear1.sparse_params) {'foo': 42, 'bar': 24} >>> print(model.submodule2.linear42.sparse_params) {'foo': 42, 'bar': 24, 'baz': 0.1}