diff --git a/deepspeed/compression/compress.py b/deepspeed/compression/compress.py index bf3b6c2760fa..e12e88438560 100644 --- a/deepspeed/compression/compress.py +++ b/deepspeed/compression/compress.py @@ -4,7 +4,6 @@ from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible from .config import get_compression_config from ..runtime.config_utils import dict_raise_error_on_duplicate_keys -from .constants import * import os import json @@ -51,21 +50,20 @@ def get_module_name(group_name, def get_compress_methods(model, compress_methods, mpu=None): # extract the compression module for each method in compress_methods layer_added_compress_methods = [] - for method, method_content in compress_methods.items(): - if LAYER_REDUCTION in method: + for method, method_content in compress_methods: + if "layer_reduction" in method: continue # for loop different methods, i.e., weight quantization, activation quantization etc exist_module_name = set() - shared_parameters = method_content[ - SHARED_PARAMETERS] # get all the shared parameters - for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items(): + shared_parameters = method_content.shared_parameters # get all the shared parameters + for group_name, method_parameters in method_content.different_groups.items(): # for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc module_name_list = [] related_module_name_list = [] - if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]: + if method_parameters.related_modules: # this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them # otherwise we just mask those as zeros - for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]): + for key_word, related_key_words in zip(method_parameters.modules, method_parameters.related_modules): module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu) module_name_list.append(module_name) tmp_related_module_name_list = [] @@ -75,15 +73,15 @@ def get_compress_methods(model, compress_methods, mpu=None): tmp_related_module_name_list.append(module_name) related_module_name_list.append(tmp_related_module_name_list) else: - for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]: + for key_word in method_parameters.modules: module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu) module_name_list.append(module_name) if module_name_list: # combine shared parameters with each group combined_method_parameters = { - **(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)), - **shared_parameters + **(method_parameters.params.dict()), + **shared_parameters.dict() } compression_item = [ module_name_list, @@ -114,7 +112,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None): c_model = model # For layer reduction - if compress_methods[LAYER_REDUCTION][LAYER_REDUCTION_ENABLED]: + if compress_methods.layer_reduction.enabled: assert teacher_model is not None, "Teacher model is required for layer reduction" student_initialization(c_model, teacher_model, deepspeed_config) @@ -148,12 +146,12 @@ def redundancy_clean(model, deepspeed_config, mpu=None): mpu=mpu) # sort methods order_list = [ - WEIGHT_QUANTIZATION, - SPARSE_PRUNING, - ROW_PRUNING, - HEAD_PRUNING, - CHANNEL_PRUNING, - ACTIVATION_QUANTIZATION + "weight_quantization", + "sparse_pruning", + "row_pruning", + "head_pruning", + "channel_pruning", + "activation_quantization" ] layer_added_compress_methods = sorted( layer_added_compress_methods_tmp, @@ -193,12 +191,12 @@ def student_initialization(student_model, teacher_model, deepspeed_config): The path of ds_config ''' config = get_compression_config(check_deepspeed_config(deepspeed_config)) - compress_methods = config[LAYER_REDUCTION] + compress_methods = config.layer_reduction - module_name_prefix = compress_methods[MODULE_NAME_PREFIX] - teacher_layer = compress_methods[TEACHER_LAYER] + module_name_prefix = compress_methods.module_name_prefix + teacher_layer = compress_methods.teacher_layer student_layer = [i for i in range(len(teacher_layer))] - other_module_name = compress_methods[OTHER_MODULE_NAME] + other_module_name = compress_methods.other_module_name ''' name_prefix (`str`) The prefix name before the layer #. diff --git a/deepspeed/compression/config.py b/deepspeed/compression/config.py index e6a710dfa3ea..4708da5e1a36 100644 --- a/deepspeed/compression/config.py +++ b/deepspeed/compression/config.py @@ -1,492 +1,212 @@ '''Copyright The Microsoft DeepSpeed Team''' +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from enum import Enum +from pydantic import root_validator, validator, Field +from typing import Dict, List -from .constants import * -import copy -from ..runtime.config_utils import get_scalar_param +COMPRESSION_TRAINING = "compression_training" def get_compression_config(param_dict): - # - output = {} - - if COMPRESSION_TRAINING not in param_dict.keys(): + if COMPRESSION_TRAINING not in param_dict: param_dict[COMPRESSION_TRAINING] = {} - sub_param_dict = param_dict[COMPRESSION_TRAINING] - output[WEIGHT_QUANTIZATION] = get_weight_quantization(sub_param_dict) - output[ACTIVATION_QUANTIZATION] = get_activation_quantization(sub_param_dict) - output[SPARSE_PRUNING] = get_sparse_pruning(sub_param_dict) - output[ROW_PRUNING] = get_row_pruning(sub_param_dict) - output[HEAD_PRUNING] = get_head_pruning(sub_param_dict) - output[CHANNEL_PRUNING] = get_channel_pruning(sub_param_dict) - - output[LAYER_REDUCTION] = get_layer_reduction(sub_param_dict) - - return output - - -def get_layer_reduction(param_dict): - output = {} - output[LAYER_REDUCTION_ENABLED] = LAYER_REDUCTION_ENABLED_DEFAULT - if get_layer_reduction_enabled(param_dict): - output[LAYER_REDUCTION_ENABLED] = get_layer_reduction_enabled(param_dict) - for key, val in get_layer_reduction_params(param_dict).items(): - output[key] = val - return output - - -def get_layer_reduction_enabled(param_dict): - if LAYER_REDUCTION in param_dict.keys(): - return get_scalar_param(param_dict[LAYER_REDUCTION], - LAYER_REDUCTION_ENABLED, - LAYER_REDUCTION_ENABLED_DEFAULT) - else: - return False - - -def get_layer_reduction_params(param_dict): - if LAYER_REDUCTION in param_dict.keys(): - layer_reduction_params = copy.copy(param_dict[LAYER_REDUCTION]) - layer_reduction_params.pop(LAYER_REDUCTION_ENABLED) - return layer_reduction_params - else: - return False - - -def get_quantize_enabled(param_dict): - if COMPRESSION_TRAINING not in param_dict.keys(): - return False - - sub_param_dict = param_dict[COMPRESSION_TRAINING] - output = get_weight_quantization_shared_parameters(sub_param_dict) - return output[WEIGHT_QUANTIZE_ENABLED] - - -def get_weight_quantization(param_dict): - output = {} - if WEIGHT_QUANTIZATION not in param_dict.keys(): - param_dict[WEIGHT_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}} - sub_param_dict = param_dict[WEIGHT_QUANTIZATION] - # shared parameters - output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict) - return output - - -def get_weight_quantization_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_ENABLED, - WEIGHT_QUANTIZE_ENABLED_DEFAULT) - output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_KERNEL, - WEIGHT_QUANTIZE_KERNEL_DEFAULT) - output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_SCHEDULE_OFFSET, - WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT) - output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_GROUPS, - WEIGHT_QUANTIZE_GROUPS_DEFAULT) - output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_VERBOSE, - WEIGHT_QUANTIZE_VERBOSE_DEFAULT) - output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, - WEIGHT_QUANTIZE_TYPE, - WEIGHT_QUANTIZE_TYPE_DEFAULT) - output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, - WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT) - assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]" - output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param( - sub_param_dict, - WEIGHT_QUANTIZE_ROUNDING, - WEIGHT_QUANTIZE_ROUNDING_DEFAULT) - assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]" - if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys(): - output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param( - sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED, - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT) - output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param( - sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], - WEIGHT_QUANTIZE_CHANGE_RATIO, - WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT) - else: - output[ - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT - output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT - else: - output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT - output[WEIGHT_QUANTIZE_KERNEL] = WEIGHT_QUANTIZE_KERNEL_DEFAULT - output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT - output[WEIGHT_QUANTIZE_GROUPS] = WEIGHT_QUANTIZE_GROUPS_DEFAULT - output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT - output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT - output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT - output[ - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT - output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT - return output - - -def get_weight_quantization_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}" - assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}" - group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param( - group_dict, - WEIGHT_QUANTIZATION_PERIOD, - WEIGHT_QUANTIZATION_PERIOD_DEFAULT) - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - - return output - - -def get_activation_quantization(param_dict): - output = {} - if ACTIVATION_QUANTIZATION not in param_dict.keys(): - param_dict[ACTIVATION_QUANTIZATION] = { - SHARED_PARAMETERS: {}, - DIFFERENT_GROUPS: {} - } - sub_param_dict = param_dict[ACTIVATION_QUANTIZATION] - # shared parameters - output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters( - sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups( - sub_param_dict) - return output - - -def get_activation_quantization_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param( - sub_param_dict, - ACTIVATION_QUANTIZATION_ENABLED, - ACTIVATION_QUANTIZATION_ENABLED_DEFAULT) - output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param( - sub_param_dict, - ACTIVATION_QUANTIZE_TYPE, - ACTIVATION_QUANTIZE_TYPE_DEFAULT) - assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]" - output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param( - sub_param_dict, - ACTIVATION_QUANTIZE_RANGE, - ACTIVATION_QUANTIZE_RANGE_DEFAULT) - assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]" - output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - ACTIVATION_QUANTIZE_SCHEDULE_OFFSET, - ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT) - else: - output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT - output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT - output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT - output[ - ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT - return output - - -def get_activation_quantization_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}" - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - - return output - - -def get_sparse_pruning(param_dict): - output = {} - if SPARSE_PRUNING not in param_dict.keys(): - param_dict[SPARSE_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}} - sub_param_dict = param_dict[SPARSE_PRUNING] - # shared parameters - output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict) - return output - - -def get_sparse_pruning_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[SPARSE_PRUNING_ENABLED] = get_scalar_param( - sub_param_dict, - SPARSE_PRUNING_ENABLED, - SPARSE_PRUNING_ENABLED_DEFAULT) - output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, - SPARSE_PRUNING_METHOD, - SPARSE_PRUNING_METHOD_DEFAULT) - assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]" - output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - SPARSE_PRUNING_SCHEDULE_OFFSET, - SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT) - else: - output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT - output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT - output[SPARSE_PRUNING_SCHEDULE_OFFSET] = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT - return output - - -def get_sparse_pruning_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}" - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - - return output - - -def get_row_pruning(param_dict): - output = {} - if ROW_PRUNING not in param_dict.keys(): - param_dict[ROW_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}} - sub_param_dict = param_dict[ROW_PRUNING] - # shared parameters - output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict) - return output - - -def get_row_pruning_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, - ROW_PRUNING_ENABLED, - ROW_PRUNING_ENABLED_DEFAULT) - output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, - ROW_PRUNING_METHOD, - ROW_PRUNING_METHOD_DEFAULT) - assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]" - output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - ROW_PRUNING_SCHEDULE_OFFSET, - ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT) - else: - output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT - output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT - output[ROW_PRUNING_SCHEDULE_OFFSET] = ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT - return output - - -def get_row_pruning_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}" - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - return output - - -def get_head_pruning(param_dict): - output = {} - if HEAD_PRUNING not in param_dict.keys(): - param_dict[HEAD_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}} - sub_param_dict = param_dict[HEAD_PRUNING] - # shared parameters - output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict) - return output - - -def get_head_pruning_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, - HEAD_PRUNING_ENABLED, - HEAD_PRUNING_ENABLED_DEFAULT) - output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, - HEAD_PRUNING_METHOD, - HEAD_PRUNING_METHOD_DEFAULT) - assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]" - output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - HEAD_PRUNING_SCHEDULE_OFFSET, - HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT) - if output[HEAD_PRUNING_ENABLED]: - assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning" - output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS] - else: - output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT - output[HEAD_PRUNING_METHOD] = HEAD_PRUNING_METHOD_DEFAULT - output[HEAD_PRUNING_SCHEDULE_OFFSET] = HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT - return output - - -def get_head_pruning_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}" - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - return output - - -def get_channel_pruning(param_dict): - output = {} - if CHANNEL_PRUNING not in param_dict.keys(): - param_dict[CHANNEL_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}} - sub_param_dict = param_dict[CHANNEL_PRUNING] - # shared parameters - output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict) - # each sub-groups - if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]: - assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified" - output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict) - return output - - -def get_channel_pruning_shared_parameters(param_dict): - output = {} - if SHARED_PARAMETERS in param_dict.keys(): - sub_param_dict = param_dict[SHARED_PARAMETERS] - output[CHANNEL_PRUNING_ENABLED] = get_scalar_param( - sub_param_dict, - CHANNEL_PRUNING_ENABLED, - CHANNEL_PRUNING_ENABLED_DEFAULT) - output[CHANNEL_PRUNING_METHOD] = get_scalar_param( - sub_param_dict, - CHANNEL_PRUNING_METHOD, - CHANNEL_PRUNING_METHOD_DEFAULT) - assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]" - output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( - sub_param_dict, - CHANNEL_PRUNING_SCHEDULE_OFFSET, - CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT) - else: - output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT - output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT - output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT - return output - - -def get_channel_pruning_different_groups(param_dict): - output = {} - sub_param_dict = param_dict[DIFFERENT_GROUPS] - - def get_params(name, group_dict): - assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}" - return group_dict - - for k, v in sub_param_dict.items(): - output[k] = {} - output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( - k, - sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) - output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_MODULE_SCOPE, - DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT) - output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( - sub_param_dict[k], - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, - DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT) - - return output + return DeepSpeedCompressionConfig(**param_dict[COMPRESSION_TRAINING]) + + +# Enum classes for pydantic models +class QuantizationTypeEnum(str, Enum): + symmetric = "symmetric" + asymmetric = "asymmetric" + + +class QuantizationRoundingEnum(str, Enum): + nearest = "nearest" + stochastic = "stochastic" + + +class QuantizationRangeEnum(str, Enum): + dynamic = "dynamic" + static = "static" + + +class PruningMethodEnum(str, Enum): + l1 = "l1" + topk = "topk" + + +class DifferentGroupsParamsConfig(DeepSpeedConfigModel): + start_bits: int + target_bits: int + quantization_period: int = Field(1, ge=0) + + +class DifferentGroupsConfig(DeepSpeedConfigModel): + params: DifferentGroupsParamsConfig = {} + modules: List[str] = ["*"] + related_modules: List[str] = None + + +class ActivationDifferentGroupsParamsConfig(DeepSpeedConfigModel): + bits: int + + +class ActivationDifferentGroupsConfig(DeepSpeedConfigModel): + params: ActivationDifferentGroupsParamsConfig = {} + modules: List[str] = ["*"] + related_modules: List[str] = None + + +class PruningDifferentGroupsParamsConfig(DeepSpeedConfigModel): + dense_ratio: float + + +class PruningDifferentGroupsConfig(DeepSpeedConfigModel): + params: PruningDifferentGroupsParamsConfig = {} + modules: List[str] = ["*"] + related_modules: List[List[str]] = None + + +class FP16MixedQuantizeConfig(DeepSpeedConfigModel): + enabled: bool = False + quantize_change_ratio: float = Field(0.001, ge=0) + + +class WeightQuantizationSharedParamsConfig(DeepSpeedConfigModel): + enabled: bool = False + quantizer_kernel: bool = False + schedule_offset: int = Field(0, ge=0) + quantize_groups: int = Field(1, ge=1) + quantize_verbose: bool = False + quantize_weight_in_forward: bool = False + quantization_type: QuantizationTypeEnum = QuantizationTypeEnum.symmetric + rounding: QuantizationRoundingEnum = QuantizationRoundingEnum.nearest + fp16_mixed_quantize: FP16MixedQuantizeConfig = {} + + +class ActivationQuantizationSharedParamsConfig(DeepSpeedConfigModel): + enabled: bool = False + quantization_type: QuantizationTypeEnum = QuantizationTypeEnum.symmetric + range_calibration: QuantizationRangeEnum = QuantizationRangeEnum.dynamic + schedule_offset: int = Field(1000, ge=0) + + +class PruningSharedParamsConfig(DeepSpeedConfigModel): + enabled: bool = False + method: PruningMethodEnum = PruningMethodEnum.l1 + schedule_offset: int = Field(1000, ge=0) + + +class HeadPruningSharedParamsConfig(DeepSpeedConfigModel): + enabled: bool = False + method: PruningMethodEnum = PruningMethodEnum.l1 + schedule_offset: int = Field(1000, ge=0) + num_heads: int = Field(None, ge=0) + + @root_validator + def assert_num_heads(cls, values): + if values.get("enabled"): + assert values.get("num_heads") != None, "'num_heads' must be specified for head pruning" + return values + + +class WeightQuantizationConfig(DeepSpeedConfigModel): + different_groups: Dict[str, DifferentGroupsConfig] = {} + shared_parameters: WeightQuantizationSharedParamsConfig = {} + + @validator("shared_parameters") + def set_enabled(cls, field_value, values): + values["enabled"] = field_value.enabled + return field_value + + @root_validator + def assert_different_groups(cls, values): + if values.get("enabled"): + assert values.get("different_groups"), "Weight Quantization is enabled, 'different_groups' must be specified" + return values + + +class ActivationQuantizationConfig(DeepSpeedConfigModel): + different_groups: Dict[str, ActivationDifferentGroupsConfig] = {} + shared_parameters: ActivationQuantizationSharedParamsConfig = {} + + @validator("shared_parameters") + def set_enabled(cls, field_value, values): + values["enabled"] = field_value.enabled + return field_value + + @root_validator + def assert_different_groups(cls, values): + if values.get("enabled"): + assert values.get("different_groups"), "Activation Quantization is enabled, 'different_groups' must be specified" + return values + + +class SparsePruningConfig(DeepSpeedConfigModel): + different_groups: Dict[str, PruningDifferentGroupsConfig] = {} + shared_parameters: PruningSharedParamsConfig = {} + + @validator("shared_parameters") + def set_enabled(cls, field_value, values): + values["enabled"] = field_value.enabled + return field_value + + @root_validator + def assert_different_groups(cls, values): + if values.get("enabled"): + assert values.get("different_groups"), "Sparse Pruning is enabled, 'different_groups' must be specified" + return values + + +class RowPruningConfig(DeepSpeedConfigModel): + different_groups: Dict[str, PruningDifferentGroupsConfig] = {} + shared_parameters: PruningSharedParamsConfig = {} + + @validator("shared_parameters") + def set_enabled(cls, field_value, values): + values["enabled"] = field_value.enabled + return field_value + + @root_validator + def assert_different_groups(cls, values): + if values.get("enabled"): + assert values.get("different_groups"), "Row Pruning is enabled, 'different_groups' must be specified" + return values + + +class HeadPruningConfig(DeepSpeedConfigModel): + different_groups: Dict[str, PruningDifferentGroupsConfig] = {} + shared_parameters: HeadPruningSharedParamsConfig = {} + + +class ChannelPruningConfig(DeepSpeedConfigModel): + different_groups: Dict[str, PruningDifferentGroupsConfig] = {} + shared_parameters: PruningSharedParamsConfig = {} + + @validator("shared_parameters") + def set_enabled(cls, field_value, values): + values["enabled"] = field_value.enabled + return field_value + + @root_validator + def assert_different_groups(cls, values): + if values.get("enabled"): + assert values.get("different_groups"), "Channel Pruning is enabled, 'different_groups' must be specified" + return values + + +class LayerReductionConfig(DeepSpeedConfigModel): + enabled: bool = False + keep_number_layer: int = Field(None, ge=0) + module_name_prefix: str = "" + teacher_layer: List[int] = [] + other_module_name: List[str] = [] + + +class DeepSpeedCompressionConfig(DeepSpeedConfigModel): + weight_quantization: WeightQuantizationConfig = {} + activation_quantization: ActivationQuantizationConfig = {} + sparse_pruning: SparsePruningConfig = {} + row_pruning: RowPruningConfig = {} + head_pruning: HeadPruningConfig = {} + channel_pruning: ChannelPruningConfig = {} + layer_reduction: LayerReductionConfig = {} diff --git a/deepspeed/compression/constants.py b/deepspeed/compression/constants.py deleted file mode 100644 index 593b86e5f5c9..000000000000 --- a/deepspeed/compression/constants.py +++ /dev/null @@ -1,170 +0,0 @@ -'''Copyright The Microsoft DeepSpeed Team''' - -######################################### -# Compression Methods -# It has several sub-components -# ######################################### -COMPRESSION_TRAINING = "compression_training" -SHARED_PARAMETERS = "shared_parameters" -DIFFERENT_GROUPS = "different_groups" -TECHNIQUE_ENABLED = "enabled" -TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset" -DIFFERENT_GROUPS_PARAMETERS = "params" -DIFFERENT_GROUPS_MODULE_SCOPE = "modules" -DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*" -DIFFERENT_GROUPS_RELATED_MODULE_SCOPE = "related_modules" -DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT = None -# COMPRESSION_TRAINING_ENABLED = "enabled" -# COMPRESSION_TRAINING_ENABLED_DEFAULT = False - -#### -# Layer Reduction -#### -LAYER_REDUCTION = "layer_reduction" -LAYER_REDUCTION_ENABLED = "enabled" -LAYER_REDUCTION_ENABLED_DEFAULT = False -KEEP_NUMBER_LAYER = "keep_number_layer" -MODULE_NAME_PREFIX = "module_name_prefix" -TEACHER_LAYER = "teacher_layer" -OTHER_MODULE_NAME = "other_module_name" - -#### -# Weight Quantzation -#### -WEIGHT_QUANTIZATION = "weight_quantization" - -WEIGHT_QUANTIZATION_PERIOD = "quantization_period" -WEIGHT_QUANTIZATION_PERIOD_DEFAULT = 1 - -WEIGHT_QUANTIZE_IN_FORWARD_ENABLED = "quantize_weight_in_forward" -WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT = False - -WEIGHT_QUANTIZE_ENABLED = TECHNIQUE_ENABLED -WEIGHT_QUANTIZE_ENABLED_DEFAULT = False - -WEIGHT_QUANTIZE_KERNEL = "quantizer_kernel" -WEIGHT_QUANTIZE_KERNEL_DEFAULT = False - -WEIGHT_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 0 - -WEIGHT_QUANTIZE_GROUPS = "quantize_groups" -WEIGHT_QUANTIZE_GROUPS_DEFAULT = 1 - -WEIGHT_QUANTIZE_VERBOSE = "quantize_verbose" -WEIGHT_QUANTIZE_VERBOSE_DEFAULT = False - -WEIGHT_QUANTIZE_TYPE = "quantization_type" -WEIGHT_QUANTIZE_TYPE_DEFAULT = "symmetric" -WEIGHT_QUANTIZE_SYMMETRIC = "symmetric" -WEIGHT_QUANTIZE_ASYMMETRIC = "asymmetric" - -WEIGHT_QUANTIZE_ROUNDING = "rounding" -WEIGHT_QUANTIZE_ROUNDING_DEFAULT = "nearest" -WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING = "stochastic" -WEIGHT_QUANTIZE_NEAREST_ROUNDING = "nearest" -# maybe deleted for a cleaner version -WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE = "fp16_mixed_quantize" - -WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED = "enabled" -WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT = False - -WEIGHT_QUANTIZE_CHANGE_RATIO = "quantize_change_ratio" -WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT = 0.001 - -WEIGHT_QUANTIZE_START_BITS = "start_bits" -WEIGHT_QUANTIZE_TARGET_BITS = "target_bits" -### -# Activation Quantization -### -ACTIVATION_QUANTIZATION = "activation_quantization" - -ACTIVATION_QUANTIZATION_ENABLED = TECHNIQUE_ENABLED -ACTIVATION_QUANTIZATION_ENABLED_DEFAULT = False - -ACTIVATION_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 1000 - -ACTIVATION_QUANTIZE_TYPE = "quantization_type" -ACTIVATION_QUANTIZE_TYPE_DEFAULT = "symmetric" -ACTIVATION_QUANTIZE_SYMMETRIC = "symmetric" -ACTIVATION_QUANTIZE_ASYMMETRIC = "asymmetric" - -ACTIVATION_QUANTIZE_RANGE = 'range_calibration' -ACTIVATION_QUANTIZE_RANGE_DEFAULT = 'dynamic' -ACTIVATION_QUANTIZE_RANGE_STATIC = 'static' -ACTIVATION_QUANTIZE_RANGE_DYNAMIC = 'dynamic' - -ACTIVATION_QUANTIZE_BITS = "bits" -### -# Sparse Pruning -### -SPARSE_PRUNING = "sparse_pruning" - -SPARSE_PRUNING_ENABLED = TECHNIQUE_ENABLED -SPARSE_PRUNING_ENABLED_DEFAULT = False - -SPARSE_PRUNING_METHOD = "method" -SPARSE_PRUNING_METHOD_DEFAULT = "l1" -SPARSE_PRUNING_METHOD_L1 = "l1" -SPARSE_PRUNING_METHOD_TOPK = "topk" - -SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000 - -SPARSE_PRUNING_DENSE_RATIO = "dense_ratio" -### -# Row Pruning -### -ROW_PRUNING = "row_pruning" - -ROW_PRUNING_ENABLED = TECHNIQUE_ENABLED -ROW_PRUNING_ENABLED_DEFAULT = False - -ROW_PRUNING_METHOD = "method" -ROW_PRUNING_METHOD_DEFAULT = "l1" -ROW_PRUNING_METHOD_L1 = "l1" -ROW_PRUNING_METHOD_TOPK = "topk" - -ROW_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000 - -ROW_PRUNING_DENSE_RATIO = "dense_ratio" - -### -# Head Pruning -### -HEAD_PRUNING = "head_pruning" - -HEAD_PRUNING_ENABLED = TECHNIQUE_ENABLED -HEAD_PRUNING_ENABLED_DEFAULT = False - -HEAD_PRUNING_METHOD = "method" -HEAD_PRUNING_METHOD_DEFAULT = "topk" -HEAD_PRUNING_METHOD_L1 = "l1" -HEAD_PRUNING_METHOD_TOPK = "topk" - -HEAD_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000 - -HEAD_PRUNING_NUM_HEADS = "num_heads" - -HEAD_PRUNING_DENSE_RATIO = "dense_ratio" - -### -# Channel Pruning -### -CHANNEL_PRUNING = "channel_pruning" - -CHANNEL_PRUNING_ENABLED = TECHNIQUE_ENABLED -CHANNEL_PRUNING_ENABLED_DEFAULT = False - -CHANNEL_PRUNING_METHOD = "method" -CHANNEL_PRUNING_METHOD_DEFAULT = "l1" -CHANNEL_PRUNING_METHOD_L1 = "l1" -CHANNEL_PRUNING_METHOD_TOPK = "topk" - -CHANNEL_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET -CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000 - -CHANNEL_PRUNING_DENSE_RATIO = "dense_ratio" diff --git a/deepspeed/compression/helper.py b/deepspeed/compression/helper.py index e839a5d03582..03cb20423f83 100644 --- a/deepspeed/compression/helper.py +++ b/deepspeed/compression/helper.py @@ -2,7 +2,6 @@ import torch from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress -from .constants import * def recursive_getattr(model, module_name): @@ -148,38 +147,36 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None) if compression_technique is not None: for k, v in compression_technique.items(): - if k == SPARSE_PRUNING: - if v[SPARSE_PRUNING_ENABLED]: - new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], - v[SPARSE_PRUNING_METHOD]) - elif k == ROW_PRUNING: - if v[ROW_PRUNING_ENABLED]: - new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], - v[ROW_PRUNING_METHOD]) - elif k == HEAD_PRUNING: - if v[HEAD_PRUNING_ENABLED]: - new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], - v[HEAD_PRUNING_METHOD], - v[HEAD_PRUNING_NUM_HEADS]) - elif k == ACTIVATION_QUANTIZATION: - if v[ACTIVATION_QUANTIZATION_ENABLED]: - new_module.enable_activation_quantization( - v[ACTIVATION_QUANTIZE_BITS], - v[ACTIVATION_QUANTIZE_TYPE], - v[ACTIVATION_QUANTIZE_RANGE]) - elif k == WEIGHT_QUANTIZATION: - if v[WEIGHT_QUANTIZE_ENABLED]: + if k == "sparse_pruning": + if v.get("enabled"): + new_module.enable_sparse_pruning(v.get("dense_ratio"), + v.get("method")) + elif k == "row_pruning": + if v.get("enabled"): + new_module.enable_row_pruning(v.get("dense_ratio"), v.get("method")) + elif k == "head_pruning": + if v.get("enabled"): + new_module.enable_head_pruning(v.get("dense_ratio"), + v.get("method"), + v.get("num_heads")) + elif k == "activation_quantization": + if v.get("enabled"): + new_module.enable_activation_quantization(v.get("bits"), + v.get("quantization_type"), + v.get("range_calibration")) + elif k == "weight_quantization": + if v.get("enabled"): new_module.enable_weight_quantization( - v[WEIGHT_QUANTIZE_START_BITS], - v[WEIGHT_QUANTIZE_TARGET_BITS], - v[WEIGHT_QUANTIZATION_PERIOD], - v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], - v[WEIGHT_QUANTIZE_TYPE], - v[WEIGHT_QUANTIZE_GROUPS]) - elif k == CHANNEL_PRUNING: - if v[CHANNEL_PRUNING_ENABLED]: - new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], - v[CHANNEL_PRUNING_METHOD]) + v.get("start_bits"), + v.get("target_bits"), + v.get("quantization_period"), + v.get("quantize_weight_in_forward"), + v.get("quantization_type"), + v.get("quantize_groups")) + elif k == "channel_pruning": + if v.get("enabled"): + new_module.enable_channel_pruning(v.get("dense_ratio"), + v.get("method")) else: raise NotImplementedError( 'Compression technique {} is not implemented'.format(k)) @@ -243,18 +240,17 @@ def fix_compression(model, # Here we can make things much simpler by just replacing the module module = recursive_getattr(model, module_name) for k, v in compression_technique.items(): - if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[ - WEIGHT_QUANTIZE_ENABLED]: + if k == "weight_quantization" and v.quantize_weight_in_forward and v.enabled: return module.fix_weight_quantization() - elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]: + elif k == "sparse_pruning" and v.enabled: return module.fix_sparse_pruning_helper() - elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None): + elif k == "row_pruning" and (v.enabled or mask is not None): return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction) - elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None): + elif k == "head_pruning" and (v.enabled or mask is not None): return module.fix_head_pruning_helper(mask, - v[HEAD_PRUNING_NUM_HEADS], + v.num_heads, dim_reduction=dim_reduction) - elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None): + elif k == "channel_pruning" and (v.enabled or mask is not None): return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction) diff --git a/deepspeed/compression/scheduler.py b/deepspeed/compression/scheduler.py index 67955a825251..927f79a3880e 100644 --- a/deepspeed/compression/scheduler.py +++ b/deepspeed/compression/scheduler.py @@ -1,7 +1,6 @@ '''Copyright The Microsoft DeepSpeed Team''' from .compress import get_module_name -from .constants import * from .helper import recursive_getattr from deepspeed.utils import logger @@ -18,145 +17,131 @@ def __init__(self, model, compression_config): self.weight_quantization_enabled = False self.verbose = { - WEIGHT_QUANTIZATION: False, - ACTIVATION_QUANTIZATION: False, - SPARSE_PRUNING: False, - HEAD_PRUNING: False, - ROW_PRUNING: False, - CHANNEL_PRUNING: False + "weight_quantization": False, + "activation_quantization": False, + "sparse_pruning": False, + "head_pruning": False, + "row_pruning": False, + "channel_pruning": False, } def make_init(self): self.different_compression_methods = {} - for method, method_content in self.compression_config.items(): - if LAYER_REDUCTION in method: + for method, method_content in self.compression_config: + if "layer_reduction" in method: continue - self.different_compression_methods[method] = { - TECHNIQUE_ENABLED: False, - SHARED_PARAMETERS: None, - DIFFERENT_GROUPS: [] - } - exist_module_name = set() - shared_parameters = method_content[SHARED_PARAMETERS] - self.different_compression_methods[method][ - TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED] - self.different_compression_methods[method][ - SHARED_PARAMETERS] = shared_parameters - - for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items(): + + for group_name, method_parameters in method_content.different_groups.items(): module_name_list = [] - for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]: - module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False) + for key_word in method_parameters.modules: + module_name, exist_module_name = get_module_name(group_name, self.model, key_word, set(), verbose=False) module_name_list.extend(module_name) - if module_name_list: - self.different_compression_methods[method][DIFFERENT_GROUPS].append([ - group_name, - module_name_list, - method_parameters.copy().pop('params') - ]) + method_parameters.modules = module_name_list def check_weight_quantization(self): - # check weight quantization - wq = self.different_compression_methods[WEIGHT_QUANTIZATION] - if not wq[TECHNIQUE_ENABLED]: + wq = self.compression_config.weight_quantization + if not wq.enabled: return - else: - shared_parameters = wq[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.weight_quantization_enabled = True - - if not self.verbose[WEIGHT_QUANTIZATION]: - logger.info( - f'Weight quantization is enabled at step {self.training_steps}') - self.weight_quantization_enabled = True - self.verbose[WEIGHT_QUANTIZATION] = True + + shared_parameters = wq.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in wq.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.weight_quantization_enabled = True + + if not self.verbose["weight_quantization"]: + logger.info( + f'Weight quantization is enabled at step {self.training_steps}') + self.weight_quantization_enabled = True + self.verbose["weight_quantization"] = True def check_activation_quantization(self): # check activation quantization - aq = self.different_compression_methods[ACTIVATION_QUANTIZATION] - if not aq[TECHNIQUE_ENABLED]: + aq = self.compression_config.activation_quantization + if not aq.enabled: return - else: - shared_parameters = aq[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.activation_quantization_enabled = True - if not self.verbose[ACTIVATION_QUANTIZATION]: - logger.info( - f'Activation quantization is enabled at step {self.training_steps}' - ) - self.verbose[ACTIVATION_QUANTIZATION] = True + + shared_parameters = aq.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in aq.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.activation_quantization_enabled = True + if not self.verbose["activation_quantization"]: + logger.info( + f'Activation quantization is enabled at step {self.training_steps}') + self.verbose["activation_quantization"] = True def check_sparse_pruning(self): # check sparse pruning - sp = self.different_compression_methods[SPARSE_PRUNING] - if not sp[TECHNIQUE_ENABLED]: + sp = self.compression_config.sparse_pruning + if not sp.enabled: return - else: - shared_parameters = sp[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.sparse_pruning_enabled = True - if not self.verbose[SPARSE_PRUNING]: - logger.info( - f'Sparse pruning is enabled at step {self.training_steps}') - self.verbose[SPARSE_PRUNING] = True + + shared_parameters = sp.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in sp.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.sparse_pruning_enabled = True + if not self.verbose["sparse_pruning"]: + logger.info(f'Sparse pruning is enabled at step {self.training_steps}') + self.verbose["sparse_pruning"] = True def check_head_pruning(self): # check head pruning - hp = self.different_compression_methods[HEAD_PRUNING] - if not hp[TECHNIQUE_ENABLED]: + hp = self.compression_config.head_pruning + if not hp.enabled: return - else: - shared_parameters = hp[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.head_pruning_enabled = True - if not self.verbose[HEAD_PRUNING]: - logger.info(f'Head pruning is enabled at step {self.training_steps}') - self.verbose[HEAD_PRUNING] = True + + shared_parameters = hp.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in hp.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.head_pruning_enabled = True + if not self.verbose["head_pruning"]: + logger.info(f'Head pruning is enabled at step {self.training_steps}') + self.verbose["head_pruning"] = True def check_row_pruning(self): # check row pruning - rp = self.different_compression_methods[ROW_PRUNING] - if not rp[TECHNIQUE_ENABLED]: + rp = self.compression_config.row_pruning + if not rp.enabled: return - else: - shared_parameters = rp[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.row_pruning_enabled = True - if not self.verbose[ROW_PRUNING]: - logger.info(f'Row pruning is enabled at step {self.training_steps}') - self.verbose[ROW_PRUNING] = True + + shared_parameters = rp.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in rp.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.row_pruning_enabled = True + if not self.verbose["row_pruning"]: + logger.info(f'Row pruning is enabled at step {self.training_steps}') + self.verbose["row_pruning"] = True def check_channel_pruning(self): # check channel pruning - cp = self.different_compression_methods[CHANNEL_PRUNING] - if not cp[TECHNIQUE_ENABLED]: + cp = self.compression_config.channel_pruning + if not cp.enabled: return - else: - shared_parameters = cp[SHARED_PARAMETERS] - if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: - for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]: - for module_name in module_name_list: - module = recursive_getattr(self.model, module_name) - module.channel_pruning_enabled = True - if not self.verbose[CHANNEL_PRUNING]: - logger.info( - f'Channel pruning is enabled at step {self.training_steps}') - self.verbose[CHANNEL_PRUNING] = True + + shared_parameters = cp.shared_parameters + if self.training_steps >= shared_parameters.schedule_offset: + for group_name, method_parameters in cp.different_groups.items(): + module_name_list = method_parameters.modules + for module_name in module_name_list: + module = recursive_getattr(self.model, module_name) + module.channel_pruning_enabled = True + if not self.verbose["channel_pruning"]: + logger.info(f'Channel pruning is enabled at step {self.training_steps}') + self.verbose["channel_pruning"] = True def check_all_modules(self): # check all different compression methods we have diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 9da1058f2c8e..72faadd3fb8d 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -54,8 +54,7 @@ from ..autotuning.config import DeepSpeedAutotuningConfig from ..nebula.config import DeepSpeedNebulaConfig -from ..compression.config import get_compression_config, get_quantize_enabled -from ..compression.constants import * +from ..compression.config import get_compression_config from .swap_tensor.aio_config import get_aio_config from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy @@ -539,8 +538,8 @@ def get_memory_breakdown(param_dict): return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) -def get_eigenvalue_config(param_dict): - if get_quantize_enabled(param_dict): +def get_eigenvalue_config(param_dict, quantize_enabled=False): + if quantize_enabled: param_dict = param_dict[QUANTIZE_TRAINING] assert not get_eigenvalue_enabled(param_dict), "Eigenvalue based MoQ is temporarily disabled" return ( @@ -877,7 +876,9 @@ def _initialize_params(self, param_dict): self.eigenvalue_gas_boundary_resolution, self.eigenvalue_layer_name, self.eigenvalue_layer_num, - ) = get_eigenvalue_config(param_dict) + ) = get_eigenvalue_config( + param_dict, + quantize_enabled=self.compression_config.weight_quantization.enabled) self.sparse_attention = get_sparse_attention(param_dict) self.pipeline = get_pipeline_config(param_dict) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 08cb1fd7276a..306cd6658207 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -42,17 +42,6 @@ DATA_PARALLEL_GROUP, GLOBAL_RANK from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.compression import compression_scheduler -from deepspeed.compression.constants import \ - WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \ - WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \ - WEIGHT_QUANTIZE_ENABLED, \ - WEIGHT_QUANTIZE_GROUPS, \ - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \ - WEIGHT_QUANTIZE_CHANGE_RATIO, \ - WEIGHT_QUANTIZE_TYPE, \ - WEIGHT_QUANTIZE_ROUNDING, \ - WEIGHT_QUANTIZE_VERBOSE, \ - WEIGHT_QUANTIZE_KERNEL from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.runtime.sparse_tensor import SparseTensor @@ -693,24 +682,24 @@ def scheduler_params(self): def quantize_training(self): return ( - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_GROUPS], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_CHANGE_RATIO], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_TYPE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_ROUNDING], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_VERBOSE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_KERNEL], + self._config.compression_config.weight_quantization.shared_parameters. + quantize_weight_in_forward, + self._config.compression_config.weight_quantization.shared_parameters. + enabled, + self._config.compression_config.weight_quantization.shared_parameters. + quantize_groups, + self._config.compression_config.weight_quantization.shared_parameters. + fp16_mixed_quantize.enabled, + self._config.compression_config.weight_quantization.shared_parameters. + fp16_mixed_quantize.quantize_change_ratio, + self._config.compression_config.weight_quantization.shared_parameters. + quantization_type, + self._config.compression_config.weight_quantization.shared_parameters. + rounding, + self._config.compression_config.weight_quantization.shared_parameters. + quantize_verbose, + self._config.compression_config.weight_quantization.shared_parameters. + quantizer_kernel, ) def zero_optimization(self):