diff --git a/docs/en_US/Compression/DependencyAware.rst b/docs/en_US/Compression/DependencyAware.rst index d68a2b22c0..05daafe7f2 100644 --- a/docs/en_US/Compression/DependencyAware.rst +++ b/docs/en_US/Compression/DependencyAware.rst @@ -54,11 +54,11 @@ To enable the dependency-aware mode for ``L1FilterPruner``\ : # for FPGMPruner # pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) # for ActivationAPoZRankFilterPruner - # pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1, , dependency_aware=True, dummy_input=dummy_input) + # pruner = ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input) # for ActivationMeanRankFilterPruner - # pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) + # pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input) # for TaylorFOWeightFilterPruner - # pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) + # pruner = TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input) pruner.compress() diff --git a/docs/en_US/Compression/Framework.rst b/docs/en_US/Compression/Framework.rst index fa46b60230..453163b5ef 100644 --- a/docs/en_US/Compression/Framework.rst +++ b/docs/en_US/Compression/Framework.rst @@ -29,8 +29,7 @@ Compressor is the base class for pruner and quntizer, it provides a unified inte 'op_types': ['Conv2d', 'Linear'], }] - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) - pruner = LevelPruner(model, configure_list, optimizer) + pruner = LevelPruner(model, configure_list) model = pruner.compress() # model is ready for pruning, now start finetune the model, @@ -103,7 +102,8 @@ Users can also remove this collector like this: Pruner ------ -A pruner receives ``model``\ , ``config_list`` and ``optimizer`` as arguments. It prunes the model per the ``config_list`` during training loop by adding a hook on ``optimizer.step()``. +A pruner receives ``model`` , ``config_list`` as arguments. +Some pruners like ``TaylorFOWeightFilter Pruner`` prune the model per the ``config_list`` during training loop by adding a hook on ``optimizer.step()``. Pruner class is a subclass of Compressor, so it contains everything in the Compressor class and some additional components only for pruning, it contains: diff --git a/docs/en_US/Compression/Pruner.rst b/docs/en_US/Compression/Pruner.rst index eb9c32c875..3e6ba420ce 100644 --- a/docs/en_US/Compression/Pruner.rst +++ b/docs/en_US/Compression/Pruner.rst @@ -71,7 +71,7 @@ PyTorch code from nni.algorithms.compression.pytorch.pruning import SlimPruner config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] - pruner = SlimPruner(model, config_list) + pruner = SlimPruner(model, config_list, optimizer, trainer, criterion) pruner.compress() User configuration for Slim Pruner @@ -269,7 +269,7 @@ PyTorch code 'sparsity': 0.5, 'op_types': ['Conv2d'] }] - pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1) + pruner = ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1) pruner.compress() Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. @@ -304,7 +304,7 @@ PyTorch code 'sparsity': 0.5, 'op_types': ['Conv2d'] }] - pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1) + pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1) pruner.compress() Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. @@ -344,7 +344,7 @@ PyTorch code 'sparsity': 0.5, 'op_types': ['Conv2d'] }] - pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1) + pruner = TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1) pruner.compress() User configuration for TaylorFOWeightFilter Pruner @@ -389,7 +389,7 @@ PyTorch code # optimizer.step(), so an optimizer is required to prune the model. optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) - pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm='level') + pruner = AGPPruner(model, config_list, optimizer, trainer, criterion, pruning_algorithm='level') pruner.compress() AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, however you can set ``pruning_algorithm`` parameter to other values to use other pruning algorithms: @@ -404,14 +404,6 @@ AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, howev * ``apoz``\ : ActivationAPoZRankFilterPruner * ``mean_activation``\ : ActivationMeanRankFilterPruner -You should add code below to update epoch number when you finish one epoch in your training code. - -PyTorch code - -.. code-block:: python - - pruner.update_epoch(epoch) - User configuration for AGP Pruner ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -620,7 +612,7 @@ PyTorch code 'op_types': ['Conv2d'], 'op_names': ['conv2'] }] - pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=30, epochs=5) + pruner = ADMMPruner(model, config_list, trainer, num_iterations=30, epochs_per_iteration=5) pruner.compress() You can view :githublink:`example ` for more information. diff --git a/docs/en_US/Compression/QuickStart.rst b/docs/en_US/Compression/QuickStart.rst index 1113e5a753..0e4b33b692 100644 --- a/docs/en_US/Compression/QuickStart.rst +++ b/docs/en_US/Compression/QuickStart.rst @@ -31,17 +31,16 @@ The specification of configuration can be found `here <./Tutorial.rst#specify-th Step2. Choose a pruner and compress the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we also define an optimizer and pass it to the pruner. +First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we may also define an optimizer and pass it to the pruner. .. code-block:: python from nni.algorithms.compression.pytorch.pruning import LevelPruner - optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.01) - pruner = LevelPruner(model, config_list, optimizer_finetune) + pruner = LevelPruner(model, config_list) model = pruner.compress() -Then, you can train your model using traditional training approach (e.g., SGD), pruning is applied transparently during the training. Some pruners (e.g., L1FilterPruner, FPGMPruner) prune once at the beginning, the following training can be seen as fine-tune. Some pruners (e.g., AGPPruner) prune your model iteratively, the masks are adjusted epoch by epoch during training. +Some pruners (e.g., L1FilterPruner, FPGMPruner) prune once, some pruners (e.g., AGPPruner) prune your model iteratively, the masks are adjusted epoch by epoch during training. Note that, ``pruner.compress`` simply adds masks on model weights, it does not include fine-tuning logic. If users want to fine tune the compressed model, they need to write the fine tune logic by themselves after ``pruner.compress``. diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py index 99cb661c7f..7c593e5c7e 100644 --- a/examples/model_compress/pruning/basic_pruners_torch.py +++ b/examples/model_compress/pruning/basic_pruners_torch.py @@ -231,10 +231,10 @@ def trainer(model, optimizer, criterion, epoch): kw_args['criterion'] = criterion if args.pruner in ('mean_activation', 'apoz', 'taylorfo'): - kw_args['sparsity_training_epochs'] = 1 + kw_args['sparsifying_training_batches'] = 1 if args.pruner == 'slim': - kw_args['sparsity_training_epochs'] = 5 + kw_args['sparsifying_training_epochs'] = 5 if args.pruner == 'agp': kw_args['pruning_algorithm'] = 'l1' diff --git a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py index 82a8f1cb98..207a8aa2f9 100644 --- a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py @@ -34,6 +34,9 @@ class AutoCompressPruner(Pruner): Function used for the first subproblem of ADMM Pruner. Users should write this function as a normal function to train the Pytorch model and include `model, optimizer, criterion, epoch` as function arguments. + criterion: function + Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. evaluator : function function to evaluate the pruned model. This function should include `model` as the only parameter, and returns a scalar value. @@ -80,7 +83,7 @@ def evaluator(model): PATH to store temporary experiment data. """ - def __init__(self, model, config_list, trainer, criterion, evaluator, dummy_input, + def __init__(self, model, config_list, trainer, evaluator, dummy_input, criterion=torch.nn.CrossEntropyLoss(), num_iterations=3, optimize_mode='maximize', base_algo='l1', # SimulatedAnnealing related start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 18e1755b32..d5cd5dfccb 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -40,6 +40,7 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', and include `model, optimizer, criterion, epoch` as function arguments. criterion: function Function used to calculate the loss between the target and the output. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. num_iterations: int Total number of iterations in pruning process. We will calculate mask at the end of an iteration. epochs_per_iteration: Union[int, list] @@ -59,8 +60,11 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration' self.epochs_per_iteration = epochs_per_iteration else: + assert num_iterations > 0, 'num_iterations should >= 1' self.epochs_per_iteration = [epochs_per_iteration] * num_iterations + self._validate_iteration_params() + self._trainer = trainer self._criterion = criterion @@ -68,6 +72,9 @@ def _fresh_calculated(self): for wrapper in self.get_modules_wrapper(): wrapper.if_calculated = False + def _validate_iteration_params(self): + assert all(num >= 0 for num in self.epochs_per_iteration), 'all epoch number need >= 0' + def compress(self): training = self.bound_model.training self.bound_model.train() @@ -75,6 +82,10 @@ def compress(self): self._fresh_calculated() for epoch in range(epochs_num): self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch) + # NOTE: workaround for statistics_batch_num bigger than max batch number in one epoch, need refactor + if hasattr(self.masker, 'statistics_batch_num') and hasattr(self, 'iterations'): + if self.iterations < self.masker.statistics_batch_num: + self.iterations = self.masker.statistics_batch_num self.update_mask() self.bound_model.train(training) @@ -97,6 +108,7 @@ class AGPPruner(IterativePruner): Function to train the model criterion: function Function used to calculate the loss between the target and the output. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. num_iterations: int Total number of iterations in pruning process. We will calculate mask at the end of an iteration. epochs_per_iteration: int @@ -245,6 +257,7 @@ class ADMMPruner(IterativePruner): and include `model, optimizer, criterion, epoch` as function arguments. criterion: function Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. num_iterations: int Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner. epochs_per_iteration: int @@ -254,7 +267,6 @@ class ADMMPruner(IterativePruner): base_algo : str Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune. - """ def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(), @@ -396,7 +408,8 @@ class SlimPruner(IterativePruner): and include `model, optimizer, criterion, epoch` as function arguments. criterion : function Function used to calculate the loss between the target and the output. - sparsity_training_epochs: int + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. + sparsifying_training_epochs: int The number of channel sparsity regularization training epochs before pruning. scale : float Penalty parameters for sparsification. @@ -413,10 +426,10 @@ class SlimPruner(IterativePruner): should on the same device with the model. """ - def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=10, scale=0.0001, + def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_epochs=10, scale=0.0001, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion, - num_iterations=1, epochs_per_iteration=sparsity_training_epochs, dependency_aware=dependency_aware, + num_iterations=1, epochs_per_iteration=sparsifying_training_epochs, dependency_aware=dependency_aware, dummy_input=dummy_input) self.scale = scale self.patch_optimizer_before(self._callback) @@ -459,8 +472,9 @@ class TaylorFOWeightFilterPruner(IterativePruner): and include `model, optimizer, criterion, epoch` as function arguments. criterion : function Function used to calculate the loss between the target and the output. - sparsity_training_epochs: int - The number of epochs to collect the contributions. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. + sparsifying_training_batches: int + The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or @@ -472,14 +486,14 @@ class TaylorFOWeightFilterPruner(IterativePruner): dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. - """ - def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=1, dependency_aware=False, - dummy_input=None): + def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, + dependency_aware=False, dummy_input=None): super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, - criterion=criterion, num_iterations=1, epochs_per_iteration=sparsity_training_epochs, - dependency_aware=dependency_aware, dummy_input=dummy_input) + criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, + epochs_per_iteration=1, dependency_aware=dependency_aware, + dummy_input=dummy_input) def _supported_dependency_aware(self): return True @@ -503,10 +517,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner): and include `model, optimizer, criterion, epoch` as function arguments. criterion : function Function used to calculate the loss between the target and the output. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. activation: str The activation type. - sparsity_training_epochs: int - The number of epochs to statistic the activation. + sparsifying_training_batches: int + The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or @@ -522,10 +537,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner): """ def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', - sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): + sparsifying_training_batches=1, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, - activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) + activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1, + epochs_per_iteration=1) self.patch_optimizer(self.update_mask) def _supported_dependency_aware(self): @@ -550,10 +566,11 @@ class ActivationMeanRankFilterPruner(IterativePruner): and include `model, optimizer, criterion, epoch` as function arguments. criterion : function Function used to calculate the loss between the target and the output. + For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. activation: str The activation type. - sparsity_training_epochs: int - The number of batches to statistic the activation. + sparsifying_training_batches: int + The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or @@ -568,10 +585,11 @@ class ActivationMeanRankFilterPruner(IterativePruner): """ def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', - sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): + sparsifying_training_batches=1, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, - activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) + activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1, + epochs_per_iteration=1) self.patch_optimizer(self.update_mask) def _supported_dependency_aware(self): diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index 671811e138..eb3cc06ebd 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -473,7 +473,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): def __init__(self, model, pruner, statistics_batch_num=1): super().__init__(model, pruner) - self.pruner.statistics_batch_num = statistics_batch_num + self.statistics_batch_num = statistics_batch_num self.pruner.iterations = 0 self.pruner.set_wrappers_attribute("contribution", None) self.pruner.patch_optimizer(self.calc_contributions) @@ -497,13 +497,13 @@ def calc_contributions(self): Calculate the estimated importance of filters as a sum of individual contribution based on the first order taylor expansion. """ - if self.pruner.iterations >= self.pruner.statistics_batch_num: + if self.pruner.iterations >= self.statistics_batch_num: return for wrapper in self.pruner.get_modules_wrapper(): filters = wrapper.module.weight.size(0) contribution = ( - wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1) + wrapper.module.weight * wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1) if wrapper.contribution is None: wrapper.contribution = contribution else: @@ -512,7 +512,7 @@ def calc_contributions(self): self.pruner.iterations += 1 def get_channel_sum(self, wrapper, wrapper_idx): - if self.pruner.iterations < self.pruner.statistics_batch_num: + if self.pruner.iterations < self.statistics_batch_num: return None if wrapper.contribution is None: return None @@ -524,6 +524,8 @@ def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'): super().__init__(model, pruner) self.statistics_batch_num = statistics_batch_num self.pruner.hook_id = self._add_activation_collector(self.pruner) + self.pruner.iterations = 0 + self.pruner.patch_optimizer(self._iteration_counter) assert activation in ['relu', 'relu6'] if activation == 'relu': @@ -533,6 +535,9 @@ def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'): else: self.pruner.activation = None + def _iteration_counter(self): + self.pruner.iterations += 1 + def _add_activation_collector(self, pruner): def collector(collected_activation): def hook(module_, input_, output): diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index 350e46025e..e6e34d30e6 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -201,7 +201,7 @@ def test_torch_taylorFOweight_pruner(self): model = TorchModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsity_training_epochs=1) + pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1) x = torch.rand((1, 1, 28, 28), requires_grad=True) model.conv1.module.weight.data = torch.tensor(w1).float()