Skip to content

Commit

Permalink
Add Support Layer List to ASP (PaddlePaddle#40253)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored and liqitong-a committed Mar 17, 2022
1 parent 6794791 commit e0d10d8
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 33 deletions.
3 changes: 2 additions & 1 deletion python/paddle/fluid/contrib/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from .asp import prune_model
from .asp import set_excluded_layers
from .asp import reset_excluded_layers
from .supported_layer_list import add_supported_layer

__all__ = [
'calculate_density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d',
'get_mask_2d_greedy', 'get_mask_2d_best', 'create_mask', 'check_sparsity',
'MaskAlgo', 'CheckMethod', 'decorate', 'prune_model', 'set_excluded_layers',
'reset_excluded_layers'
'reset_excluded_layers', 'add_supported_layer'
]
79 changes: 50 additions & 29 deletions python/paddle/fluid/contrib/sparsity/asp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from paddle.fluid import global_scope, program_guard, layers
from paddle.fluid.initializer import ConstantInitializer
from paddle.fluid.contrib import sparsity
from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map
from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning
from paddle.fluid import core

OpRole = core.op_proto_and_checker_maker.OpRole
Expand Down Expand Up @@ -292,8 +294,8 @@ class ASPHelper(object):
2. pruning well-trained models into 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 for fine-tuning.
"""

MASK_APPENDDED_NAME = '_asp_mask'
SUPPORTED_LAYERS = {'fc': 'w_0', 'linear': 'w_0', 'conv2d': 'w_0'}
MASK_APPENDDED_NAME = 'asp_mask'
PADDLE_WEIGHT_SUFFIX = "w_"

__asp_info = {}

Expand Down Expand Up @@ -334,7 +336,6 @@ def prune_model(cls,
r"""
This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
"""
checked_func_name = sparsity.CheckMethod.get_checking_method(mask_algo)

if main_program is None:
main_program = paddle.static.default_main_program()
Expand All @@ -345,33 +346,27 @@ def prune_model(cls,
weight_tensor = global_scope().find_var(param.name).get_tensor()
weight_nparray = np.array(weight_tensor)

# The double transpose ops here make sure pruning direction consistent with cuSparseLt.
# SPMMA in cuSparseLt: D = (AxB) + C, where matrix A (mxk) is sparse matrix.
# cuSparseLt would prune matrix A along k dimension.
# In sparse training, layer weight matriices is viewed sparse matrix A, so
# the math fomula should be 'Act(WX + b)'. However, default fomula in PaddlePaddle
# is 'Act(XW + b)'. For enabling SPMMA, weights and inputs should be transposed
# for computing, Act( (W^T X^T)^T + b). Therefore, we have to prune alog k dimension
# of W^T, which is m dimension of W. Moreove, all mask generating functions in
# sparsity/utils is row-major pruning. That is the reason we have to transpose weight
# matrices beforce invoking create_mask. Then we transpose the result maks to make
# sure its shape to be the same as the input weight.
weight_sparse_mask = sparsity.create_mask(
weight_nparray.T, func_name=mask_algo, n=n, m=m).T
weight_pruned_nparray = np.multiply(weight_nparray,
weight_sparse_mask)
prune_func = ASPHelper._get_prune_func_by_name(param.name)

weight_pruned_nparray, weight_sparse_mask = \
prune_func(weight_nparray, m, n, mask_algo, param.name)
weight_pruned_nparray = weight_pruned_nparray.astype(
weight_nparray.dtype)
weight_tensor.set(weight_pruned_nparray, place)
assert sparsity.check_sparsity(weight_pruned_nparray.T, n=n, m=m, func_name=checked_func_name), \
'Pruning {} weight matrix failure!!!'.format(param.name)

if with_mask:
weight_mask_param = global_scope().find_var(
ASPHelper._get_mask_name(param.name))
assert weight_mask_param is not None, \
'Cannot find {} variable, please call ASPHelper.minimize' \
'Cannot find {} variable, please call optimizer.minimize (' \
'paddle.sparsity.decorate(optimizer).minimize(loss)' \
' and initialization (exe.run(startup_program)) first!'.format(ASPHelper._get_mask_name(param.name))
weight_mask_tensor = weight_mask_param.get_tensor()
weight_sparse_mask = weight_sparse_mask.astype(
np.array(weight_mask_tensor).dtype)
weight_mask_tensor.set(weight_sparse_mask, place)
asp_info.update_masks(param.name, weight_sparse_mask)

return asp_info.masks.copy()

@staticmethod
Expand All @@ -384,7 +379,7 @@ def _get_mask_name(param_name):
Returns:
string: The mask name of :attr:`param_name`.
"""
return param_name + ASPHelper.MASK_APPENDDED_NAME
return param_name + "." + ASPHelper.MASK_APPENDDED_NAME

@staticmethod
def _get_not_ASP_relevant_vars(main_program):
Expand Down Expand Up @@ -434,19 +429,46 @@ def _is_supported_layer(cls, main_program, param_name):
# fc_0.w_0 -> True
# fc_0.b_0 -> False
"""
if ASPHelper.MASK_APPENDDED_NAME in param_name:
param_name_list = param_name.split('.')

if ASPHelper.MASK_APPENDDED_NAME in param_name_list:
return False

for layer in cls._get_program_asp_info(main_program).excluded_layers:
if layer in param_name:
return False

for name in ASPHelper.SUPPORTED_LAYERS:
if name in param_name and \
ASPHelper.SUPPORTED_LAYERS[name] in param_name:
return True
if param_name in supported_layers_and_prune_func_map:
return True

param_name_no_weight_suffix = param_name_list[0]
param_type_suffix = param_name_list[1]
layer_name = param_name_no_weight_suffix[:param_name_no_weight_suffix.
rfind('_')]
if ASPHelper.PADDLE_WEIGHT_SUFFIX not in param_type_suffix:
return False

if param_name_no_weight_suffix in supported_layers_and_prune_func_map or \
layer_name in supported_layers_and_prune_func_map:
return True

return False

@classmethod
def _get_prune_func_by_name(cls, param_name):
func = supported_layers_and_prune_func_map.get(param_name, None)
param_name_no_weight_suffix = param_name.split('.')[0]
if func is None:
func = supported_layers_and_prune_func_map.get(
param_name_no_weight_suffix, None)
if func is None:
layer_name = param_name_no_weight_suffix[:
param_name_no_weight_suffix.
rfind('_')]
func = supported_layers_and_prune_func_map.get(layer_name,
_default_pruning)
return func

@classmethod
def _minimize(cls,
optimizer,
Expand Down Expand Up @@ -509,8 +531,7 @@ def _create_mask_variables(cls, main_program, startup_program,
if ASPHelper._is_supported_layer(main_program,
param_and_grad[0].name):
mask_param = layers.create_parameter(
name=param_and_grad[0].name +
ASPHelper.MASK_APPENDDED_NAME,
name=ASPHelper._get_mask_name(param_and_grad[0].name),
shape=param_and_grad[0].shape,
dtype=param_and_grad[0].dtype,
default_initializer=ConstantInitializer(value=1.0))
Expand Down
86 changes: 86 additions & 0 deletions python/paddle/fluid/contrib/sparsity/supported_layer_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.fluid.contrib import sparsity
import threading

__all__ = ['add_supported_layer']


def _default_pruning(weight_nparray, m, n, func_name, param_name):

checked_func_name = sparsity.CheckMethod.get_checking_method(func_name)

# The double transpose ops here make sure pruning direction consistent with cuSparseLt.
# SPMMA in cuSparseLt: D = (AxB) + C, where matrix A (mxk) is sparse matrix.
# cuSparseLt would prune matrix A along k dimension.
# In sparse training, layer weight matrices is viewed sparse matrix A, so
# the math fomula should be 'Act(WX + b)'. However, default fomula in PaddlePaddle
# is 'Act(XW + b)'. For enabling SPMMA, weights and inputs should be transposed
# for computing, Act( (W^T X^T)^T + b). Therefore, we have to prune alog k dimension
# of W^T, which is m dimension of W. Moreove, all mask generating functions in
# sparsity/utils is row-major pruning. That is the reason we have to transpose weight
# matrices beforce invoking create_mask. Then we transpose the result mask to make
# sure its shape to be the same as the input weight.
weight_sparse_mask = sparsity.create_mask(
weight_nparray.T, func_name=func_name, n=n, m=m).T
weight_pruned_nparray = np.multiply(weight_nparray, weight_sparse_mask)
assert sparsity.check_sparsity(weight_pruned_nparray.T, n=n, m=m, func_name=checked_func_name), \
'Pruning {} weight matrix failure!!!'.format(param_name)
return weight_pruned_nparray, weight_sparse_mask


# When value of given key in this DICT is None,
# ASP will call default pruning function in pruning stage.
_supported_layers_and_prune_func_map_lock = threading.Lock()
supported_layers_and_prune_func_map = {}


def add_supported_layer(layer, pruning_func=None):
r"""
Add supported layers and its corresponding pruning function.
Args:
name (string|Layer): The name or type of layer, needed to support. If layer is `Layer` then
it would be turn to string internally. ASP would use this name to match parameter's name and call
its the corresponding pruning function.
pruning_func (function, optional): a function type which receives five argument (weight_nparray,
m, n, func_name, param_name), weight_nparray is a nparray of weight, param_name is the name of weight,
m, n, and func_name, please see `prune_model` for details.
"""
name = None
if isinstance(layer, str):
name = layer
elif isinstance(layer, paddle.fluid.dygraph.layers.Layer):
name = paddle.fluid.dygraph.layers._convert_camel_to_snake(
type(layer).__name__)
elif issubclass(layer, paddle.fluid.dygraph.layers.Layer):
name = paddle.fluid.dygraph.layers._convert_camel_to_snake(
layer.__name__)
else:
assert "The type of layer should be string of Layer, but got {}!".format(
type(layer))
if pruning_func is None:
pruning_func = _default_pruning
_supported_layers_and_prune_func_map_lock.acquire()
supported_layers_and_prune_func_map.update({name: pruning_func})
_supported_layers_and_prune_func_map_lock.release()


add_supported_layer('fc')
add_supported_layer('linear')
add_supported_layer('conv2d')
Loading

0 comments on commit e0d10d8

Please sign in to comment.