From 0c3f8476e718111bc67ae4039bdb9ee22e418d8e Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Tue, 28 Dec 2021 03:52:37 +0000 Subject: [PATCH 1/2] Dynamic graph support to Automatic SParsity. 1. Added functions step and clear_grad to OptimizerWithSparsityGuarantee. 2. Added step function to ASPHelper. 3. Added prune_model_by_layer and renamed original prune_mode to prune_model_by_program. 4. Move paddle.static.sparsity to paddle.sparsity --- python/paddle/__init__.py | 1 + .../paddle/fluid/contrib/sparsity/__init__.py | 3 +- python/paddle/fluid/contrib/sparsity/asp.py | 470 ++++++++++++++---- .../contrib/sparsity/supported_layer_list.py | 85 ++++ .../paddle/{static => }/sparsity/__init__.py | 10 +- python/paddle/static/__init__.py | 1 - python/setup.py.in | 2 +- 7 files changed, 469 insertions(+), 103 deletions(-) create mode 100644 python/paddle/fluid/contrib/sparsity/supported_layer_list.py rename python/paddle/{static => }/sparsity/__init__.py (70%) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 77aceef50f8d5..7434a17008732 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -56,6 +56,7 @@ import paddle.jit # noqa: F401 import paddle.amp # noqa: F401 +import paddle.sparsity # noqa: F401 import paddle.dataset # noqa: F401 import paddle.inference # noqa: F401 import paddle.io # noqa: F401 diff --git a/python/paddle/fluid/contrib/sparsity/__init__.py b/python/paddle/fluid/contrib/sparsity/__init__.py index 9bf45f4272738..ec288a1287119 100644 --- a/python/paddle/fluid/contrib/sparsity/__init__.py +++ b/python/paddle/fluid/contrib/sparsity/__init__.py @@ -29,10 +29,11 @@ from .asp import prune_model from .asp import set_excluded_layers from .asp import reset_excluded_layers +from .supported_layer_list import add_supported_layer __all__ = [ 'calculate_density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d', 'get_mask_2d_greedy', 'get_mask_2d_best', 'create_mask', 'check_sparsity', 'MaskAlgo', 'CheckMethod', 'decorate', 'prune_model', 'set_excluded_layers', - 'reset_excluded_layers' + 'reset_excluded_layers', 'add_supported_layer' ] diff --git a/python/paddle/fluid/contrib/sparsity/asp.py b/python/paddle/fluid/contrib/sparsity/asp.py index 61e3a61fc9cd2..f26f080bfaef1 100644 --- a/python/paddle/fluid/contrib/sparsity/asp.py +++ b/python/paddle/fluid/contrib/sparsity/asp.py @@ -19,28 +19,61 @@ import copy import numpy as np import paddle +from paddle.fluid.framework import dygraph_only, in_dygraph_mode from paddle.fluid import global_scope, program_guard, layers from paddle.fluid.initializer import ConstantInitializer from paddle.fluid.contrib import sparsity +from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map, _default_pruning __all__ = [ 'decorate', 'prune_model', 'set_excluded_layers', 'reset_excluded_layers' ] -def set_excluded_layers(main_program, param_names): +def set_excluded_layers(param_names, main_program=None): r""" Set parameter name of layers which would not be pruned as sparse weights. Args: + param_names (list of string): A list contains names of parameters. main_program (Program, optional): Program with model definition and its parameters. - param_names (list): A list contains names of parameters. + If None is given, then it would be set as `paddle.static.default.default_main_program(). + Default is None. Examples: .. code-block:: python import paddle - from paddle.static import sparsity + from paddle import sparsity + # Dynamic Graph + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + sparsity.set_excluded_layers(["linear_0"]) + + optimizer = sparsity.decorate(optimizer) + + + # Static Graph paddle.enable_static() main_program = paddle.static.Program() @@ -56,7 +89,7 @@ def set_excluded_layers(main_program, param_names): # Setup exluded layers out from ASP workflow. # Please note, excluded_layers must be set before calling `optimizer.minimize()`. - sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) + sparsity.set_excluded_layers(["need_dense_fc"], main_program) optimizer = paddle.optimizer.SGD(learning_rate=0.1) optimizer = paddle.static.amp.decorate(optimizer ) @@ -65,8 +98,10 @@ def set_excluded_layers(main_program, param_names): optimizer = sparsity.decorate(optimizer) optimizer.minimize(loss, startup_program) """ + if main_program is None: + main_program = paddle.static.default_main_program() ASPHelper.set_excluded_layers( - main_program=main_program, param_names=param_names) + param_names=param_names, main_program=main_program) def reset_excluded_layers(main_program=None): @@ -76,12 +111,44 @@ def reset_excluded_layers(main_program=None): Args: main_program (Program, optional): Program with model definition and its parameters. + If None is given, then this function would reset all excluded_layers. + Default is None. Examples: .. code-block:: python import paddle - from paddle.static import sparsity + from paddle import sparsity + + # Dynamic Graph + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + sparsity.set_excluded_layers(["linear_0"]) + sparsity.reset_excluded_layers(main_program) + + optimizer = sparsity.decorate(optimizer) + + # Static Graph paddle.enable_static() main_program = paddle.static.Program() @@ -109,8 +176,10 @@ def reset_excluded_layers(main_program=None): def decorate(optimizer): r""" - Wrap the given optimizer as a OptimizerWithSparsityGuarantee, - which would insert necessary ops for ASP workflows when calling minimize() + Wrap the given optimizer as a OptimizerWithSparsityGuarantee, + If runnig with dynamic graph mode. ASP would creates mask variables for supported parameters. + Else if in static graph mode, ASP would creates mask variables and inserts necessary ops + when calling minimize() Args: optimizer (Optimizer): A Optimizer used for training. @@ -120,13 +189,42 @@ def decorate(optimizer): .. code-block:: python import paddle - from paddle.static import sparsity + from paddle import sparsity + + # Dynamic Graph + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + sparsity.set_excluded_layers(["linear_0"]) + sparsity.reset_excluded_layers(main_program) - main_program = paddle.static.Program() - startup_program = paddle.static.Program() + optimizer = sparsity.decorate(optimizer) + # Static Graph paddle.enable_static() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): input_data = paddle.static.data(name='data', shape=[None, 128]) label = paddle.static.data(name='label', shape=[None, 10]) @@ -146,28 +244,21 @@ def decorate(optimizer): return ASPHelper.decorate(optimizer) -def prune_model(main_program=None, - n=2, - m=4, - mask_algo='mask_1d', - with_mask=True): +def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True): r""" - Pruning parameters of supported layers in :attr:`main_program` via + Pruning parameters of supported layers in :attr:`model` via specified mask generation function given by :attr:`mask_algo`. This function supports both training and inference controlled by :attr:`with_mask`. If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables, else only prunes parameters. - *Note*: If parameters are supported and in FP16, please set :attr:`n`=2, :attr:`m`=4, - if they in FP32, then :attr:`n`=1, :attr:`m`=2` to further enable Sparse Tensor Core acceleration. - - *Note*: If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` + *Note*: (Static graph mode) If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable). Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. Args: - main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() + model (Program|nn.Layer): Program with model definition and its parameters, or a object of `paddle.nn.Layer`. n (int): n of `n:m` sparse pattern. m (int): m of `n:m` sparse pattern. mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`. @@ -179,8 +270,39 @@ def prune_model(main_program=None, .. code-block:: python import paddle - from paddle.static import sparsity + from paddle import sparsity + + # Dynamic Graph + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + sparsity.set_excluded_layers(["linear_0"]) + sparsity.reset_excluded_layers(main_program) + optimizer = sparsity.decorate(optimizer) + # Must call `sparsity.decorate` first before calling `sparsity.prune_model` + sparsity.prune_model(main_program, mask_algo='mask_2d_best') + + # Static Graph paddle.enable_static() main_program = paddle.static.Program() @@ -223,11 +345,21 @@ def prune_model(main_program=None, 'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST } assert (mask_algo in MaskAlgo_mapping), \ - 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' - - return ASPHelper.prune_model( - place=place, - main_program=main_program, + 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' + + prune_func = None + if isinstance(model, paddle.nn.Layer): + prune_func = ASPHelper.prune_model_by_layer + elif isinstance(model, paddle.static.Program): + prune_func = ASPHelper.prune_model_by_program + else: + raise TypeError( + "model should be paddle.nn.Layer or paddle.static.Program, but got {}". + format(type(model))) + + return prune_func( + place, + model, n=n, m=m, mask_algo=MaskAlgo_mapping[mask_algo], @@ -281,12 +413,12 @@ class ASPHelper(object): """ MASK_APPENDDED_NAME = '_asp_mask' - SUPPORTED_LAYERS = {'fc': 'w_0', 'linear': 'w_0', 'conv2d': 'w_0'} + PADDLE_WEIGHT_SUFFIX = "w_" __asp_info = {} @classmethod - def set_excluded_layers(cls, main_program, param_names): + def set_excluded_layers(cls, param_names, main_program): r""" This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`. """ @@ -299,8 +431,8 @@ def reset_excluded_layers(cls, main_program=None): This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`. """ if main_program is None: - for asp_info in cls.__asp_info: - asp_info.reset_excluded_layers() + for prog in cls.__asp_info: + cls.__asp_info[prog].reset_excluded_layers() else: cls._get_program_asp_info(main_program).reset_excluded_layers() @@ -309,20 +441,24 @@ def decorate(optimizer): r""" This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`. """ + if in_dygraph_mode(): + main_prog = paddle.static.default_main_program() + startup_prog = paddle.static.default_startup_program() + ASPHelper._create_mask_variables(main_prog, startup_prog, + optimizer._parameter_list) return OptimizerWithSparsityGuarantee(optimizer) @classmethod - def prune_model(cls, - place, - main_program=None, - n=2, - m=4, - mask_algo=sparsity.MaskAlgo.MASK_1D, - with_mask=True): + def prune_model_by_program(cls, + place, + main_program=None, + n=2, + m=4, + mask_algo=sparsity.MaskAlgo.MASK_1D, + with_mask=True): r""" This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. """ - checked_func_name = sparsity.CheckMethod.get_checking_method(mask_algo) if main_program is None: main_program = paddle.static.default_main_program() @@ -333,35 +469,77 @@ def prune_model(cls, weight_tensor = global_scope().find_var(param.name).get_tensor() weight_nparray = np.array(weight_tensor) - # The double transpose ops here make sure pruning direction consistent with cuSparseLt. - # SPMMA in cuSparseLt: D = (AxB) + C, where matrix A (mxk) is sparse matrix. - # cuSparseLt would prune matrix A along k dimension. - # In sparse training, layer weight matriices is viewed sparse matrix A, so - # the math fomula should be 'Act(WX + b)'. However, default fomula in PaddlePaddle - # is 'Act(XW + b)'. For enabling SPMMA, weights and inputs should be transposed - # for computing, Act( (W^T X^T)^T + b). Therefore, we have to prune alog k dimension - # of W^T, which is m dimension of W. Moreove, all mask generating functions in - # sparsity/utils is row-major pruning. That is the reason we have to transpose weight - # matrices beforce invoking create_mask. Then we transpose the result maks to make - # sure its shape to be the same as the input weight. - weight_sparse_mask = sparsity.create_mask( - weight_nparray.T, func_name=mask_algo, n=n, m=m).T - weight_pruned_nparray = np.multiply(weight_nparray, - weight_sparse_mask) + prune_func = ASPHelper._get_prune_func_by_name(param.name) + + weight_pruned_nparray, weight_sparse_mask = \ + prune_func(weight_nparray, m, n, mask_algo, param.name) + weight_pruned_nparray = weight_pruned_nparray.astype( + weight_nparray.dtype) weight_tensor.set(weight_pruned_nparray, place) - assert sparsity.check_sparsity(weight_pruned_nparray.T, n=n, m=m, func_name=checked_func_name), \ - 'Pruning {} weight matrix failure!!!'.format(param.name) + if with_mask: weight_mask_param = global_scope().find_var( ASPHelper._get_mask_name(param.name)) assert weight_mask_param is not None, \ - 'Cannot find {} variable, please call ASPHelper.minimize' \ + 'Cannot find {} variable, please call optimizer.minimize (' \ + 'paddle.sparsity.decorate(optimizer).minimize(loss)' \ ' and initialization (exe.run(startup_program)) first!'.format(ASPHelper._get_mask_name(param.name)) weight_mask_tensor = weight_mask_param.get_tensor() + weight_sparse_mask = weight_sparse_mask.astype( + np.array(weight_mask_tensor).dtype) weight_mask_tensor.set(weight_sparse_mask, place) asp_info.update_masks(param.name, weight_sparse_mask) return asp_info.masks.copy() + @classmethod + def prune_model_by_layer(cls, + place, + layer, + n=2, + m=4, + mask_algo=sparsity.MaskAlgo.MASK_1D, + with_mask=True): + r""" + This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. + """ + if in_dygraph_mode(): + main_program = paddle.static.default_main_program() + asp_info = cls._get_program_asp_info(main_program) + + for param in layer.parameters(): + if ASPHelper._is_supported_layer(main_program, param.name): + weight_nparray = param.numpy() + + prune_func = ASPHelper._get_prune_func_by_name(param.name) + + weight_pruned_nparray, weight_sparse_mask = \ + prune_func(weight_nparray, m, n, mask_algo, param.name) + + weight_pruned_nparray = weight_pruned_nparray.astype( + weight_nparray.dtype) + param.set_value(weight_pruned_nparray) + + if with_mask: + weight_mask_param = asp_info.mask_vars.get(param.name, + None) + assert weight_mask_param is not None, \ + 'Cannot find {} variable, please call sparsity.decorate() to' \ + ' decorate your optimizer first!'.format(ASPHelper._get_mask_name(param.name)) + weight_mask_param.set_value(weight_sparse_mask) + + asp_info.update_masks(param.name, weight_sparse_mask) + + return asp_info.masks.copy() + else: + for param in layer.parameters(): + return ASPHelper.prune_model_by_program( + place, + param.block.program, + n=n, + m=m, + mask_algo=mask_algo, + with_mask=with_mask) + @staticmethod def _get_mask_name(param_name): r""" @@ -372,7 +550,7 @@ def _get_mask_name(param_name): Returns: string: The mask name of :attr:`param_name`. """ - return param_name + ASPHelper.MASK_APPENDDED_NAME + return param_name + "." + ASPHelper.MASK_APPENDDED_NAME @staticmethod def _get_not_ASP_relevant_vars(main_program): @@ -386,7 +564,9 @@ def _get_not_ASP_relevant_vars(main_program): """ var_list = [] for param in main_program.global_block().all_parameters(): - if ASPHelper.MASK_APPENDDED_NAME not in param.name: + param_name_list = param.name.split('.') + + if ASPHelper.MASK_APPENDDED_NAME not in param_name_list: var_list.append(param) return var_list @@ -422,19 +602,46 @@ def _is_supported_layer(cls, main_program, param_name): # fc_0.w_0 -> True # fc_0.b_0 -> False """ - if ASPHelper.MASK_APPENDDED_NAME in param_name: + param_name_list = param_name.split('.') + + if ASPHelper.MASK_APPENDDED_NAME in param_name_list: return False for layer in cls._get_program_asp_info(main_program).excluded_layers: if layer in param_name: return False - for name in ASPHelper.SUPPORTED_LAYERS: - if name in param_name and \ - ASPHelper.SUPPORTED_LAYERS[name] in param_name: - return True + if param_name in supported_layers_and_prune_func_map: + return True + + param_name_no_weight_suffix = param_name_list[0] + param_type_suffix = param_name_list[1] + layer_name = param_name_no_weight_suffix[:param_name_no_weight_suffix. + rfind('_')] + if ASPHelper.PADDLE_WEIGHT_SUFFIX not in param_type_suffix: + return False + + if param_name_no_weight_suffix in supported_layers_and_prune_func_map or \ + layer_name in supported_layers_and_prune_func_map: + return True + return False + @classmethod + def _get_prune_func_by_name(cls, param_name): + func = supported_layers_and_prune_func_map.get(param_name, None) + param_name_no_weight_suffix = param_name.split('.')[0] + if func is None: + func = supported_layers_and_prune_func_map.get( + param_name_no_weight_suffix, None) + if func is None: + layer_name = param_name_no_weight_suffix[: + param_name_no_weight_suffix. + rfind('_')] + func = supported_layers_and_prune_func_map.get(layer_name, + _default_pruning) + return func + @classmethod def _minimize(cls, optimizer, @@ -474,14 +681,38 @@ def _minimize(cls, optimizer_ops, params_and_grads = optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) - cls._create_mask_variables(main_program, startup_program, - params_and_grads) - cls._insert_sparse_mask_ops(main_program, params_and_grads) + + params_only = [pg[0] for pg in params_and_grads] + cls._create_mask_variables(main_program, startup_program, params_only) + cls._insert_sparse_mask_ops(main_program, params_only) return optimizer_ops, params_and_grads @classmethod - def _create_mask_variables(cls, main_program, startup_program, - params_and_grads): + @dygraph_only + def _step(cls, optimizer): + r""" + This function is a decorator of `step` function in `Optimizer`. + There are three steps: + + 1. Call :attr:`optimizer`.step() + 2. Mask parameters with sparse masks. + + *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. + (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph + cannot be modified anymore.) + + Args: + optimizer (Optimizer): A Optimizer used for training. + """ + optimizer.step() + main_prog = paddle.static.default_main_program() + startup_prog = paddle.static.default_startup_program() + with paddle.fluid.dygraph.no_grad(): + ASPHelper._insert_sparse_mask_ops(main_prog, + optimizer._parameter_list) + + @classmethod + def _create_mask_variables(cls, main_program, startup_program, params): r""" Create sparse mask Tensors according to supported layers in :attr:`main_program`. This function is called in second step of `ASPHelper._minimize` @@ -489,45 +720,41 @@ def _create_mask_variables(cls, main_program, startup_program, Args: main_program (Program): Program with model definition and its parameters. startup_program (Program): Program for initializing parameters. - params_and_grads (list): Variable pairs of parameters and their gradients. + params (list): Variable parameters. """ asp_info = cls._get_program_asp_info(main_program) with program_guard(main_program, startup_program): - for param_and_grad in params_and_grads: - if ASPHelper._is_supported_layer(main_program, - param_and_grad[0].name): - mask_param = layers.create_parameter( - name=param_and_grad[0].name + - ASPHelper.MASK_APPENDDED_NAME, - shape=param_and_grad[0].shape, - dtype=param_and_grad[0].dtype, - default_initializer=ConstantInitializer(value=1.0)) - mask_param.stop_gradient = True - mask_param.trainable = False - asp_info.update_mask_vars(param_and_grad[0].name, - mask_param) + for param in params: + if ASPHelper._is_supported_layer(main_program, param.name): + if param.name not in asp_info.mask_vars: + mask_param = layers.create_parameter( + name=ASPHelper._get_mask_name(param.name), + shape=param.shape, + dtype=param.dtype, + default_initializer=ConstantInitializer(value=1.0)) + mask_param.stop_gradient = True + mask_param.trainable = False + asp_info.update_mask_vars(param.name, mask_param) @classmethod - def _insert_sparse_mask_ops(cls, main_program, param_grads): + def _insert_sparse_mask_ops(cls, main_program, params): r""" Insert masking ops in the end of parameters update. This function is called in third step of `ASPHelper._minimize` Args: main_program (Program): Program with model definition and its parameters. - params_and_grads (list): Variable pairs of parameters and their gradients. + params (list): Variable parameters. """ block = main_program.global_block() asp_info = cls._get_program_asp_info(main_program) - for param_grad in param_grads: - if param_grad[0].name in asp_info.mask_vars: + for param in params: + if param.name in asp_info.mask_vars: block.append_op( type='elementwise_mul', - inputs={ - "X": param_grad[0], - 'Y': asp_info.mask_vars[param_grad[0].name] - }, - outputs={'Out': param_grad[0]}, + inputs={"X": param, + 'Y': asp_info.mask_vars[param.name]}, + outputs={'Out': param}, attrs={'axis': -1, 'use_mkldnn': False}) @@ -543,8 +770,9 @@ class OptimizerWithSparsityGuarantee(object): def __init__(self, optimizer): self._optimizer = optimizer - self._learning_rate = optimizer._learning_rate - self._learning_rate_map = optimizer._learning_rate_map + + def __getattr__(self, item): + return getattr(self._optimizer, item) def minimize(self, loss, @@ -569,3 +797,55 @@ def minimize(self, startup_program=startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set) + + @dygraph_only + def step(self): + r""" + This function is a decorator of `step` function in `Optimizer`. + There are three steps: + + 1. Call :attr:`optimizer`.step() + 2. Mask parameters with sparse masks. + + *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. + (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph + cannot be modified anymore.) + + Args: + optimizer (Optimizer): A Optimizer used for training. + """ + ASPHelper._step(self._optimizer) + + @dygraph_only + def state_dict(self): + r""" + This function is a decorator of `state_dict` function in `Optimizer`. + + Returns: + state_dict(dict) : dict contains all the Tensor used by optimizer + """ + state_dict = self._optimizer.state_dict() + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name, var in asp_info.mask_vars.items(): + state_dict.update({ASPHelper._get_mask_name(param_name): var}) + return state_dict + + @dygraph_only + def set_state_dict(self, state_dict): + r""" + This function is a decorator of `set_state_dict` function in `Optimizer`. + Args: + state_dict(dict) : Dict contains all the Tensor needed by optimizer + Return: + None + """ + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name, var in asp_info.mask_vars.items(): + param_mask_name = ASPHelper._get_mask_name(param_name) + assert param_mask_name in state_dict, \ + "The {} is not found.".format(param_mask_name) + var.set_value(state_dict[param_mask_name]) + asp_info.update_masks(param_name, var.numpy()) + return self._optimizer.set_state_dict(state_dict) diff --git a/python/paddle/fluid/contrib/sparsity/supported_layer_list.py b/python/paddle/fluid/contrib/sparsity/supported_layer_list.py new file mode 100644 index 0000000000000..13502ddb82668 --- /dev/null +++ b/python/paddle/fluid/contrib/sparsity/supported_layer_list.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.fluid.contrib import sparsity +import threading + +__all__ = ['add_supported_layer'] + + +def _default_pruning(weight_nparray, m, n, func_name, param_name): + + checked_func_name = sparsity.CheckMethod.get_checking_method(func_name) + + # The double transpose ops here make sure pruning direction consistent with cuSparseLt. + # SPMMA in cuSparseLt: D = (AxB) + C, where matrix A (mxk) is sparse matrix. + # cuSparseLt would prune matrix A along k dimension. + # In sparse training, layer weight matriices is viewed sparse matrix A, so + # the math fomula should be 'Act(WX + b)'. However, default fomula in PaddlePaddle + # is 'Act(XW + b)'. For enabling SPMMA, weights and inputs should be transposed + # for computing, Act( (W^T X^T)^T + b). Therefore, we have to prune alog k dimension + # of W^T, which is m dimension of W. Moreove, all mask generating functions in + # sparsity/utils is row-major pruning. That is the reason we have to transpose weight + # matrices beforce invoking create_mask. Then we transpose the result maks to make + # sure its shape to be the same as the input weight. + weight_sparse_mask = sparsity.create_mask( + weight_nparray.T, func_name=func_name, n=n, m=m).T + weight_pruned_nparray = np.multiply(weight_nparray, weight_sparse_mask) + assert sparsity.check_sparsity(weight_pruned_nparray.T, n=n, m=m, func_name=checked_func_name), \ + 'Pruning {} weight matrix failure!!!'.format(param_name) + return weight_pruned_nparray, weight_sparse_mask + + +# When value of given key in this DICT is None, +# ASP will call default pruning function in pruning stage. +_supported_layers_and_prune_func_map_lock = threading.Lock() +supported_layers_and_prune_func_map = {} + + +def add_supported_layer(layer, pruning_func=None): + r""" + Add supported layers and its corresponding pruning functino. + + Args: + name (string|Layer): The name or type of layer, needed to support. If layer is `Layer` then + it would be turn to string internally. ASP would use this name to match parameter's name and call + its the corresponding pruning function. + pruning_func (function, optional): a function type which receives five argument (weight_nparray, + m, n, func_name, param_name), weight_nparray is a nparray of weight, param_name is the name of weight, + m, n, and func_name, please see `prune_model` for details. + """ + name = None + if isinstance(layer, str): + name = layer + elif isinstance(layer, paddle.fluid.dygraph.layers.Layer): + name = paddle.fluid.dygraph.layers._convert_camel_to_snake( + type(layer).__name__) + elif issubclass(layer, paddle.fluid.dygraph.layers.Layer): + name = paddle.fluid.dygraph.layers._convert_camel_to_snake( + layer.__name__) + else: + assert "The type of layer should be string of Layer, but got {}!".format( + type(layer)) + if pruning_func is None: + pruning_func = _default_pruning + _supported_layers_and_prune_func_map_lock.acquire() + supported_layers_and_prune_func_map.update({name: pruning_func}) + _supported_layers_and_prune_func_map_lock.release() + + +add_supported_layer('fc') +add_supported_layer('linear') +add_supported_layer('conv2d') diff --git a/python/paddle/static/sparsity/__init__.py b/python/paddle/sparsity/__init__.py similarity index 70% rename from python/paddle/static/sparsity/__init__.py rename to python/paddle/sparsity/__init__.py index 59f794ef28aa4..8b9f8326f5e76 100644 --- a/python/paddle/static/sparsity/__init__.py +++ b/python/paddle/sparsity/__init__.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...fluid.contrib.sparsity import calculate_density #noqa: F401 -from ...fluid.contrib.sparsity import decorate #noqa: F401 -from ...fluid.contrib.sparsity import prune_model #noqa: F401 -from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 -from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 +from ..fluid.contrib.sparsity import calculate_density #noqa: F401 +from ..fluid.contrib.sparsity import decorate #noqa: F401 +from ..fluid.contrib.sparsity import prune_model #noqa: F401 +from ..fluid.contrib.sparsity import set_excluded_layers #noqa: F401 +from ..fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 __all__ = [ #noqa 'calculate_density', diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 92aa5000dfa58..992d2f1524f9b 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -14,7 +14,6 @@ # limitations under the License. from . import amp # noqa: F401 -from . import sparsity # noqa: F401 from . import nn # noqa: F401 from .io import save_inference_model # noqa: F401 from .io import load_inference_model # noqa: F401 diff --git a/python/setup.py.in b/python/setup.py.in index f14111c7dabb9..dd926bc7043c4 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -363,7 +363,7 @@ packages=['paddle', 'paddle.static', 'paddle.static.nn', 'paddle.static.amp', - 'paddle.static.sparsity', + 'paddle.sparsity', 'paddle.tensor', 'paddle.onnx', 'paddle.autograd', From 5c639a89086083ebd502e20bc6935d6b1434560a Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Tue, 28 Dec 2021 05:00:29 +0000 Subject: [PATCH 2/2] Added ASP related unit-tests. --- .../fluid/tests/unittests/asp/CMakeLists.txt | 8 +- .../asp/test_asp_customized_pruning.py | 270 ++++++++++++++++++ .../asp/test_asp_optimize_dynamic.py | 176 ++++++++++++ ...ptimize.py => test_asp_optimize_static.py} | 6 +- .../unittests/asp/test_asp_pruning_1d.py | 37 --- .../unittests/asp/test_asp_pruning_2d_best.py | 37 --- .../asp/test_asp_pruning_2d_greedy.py | 39 --- .../unittests/asp/test_asp_pruning_dynamic.py | 122 ++++++++ ...ing_base.py => test_asp_pruning_static.py} | 40 ++- .../tests/unittests/asp/test_asp_save_load.py | 176 ++++++++++++ .../tests/unittests/asp/test_asp_utils.py | 6 +- .../unittests/asp/test_fleet_with_asp.py | 91 ------ .../asp/test_fleet_with_asp_dynamic.py | 163 +++++++++++ ...p_amp.py => test_fleet_with_asp_static.py} | 59 +++- 14 files changed, 1014 insertions(+), 216 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py create mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py rename python/paddle/fluid/tests/unittests/asp/{test_asp_optimize.py => test_asp_optimize_static.py} (98%) delete mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py delete mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py delete mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py create mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py rename python/paddle/fluid/tests/unittests/asp/{asp_pruning_base.py => test_asp_pruning_static.py} (72%) create mode 100644 python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py delete mode 100644 python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py create mode 100644 python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py rename python/paddle/fluid/tests/unittests/asp/{test_fleet_with_asp_amp.py => test_fleet_with_asp_static.py} (68%) diff --git a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt index 364f17c2e0d0a..324466ef3d897 100644 --- a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt @@ -1,14 +1,14 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp") -list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_amp") +list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_static") +list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_dynamic") foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) if(WITH_DISTRIBUTE) - py_test_modules(test_fleet_with_asp MODULES test_fleet_with_asp ENVS ${dist_ENVS}) - py_test_modules(test_fleet_with_asp_amp MODULES test_fleet_with_asp_amp ENVS ${dist_ENVS}) + py_test_modules(test_fleet_with_asp_dynamic MODULES test_fleet_with_asp_dynamic ENVS ${dist_ENVS}) + py_test_modules(test_fleet_with_asp_static MODULES test_fleet_with_asp_static ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py b/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py new file mode 100644 index 0000000000000..1622270d6ddd0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py @@ -0,0 +1,270 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib import sparsity +from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map +from paddle.fluid.dygraph.layers import Layer, _convert_camel_to_snake + + +class MyOwnLayer(Layer): + def __init__(self): + super(MyOwnLayer, self).__init__() + + def forward(self, x): + return x + + +static_tensor = None +static_tensor_mask = None + + +def my_own_pruning(tensor, m, n, mask_algo, param_name): + global static_tensor + global static_tensor_mask + if static_tensor is None: + static_tensor = np.random.rand(*tensor.shape).astype(np.float32) + if static_tensor_mask is None: + static_tensor_mask = np.random.rand(*tensor.shape).astype(np.float32) + return static_tensor, static_tensor_mask + + +class TestASPAddSupportedLayer(unittest.TestCase): + def test_add_supported_layer_via_name(self): + sparsity.add_supported_layer("test_supported_1") + sparsity.add_supported_layer("test_supported_2", my_own_pruning) + sparsity.add_supported_layer(MyOwnLayer) + my_own_layer_name = _convert_camel_to_snake(MyOwnLayer.__name__) + + self.assertTrue( + "test_supported_1" in supported_layers_and_prune_func_map) + self.assertTrue( + "test_supported_2" in supported_layers_and_prune_func_map) + self.assertTrue( + "test_supported_2" in supported_layers_and_prune_func_map) + self.assertTrue(supported_layers_and_prune_func_map["test_supported_2"] + == my_own_pruning) + self.assertTrue( + my_own_layer_name in supported_layers_and_prune_func_map) + + +class TestASPDynamicCustomerizedPruneFunc(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + class CustomerLayer(paddle.nn.Layer): + def __init__(self): + super(CustomerLayer, self).__init__() + + self.weight = self.create_parameter( + shape=[32, 32], attr=None, dtype='float32', is_bias=False) + self.linear1 = paddle.nn.Linear(32, 32) + self.linear2 = paddle.nn.Linear(32, 10) + + def forward(self, input_): + hidden = paddle.nn.functional.linear( + x=input_, weight=self.weight) + hidden = self.linear1(hidden) + out = self.linear2(hidden) + return out + + sparsity.add_supported_layer(CustomerLayer, my_own_pruning) + + self.layer = CustomerLayer() + self.customer_prefix = paddle.fluid.dygraph.layers._convert_camel_to_snake( + CustomerLayer.__name__) + self.supported_layer_count_ref = 3 + + def test_inference_pruning(self): + + sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=False) + + supported_layer_count = 0 + for param in self.layer.parameters(): + mat = param.numpy() + + if sparsity.asp.ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + def test_training_pruning(self): + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=self.layer.parameters()) + optimizer = sparsity.decorate(optimizer) + + sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=True) + + supported_layer_count = 0 + for param in self.layer.parameters(): + mat = param.numpy() + + if sparsity.asp.ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + + mat_mask = sparsity.asp.ASPHelper._get_program_asp_info( + paddle.static.default_main_program()).mask_vars[ + param.name].numpy() + + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + self.assertLessEqual( + np.sum(mat_mask.flatten() - static_tensor_mask.flatten( + )), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertTrue( + sparsity.check_sparsity( + mat_mask.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + +class TestASPStaticCustomerizedPruneFunc(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + self.customer_prefix = "customer_layer" + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=4, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, + size=32, + act='relu', + name=self.customer_prefix) + hidden = fluid.layers.fc(input=hidden, + size=32, + act='relu', + name=self.customer_prefix) + hidden = fluid.layers.fc(input=hidden, size=32, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, self.predict = build_model() + self.supported_layer_count_ref = 5 + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + self.exe = fluid.Executor(self.place) + + sparsity.add_supported_layer(self.customer_prefix, my_own_pruning) + + def test_inference_pruning(self): + self.exe.run(self.startup_program) + + sparsity.prune_model( + self.main_program, mask_algo="mask_1d", with_mask=False) + + supported_layer_count = 0 + for param in self.main_program.global_block().all_parameters(): + mat = np.array(fluid.global_scope().find_var(param.name).get_tensor( + )) + if sparsity.asp.ASPHelper._is_supported_layer(self.main_program, + param.name): + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + def test_training_pruning(self): + with fluid.program_guard(self.main_program, self.startup_program): + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=self.predict, label=self.label)) + optimizer = sparsity.decorate( + fluid.optimizer.SGD(learning_rate=0.01)) + optimizer.minimize(loss, self.startup_program) + + self.exe.run(self.startup_program) + + sparsity.prune_model( + self.main_program, mask_algo="mask_1d", with_mask=True) + + supported_layer_count = 0 + for param in self.main_program.global_block().all_parameters(): + mat = np.array(fluid.global_scope().find_var(param.name).get_tensor( + )) + if sparsity.asp.ASPHelper._is_supported_layer(self.main_program, + param.name): + mat_mask = np.array(fluid.global_scope().find_var( + sparsity.asp.ASPHelper._get_mask_name(param.name)) + .get_tensor()) + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + self.assertLessEqual( + np.sum(mat_mask.flatten() - static_tensor_mask.flatten( + )), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertTrue( + sparsity.check_sparsity( + mat_mask.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py new file mode 100644 index 0000000000000..5039e720a93bb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestASPDynamicOptimize(unittest.TestCase): + def setUp(self): + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_is_supported_layers(self): + program = paddle.static.default_main_program() + + names = [ + 'embedding_0.w_0', 'fack_layer_0.w_0', 'conv2d_0.w_0', + 'conv2d_0.b_0', 'conv2d_1.w_0', 'conv2d_1.b_0', 'fc_0.w_0', + 'fc_0.b_0', 'fc_1.w_0', 'fc_1.b_0', 'linear_2.w_0', 'linear_2.b_0' + ] + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + sparsity.set_excluded_layers(['fc_1', 'conv2d_0']) + ref = [ + False, False, False, False, True, False, True, False, False, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + sparsity.reset_excluded_layers() + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + def test_decorate(self): + param_names = [param.name for param in self.layer.parameters()] + self.optimizer = sparsity.decorate(self.optimizer) + + program = paddle.static.default_main_program() + + for name in param_names: + mask_var = ASPHelper._get_program_asp_info(program).mask_vars.get( + name, None) + if ASPHelper._is_supported_layer(program, name): + self.assertTrue(mask_var is not None) + else: + self.assertTrue(mask_var is None) + + def test_asp_training(self): + self.optimizer = sparsity.decorate(self.optimizer) + + sparsity.prune_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + def test_asp_training_with_amp(self): + self.optimizer = sparsity.decorate(self.optimizer) + + sparsity.prune_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + with paddle.amp.auto_cast(enable=True): + output = self.layer(imgs) + loss = loss_fn(output, labels) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(self.optimizer, scaled) + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py similarity index 98% rename from python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py rename to python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py index 9e5e3c924f1a5..a4b9377983060 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py @@ -20,14 +20,14 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.static import sparsity +from paddle import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np paddle.enable_static() -class TestASPHelper(unittest.TestCase): +class TestASPStaticOptimize(unittest.TestCase): def setUp(self): self.main_program = fluid.Program() self.startup_program = fluid.Program() @@ -87,7 +87,7 @@ def test_is_supported_layers(self): self.assertTrue( ref[i] == ASPHelper._is_supported_layer(program, name)) - sparsity.set_excluded_layers(program, ['fc_1', 'conv2d_0']) + sparsity.set_excluded_layers(['fc_1', 'conv2d_0'], program) ref = [ False, False, False, False, True, False, True, False, False, False, True, False diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py deleted file mode 100644 index 7a3fa0244930c..0000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import paddle -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning1D(TestASPHelperPruningBase): - def test_1D_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) - - def test_1D_training_pruning(self): - self.run_training_pruning_test( - 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py deleted file mode 100644 index e99509187038c..0000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import paddle -import unittest -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning2DBest(TestASPHelperPruningBase): - def test_2D_best_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - def test_2D_best_training_pruning(self): - self.run_training_pruning_test( - 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py deleted file mode 100644 index 7ad6c3ae02275..0000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import paddle -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning2DGreedy(TestASPHelperPruningBase): - def test_2D_greedy_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_2d_greedy', - paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - def test_2D_greedy_training_pruning(self): - self.run_training_pruning_test( - 'mask_2d_greedy', - paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py new file mode 100644 index 0000000000000..ba9e1eb48a618 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.fluid import core +from paddle import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(32 * 32 * 4, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestASPDynamicPruningBase(unittest.TestCase): + def setUp(self): + self.layer = MyLayer() + + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + + self.img = paddle.to_tensor( + np.random.uniform( + low=-0.5, high=0.5, size=(64, 3, 32, 32)), + dtype=np.float32, + place=place, + stop_gradient=False) + + def run_inference_pruning_test(self, get_mask_gen_func, + get_mask_check_func): + self.__pruning_and_checking(get_mask_gen_func, get_mask_check_func, + False) + + def run_training_pruning_test(self, get_mask_gen_func, get_mask_check_func): + + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=self.layer.parameters()) + optimizer = sparsity.decorate(optimizer) + + self.__pruning_and_checking(get_mask_gen_func, get_mask_check_func, + True) + + def __pruning_and_checking(self, mask_func_name, check_func_name, + with_mask): + + sparsity.prune_model( + self.layer, mask_algo=mask_func_name, with_mask=with_mask) + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, func_name=check_func_name, n=2, m=4)) + + +class TestASPDynamicPruning1D(TestASPDynamicPruningBase): + def test_1D_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) + + def test_1D_training_pruning(self): + self.run_training_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) + + +class TestASPDynamicPruning2DBest(TestASPDynamicPruningBase): + def test_1D_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + def test_1D_training_pruning(self): + self.run_training_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + +class TestASPDynamicPruning2DGreedy(TestASPDynamicPruningBase): + def test_1D_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + def test_1D_training_pruning(self): + self.run_training_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py similarity index 72% rename from python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py rename to python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py index d41a7b2b842e8..8411d3fd6d77f 100644 --- a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py @@ -20,14 +20,14 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.static import sparsity +from paddle import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np paddle.enable_static() -class TestASPHelperPruningBase(unittest.TestCase): +class TestASPStaticPruningBase(unittest.TestCase): def setUp(self): self.main_program = fluid.Program() self.startup_program = fluid.Program() @@ -84,3 +84,39 @@ def __pruning_and_checking(self, exe, place, mask_func_name, self.assertTrue( paddle.fluid.contrib.sparsity.check_sparsity( mat.T, func_name=check_func_name, n=2, m=4)) + + +class TestASPStaticPruning1D(TestASPStaticPruningBase): + def test_1D_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) + + def test_1D_training_pruning(self): + self.run_training_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) + + +class TestASPStaticPruning2DBest(TestASPStaticPruningBase): + def test_2D_best_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + def test_2D_best_training_pruning(self): + self.run_training_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + +class TestASPStaticPruning2DGreedy(TestASPStaticPruningBase): + def test_2D_greedy_inference_pruning(self): + self.run_inference_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + def test_2D_greedy_training_pruning(self): + self.run_training_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py b/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py new file mode 100644 index 0000000000000..3721f4d029b37 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestASPDynamicOptimize(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + self.optimizer = sparsity.decorate(self.optimizer) + sparsity.prune_model(self.layer) + + def test_save_and_load(self): + path = "/tmp/paddle_asp_save_dy/" + net_path = path + "asp_net.pdparams" + opt_path = path + "asp_opt.pdopt" + + paddle.save(self.layer.state_dict(), net_path) + paddle.save(self.optimizer.state_dict(), opt_path) + + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name in asp_info.mask_vars: + mask = asp_info.mask_vars[param_name] + asp_info.update_mask_vars( + param_name, paddle.ones( + shape=mask.shape, dtype=mask.dtype)) + asp_info.update_masks(param_name, np.ones(shape=mask.shape)) + + net_state_dict = paddle.load(net_path) + opt_state_dict = paddle.load(opt_path) + + self.layer.set_state_dict(net_state_dict) + self.optimizer.set_state_dict(opt_state_dict) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestASPStaticOptimize(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=4, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, size=32, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, predict = build_model() + self.loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict, label=self.label)) + self.optimizer = fluid.optimizer.SGD(learning_rate=0.01) + self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer.minimize(self.loss, self.startup_program) + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + self.exe = fluid.Executor(self.place) + self.exe.run(self.startup_program) + + sparsity.prune_model(self.main_program) + + def test_save_and_load(self): + path = "/tmp/paddle_asp_save_st/" + param_path = path + "asp.pdparams" + model_path = path + "asp.pdmodel" + + paddle.save(self.main_program.state_dict(), param_path) + paddle.save(self.main_program, model_path) + + prog = paddle.load(model_path) + + state_dict = paddle.load(param_path) + prog.set_state_dict(state_dict) + + feeder = fluid.DataFeeder( + feed_list=[self.img, self.label], place=self.place) + + data = (np.random.randn(64, 3, 32, 32), np.random.randint( + 10, size=(64, 1))) + self.exe.run(prog, feed=feeder.feed([data])) + + for param in prog.global_block().all_parameters(): + if ASPHelper._is_supported_layer(prog, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py index 4aac878763b6f..db4d0aafc80bc 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py @@ -18,7 +18,7 @@ import unittest import threading, time import paddle -from paddle.static import sparsity +from paddle import sparsity import numpy as np @@ -219,3 +219,7 @@ def __test_1D_2D_sparse_mask_generation_methods(self, x): func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py deleted file mode 100644 index 074aedb947613..0000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.base.role_maker as role_maker -import unittest -import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -import os -from paddle.static import sparsity -from paddle.fluid.contrib.sparsity.asp import ASPHelper -import numpy as np -cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') -if cuda_visible_devices is None or cuda_visible_devices == "": - os.environ['CUDA_VISIBLE_DEVICES'] = '0' -else: - os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0] - -paddle.enable_static() - - -class TestFleetWithASP(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" - os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" - os.environ["PADDLE_TRAINERS_NUM"] = "1" - os.environ["PADDLE_TRAINER_ID"] = "0" - - def net(self, main_prog, startup_prog): - with fluid.program_guard(main_prog, startup_prog): - input_x = paddle.static.data( - name="x", shape=[-1, 32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') - - fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh') - prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') - cost = fluid.layers.cross_entropy(input=prediction, label=input_y) - avg_cost = paddle.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.asp = True - return avg_cost, strategy, input_x, input_y - - def test_with_asp(self): - fleet.init(is_collective=True) - train_prog, startup_prog = fluid.Program(), fluid.Program() - avg_cost, strategy, input_x, input_y = self.net(train_prog, - startup_prog) - - with fluid.program_guard(train_prog, startup_prog): - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer( - optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda( - ) else fluid.CPUPlace() - - exe = fluid.Executor(place) - feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) - exe.run(startup_prog) - - sparsity.prune_model(train_prog) - - data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) - exe.run(train_prog, feed=feeder.feed([data])) - - for param in train_prog.global_block().all_parameters(): - if ASPHelper._is_supported_layer(train_prog, param.name): - mat = np.array(fluid.global_scope().find_var(param.name) - .get_tensor()) - self.assertTrue( - paddle.fluid.contrib.sparsity.check_sparsity( - mat.T, n=2, m=4)) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py new file mode 100644 index 0000000000000..d19d75c8efd07 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import os +from paddle import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np +cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') +if cuda_visible_devices is None or cuda_visible_devices == "": + os.environ['CUDA_VISIBLE_DEVICES'] = '0' +else: + os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0] + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestFleetWithASPDynamic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_with_asp(self): + fleet.init(is_collective=True) + + self.optimizer = sparsity.decorate(self.optimizer) + sparsity.prune_model(self.layer) + + self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.layer = fleet.distributed_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestFleetWithASPAMPDynamic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_with_asp(self): + fleet.init(is_collective=True) + + self.optimizer = sparsity.decorate(self.optimizer) + sparsity.prune_model(self.layer) + + self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.layer = fleet.distributed_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + with paddle.amp.auto_cast(enable=True): + output = self.layer(imgs) + loss = loss_fn(output, labels) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(self.optimizer, scaled) + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py similarity index 68% rename from python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py rename to python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py index a34d7e69872e2..3e735f452c094 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py @@ -20,7 +20,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core import os -from paddle.static import sparsity +from paddle import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') @@ -32,7 +32,62 @@ paddle.enable_static() -class TestFleetWithASP(unittest.TestCase): +class TestFleetWithASPStatic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + def net(self, main_prog, startup_prog): + with fluid.program_guard(main_prog, startup_prog): + input_x = paddle.static.data( + name="x", shape=[-1, 32], dtype='float32') + input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') + + fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh') + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + avg_cost = paddle.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.asp = True + return avg_cost, strategy, input_x, input_y + + def test_with_asp(self): + fleet.init(is_collective=True) + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy, input_x, input_y = self.net(train_prog, + startup_prog) + + with fluid.program_guard(train_prog, startup_prog): + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) + exe.run(startup_prog) + + sparsity.prune_model(train_prog) + + data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) + exe.run(train_prog, feed=feeder.feed([data])) + + for param in train_prog.global_block().all_parameters(): + if ASPHelper._is_supported_layer(train_prog, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestFleetWithASPAMPStatic(unittest.TestCase): def setUp(self): os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"