Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Sensitivity pruner #2684

Merged
merged 14 commits into from
Aug 11, 2020
72 changes: 72 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,75 @@ We try to reproduce the experiment result of the fully connected network on MNIS
![](../../img/lottery_ticket_mnist_fc.png)

The above figure shows the result of the fully connected network. `round0-sparsity-0.0` is the performance without pruning. Consistent with the paper, pruning around 80% also obtain similar performance compared to non-pruning, and converges a little faster. If pruning too much, e.g., larger than 94%, the accuracy becomes lower and convergence becomes a little slower. A little different from the paper, the trend of the data in the paper is relatively more clear.


## Sensitivity Pruner
For each round, SensitivityPruner prunes the model based on the sensitivity to the accuracy of each layer until meeting the final configured sparsity of the whole model:
1. Analyze the sensitivity of each layer in the current state of the model.
2. Prune each layer according to the sensitivity.

For more details, please refer to [Learning both Weights and Connections for Efficient Neural Networks ](https://arxiv.org/abs/1506.02626).

#### Usage

PyTorch code

```python
from nni.compression.torch import SensitivityPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = SensitivityPruner(model, config_list, finetuner=fine_tuner, evaluator=evaluator)
# eval_args and finetune_args are the parameters passed to the evaluator and finetuner respectively
pruner.compress(eval_args=[model], finetune_args=[model])
```


#### User configuration for Sensitivity Pruner

- **sparsity:** The target overall sparsity.
- **op_types:** The operation type to prune. If `base_algo` is `l1` or `l2`, then only `Conv2d` is supported as `op_types`.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved

- **evaluator:** Function to evaluate the masked model. This function should return a scalar value(such as the accuracy). The input parameters of the evaluator can be passed by the `eval_args` and `eval_kwargs` parameter in the compress function.
Example::
```python
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> val_loader = ...
>>> model.eval()
>>> correct = 0
>>> with torch.no_grad():
>>> for data, target in val_loader:
>>> data, target = data.to(device), target.to(device)
>>> output = model(data)
>>> # get the index of the max log-probability
>>> pred = output.argmax(dim=1, keepdim=True)
>>> correct += pred.eq(target.view_as(pred)).sum().item()
>>> accuracy = correct / len(val_loader.dataset)
>>> return accuracy
```
- **finetuner:** Function used to finetune the model after each iteration if needed. If this parameter is not set, then SensitivityPruner will not finetune the model after each iteration. The input parameters of the finetuner can be passed by the `finetune_args` and `finetune_kwargs` in the `compression` function if needed.

Example:
```python
>>> def finetuner(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ...
>>> model.train()
>>> for batch_idx, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad()
>>> output = model(data)
>>> loss = criterion(output, target)
>>> loss.backward()
>>> optimizer.step()
```


- **base_algo:** Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`.
Given the sparsity distribution among the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
- **sparsity_proportion_calc:** This function generate the sparsity proportion between the conv layers according to the sensitivity analysis results. We provide a default function to quantify the sparsity proportion according to the sensitivity analysis results. Users can also customize this function according to their needs. The input of this function is a dict, for example : {'conv1' : {0.1: 0.9, 0.2 : 0.8}, 'conv2' : {0.1: 0.9, 0.2 : 0.8}}, in which, 'conv1' and is the name of the conv layer, and 0.1:0.9 means when the sparsity of conv1 is 0.1 (10%), the model's val accuracy equals to 0.9.
- **sparsity_per_iter:** The sparsity of the model that the pruner try to prune in each iteration..
- **acc_drop_threshold:** The hyperparameter used to quantifiy the sensitivity for each layer.
- **checkpoint_dir:** The dir path to save the checkpoints during the pruning.
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/compression/torch/pruning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .net_adapt_pruner import NetAdaptPruner
from .admm_pruner import ADMMPruner
from .auto_compress_pruner import AutoCompressPruner
from .sensitivity_pruner import SensitivityPruner
Loading