Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Pruning schedule supports fpgm (#3110)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Nov 25, 2020
1 parent 055885d commit 9d0b6fa
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 40 deletions.
33 changes: 10 additions & 23 deletions nni/algorithms/compression/pytorch/pruning/admm_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import torch
from schema import And, Optional
import copy

from nni.compression.pytorch.utils.config_validation import CompressorSchema
from .constants import MASKER_DICT
Expand Down Expand Up @@ -53,7 +54,7 @@ def trainer(model, criterion, optimizer, epoch, callback):
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
Expand Down Expand Up @@ -87,7 +88,7 @@ def validate_config(self, model, config_list):
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self._base_algo in ['l1', 'l2']:
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Expand All @@ -96,7 +97,7 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def _projection(self, weight, sparsity):
def _projection(self, weight, sparsity, wrapper):
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Expand All @@ -106,31 +107,17 @@ def _projection(self, weight, sparsity):
original matrix
sparsity : float
the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
-------
tensor
the projected matrix
'''
w_abs = weight.abs()
if self._base_algo == 'level':
k = int(weight.numel() * sparsity)
if k == 0:
mask_weight = torch.ones(weight.shape).type_as(weight)
else:
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
mask_weight = torch.gt(w_abs, threshold).type_as(weight)
elif self._base_algo in ['l1', 'l2']:
filters = weight.size(0)
num_prune = int(filters * sparsity)
if filters < 2 or num_prune < 1:
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
else:
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask_weight = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)

return weight.data.mul(mask_weight)
wrapper_copy = copy.deepcopy(wrapper)
wrapper_copy.module.weight.data = weight
return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])

def compress(self):
"""
Expand Down Expand Up @@ -179,7 +166,7 @@ def callback():
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + U[i]
Z[i] = self._projection(z, wrapper.config['sparsity'])
Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
U[i] = U[i] + wrapper.module.weight.data - Z[i]

# apply prune
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def evaluator(model):
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
Expand Down Expand Up @@ -151,7 +151,7 @@ def validate_config(self, model, config_list):
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self._base_algo in ['l1', 'l2']:
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Licensed under the MIT license.


from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner
from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner

PRUNER_DICT = {
'level': LevelPruner,
'l1': L1FilterPruner,
'l2': L2FilterPruner
'l2': L2FilterPruner,
'fpgm': FPGMPruner
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def evaluator(model):
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
sparsity_per_iteration : float
sparsity to prune in each iteration.
Expand Down Expand Up @@ -125,7 +125,7 @@ def validate_config(self, model, config_list):
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self._base_algo in ['l1', 'l2']:
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Expand All @@ -149,7 +149,7 @@ def _update_config_list(self, config_list, op_name, sparsity):
return config_list_updated

# if op_name is not in self._config_list_generated, create a new json item
if self._base_algo in ['l1', 'l2']:
if self._base_algo in ['l1', 'l2', 'fpgm']:
config_list_updated.append(
{'sparsity': sparsity, 'op_types': ['Conv2d'], 'op_names': [op_name]})
elif self._base_algo == 'level':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SensitivityPruner(Pruner):
>>> loss.backward()
>>> optimizer.step()
base_algo: str
base pruning algorithm. `level`, `l1` or `l2`, by default `l1`.
base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`.
sparsity_proportion_calc: function
This function generate the sparsity proportion between the conv layers according to the
sensitivity analysis results. We provide a default function to quantify the sparsity
Expand Down Expand Up @@ -150,7 +150,7 @@ def validate_config(self, model, config_list):
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self.base_algo in ['l1', 'l2']:
elif self.base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def evaluator(model):
optimize_mode : str
Optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
Expand Down Expand Up @@ -120,7 +120,7 @@ def validate_config(self, model, config_list):
Optional('op_types'): [str],
Optional('op_names'): [str],
}], model, _logger)
elif self._base_algo in ['l1', 'l2']:
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Expand Down Expand Up @@ -152,7 +152,7 @@ def _sparsities_2_config_list(self, sparsities):
# a layer with more weights will have no less pruning rate
for idx, wrapper in enumerate(self.get_modules_wrapper()):
# L1Filter Pruner requires to specify op_types
if self._base_algo in ['l1', 'l2']:
if self._base_algo in ['l1', 'l2', 'fpgm']:
config_list.append(
{'sparsity': sparsities[idx], 'op_types': ['Conv2d'], 'op_names': [wrapper.name]})
elif self._base_algo == 'level':
Expand Down
35 changes: 30 additions & 5 deletions test/ut/sdk/test_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,37 @@ def validate_sparsity(wrapper, sparsity, bias=False):
lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
]
},
'autocompress': {
'autocompress_l1': {
'pruner_class': AutoCompressPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'base_algo': 'l1',
'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': []
},
'autocompress_l2': {
'pruner_class': AutoCompressPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'base_algo': 'l2',
'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': []
},
'autocompress_fpgm': {
'pruner_class': AutoCompressPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'base_algo': 'fpgm',
'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
Expand All @@ -181,7 +206,7 @@ def __init__(self, bias=True):
def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))

def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress'], bias=True):
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True):
for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -203,8 +228,8 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator'])
elif pruner_name == 'admm':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
elif pruner_name == 'autocompress':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x)
elif pruner_name.startswith('autocompress'):
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x, base_algo=prune_config[pruner_name]['base_algo'])
else:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress()
Expand Down Expand Up @@ -272,7 +297,7 @@ def test_pruners_no_bias(self):
pruners_test(bias=False)

def test_agp_pruner(self):
for pruning_algorithm in ['l1', 'l2', 'taylorfo', 'apoz']:
for pruning_algorithm in ['l1', 'l2', 'fpgm', 'taylorfo', 'apoz']:
_test_agp(pruning_algorithm)

for pruning_algorithm in ['level']:
Expand Down

0 comments on commit 9d0b6fa

Please sign in to comment.