diff --git a/deepspeed/compression/__init__.py b/deepspeed/compression/__init__.py
new file mode 100644
index 000000000000..50446e42b5ed
--- /dev/null
+++ b/deepspeed/compression/__init__.py
@@ -0,0 +1,3 @@
+from .compress import init_compression, redundancy_clean
+from .scheduler import compression_scheduler
+from .helper import convert_conv1d_to_linear
diff --git a/deepspeed/compression/basic_layer.py b/deepspeed/compression/basic_layer.py
new file mode 100644
index 000000000000..5a5209ad7691
--- /dev/null
+++ b/deepspeed/compression/basic_layer.py
@@ -0,0 +1,924 @@
+import copy
+import torch
+import math
+from torch import nn
+from torch.nn import init
+import deepspeed.comm as dist
+from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
+from deepspeed.utils import logger
+
+g_mpu = None
+
+
+class QuantAct(nn.Module):
+ """
+ Class to quantize given activations. Note that when using this function, the input acttivation quantization range will be fixed for all
+ tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
+ Parameters:
+ ----------
+ act_range_momentum : float, default 0.95
+ Momentum for updating the activation quantization range.
+ quant_mode : str, default 'symmetric'
+ """
+ def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
+ super(QuantAct, self).__init__()
+
+ self.act_range_momentum = act_range_momentum
+ self.quant_mode = quant_mode
+ if quant_mode == 'symmetric':
+ self.act_function = SymQuantizer.apply
+ else:
+ self.act_function = AsymQuantizer.apply
+
+ self.register_buffer('x_min_max', torch.zeros(2))
+
+ def forward(self, x, num_bits, *args):
+ """
+ x: the activation that we need to quantize
+ num_bits: the number of bits we need to quantize the activation to
+ *args: some extra arguments that are useless but needed for align with the interface of other quantization functions
+ """
+
+ if self.training:
+ x_min = x.data.min()
+ x_max = x.data.max()
+
+ # Initialization
+ if self.x_min_max[0] == self.x_min_max[1]:
+ self.x_min_max[0] = x_min
+ self.x_min_max[1] = x_max
+
+ # if do not need momentum, please set self.act_range_momentum = 0
+ self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (
+ 1 - self.act_range_momentum)
+ self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
+ 1 - self.act_range_momentum)
+
+ x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
+
+ return x_q
+
+
+class Embedding_Compress(nn.Embedding):
+ def __init__(self, *kargs):
+ super(Embedding_Compress, self).__init__(*kargs)
+ self.weight.start_bits = None
+ self.weight.target_bits = None
+ self.weight.q_period = None
+ self.weight_quantization_enabled_in_forward = False
+ self.weight_quantization_enabled = False
+
+ def extra_repr(self):
+ return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
+ self.num_embeddings,
+ self.embedding_dim,
+ self.weight.target_bits)
+
+ def enable_weight_quantization(self,
+ start_bits,
+ target_bits,
+ quantization_period,
+ weight_quantization_enabled_in_forward,
+ quantization_type,
+ num_groups):
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = quantization_period
+ self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
+ if self.weight_quantization_enabled_in_forward:
+ logger.warning(
+ "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
+ )
+ if self.weight.target_bits >= 3:
+ if quantization_type == 'symmetric':
+ self.weight_quantizer = SymQuantizer.apply
+ else:
+ self.weight_quantizer = AsymQuantizer.apply
+ elif self.weight.target_bits == 2:
+ assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
+ self.weight_quantizer = TernaryQuantizer.apply
+ elif self.weight.target_bits == 1:
+ assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
+ self.weight_quantizer = BinaryQuantizer.apply
+ # for embedding, we always use token-wise quantization
+ self.weight_quantize_num_groups = self.weight.size(0)
+
+ def fix_weight_quantization(self):
+ self.weight.data = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups).data
+ self.weight_quantization_enabled_in_forward = False
+ return None
+
+ def forward(self, input):
+ if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
+ weight = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups)
+ else:
+ weight = self.weight
+
+ out = nn.functional.embedding(input,
+ weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse)
+ return out
+
+
+class LinearLayer_Compress(nn.Linear):
+ """
+ Linear layer with compression.
+ """
+ def __init__(self, *kargs, bias=True):
+ super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
+ self.sparse_pruning_method = None
+ self.row_pruning_method = None
+ self.head_pruning_method = None
+ self.activation_quantization_method = None
+ self.weight.start_bits = None
+ self.weight.target_bits = None
+ self.weight.q_period = None
+ self.weight_quantization_enabled_in_forward = False
+ self.weight_quantization_enabled = False
+ self.sparse_pruning_enabled = False
+ self.row_pruning_enabled = False
+ self.head_pruning_enabled = False
+ self.activation_quantization_enabled = False
+
+ def extra_repr(self):
+ return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
+ self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
+ self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
+
+ def enable_sparse_pruning(self, ratio, method):
+ # Here, we support two cases: L1 norm based pruning and topk based pruning
+ self.sparse_pruning_ratio = ratio
+ self.sparse_pruning_method = method
+ if method == 'l1':
+ weight_norm = torch.abs(self.weight.data)
+ mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
+ mask = mask.view(self.weight.size())
+ mask = mask.to(self.weight.device)
+ elif method == 'topk':
+ self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
+ self.weight.device)
+ init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
+ mask = None
+ else:
+ raise NotImplementedError
+
+ self.register_buffer('sparse_pruning_mask', mask)
+
+ def enable_row_pruning(self, ratio, method):
+ # Here, we support two cases: L1 norm based pruning and topk based pruning
+ self.row_pruning_ratio = ratio
+ self.row_pruning_method = method
+
+ if method == 'l1':
+ # compute the l1 norm of each column
+ weight_norm = torch.norm(self.weight.data, p=1, dim=1)
+ mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
+ mask = mask.view(-1, 1)
+ mask = mask.to(self.weight.device)
+ elif method == 'topk':
+ self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
+ self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
+ init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
+ mask = None
+ else:
+ raise NotImplementedError
+
+ self.register_buffer('row_pruning_mask', mask)
+
+ def enable_head_pruning(self, ratio, method, num_heads):
+ # Here, we support only topk based pruning
+ self.num_heads = num_heads
+ self.head_pruning_ratio = ratio
+ self.head_pruning_method = method
+
+ if method not in ['topk']:
+ raise NotImplementedError
+ else:
+ self.head_pruning_ratio = ratio
+ self.head_pruning_scores = nn.Parameter(torch.Tensor(
+ 1,
+ self.num_heads)) # we apply the pruning to O matrix
+ self.head_pruning_scores.data = self.head_pruning_scores.data.to(
+ self.weight.device)
+ init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
+
+ def fix_sparse_pruning_helper(self):
+ mask = self.get_mask(pruning_type='sparse')
+ self.weight.data = self.weight.data * mask
+ del self.sparse_pruning_mask
+ if self.sparse_pruning_method == 'topk':
+ del self.sparse_mask_scores
+ self.sparse_pruning_method = None
+ self.sparse_pruning_enabled = False
+ return None
+
+ def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
+ # This function is used for row/col pruning
+ # particularly, if we have two back-to-back layers, F1 and F2; when
+ # we remove rows from F1, we also need to remove columns from F2
+ # However, if we only have one layer, F1, then we only need to mask pruned
+ # rows as 0 in F1
+ if mask is None:
+ mask = self.get_mask(pruning_type='row').bool()
+ if dim_reduction:
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ if self.bias is not None:
+ self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
+ self.out_features = self.weight.size(0)
+ else:
+ self.weight.data = self.weight.data * mask.view(-1, 1)
+ if self.bias is not None:
+ self.bias.data = self.bias.data * mask.view(-1)
+
+ del self.row_pruning_mask
+ if self.row_pruning_method == 'topk':
+ del self.row_mask_scores
+ self.row_pruning_method = None
+ else:
+ # this is generally for column pruning
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ self.in_features = self.weight.size(1)
+ mask = None
+ self.row_pruning_enabled = False
+ return mask
+
+ def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
+ # similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
+ num_heads = num_heads if num_heads else self.num_heads
+ if mask is None:
+ if self.head_pruning_method == 'topk':
+ mask = self.get_mask(pruning_type='head').bool()
+ if dim_reduction:
+ shape = self.weight.size(0)
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t())
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ else:
+
+ shape = self.weight.size()
+ self.weight.data = (self.weight.data.t().reshape(self.num_heads,
+ -1) *
+ mask.view(-1,
+ 1)).reshape(shape[1],
+ shape[0]).t()
+
+ if self.head_pruning_method == 'topk':
+ del self.head_pruning_scores
+ self.head_pruning_method = None
+ else:
+ raise NotImplementedError
+ else:
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ shape = self.weight.size(1)
+ self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ if self.bias is not None:
+ self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
+ self.head_pruning_enabled = False
+ return mask
+
+ def get_mask(self, pruning_type='row'):
+ if pruning_type == 'sparse':
+ if self.sparse_pruning_method == 'l1':
+ return self.sparse_pruning_mask.to(self.weight.device)
+ elif self.sparse_pruning_method == 'topk':
+ return TopKBinarizer.apply(self.sparse_mask_scores,
+ self.sparse_pruning_ratio,
+ False)
+ else:
+ raise NotImplementedError
+ if pruning_type == 'row':
+ if self.row_pruning_method == 'l1':
+ return self.row_pruning_mask.to(self.weight.device)
+ elif self.row_pruning_method == 'topk':
+ return TopKBinarizer.apply(self.row_mask_scores,
+ self.row_pruning_ratio,
+ False)
+ else:
+ raise NotImplementedError
+ elif pruning_type == 'head':
+ if self.head_pruning_method == 'topk':
+ return TopKBinarizer.apply(self.head_pruning_scores,
+ self.head_pruning_ratio,
+ False)
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ def enable_weight_quantization(self,
+ start_bits,
+ target_bits,
+ quantization_period,
+ weight_quantization_enabled_in_forward,
+ quantization_type,
+ num_groups):
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = quantization_period
+ self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
+ if self.weight_quantization_enabled_in_forward:
+ logger.warning(
+ "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
+ )
+ if self.weight.target_bits >= 3:
+ if quantization_type == 'symmetric':
+ self.weight_quantizer = SymQuantizer.apply
+ else:
+ self.weight_quantizer = AsymQuantizer.apply
+ elif self.weight.target_bits == 2:
+ assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
+ self.weight_quantizer = TernaryQuantizer.apply
+ elif self.weight.target_bits == 1:
+ assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
+ self.weight_quantizer = BinaryQuantizer.apply
+ self.weight_quantize_num_groups = num_groups
+
+ def fix_weight_quantization(self):
+ self.weight.data = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups).data
+ self.weight_quantization_enabled_in_forward = False
+ return None
+
+ def enable_activation_quantization(self, bits, quantization_type, range_calibration):
+ assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
+ self.activation_quantization_bits = bits
+ self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
+ if range_calibration == 'static':
+ self.activation_quantizer = QuantAct(quant_mode=quantization_type)
+ else:
+ if quantization_type == 'symmetric':
+ self.activation_quantizer = SymQuantizer.apply
+ else:
+ self.activation_quantizer = AsymQuantizer.apply
+
+ def head_pruning_reshape(self, w, mask):
+ shape = w.shape
+ return (w.t().reshape(self.num_heads,
+ -1) * mask.view(-1,
+ 1)).reshape(shape[1],
+ shape[0]).t()
+
+ def forward(self, input, skip_bias_add=False):
+
+ if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
+ weight = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups)
+ bias = self.bias
+ else:
+ weight = self.weight
+ bias = self.bias
+
+ if self.sparse_pruning_enabled and self.sparse_pruning_method:
+ mask = self.get_mask(pruning_type='sparse')
+ weight = weight * mask.view(self.weight.size())
+
+ if self.row_pruning_enabled and self.row_pruning_method:
+ mask = self.get_mask(pruning_type='row')
+ weight = weight * mask.view(-1, 1)
+ if bias is not None:
+ bias = bias * mask.view(-1)
+
+ if self.head_pruning_enabled and self.head_pruning_method:
+ mask = self.get_mask(pruning_type='head')
+ weight = self.head_pruning_reshape(weight, mask)
+
+ if self.activation_quantization_enabled:
+ if 'dynamic' in self.activation_quantization_method:
+ num_groups = input.numel() // input.size(-1)
+ else:
+ num_groups = 1
+ input = self.activation_quantizer(input,
+ self.activation_quantization_bits,
+ None,
+ None,
+ num_groups)
+
+ if skip_bias_add:
+ # used for mpu linear layers
+ output = nn.functional.linear(input, weight, None)
+ return output, bias
+ else:
+ output = nn.functional.linear(input, weight, bias)
+ return output
+
+
+class Conv2dLayer_Compress(nn.Conv2d):
+ """
+ Conv2D layer with compression.
+ """
+ def __init__(self, *kargs):
+ super(Conv2dLayer_Compress, self).__init__(*kargs)
+ self.sparse_pruning_method = None
+ self.channel_pruning_method = None
+ self.activation_quantization_method = None
+ self.weight.start_bits = None
+ self.weight.target_bits = None
+ self.weight.q_period = None
+ self.weight_quantization_enabled_in_forward = False
+ self.sparse_pruning_enabled = False
+ self.channel_pruning_enabled = False
+ self.activation_quantization_enabled = False
+
+ def __repr__(self):
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
+ ', stride={stride}')
+ if self.padding != (0, ) * len(self.padding):
+ s += ', padding={padding}'
+ if self.dilation != (1, ) * len(self.dilation):
+ s += ', dilation={dilation}'
+ if self.output_padding != (0, ) * len(self.output_padding):
+ s += ', output_padding={output_padding}'
+ if self.groups != 1:
+ s += ', groups={groups}'
+ if self.bias is None:
+ s += ', bias=False'
+ if self.padding_mode != 'zeros':
+ s += ', padding_mode={padding_mode}'
+ output = s.format(**self.__dict__)
+
+ return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
+ self.sparse_pruning_method is not None,
+ self.channel_pruning_method is not None,
+ self.activation_quantization_method is not None,
+ self.weight.target_bits)
+
+ def enable_sparse_pruning(self, ratio, method):
+ self.sparse_pruning_ratio = ratio
+ self.sparse_pruning_method = method
+ if method == 'l1':
+ weight_norm = torch.abs(self.weight.data)
+ mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
+ mask = mask.view(self.weight.size())
+ mask = mask.to(self.weight.device)
+ elif method == 'topk':
+ self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
+ self.weight.device)
+ init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
+ mask = None
+ else:
+ raise NotImplementedError
+
+ self.register_buffer('sparse_pruning_mask', mask)
+
+ def enable_channel_pruning(self, ratio, method):
+ # Here, we support two cases: L1 norm based pruning and topk based pruning
+ self.channel_pruning_ratio = ratio
+ self.channel_pruning_method = method
+
+ if method == 'l1':
+ # compute the l1 norm of each conv2d kernel (the last three dimension)
+ weight_norm = torch.norm(self.weight.data, p=1, dim=[1, 2, 3])
+ mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
+ mask = mask.view(-1, 1, 1, 1)
+ mask = mask.to(self.weight.device)
+ elif method == 'topk':
+ self.channel_mask_scores = nn.Parameter(
+ torch.Tensor(self.weight.size(0),
+ 1,
+ 1,
+ 1))
+ self.channel_mask_scores.data = self.channel_mask_scores.data.to(
+ self.weight.device)
+ init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
+ mask = None
+ else:
+ raise NotImplementedError
+
+ self.register_buffer('channel_pruning_mask', mask)
+
+ def fix_sparse_pruning_helper(self):
+ mask = self.get_mask(pruning_type='sparse')
+ self.weight.data = self.weight.data * mask
+ del self.sparse_pruning_mask
+ if self.sparse_pruning_method == 'topk':
+ del self.sparse_mask_scores
+ self.sparse_pruning_method = None
+ self.sparse_pruning_enabled = False
+ return None
+
+ def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
+ if mask is None:
+ if self.channel_pruning_method in ['l1', 'topk']:
+ mask = self.get_mask(pruning_type='channel').bool()
+ if dim_reduction:
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ if self.bias is not None:
+ self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
+ else:
+ self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
+ if self.bias is not None:
+ self.bias.data = self.bias.data * mask.view(-1)
+ del self.channel_pruning_mask
+ if self.channel_pruning_method == 'topk':
+ del self.channel_mask_scores
+ self.channel_pruning_method = None
+ else:
+ raise NotImplementedError
+ else:
+ start_bits = self.weight.start_bits
+ target_bits = self.weight.target_bits
+ q_period = self.weight.q_period
+ self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = q_period
+ mask = None
+ self.channel_pruning_enabled = False
+ return mask
+
+ def get_mask(self, pruning_type='sparse'):
+ if pruning_type == 'sparse':
+ if self.sparse_pruning_method == 'l1':
+ return self.sparse_pruning_mask.to(self.weight.device)
+ elif self.sparse_pruning_method == 'topk':
+ return TopKBinarizer.apply(self.sparse_mask_scores,
+ self.sparse_pruning_ratio,
+ False)
+ else:
+ raise NotImplementedError
+ elif pruning_type == 'channel':
+ if self.channel_pruning_method == 'l1':
+ return self.channel_pruning_mask.to(self.weight.device)
+ elif self.channel_pruning_method == 'topk':
+ return TopKBinarizer.apply(self.channel_mask_scores,
+ self.channel_pruning_ratio,
+ False)
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ def fix_weight_quantization(self):
+ self.weight.data = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups).data
+ self.weight_quantization_enabled_in_forward = False
+ return None
+
+ def enable_weight_quantization(self,
+ start_bits,
+ target_bits,
+ quantization_period,
+ weight_quantization_enabled_in_forward,
+ quantization_type,
+ num_groups):
+ self.weight.start_bits = start_bits
+ self.weight.target_bits = target_bits
+ self.weight.q_period = quantization_period
+ self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
+ if self.weight_quantization_enabled_in_forward:
+ assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
+ logger.warning(
+ "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
+ )
+ if quantization_type == 'symmetric':
+ self.weight_quantizer = SymQuantizer.apply
+ else:
+ self.weight_quantizer = AsymQuantizer.apply
+ self.weight_quantize_num_groups = num_groups
+
+ def enable_activation_quantization(self, bits, quantization_type, range_calibration):
+ assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
+ self.activation_quantization_bits = bits
+ self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
+ if range_calibration == 'static':
+ self.activation_quantizer = QuantAct(quant_mode=quantization_type)
+ else:
+ if quantization_type == 'symmetric':
+ self.activation_quantizer = SymQuantizer.apply
+ else:
+ self.activation_quantizer = AsymQuantizer.apply
+
+ def forward(self, input):
+
+ if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
+ weight = self.weight_quantizer(self.weight,
+ self.weight.target_bits,
+ None,
+ None,
+ self.weight_quantize_num_groups)
+ bias = self.bias
+ else:
+ weight = self.weight
+ bias = self.bias
+
+ if self.sparse_pruning_enabled and self.sparse_pruning_method:
+ mask = self.get_mask(pruning_type='sparse')
+ weight = weight * mask.view(self.weight.size())
+
+ if self.channel_pruning_enabled:
+ mask = self.get_mask(pruning_type='channel')
+ weight = weight * mask.view(-1, 1, 1, 1)
+ if bias is not None:
+ bias = bias * mask.view(-1)
+
+ if self.activation_quantization_enabled:
+ if 'dynamic' in self.activation_quantization_method:
+ num_groups = input.numel() // input[0].numel()
+ else:
+ num_groups = 1
+ input = self.activation_quantizer(input,
+ self.activation_quantization_bits,
+ None,
+ None,
+ num_groups)
+
+ return nn.functional.conv2d(input,
+ weight,
+ bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups)
+
+
+class BNLayer_Compress(nn.BatchNorm2d):
+ def fix_channel_pruning_helper(self, mask, dim_reduction=True):
+ self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
+ self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
+ self.running_mean = self.running_mean[mask.view(-1)]
+ self.running_var = self.running_var[mask.view(-1)]
+
+
+def _reduce(input_):
+ """All-reduce the the input tensor across model parallel group."""
+ group = g_mpu.get_model_parallel_group()
+
+ # Bypass the function if we are using only 1 GPU.
+ if dist.get_world_size(group=group) == 1:
+ return input_
+
+ # All-reduce.
+ dist.all_reduce(input_, group=group)
+
+ return input_
+
+
+def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ assert tensor.size()[last_dim] % num_partitions == 0
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+def _split(input_):
+ """Split the tensor along its last dimension and keep the
+ corresponding slice."""
+ group = g_mpu.get_model_parallel_group()
+
+ # Bypass the function if we are using only 1 GPU.
+ if dist.get_world_size(group=group) == 1:
+ return input_
+
+ # Split along last dimension.
+ world_size = dist.get_world_size(group=group)
+ input_list = split_tensor_along_last_dim(input_, world_size)
+
+ # Note: torch.split does not create contiguous tensors by default.
+ rank = dist.get_rank(group=group)
+ output = input_list[rank].contiguous()
+
+ return output
+
+
+def _gather(input_):
+ """Gather tensors and concatinate along the last dimension."""
+ group = g_mpu.get_model_parallel_group()
+
+ # Bypass the function if we are using only 1 GPU.
+ if dist.get_world_size(group=group) == 1:
+ return input_
+
+ # Size and dimension.
+ last_dim = input_.dim() - 1
+ rank = dist.get_rank(group=group)
+ world_size = dist.get_world_size(group=group)
+
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank] = input_
+ dist.all_gather(tensor_list, input_, group=group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=last_dim).contiguous()
+
+ return output
+
+
+class _CopyToModelParallelRegion(torch.autograd.Function):
+ """Pass the input to the model parallel region."""
+ @staticmethod
+ def forward(ctx, input_):
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _reduce(grad_output)
+
+
+class _ReduceFromModelParallelRegion(torch.autograd.Function):
+ """All-redcue the input from the model parallel region."""
+ @staticmethod
+ def forward(ctx, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+
+class _ScatterToModelParallelRegion(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+ @staticmethod
+ def forward(ctx, input_):
+ return _split(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _gather(grad_output)
+
+
+class _GatherFromModelParallelRegion(torch.autograd.Function):
+ """Gather the input from model parallel region and concatinate."""
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output)
+
+
+# -----------------
+# Helper functions.
+# -----------------
+
+
+def copy_to_model_parallel_region(input_):
+ return _CopyToModelParallelRegion.apply(input_)
+
+
+def reduce_from_model_parallel_region(input_):
+ return _ReduceFromModelParallelRegion.apply(input_)
+
+
+def scatter_to_model_parallel_region(input_):
+ return _ScatterToModelParallelRegion.apply(input_)
+
+
+def gather_from_model_parallel_region(input_):
+ return _GatherFromModelParallelRegion.apply(input_)
+
+
+class ColumnParallelLinear_Compress(LinearLayer_Compress):
+ def __init__(self,
+ mpu,
+ input_size,
+ output_size,
+ bias=True,
+ gather_output=True,
+ skip_bias_add=False):
+ # Keep input parameters
+ global g_mpu
+ g_mpu = mpu
+ self.input_size = input_size
+ self.output_size = output_size
+ self.gather_output = gather_output
+ self.skip_bias_add = skip_bias_add
+
+ # Divide the weight matrix along the last dimension.
+ world_size = mpu.get_model_parallel_world_size()
+ assert output_size % world_size == 0
+ self.output_size_per_partition = output_size // world_size
+
+ super(ColumnParallelLinear_Compress,
+ self).__init__(self.input_size,
+ self.output_size_per_partition,
+ bias=bias)
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ input_parallel = copy_to_model_parallel_region(input_)
+ # Matrix multiply.
+ if self.skip_bias_add:
+ output_parallel, bias = super().forward(input_parallel, True)
+ else:
+ output_parallel = super().forward(input_parallel)
+ bias = None
+ if self.gather_output:
+ # All-gather across the partitions.
+ output = gather_from_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ return output, bias
+
+
+class RowParallelLinear_Compress(LinearLayer_Compress):
+ def __init__(self,
+ mpu,
+ input_size,
+ output_size,
+ bias=True,
+ input_is_parallel=False,
+ skip_bias_add=False):
+ # Keep input parameters
+ global g_mpu
+ g_mpu = mpu
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ self.skip_bias_add = skip_bias_add
+
+ # Divide the weight matrix along the last dimension.
+ world_size = mpu.get_model_parallel_world_size()
+ assert input_size % world_size == 0
+ self.input_size_per_partition = input_size // world_size
+
+ super(RowParallelLinear_Compress,
+ self).__init__(self.input_size_per_partition,
+ self.output_size,
+ bias=bias)
+
+ def forward(self, input_):
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ input_parallel = scatter_to_model_parallel_region(input_)
+ # Matrix multiply.
+ output_parallel, bias = super().forward(input_parallel, True)
+
+ # All-reduce across all the partitions.
+ output_ = reduce_from_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ if bias is not None:
+ output = output_ + bias
+ else:
+ output = output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = bias
+ return output, output_bias
diff --git a/deepspeed/compression/compress.py b/deepspeed/compression/compress.py
new file mode 100644
index 000000000000..459b17075526
--- /dev/null
+++ b/deepspeed/compression/compress.py
@@ -0,0 +1,234 @@
+import torch
+import re
+from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible
+from .config import get_compression_config
+from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
+from .constants import *
+import os
+import json
+
+from deepspeed.utils import logger
+
+
+def check_deepspeed_config(config):
+ if isinstance(config, dict):
+ return config
+ elif os.path.exists(config):
+ return json.load(open(config,
+ "r"),
+ object_pairs_hook=dict_raise_error_on_duplicate_keys)
+ else:
+ raise ValueError(
+ f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
+ )
+
+
+def get_module_name(group_name,
+ model,
+ key_word,
+ exist_module_name,
+ mpu=None,
+ verbose=True):
+ '''
+ get the associated module name from the model based on the key_word provided by users
+ '''
+ return_module_name = []
+ for name, module in model.named_modules():
+
+ module_check = is_module_compressible(module, mpu)
+
+ if re.search(key_word, name) is not None and module_check:
+ if name in exist_module_name and verbose:
+ # logger.warning
+ raise ValueError(
+ f"{name} is already added to compression, please check your config file for {group_name}."
+ )
+ if name not in exist_module_name:
+ exist_module_name.add(name)
+ return_module_name.append(name)
+ return return_module_name, exist_module_name
+
+
+def get_compress_methods(model, compress_methods, mpu=None):
+ # extract the compression module for each method in compress_methods
+ layer_added_compress_methods = []
+ for method, method_content in compress_methods.items():
+ if LAYER_REDUCTION in method:
+ continue
+ # for loop different methods, i.e., weight quantization, activation quantization etc
+ exist_module_name = set()
+ shared_parameters = method_content[
+ SHARED_PARAMETERS] # get all the shared parameters
+ for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
+ # for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
+ module_name_list = []
+ related_module_name_list = []
+ if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
+ # this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
+ # otherwise we just mask those as zeros
+ for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
+ module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
+ module_name_list.append(module_name)
+ tmp_related_module_name_list = []
+ for rkw in related_key_words:
+ # related key word can be a list, for instance the QKV for O matrix in Attention
+ module_name, _ = get_module_name(group_name, model, rkw, set(), mpu=mpu)
+ tmp_related_module_name_list.append(module_name)
+ related_module_name_list.append(tmp_related_module_name_list)
+ else:
+ for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
+ module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
+ module_name_list.append(module_name)
+
+ if module_name_list:
+ # combine shared parameters with each group
+ combined_method_parameters = {
+ **(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
+ **shared_parameters
+ }
+ compression_item = [
+ module_name_list,
+ related_module_name_list,
+ {
+ method: combined_method_parameters
+ }
+ ]
+ layer_added_compress_methods.append(compression_item)
+ return layer_added_compress_methods
+
+
+def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
+ """
+ Compress a model: replace linear/conv2d layer with deepspeed compression-aware modules
+ Args:
+ model (`torch.nn.Module`)
+ The model to compress.
+ deepspeed_config (`DeepSpeedConfig`)
+ The path of ds_config
+ mpu
+ The mpu module for Row/Column parallelism
+ """
+ compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
+ if hasattr(model, 'module'):
+ c_model = model.module
+ else:
+ c_model = model
+
+ # For layer reduction
+ if compress_methods[LAYER_REDUCTION][LAYER_REDUCTION_ENABLED]:
+ assert teacher_model is not None, "Teacher model is required for layer reduction"
+ student_initialization(c_model, teacher_model, deepspeed_config)
+
+ layer_added_compress_methods = get_compress_methods(c_model,
+ compress_methods,
+ mpu=mpu)
+ compression_preparation(c_model, layer_added_compress_methods, mpu)
+
+ return model
+
+
+def redundancy_clean(model, deepspeed_config, mpu=None):
+ """
+ Remove the redundancy of a model
+ Args:
+ model (`torch.nn.Module`)
+ The model to compress.
+ deepspeed_config (`DeepSpeedConfig`)
+ The path of ds_config
+ mpu
+ The mpu module for Row/Column parallelism
+ """
+ compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
+ if hasattr(model, 'module'):
+ c_model = model.module
+ else:
+ c_model = model
+
+ layer_added_compress_methods_tmp = get_compress_methods(c_model,
+ compress_methods,
+ mpu=mpu)
+ # sort methods
+ order_list = [
+ WEIGHT_QUANTIZATION,
+ SPARSE_PRUNING,
+ ROW_PRUNING,
+ HEAD_PRUNING,
+ CHANNEL_PRUNING,
+ ACTIVATION_QUANTIZATION
+ ]
+ layer_added_compress_methods = sorted(
+ layer_added_compress_methods_tmp,
+ key=lambda x: order_list.index(list(x[2].keys())[0]))
+
+ for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
+ stored_mask = []
+ need_mask = True if related_module_name_lists else False
+ for i, mnl in enumerate(module_name_lists):
+ for module_name in mnl:
+ mask = fix_compression(c_model,
+ module_name,
+ compression_technique,
+ dim_reduction=need_mask)
+ if need_mask:
+ stored_mask.append(mask)
+ if need_mask:
+ for rmnl in related_module_name_lists[i]:
+ for j, module_name in enumerate(rmnl):
+ mask = fix_compression(c_model,
+ module_name,
+ compression_technique,
+ mask=stored_mask[j],
+ dim_reduction=True)
+ return model
+
+
+def student_initialization(student_model, teacher_model, deepspeed_config):
+ '''
+ Given a student model and a teacher model, select the
+ Args:
+ student_model (`torch.nn.Module`)
+ The model we will update weight
+ teacher_model (`torch.nn.Module`)
+ The model guide the student to learn
+ deepspeed_config (`DeepSpeedConfig`)
+ The path of ds_config
+ '''
+ config = get_compression_config(check_deepspeed_config(deepspeed_config))
+ compress_methods = config[LAYER_REDUCTION]
+
+ module_name_prefix = compress_methods[MODULE_NAME_PREFIX]
+ teacher_layer = compress_methods[TEACHER_LAYER]
+ student_layer = [i for i in range(len(teacher_layer))]
+ other_module_name = compress_methods[OTHER_MODULE_NAME]
+ '''
+ name_prefix (`str`)
+ The prefix name before the layer #.
+ Example 1: bert.encoder.layer, for BERT_base model's prefix name
+ Example 2: transformer.h, for GPT-2 hugging face prefix name
+ teacher_layer (`list of intergers`)
+ The layer of teacher will be used for student's reinitializedion
+ Example 1: [1,3,5,7,9], means we want to matches the 2nd/4th/6th/8th/10th layer of teacher to the first 5 layers of student
+ student_layer (`list` or None)
+ The layer of student need to be re-intiialized
+ Example 1: None, means we want to reinitialize all the layers
+ Example 1: [0,1,2,3,4], means we want to reinitialize the first 5 layers
+ other_module_name (`list of string`)
+ The modules will be used for student's reinitializedion
+ Example 1: ['bert.pooler', 'bert.embeddings', 'classifier'], means we want to apply the weight in teacher's embedding/pooler/classier module to the student
+ Example 2: ['transformer.w', 'transformer.ln_f', 'lm_head'], means we want to apply the weight in teacher's embeddingn layers module to the student
+ Note that teacher_layer should matches student layer
+ '''
+ assert len(student_layer) == len(teacher_layer)
+ for s_name, t_name in zip(student_layer, teacher_layer):
+ s_module = recursive_getattr(student_model,
+ module_name_prefix + '.' + str(s_name))
+ t_module = recursive_getattr(teacher_model,
+ module_name_prefix + '.' + str(t_name))
+ for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
+ s_param.data.copy_(t_param.data)
+ for name in other_module_name:
+ s_module = recursive_getattr(student_model, name)
+ t_module = recursive_getattr(teacher_model, name)
+ print(name)
+ for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
+ s_param.data.copy_(t_param.data)
diff --git a/deepspeed/compression/config.py b/deepspeed/compression/config.py
new file mode 100644
index 000000000000..d53246e2ed87
--- /dev/null
+++ b/deepspeed/compression/config.py
@@ -0,0 +1,490 @@
+from .constants import *
+import copy
+from ..runtime.config_utils import get_scalar_param
+
+
+def get_compression_config(param_dict):
+ #
+ output = {}
+
+ if COMPRESSION_TRAINING not in param_dict.keys():
+ param_dict[COMPRESSION_TRAINING] = {}
+ sub_param_dict = param_dict[COMPRESSION_TRAINING]
+ output[WEIGHT_QUANTIZATION] = get_weight_quantization(sub_param_dict)
+ output[ACTIVATION_QUANTIZATION] = get_activation_quantization(sub_param_dict)
+ output[SPARSE_PRUNING] = get_sparse_pruning(sub_param_dict)
+ output[ROW_PRUNING] = get_row_pruning(sub_param_dict)
+ output[HEAD_PRUNING] = get_head_pruning(sub_param_dict)
+ output[CHANNEL_PRUNING] = get_channel_pruning(sub_param_dict)
+
+ output[LAYER_REDUCTION] = get_layer_reduction(sub_param_dict)
+
+ return output
+
+
+def get_layer_reduction(param_dict):
+ output = {}
+ output[LAYER_REDUCTION_ENABLED] = LAYER_REDUCTION_ENABLED_DEFAULT
+ if get_layer_reduction_enabled(param_dict):
+ output[LAYER_REDUCTION_ENABLED] = get_layer_reduction_enabled(param_dict)
+ for key, val in get_layer_reduction_params(param_dict).items():
+ output[key] = val
+ return output
+
+
+def get_layer_reduction_enabled(param_dict):
+ if LAYER_REDUCTION in param_dict.keys():
+ return get_scalar_param(param_dict[LAYER_REDUCTION],
+ LAYER_REDUCTION_ENABLED,
+ LAYER_REDUCTION_ENABLED_DEFAULT)
+ else:
+ return False
+
+
+def get_layer_reduction_params(param_dict):
+ if LAYER_REDUCTION in param_dict.keys():
+ layer_reduction_params = copy.copy(param_dict[LAYER_REDUCTION])
+ layer_reduction_params.pop(LAYER_REDUCTION_ENABLED)
+ return layer_reduction_params
+ else:
+ return False
+
+
+def get_quantize_enabled(param_dict):
+ if COMPRESSION_TRAINING not in param_dict.keys():
+ return False
+
+ sub_param_dict = param_dict[COMPRESSION_TRAINING]
+ output = get_weight_quantization_shared_parameters(sub_param_dict)
+ return output[WEIGHT_QUANTIZE_ENABLED]
+
+
+def get_weight_quantization(param_dict):
+ output = {}
+ if WEIGHT_QUANTIZATION not in param_dict.keys():
+ param_dict[WEIGHT_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
+ sub_param_dict = param_dict[WEIGHT_QUANTIZATION]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
+ return output
+
+
+def get_weight_quantization_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_ENABLED,
+ WEIGHT_QUANTIZE_ENABLED_DEFAULT)
+ output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_KERNEL,
+ WEIGHT_QUANTIZE_KERNEL_DEFAULT)
+ output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
+ WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
+ output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_GROUPS,
+ WEIGHT_QUANTIZE_GROUPS_DEFAULT)
+ output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_VERBOSE,
+ WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
+ output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
+ WEIGHT_QUANTIZE_TYPE,
+ WEIGHT_QUANTIZE_TYPE_DEFAULT)
+ output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
+ WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
+ assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
+ output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(
+ sub_param_dict,
+ WEIGHT_QUANTIZE_ROUNDING,
+ WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
+ assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
+ if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
+ output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
+ sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
+ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
+ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
+ output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
+ sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
+ WEIGHT_QUANTIZE_CHANGE_RATIO,
+ WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
+ else:
+ output[
+ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
+ output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
+ else:
+ output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
+ output[WEIGHT_QUANTIZE_KERNEL] = WEIGHT_QUANTIZE_KERNEL_DEFAULT
+ output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
+ output[WEIGHT_QUANTIZE_GROUPS] = WEIGHT_QUANTIZE_GROUPS_DEFAULT
+ output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
+ output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
+ output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
+ output[
+ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
+ output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
+ return output
+
+
+def get_weight_quantization_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
+ assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
+ group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(
+ group_dict,
+ WEIGHT_QUANTIZATION_PERIOD,
+ WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+
+ return output
+
+
+def get_activation_quantization(param_dict):
+ output = {}
+ if ACTIVATION_QUANTIZATION not in param_dict.keys():
+ param_dict[ACTIVATION_QUANTIZATION] = {
+ SHARED_PARAMETERS: {},
+ DIFFERENT_GROUPS: {}
+ }
+ sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(
+ sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(
+ sub_param_dict)
+ return output
+
+
+def get_activation_quantization_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(
+ sub_param_dict,
+ ACTIVATION_QUANTIZATION_ENABLED,
+ ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
+ output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(
+ sub_param_dict,
+ ACTIVATION_QUANTIZE_TYPE,
+ ACTIVATION_QUANTIZE_TYPE_DEFAULT)
+ assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
+ output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(
+ sub_param_dict,
+ ACTIVATION_QUANTIZE_RANGE,
+ ACTIVATION_QUANTIZE_RANGE_DEFAULT)
+ assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
+ output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
+ ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
+ else:
+ output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
+ output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
+ output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
+ output[
+ ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
+ return output
+
+
+def get_activation_quantization_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+
+ return output
+
+
+def get_sparse_pruning(param_dict):
+ output = {}
+ if SPARSE_PRUNING not in param_dict.keys():
+ param_dict[SPARSE_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
+ sub_param_dict = param_dict[SPARSE_PRUNING]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
+ return output
+
+
+def get_sparse_pruning_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[SPARSE_PRUNING_ENABLED] = get_scalar_param(
+ sub_param_dict,
+ SPARSE_PRUNING_ENABLED,
+ SPARSE_PRUNING_ENABLED_DEFAULT)
+ output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
+ SPARSE_PRUNING_METHOD,
+ SPARSE_PRUNING_METHOD_DEFAULT)
+ assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
+ output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ SPARSE_PRUNING_SCHEDULE_OFFSET,
+ SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
+ else:
+ output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
+ output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
+ output[SPARSE_PRUNING_SCHEDULE_OFFSET] = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
+ return output
+
+
+def get_sparse_pruning_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+
+ return output
+
+
+def get_row_pruning(param_dict):
+ output = {}
+ if ROW_PRUNING not in param_dict.keys():
+ param_dict[ROW_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
+ sub_param_dict = param_dict[ROW_PRUNING]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
+ return output
+
+
+def get_row_pruning_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
+ ROW_PRUNING_ENABLED,
+ ROW_PRUNING_ENABLED_DEFAULT)
+ output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
+ ROW_PRUNING_METHOD,
+ ROW_PRUNING_METHOD_DEFAULT)
+ assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
+ output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ ROW_PRUNING_SCHEDULE_OFFSET,
+ ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
+ else:
+ output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
+ output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
+ output[ROW_PRUNING_SCHEDULE_OFFSET] = ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT
+ return output
+
+
+def get_row_pruning_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+ return output
+
+
+def get_head_pruning(param_dict):
+ output = {}
+ if HEAD_PRUNING not in param_dict.keys():
+ param_dict[HEAD_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
+ sub_param_dict = param_dict[HEAD_PRUNING]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
+ return output
+
+
+def get_head_pruning_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
+ HEAD_PRUNING_ENABLED,
+ HEAD_PRUNING_ENABLED_DEFAULT)
+ output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
+ HEAD_PRUNING_METHOD,
+ HEAD_PRUNING_METHOD_DEFAULT)
+ assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
+ output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ HEAD_PRUNING_SCHEDULE_OFFSET,
+ HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
+ if output[HEAD_PRUNING_ENABLED]:
+ assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
+ output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
+ else:
+ output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
+ output[HEAD_PRUNING_METHOD] = HEAD_PRUNING_METHOD_DEFAULT
+ output[HEAD_PRUNING_SCHEDULE_OFFSET] = HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT
+ return output
+
+
+def get_head_pruning_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}"
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+ return output
+
+
+def get_channel_pruning(param_dict):
+ output = {}
+ if CHANNEL_PRUNING not in param_dict.keys():
+ param_dict[CHANNEL_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
+ sub_param_dict = param_dict[CHANNEL_PRUNING]
+ # shared parameters
+ output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
+ # each sub-groups
+ if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
+ assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
+ output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
+ return output
+
+
+def get_channel_pruning_shared_parameters(param_dict):
+ output = {}
+ if SHARED_PARAMETERS in param_dict.keys():
+ sub_param_dict = param_dict[SHARED_PARAMETERS]
+ output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(
+ sub_param_dict,
+ CHANNEL_PRUNING_ENABLED,
+ CHANNEL_PRUNING_ENABLED_DEFAULT)
+ output[CHANNEL_PRUNING_METHOD] = get_scalar_param(
+ sub_param_dict,
+ CHANNEL_PRUNING_METHOD,
+ CHANNEL_PRUNING_METHOD_DEFAULT)
+ assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
+ output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
+ sub_param_dict,
+ CHANNEL_PRUNING_SCHEDULE_OFFSET,
+ CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
+ else:
+ output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
+ output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
+ output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT
+ return output
+
+
+def get_channel_pruning_different_groups(param_dict):
+ output = {}
+ sub_param_dict = param_dict[DIFFERENT_GROUPS]
+
+ def get_params(name, group_dict):
+ assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
+ return group_dict
+
+ for k, v in sub_param_dict.items():
+ output[k] = {}
+ output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
+ k,
+ sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
+ output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_MODULE_SCOPE,
+ DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
+ output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
+ sub_param_dict[k],
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
+ DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
+
+ return output
diff --git a/deepspeed/compression/constants.py b/deepspeed/compression/constants.py
new file mode 100644
index 000000000000..6507ae6fbccf
--- /dev/null
+++ b/deepspeed/compression/constants.py
@@ -0,0 +1,168 @@
+#########################################
+# Compression Methods
+# It has several sub-components
+# #########################################
+COMPRESSION_TRAINING = "compression_training"
+SHARED_PARAMETERS = "shared_parameters"
+DIFFERENT_GROUPS = "different_groups"
+TECHNIQUE_ENABLED = "enabled"
+TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset"
+DIFFERENT_GROUPS_PARAMETERS = "params"
+DIFFERENT_GROUPS_MODULE_SCOPE = "modules"
+DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*"
+DIFFERENT_GROUPS_RELATED_MODULE_SCOPE = "related_modules"
+DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT = None
+# COMPRESSION_TRAINING_ENABLED = "enabled"
+# COMPRESSION_TRAINING_ENABLED_DEFAULT = False
+
+####
+# Layer Reduction
+####
+LAYER_REDUCTION = "layer_reduction"
+LAYER_REDUCTION_ENABLED = "enabled"
+LAYER_REDUCTION_ENABLED_DEFAULT = False
+KEEP_NUMBER_LAYER = "keep_number_layer"
+MODULE_NAME_PREFIX = "module_name_prefix"
+TEACHER_LAYER = "teacher_layer"
+OTHER_MODULE_NAME = "other_module_name"
+
+####
+# Weight Quantzation
+####
+WEIGHT_QUANTIZATION = "weight_quantization"
+
+WEIGHT_QUANTIZATION_PERIOD = "quantization_period"
+WEIGHT_QUANTIZATION_PERIOD_DEFAULT = 1
+
+WEIGHT_QUANTIZE_IN_FORWARD_ENABLED = "quantize_weight_in_forward"
+WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT = False
+
+WEIGHT_QUANTIZE_ENABLED = TECHNIQUE_ENABLED
+WEIGHT_QUANTIZE_ENABLED_DEFAULT = False
+
+WEIGHT_QUANTIZE_KERNEL = "quantizer_kernel"
+WEIGHT_QUANTIZE_KERNEL_DEFAULT = False
+
+WEIGHT_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 0
+
+WEIGHT_QUANTIZE_GROUPS = "quantize_groups"
+WEIGHT_QUANTIZE_GROUPS_DEFAULT = 1
+
+WEIGHT_QUANTIZE_VERBOSE = "quantize_verbose"
+WEIGHT_QUANTIZE_VERBOSE_DEFAULT = False
+
+WEIGHT_QUANTIZE_TYPE = "quantization_type"
+WEIGHT_QUANTIZE_TYPE_DEFAULT = "symmetric"
+WEIGHT_QUANTIZE_SYMMETRIC = "symmetric"
+WEIGHT_QUANTIZE_ASYMMETRIC = "asymmetric"
+
+WEIGHT_QUANTIZE_ROUNDING = "rounding"
+WEIGHT_QUANTIZE_ROUNDING_DEFAULT = "nearest"
+WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING = "stochastic"
+WEIGHT_QUANTIZE_NEAREST_ROUNDING = "nearest"
+# maybe deleted for a cleaner version
+WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE = "fp16_mixed_quantize"
+
+WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED = "enabled"
+WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT = False
+
+WEIGHT_QUANTIZE_CHANGE_RATIO = "quantize_change_ratio"
+WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT = 0.001
+
+WEIGHT_QUANTIZE_START_BITS = "start_bits"
+WEIGHT_QUANTIZE_TARGET_BITS = "target_bits"
+###
+# Activation Quantization
+###
+ACTIVATION_QUANTIZATION = "activation_quantization"
+
+ACTIVATION_QUANTIZATION_ENABLED = TECHNIQUE_ENABLED
+ACTIVATION_QUANTIZATION_ENABLED_DEFAULT = False
+
+ACTIVATION_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 1000
+
+ACTIVATION_QUANTIZE_TYPE = "quantization_type"
+ACTIVATION_QUANTIZE_TYPE_DEFAULT = "symmetric"
+ACTIVATION_QUANTIZE_SYMMETRIC = "symmetric"
+ACTIVATION_QUANTIZE_ASYMMETRIC = "asymmetric"
+
+ACTIVATION_QUANTIZE_RANGE = 'range_calibration'
+ACTIVATION_QUANTIZE_RANGE_DEFAULT = 'dynamic'
+ACTIVATION_QUANTIZE_RANGE_STATIC = 'static'
+ACTIVATION_QUANTIZE_RANGE_DYNAMIC = 'dynamic'
+
+ACTIVATION_QUANTIZE_BITS = "bits"
+###
+# Sparse Pruning
+###
+SPARSE_PRUNING = "sparse_pruning"
+
+SPARSE_PRUNING_ENABLED = TECHNIQUE_ENABLED
+SPARSE_PRUNING_ENABLED_DEFAULT = False
+
+SPARSE_PRUNING_METHOD = "method"
+SPARSE_PRUNING_METHOD_DEFAULT = "l1"
+SPARSE_PRUNING_METHOD_L1 = "l1"
+SPARSE_PRUNING_METHOD_TOPK = "topk"
+
+SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
+
+SPARSE_PRUNING_DENSE_RATIO = "dense_ratio"
+###
+# Row Pruning
+###
+ROW_PRUNING = "row_pruning"
+
+ROW_PRUNING_ENABLED = TECHNIQUE_ENABLED
+ROW_PRUNING_ENABLED_DEFAULT = False
+
+ROW_PRUNING_METHOD = "method"
+ROW_PRUNING_METHOD_DEFAULT = "l1"
+ROW_PRUNING_METHOD_L1 = "l1"
+ROW_PRUNING_METHOD_TOPK = "topk"
+
+ROW_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
+
+ROW_PRUNING_DENSE_RATIO = "dense_ratio"
+
+###
+# Head Pruning
+###
+HEAD_PRUNING = "head_pruning"
+
+HEAD_PRUNING_ENABLED = TECHNIQUE_ENABLED
+HEAD_PRUNING_ENABLED_DEFAULT = False
+
+HEAD_PRUNING_METHOD = "method"
+HEAD_PRUNING_METHOD_DEFAULT = "topk"
+HEAD_PRUNING_METHOD_L1 = "l1"
+HEAD_PRUNING_METHOD_TOPK = "topk"
+
+HEAD_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
+
+HEAD_PRUNING_NUM_HEADS = "num_heads"
+
+HEAD_PRUNING_DENSE_RATIO = "dense_ratio"
+
+###
+# Channel Pruning
+###
+CHANNEL_PRUNING = "channel_pruning"
+
+CHANNEL_PRUNING_ENABLED = TECHNIQUE_ENABLED
+CHANNEL_PRUNING_ENABLED_DEFAULT = False
+
+CHANNEL_PRUNING_METHOD = "method"
+CHANNEL_PRUNING_METHOD_DEFAULT = "l1"
+CHANNEL_PRUNING_METHOD_L1 = "l1"
+CHANNEL_PRUNING_METHOD_TOPK = "topk"
+
+CHANNEL_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
+CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
+
+CHANNEL_PRUNING_DENSE_RATIO = "dense_ratio"
diff --git a/deepspeed/compression/helper.py b/deepspeed/compression/helper.py
new file mode 100644
index 000000000000..66d612ba8e82
--- /dev/null
+++ b/deepspeed/compression/helper.py
@@ -0,0 +1,281 @@
+import torch
+from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
+from .constants import *
+
+
+def recursive_getattr(model, module_name):
+ """
+ Recursively get the attribute of a module.
+ Args:
+ model (`torch.nn.Module`)
+ The model to get the attribute from.
+ module_name (`str`)
+ The name of the module to get the attribute from.
+ """
+ split_list = module_name.split('.')
+ output = model
+ for name in split_list:
+ output = getattr(output, name)
+ return output
+
+
+def recursive_setattr(model, module_name, module):
+ """
+ Recursively set the attribute of a module.
+ Args:
+ model (`torch.nn.Module`)
+ The model to set the attribute in.
+ module_name (`str`)
+ The name of the module to set the attribute in.
+ module (`torch.nn.Module`)
+ The module to set the attribute to.
+ """
+ split_list = module_name.split('.')
+ output = model
+ for name in split_list[:-1]:
+ output = getattr(output, name)
+ output.__setattr__(split_list[-1], module)
+
+
+def module_replacement(model, module_name, compression_technique=None, mpu=None):
+ """
+ Replace a module with a new module.
+ Args:
+ model (`torch.nn.Module`)
+ The model to replace the module in.
+ module_name (`str`)
+ The name of the module to replace.
+ compression_technique (`str`)
+ The compression technique to use for the new module.
+ """
+
+ # Get the old module
+ old_module = recursive_getattr(model, module_name)
+
+ need_bias = False
+ if hasattr(old_module, 'bias') and old_module.bias is not None:
+ need_bias = True
+
+ # Initialize the new module
+ if isinstance(old_module,
+ LinearLayer_Compress) or isinstance(old_module,
+ torch.nn.Linear):
+ if isinstance(old_module, LinearLayer_Compress):
+ new_module = old_module
+ else:
+ new_module = LinearLayer_Compress(old_module.in_features,
+ old_module.out_features,
+ bias=need_bias).to(
+ device=old_module.weight.device,
+ dtype=old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ if need_bias:
+ new_module.bias.data = old_module.bias.data
+ elif isinstance(old_module,
+ Conv2dLayer_Compress) or isinstance(old_module,
+ torch.nn.Conv2d):
+ if isinstance(old_module, Conv2dLayer_Compress):
+ new_module = old_module
+ else:
+ new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
+ old_module.dilation, old_module.groups, need_bias, \
+ old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ if need_bias:
+ new_module.bias.data = old_module.bias.data
+ elif isinstance(old_module, torch.nn.BatchNorm2d):
+ new_module = BNLayer_Compress(old_module.num_features,
+ old_module.eps,
+ old_module.momentum,
+ old_module.affine,
+ old_module.track_running_stats).to(
+ old_module.weight.device,
+ old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ if need_bias:
+ new_module.bias.data = old_module.bias.data
+ new_module.running_mean.data = old_module.running_mean.data
+ new_module.running_var.data = old_module.running_var.data
+ elif isinstance(old_module,
+ Embedding_Compress) or isinstance(old_module,
+ torch.nn.Embedding):
+ if isinstance(old_module, Embedding_Compress):
+ new_module = old_module
+ else:
+ new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
+ old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ elif mpu is not None and (isinstance(old_module,
+ ColumnParallelLinear_Compress)
+ or isinstance(old_module,
+ mpu.ColumnParallelLinear)):
+ if isinstance(old_module, ColumnParallelLinear_Compress):
+ new_module = old_module
+ else:
+ new_module = ColumnParallelLinear_Compress(
+ mpu,
+ old_module.input_size,
+ old_module.output_size,
+ gather_output=old_module.gather_output,
+ skip_bias_add=old_module.skip_bias_add,
+ bias=need_bias).to(device=old_module.weight.device,
+ dtype=old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ if need_bias:
+ new_module.bias.data = old_module.bias.data
+ elif mpu is not None and (isinstance(old_module,
+ RowParallelLinear_Compress)
+ or isinstance(old_module,
+ mpu.RowParallelLinear)):
+ if isinstance(old_module, RowParallelLinear_Compress):
+ new_module = old_module
+ else:
+ new_module = RowParallelLinear_Compress(
+ mpu,
+ old_module.input_size,
+ old_module.output_size,
+ input_is_parallel=old_module.input_is_parallel,
+ skip_bias_add=old_module.skip_bias_add,
+ bias=need_bias).to(device=old_module.weight.device,
+ dtype=old_module.weight.dtype)
+ new_module.weight.data = old_module.weight.data
+ if need_bias:
+ new_module.bias.data = old_module.bias.data
+ else:
+ new_module = None
+
+ if compression_technique is not None:
+ for k, v in compression_technique.items():
+ if k == SPARSE_PRUNING:
+ if v[SPARSE_PRUNING_ENABLED]:
+ new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
+ v[SPARSE_PRUNING_METHOD])
+ elif k == ROW_PRUNING:
+ if v[ROW_PRUNING_ENABLED]:
+ new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
+ v[ROW_PRUNING_METHOD])
+ elif k == HEAD_PRUNING:
+ if v[HEAD_PRUNING_ENABLED]:
+ new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
+ v[HEAD_PRUNING_METHOD],
+ v[HEAD_PRUNING_NUM_HEADS])
+ elif k == ACTIVATION_QUANTIZATION:
+ if v[ACTIVATION_QUANTIZATION_ENABLED]:
+ new_module.enable_activation_quantization(
+ v[ACTIVATION_QUANTIZE_BITS],
+ v[ACTIVATION_QUANTIZE_TYPE],
+ v[ACTIVATION_QUANTIZE_RANGE])
+ elif k == WEIGHT_QUANTIZATION:
+ if v[WEIGHT_QUANTIZE_ENABLED]:
+ new_module.enable_weight_quantization(
+ v[WEIGHT_QUANTIZE_START_BITS],
+ v[WEIGHT_QUANTIZE_TARGET_BITS],
+ v[WEIGHT_QUANTIZATION_PERIOD],
+ v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
+ v[WEIGHT_QUANTIZE_TYPE],
+ v[WEIGHT_QUANTIZE_GROUPS])
+ elif k == CHANNEL_PRUNING:
+ if v[CHANNEL_PRUNING_ENABLED]:
+ new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
+ v[CHANNEL_PRUNING_METHOD])
+ else:
+ raise NotImplementedError(
+ 'Compression technique {} is not implemented'.format(k))
+
+ # Replace the old module with the new one
+ recursive_setattr(model, module_name, new_module)
+
+
+def is_module_compressible(module, mpu=None):
+ ret = isinstance(module, torch.nn.Linear) or \
+ isinstance(module, torch.nn.Conv2d) or \
+ isinstance(module, torch.nn.Embedding) or \
+ isinstance(module, torch.nn.BatchNorm2d)
+
+ if mpu is not None:
+ ret = ret or isinstance(module,
+ mpu.RowParallelLinear) or isinstance(
+ module,
+ mpu.ColumnParallelLinear)
+
+ return ret
+
+
+def compression_preparation(model, compression_techinique_list, mpu):
+ """
+ Prepare the compression techniques of a model.
+ Args:
+ model (`torch.nn.Module`)
+ The model to prepare the compression techniques of.
+ compression_techinique_list (`list`)
+ The list of compression techniques to prepare the model to.
+ list[]
+ """
+ # Here we first replace all module with our linear wrapper
+ for module_name, module in model.named_modules():
+ if is_module_compressible(module, mpu):
+ module_replacement(model, module_name, mpu=mpu)
+ for module_name_lists, _, compression_technique in compression_techinique_list:
+ for mnl in module_name_lists:
+ for module_name in mnl:
+ module_replacement(model, module_name, compression_technique)
+
+ return model
+
+
+def fix_compression(model,
+ module_name,
+ compression_technique,
+ mask=None,
+ dim_reduction=False):
+ """
+ Fix the compression technique of a module.
+ Args:
+ model (`torch.nn.Module`)
+ The model to fix the compression technique of.
+ module_name (`str`)
+ The name of the module to fix the compression technique of.
+ compression_technique (`str`)
+ The compression technique to fix the module to.
+ """
+ # Here we can make things much simpler by just replacing the module
+ module = recursive_getattr(model, module_name)
+ for k, v in compression_technique.items():
+ if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
+ WEIGHT_QUANTIZE_ENABLED]:
+ return module.fix_weight_quantization()
+ elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
+ return module.fix_sparse_pruning_helper()
+ elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
+ return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
+ elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
+ return module.fix_head_pruning_helper(mask,
+ v[HEAD_PRUNING_NUM_HEADS],
+ dim_reduction=dim_reduction)
+ elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
+ return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
+
+
+def convert_conv1d_to_linear(model, convert_type):
+ '''
+ This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
+ '''
+ if hasattr(model, 'module'):
+ c_model = model.module
+ else:
+ c_model = model
+
+ for name, module in c_model.named_modules():
+ if isinstance(module, convert_type):
+ old_module = recursive_getattr(c_model, name)
+ new_module = torch.nn.Linear(
+ old_module.weight.data.size(0),
+ old_module.weight.data.size(1),
+ bias=True if old_module.bias is not None else False)
+ new_module.weight.data = old_module.weight.data.t()
+ if new_module.bias is not None:
+ new_module.bias.data = old_module.bias.data.view(-1)
+
+ recursive_setattr(c_model, name, new_module)
+
+ return model
diff --git a/deepspeed/compression/scheduler.py b/deepspeed/compression/scheduler.py
new file mode 100644
index 000000000000..02c8fb904af8
--- /dev/null
+++ b/deepspeed/compression/scheduler.py
@@ -0,0 +1,171 @@
+from .compress import get_module_name
+from .constants import *
+from .helper import recursive_getattr
+from deepspeed.utils import logger
+
+
+class compression_scheduler():
+ '''
+ Used to schedule different compression methods
+ '''
+ def __init__(self, model, compression_config):
+ self.model = model
+ self.compression_config = compression_config
+ self.make_init()
+ self.training_steps = 0
+ self.weight_quantization_enabled = False
+
+ self.verbose = {
+ WEIGHT_QUANTIZATION: False,
+ ACTIVATION_QUANTIZATION: False,
+ SPARSE_PRUNING: False,
+ HEAD_PRUNING: False,
+ ROW_PRUNING: False,
+ CHANNEL_PRUNING: False
+ }
+
+ def make_init(self):
+ self.different_compression_methods = {}
+ for method, method_content in self.compression_config.items():
+ if LAYER_REDUCTION in method:
+ continue
+ self.different_compression_methods[method] = {
+ TECHNIQUE_ENABLED: False,
+ SHARED_PARAMETERS: None,
+ DIFFERENT_GROUPS: []
+ }
+ exist_module_name = set()
+ shared_parameters = method_content[SHARED_PARAMETERS]
+ self.different_compression_methods[method][
+ TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
+ self.different_compression_methods[method][
+ SHARED_PARAMETERS] = shared_parameters
+
+ for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
+ module_name_list = []
+ for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
+ module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False)
+ module_name_list.extend(module_name)
+ if module_name_list:
+ self.different_compression_methods[method][DIFFERENT_GROUPS].append([
+ group_name,
+ module_name_list,
+ method_parameters.copy().pop('params')
+ ])
+
+ def check_weight_quantization(self):
+ # check weight quantization
+ wq = self.different_compression_methods[WEIGHT_QUANTIZATION]
+ if not wq[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = wq[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.weight_quantization_enabled = True
+
+ if not self.verbose[WEIGHT_QUANTIZATION]:
+ logger.info(
+ f'Weight quantization is enabled at step {self.training_steps}')
+ self.weight_quantization_enabled = True
+ self.verbose[WEIGHT_QUANTIZATION] = True
+
+ def check_activation_quantization(self):
+ # check activation quantization
+ aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
+ if not aq[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = aq[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.activation_quantization_enabled = True
+ if not self.verbose[ACTIVATION_QUANTIZATION]:
+ logger.info(
+ f'Activation quantization is enabled at step {self.training_steps}'
+ )
+ self.verbose[ACTIVATION_QUANTIZATION] = True
+
+ def check_sparse_pruning(self):
+ # check sparse pruning
+ sp = self.different_compression_methods[SPARSE_PRUNING]
+ if not sp[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = sp[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.sparse_pruning_enabled = True
+ if not self.verbose[SPARSE_PRUNING]:
+ logger.info(
+ f'Sparse pruning is enabled at step {self.training_steps}')
+ self.verbose[SPARSE_PRUNING] = True
+
+ def check_head_pruning(self):
+ # check head pruning
+ hp = self.different_compression_methods[HEAD_PRUNING]
+ if not hp[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = hp[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.head_pruning_enabled = True
+ if not self.verbose[HEAD_PRUNING]:
+ logger.info(f'Head pruning is enabled at step {self.training_steps}')
+ self.verbose[HEAD_PRUNING] = True
+
+ def check_row_pruning(self):
+ # check row pruning
+ rp = self.different_compression_methods[ROW_PRUNING]
+ if not rp[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = rp[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.row_pruning_enabled = True
+ if not self.verbose[ROW_PRUNING]:
+ logger.info(f'Row pruning is enabled at step {self.training_steps}')
+ self.verbose[ROW_PRUNING] = True
+
+ def check_channel_pruning(self):
+ # check channel pruning
+ cp = self.different_compression_methods[CHANNEL_PRUNING]
+ if not cp[TECHNIQUE_ENABLED]:
+ return
+ else:
+ shared_parameters = cp[SHARED_PARAMETERS]
+ if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]:
+ for module_name in module_name_list:
+ module = recursive_getattr(self.model, module_name)
+ module.channel_pruning_enabled = True
+ if not self.verbose[CHANNEL_PRUNING]:
+ logger.info(
+ f'Channel pruning is enabled at step {self.training_steps}')
+ self.verbose[CHANNEL_PRUNING] = True
+
+ def check_all_modules(self):
+ # check all different compression methods we have
+ self.check_weight_quantization()
+ self.check_activation_quantization()
+ self.check_sparse_pruning()
+ self.check_head_pruning()
+ self.check_row_pruning()
+ self.check_channel_pruning()
+
+ def step(self, step_zero_check=False):
+ if not step_zero_check:
+ self.training_steps += 1
+ self.check_all_modules()
diff --git a/deepspeed/compression/utils.py b/deepspeed/compression/utils.py
new file mode 100644
index 000000000000..3d6724939c12
--- /dev/null
+++ b/deepspeed/compression/utils.py
@@ -0,0 +1,216 @@
+import torch
+from torch import autograd
+import math
+
+
+class TopKBinarizer(autograd.Function):
+ """
+ Top-k Binarizer.
+ Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
+ is among the k% highest values of S.
+ Implementation is inspired from:
+ https://github.com/yaozhewei/MLPruning
+ """
+ @staticmethod
+ def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
+ """
+ Args:
+ inputs (`torch.FloatTensor`)
+ The input matrix from which the binarizer computes the binary mask.
+ threshold (`float`)
+ The percentage of weights to keep (the rest is pruned).
+ `threshold` is a float between 0 and 1.
+ sigmoid (`bool`)
+ Whether to apply a sigmoid on the threshold
+ Returns:
+ mask (`torch.FloatTensor`)
+ Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
+ retained, 0 - the associated weight is pruned).
+ """
+ # Get the subnetwork by sorting the inputs and using the top threshold
+ if sigmoid:
+ threshold = torch.sigmoid(threshold).item()
+ ctx.sigmoid = sigmoid
+ mask = inputs.clone()
+
+ _, idx = inputs.flatten().sort(descending=True)
+ j = math.ceil(threshold * inputs.numel())
+
+ # flat_out and mask access the same memory.
+ flat_out = mask.flatten()
+ flat_out[idx[j:]] = 0.
+ flat_out[idx[:j]] = 1.
+ ctx.save_for_backward(mask)
+
+ return mask
+
+ @staticmethod
+ def backward(ctx, gradOutput):
+ mask, = ctx.saved_tensors
+ if ctx.sigmoid:
+ return gradOutput.clone(), ((gradOutput * mask).sum()).view(-1), None
+ else:
+ return gradOutput.clone(), None, None
+
+
+class SymQuantizer(torch.autograd.Function):
+ """
+ Symmetric quantization
+ """
+ @staticmethod
+ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
+ """
+ Args:
+ inputs (`torch.FloatTensor`)
+ The input which needs to be quantized
+ num_bits (int, >=4)
+ Number of bits to use for quantization
+ min_value/max_vlue (torch.FloatTensor)
+ Used for static activation quantization
+ num_groups (int)
+ How many groups to partition the quantization into
+ Returns:
+ quantized_input (`torch.FloatTensor`)
+ Quantized input
+ """
+ assert (min_value is None
+ and max_value is None) or (min_value is not None
+ and max_value is not None and num_groups == 1)
+ q_range = 2**num_bits
+ input_shape = input.shape
+ if min_value is None:
+ input = input.reshape(num_groups, -1)
+ max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1)
+ else:
+ max_input = torch.max(min_value.abs(), max_value).view(-1)
+
+ scale = 2 * max_input / q_range
+ output = (input / scale).round().clamp(-q_range // 2, q_range // 2 - 1) * scale
+ output = output.reshape(input_shape).contiguous()
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ return grad_input, None, None, None, None
+
+
+class AsymQuantizer(torch.autograd.Function):
+ """
+ Asymmetric quantization
+ """
+ @staticmethod
+ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
+ """
+ Args:
+ inputs (`torch.FloatTensor`)
+ The input which needs to be quantized
+ num_bits (int, >=4)
+ Number of bits to use for quantization
+ min_value/max_vlue (torch.FloatTensor)
+ Used for static activation quantization
+ num_groups (int)
+ How many groups to partition the quantization into
+ Returns:
+ quantized_input (`torch.FloatTensor`)
+ Quantized input
+ """
+
+ assert (min_value is None
+ and max_value is None) or (min_value is not None
+ and max_value is not None and num_groups == 1)
+ q_range = 2**num_bits
+ input_shape = input.shape
+ if min_value is None:
+ input = input.reshape(num_groups, -1)
+ min_value = input.amin(dim=-1, keepdim=True)
+ max_value = input.amax(dim=-1, keepdim=True)
+
+ scale = (max_value - min_value) / q_range
+ zero_point = (min_value / scale).round() * scale
+
+ output = (
+ (input - zero_point) / scale).round().clamp(0,
+ q_range - 1) * scale + zero_point
+ output = output.reshape(input_shape).contiguous()
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ return grad_input, None, None, None, None
+
+
+class TernaryQuantizer(torch.autograd.Function):
+ """
+ Ternary quantization
+ """
+ @staticmethod
+ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
+ """
+ Args:
+ inputs (`torch.FloatTensor`)
+ The input which needs to be quantized
+ num_bits (int)
+ Dummy variable
+ min_value/max_vlue (torch.FloatTensor)
+ Used for static activation quantization; for now they are dummy variable
+ num_groups (int)
+ How many groups to partition the quantization into
+ Returns:
+ quantized_input (`torch.FloatTensor`)
+ Quantized input
+ """
+
+ assert (min_value is None and max_value is None)
+ input_flat = input.reshape(num_groups, -1)
+ n = input_flat.shape[1]
+ m = input_flat.norm(p=1, dim=1).div(n)
+ thres = (0.7 * m).view(-1, 1)
+ pos = (input_flat > thres).type(input.type())
+ neg = (input_flat < -thres).type(input.type())
+ mask = (input_flat.abs() > thres).type(input.type())
+ alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
+ output = alpha * pos - alpha * neg
+ output = output.reshape(input.shape).contiguous()
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ return grad_input, None, None, None, None
+
+
+class BinaryQuantizer(torch.autograd.Function):
+ """
+ Binary quantization
+ """
+ @staticmethod
+ def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
+ """
+ Args:
+ inputs (`torch.FloatTensor`)
+ The input which needs to be quantized
+ num_bits (int)
+ Dummy variable
+ min_value/max_vlue (torch.FloatTensor)
+ Used for static activation quantization; for now they are dummy variable
+ num_groups (int)
+ How many groups to partition the quantization into
+ Returns:
+ quantized_input (`torch.FloatTensor`)
+ Quantized input
+ """
+
+ assert (min_value is None and max_value is None)
+ input_flat = input.reshape(num_groups, -1)
+ n = input_flat.shape[1]
+ m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
+ output = input_flat.sign().mul(m)
+ output = output.reshape(input.shape).contiguous()
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ return grad_input, None, None, None, None
diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py
index 4571cbdf7056..35d50d00f36b 100755
--- a/deepspeed/runtime/config.py
+++ b/deepspeed/runtime/config.py
@@ -45,7 +45,8 @@
from ..profiling.config import DeepSpeedFlopsProfilerConfig
from ..autotuning.config import DeepSpeedAutotuningConfig
-
+from ..compression.config import get_compression_config, get_quantize_enabled
+from ..compression.constants import *
from .swap_tensor.aio_config import get_aio_config
TENSOR_CORE_ALIGN_SIZE = 8
@@ -264,73 +265,6 @@ def get_gradient_predivide_factor(param_dict):
GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
-def get_quantize_enabled(param_dict):
- if QUANTIZE_TRAINING in param_dict.keys():
- return get_scalar_param(
- param_dict[QUANTIZE_TRAINING],
- QUANTIZE_TRAINING_ENABLED,
- QUANTIZE_TRAINING_ENABLED_DEFAULT,
- )
- else:
- return False
-
-
-def get_quantize_training(param_dict):
- if QUANTIZE_TRAINING in param_dict.keys():
- return (
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][TARGET_BITS]),
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][START_BITS]
- if START_BITS in param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS].keys() else
- QUANTIZE_START_BITS_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][QUANTIZE_PERIOD]
- if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() else
- QUANTIZE_PERIOD_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][SCHEDULE_OFFSET]
- if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() and
- SCHEDULE_OFFSET in param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE].keys()
- else QUANTIZE_OFFSET_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_GROUPS] if QUANTIZE_GROUPS
- in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_GROUPS_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE]
- [FP16_MIXED_QUANTIZE_ENABLED]
- if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys()
- and FP16_MIXED_QUANTIZE_ENABLED
- in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else
- FP16_MIXED_QUANTIZE_ENABLED_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][QUANTIZE_CHANGE_RATIO]
- if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys()
- and QUANTIZE_CHANGE_RATIO
- in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else
- QUANTIZE_CHANGE_RATIO_DEFAULT),
- (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING]
- and QUANTIZE_TYPE in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys()
- and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_TYPE]
- == QUANTIZE_ASYMMETRIC else QUANTIZE_TYPE_DEFAULT),
- (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_ROUNDING
- in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys()
- and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_ROUNDING]
- == STOCHASTIC_ROUNDING else QUANTIZE_ROUNDING_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][QUANTIZE_VERBOSE] if QUANTIZE_VERBOSE
- in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_VERBOSE_DEFAULT),
- (param_dict[QUANTIZE_TRAINING][QUANTIZER_KERNEL] if QUANTIZER_KERNEL
- in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZER_KERNEL_DEFAULT),
- )
- else:
- return (
- QUANTIZE_TARGET_BITS_DEFAULT,
- QUANTIZE_START_BITS_DEFAULT,
- QUANTIZE_PERIOD_DEFAULT,
- QUANTIZE_OFFSET_DEFAULT,
- QUANTIZE_GROUPS_DEFAULT,
- FP16_MIXED_QUANTIZE_ENABLED_DEFAULT,
- QUANTIZE_CHANGE_RATIO_DEFAULT,
- QUANTIZE_TYPE_DEFAULT,
- QUANTIZE_ROUNDING_DEFAULT,
- QUANTIZE_VERBOSE_DEFAULT,
- QUANTIZER_KERNEL_DEFAULT,
- )
-
-
def get_steps_per_print(param_dict):
return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT)
@@ -621,6 +555,7 @@ def get_memory_breakdown(param_dict):
def get_eigenvalue_config(param_dict):
if get_quantize_enabled(param_dict):
param_dict = param_dict[QUANTIZE_TRAINING]
+ assert not get_eigenvalue_enabled(param_dict), "Eigenvalue based MoQ is temporarily disabled"
return (
get_eigenvalue_enabled(param_dict),
get_eigenvalue_verbose(param_dict),
@@ -885,20 +820,7 @@ def _initialize_params(self, param_dict):
self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
- self.quantize_training_enabled = get_quantize_enabled(param_dict)
- (
- self.quantize_target_bits,
- self.quantize_start_bits,
- self.quantize_period,
- self.quantize_offset,
- self.quantize_groups,
- self.fp16_mixed_quantize,
- self.quantize_change_rate,
- self.quantize_type,
- self.quantize_rounding,
- self.quantize_verbose,
- self.use_quantizer_kernel,
- ) = get_quantize_training(param_dict)
+ self.compression_config = get_compression_config(param_dict)
self.optimizer_name = get_optimizer_name(param_dict)
if (self.optimizer_name is not None
diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py
index 88b055b3e210..250240089f89 100755
--- a/deepspeed/runtime/constants.py
+++ b/deepspeed/runtime/constants.py
@@ -370,44 +370,6 @@ class ValidationMode:
ValidationMode.FAIL
]
-#########################################
-# Quantization
-#########################################
-QUANTIZE_TRAINING = "quantize_training"
-QUANTIZE_BITS = "quantize_bits"
-START_BITS = "start_bits"
-TARGET_BITS = "target_bits"
-QUANTIZER_KERNEL = "quantizer_kernel"
-QUANTIZE_SCHEDULE = "quantize_schedule"
-QUANTIZE_PERIOD = "quantize_period"
-SCHEDULE_OFFSET = "schedule_offset"
-QUANTIZE_GROUPS = "quantize_groups"
-FP16_MIXED_QUANTIZE = "fp16_mixed_quantize"
-QUANTIZE_CHANGE_RATIO = "quantize_change_ratio"
-FP16_MIXED_QUANTIZE_ENABLED = "enabled"
-QUANTIZE_VERBOSE = "quantize_verbose"
-QUANTIZE_ALGO = "quantize_algo"
-QUANTIZE_TYPE = "q_type"
-QUANTIZE_SYMMETRIC = "symmetric"
-QUANTIZE_ASYMMETRIC = "asymmetric"
-STOCHASTIC_ROUNDING = "stochastic"
-NEAREST_ROUNDING = "nearest"
-QUANTIZE_ROUNDING = "rounding"
-QUANTIZE_TRAINING_ENABLED = "enabled"
-QUANTIZE_TRAINING_ENABLED_DEFAULT = False
-QUANTIZE_TRAINING_DEFAULT = False
-QUANTIZE_START_BITS_DEFAULT = 16
-QUANTIZE_TARGET_BITS_DEFAULT = 8
-QUANTIZER_KERNEL_DEFAULT = False
-QUANTIZE_PERIOD_DEFAULT = 1000
-QUANTIZE_OFFSET_DEFAULT = 1000
-QUANTIZE_GROUPS_DEFAULT = 1
-QUANTIZE_TYPE_DEFAULT = 0 #symmetric
-QUANTIZE_ROUNDING_DEFAULT = 0 #nearest
-FP16_MIXED_QUANTIZE_ENABLED_DEFAULT = False
-QUANTIZE_CHANGE_RATIO_DEFAULT = 0.001
-QUANTIZE_VERBOSE_DEFAULT = False
-
#########################################
# Drop the last incomplete Batch
# #########################################
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 7b6669d961fd..51ed1b8e5884 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -43,6 +43,27 @@
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16
+
+from deepspeed.compression import compression_scheduler
+from deepspeed.compression.constants import \
+ SHARED_PARAMETERS, \
+ WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \
+ WEIGHT_QUANTIZATION, SHARED_PARAMETERS, WEIGHT_QUANTIZE_ENABLED, \
+ WEIGHT_QUANTIZE_ENABLED, \
+ WEIGHT_QUANTIZE_SCHEDULE_OFFSET, \
+ WEIGHT_QUANTIZE_GROUPS, \
+ WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \
+ WEIGHT_QUANTIZE_CHANGE_RATIO, \
+ WEIGHT_QUANTIZE_TYPE, \
+ WEIGHT_QUANTIZE_ROUNDING, \
+ WEIGHT_QUANTIZE_VERBOSE, \
+ WEIGHT_QUANTIZE_KERNEL, \
+ ACTIVATION_QUANTIZATION, \
+ SPARSE_PRUNING, \
+ ROW_PRUNING, \
+ HEAD_PRUNING, \
+ CHANNEL_PRUNING
+
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
@@ -390,7 +411,6 @@ def _get_model_parameters(self):
def get_batch_info(self):
"""Get all training batch related settings.
-
Returns:
train_batch_size (int): The effective training batch size. This is the amount of data
samples that leads to one step of model update.
@@ -429,7 +449,6 @@ def set_train_batch_size(self, train_batch_size):
def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.
-
The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
@@ -592,18 +611,24 @@ def scheduler_params(self):
def quantize_training(self):
return (
- self._config.quantize_training_enabled,
- self._config.quantize_target_bits,
- self._config.quantize_start_bits,
- self._config.quantize_period,
- self._config.quantize_offset,
- self._config.quantize_groups,
- self._config.fp16_mixed_quantize,
- self._config.quantize_change_rate,
- self._config.quantize_type,
- self._config.quantize_rounding,
- self._config.quantize_verbose,
- self._config.use_quantizer_kernel,
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_ENABLED],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_GROUPS],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_CHANGE_RATIO],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_TYPE],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_ROUNDING],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_VERBOSE],
+ self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
+ [WEIGHT_QUANTIZE_KERNEL],
)
def zero_optimization(self):
@@ -1120,6 +1145,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()),
ranks=[0])
+ self.compression_scheduler = self._configure_compression_scheduler()
self.quantizer = self._configure_quantization()
def _configure_basic_optimizer(self, model_parameters):
@@ -1202,13 +1228,13 @@ def _configure_basic_optimizer(self, model_parameters):
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
return optimizer
+ def _configure_compression_scheduler(self):
+ return compression_scheduler(self.module, self._config.compression_config)
+
def _configure_quantization(self):
(
+ quantize_weight_in_forward,
quantize_enabled,
- q_target_bits,
- q_start_bits,
- q_period,
- q_offset,
q_groups,
q_mixed_fp16,
q_change_ratio,
@@ -1217,15 +1243,13 @@ def _configure_quantization(self):
q_verbose,
use_quantizer_kernel,
) = self.quantize_training()
+ if quantize_enabled and not quantize_weight_in_forward:
+ assert self.fp16_enabled(), "MoQ (quantize in optimization step) weight quantization is only supported for FP16"
quantizer = None
- if quantize_enabled:
+ if quantize_enabled and not quantize_weight_in_forward:
from deepspeed.runtime.quantize import Quantizer
quantizer = Quantizer(
- q_target_bits,
- q_start_bits,
- q_period,
- q_offset,
q_groups,
q_mixed_fp16,
q_change_ratio,
@@ -1465,11 +1489,9 @@ def dataloader_drop_last(self):
def was_step_applied(self) -> bool:
"""Returns True if the latest ``step()`` produced in parameter updates.
-
Note that a ``False`` return is not an error condition. Steps are frequently
no-ops, such as between gradient accumulation boundaries or when overflows
occur.
-
Returns:
bool: Whether the latest ``step()`` modified model parameters.
"""
@@ -1569,11 +1591,26 @@ def forward(self, *inputs, **kwargs):
== self.flops_profiler_profile_step()
and self.global_rank == 0)
+ # used to check quantization happens at step 0!
+ if self.global_steps == 0 and hasattr(self, "compression_scheduler"):
+ self.compression_scheduler.step(step_zero_check=True)
+ if self.quantizer:
+ tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
+ ) == 2 else self.optimizer.fp16_groups
+ if self.compression_scheduler.weight_quantization_enabled:
+ self.quantizer.quantize(
+ tensor_to_quantize,
+ (self.optimizer.overflow if self.fp16_enabled() else False),
+ self.eigenvalue_enabled(),
+ None,
+ )
+
if flops_profiler_active:
self.flops_profiler.start_profile(ignore_list=None)
- if self.module.training and self.progressive_layer_drop:
- kwargs.update(self.progressive_layer_drop.get_state())
+ if self.module.training:
+ if self.progressive_layer_drop:
+ kwargs.update(self.progressive_layer_drop.get_state())
if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
@@ -1668,7 +1705,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss
-
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
@@ -1767,9 +1803,7 @@ def set_gradient_accumulation_boundary(self, is_boundary):
value before each forward/backward. The final fordward/backward should have the
boundary state set to True. This style allows client code to only call engine.step() once after all
the gradient accumulation passes are complete. See example below:
-
.. code-block:: python
-
engine.set_gradient_accumulation_boundary(False)
for _ in range(gradient_accumulation_steps - 1):
micro_batch = next(data_loader)
@@ -1780,7 +1814,6 @@ def set_gradient_accumulation_boundary(self, is_boundary):
loss = engine(micro_batch)
engine.backward(loss)
engine.step()
-
Arguments:
is_boundary (bool): are we at a gradient accumulation boundary or not?
"""
@@ -1818,17 +1851,15 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
# Quantize the updated parameter if there is no overflow
if self.quantizer:
- if self.fp16_enabled():
- tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
- ) == 2 else self.optimizer.fp16_groups
- else:
- tensor_to_quantize = self.optimizer.param_groups
- self.quantizer.quantize(
- tensor_to_quantize,
- (self.optimizer.overflow if self.fp16_enabled() else False),
- self.eigenvalue_enabled(),
- block_eigenvalue,
- )
+ tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage(
+ ) == 2 else self.optimizer.fp16_groups
+ if self.compression_scheduler.weight_quantization_enabled:
+ self.quantizer.quantize(
+ tensor_to_quantize,
+ (self.optimizer.overflow if self.fp16_enabled() else False),
+ self.eigenvalue_enabled(),
+ block_eigenvalue,
+ )
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behaviour that we want
if self.bfloat16_enabled():
@@ -1853,6 +1884,7 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
if overflow:
self.skipped_steps += 1
else:
+ self.compression_scheduler.step()
if self.lr_scheduler is not None:
try:
self.lr_scheduler.step(**(lr_kwargs or {}))
@@ -2448,7 +2480,6 @@ def load_checkpoint(self,
load_module_only=False,
custom_load_fn=None):
"""Load training checkpoint
-
Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
@@ -2459,11 +2490,8 @@ def load_checkpoint(self,
custom_load_fn: Optional. Custom model load function.
Returns:
A tuple of ``load_path`` and ``client_state``.
-
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
-
*``client_state``: State dictionary used for loading required training states in the client code.
-
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
@@ -2771,14 +2799,12 @@ def _checkpoint_tag_validation(self, tag):
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
-
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
-
Important: all processes must call this method and not just the process with rank 0. It is
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
@@ -3028,11 +3054,9 @@ def _get_zero_param_shapes(self):
optimizer. the names are exactly as in state_dict. The order is absolutely important, since
the saved data is just flattened data with no identifiers and requires reconstruction in the
same order it was saved.
-
We can't rely on self.module.named_parameters() to get the saved tensors, as some params
will be missing and others unsaved and then it'd be impossible to reconstruct state_dict
from the flattened weights.
-
optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions.
"""
param_group_shapes = []
@@ -3092,21 +3116,15 @@ def _save_zero_checkpoint(self, save_path, tag):
def _zero3_consolidated_16bit_state_dict(self):
"""
-
Get a full non-partitioned state_dict with fp16 weights on cpu.
-
Important: this function must be called on all ranks and not just rank 0.
-
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
-
1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu
-
Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
-
"""
if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
@@ -3168,21 +3186,16 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save 16bit model weights
-
This method saves the 16bit model weights at the desired destination.
-
Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
-
Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_16bit_weights_on_model_save is ``False``.
-
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
-
"""
path = os.path.join(save_dir, save_filename)
diff --git a/deepspeed/runtime/quantize.py b/deepspeed/runtime/quantize.py
index 05fc50201b77..d264ddf2a459 100755
--- a/deepspeed/runtime/quantize.py
+++ b/deepspeed/runtime/quantize.py
@@ -1,20 +1,15 @@
+import pdb
import torch
import math
from deepspeed.utils import log_dist
from deepspeed.utils import logger
from deepspeed.ops.quantizer import ds_quantizer
-# number of 2-dimensional parameters in a layer
-# this is set for transformer-based models
TWO_D_PARAMS = 6
class Quantizer(object):
def __init__(self,
- q_target_bits=8,
- q_start_bits=16,
- q_period=100,
- q_offset=100,
q_groups=1,
q_mixed_fp16=False,
q_change_ratio=0.01,
@@ -25,17 +20,11 @@ def __init__(self,
use_quantizer_kernel=False,
layer_num=0):
- self.q_target_bits = q_target_bits
-
- self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1)
- self.q_period = [q_period] * (layer_num if layer_num != 0 else 1)
- self.q_offset = q_offset
self.q_groups = q_groups
self.q_mixed_fp16 = q_mixed_fp16
self.q_change_ratio = q_change_ratio
self.q_type = q_type
self.qsteps = 0
- self.q_init_period = q_period
self.quantize_real_ratio = 1.000
self.q_verbose = q_verbose
self.q_eigenvalue = q_eigenvalue
@@ -44,6 +33,7 @@ def __init__(self,
self.layer_num = layer_num
def any_precision_switch(self):
+ # Temporary disabled functionality
if self.layer_num == 0:
return True
result = False
@@ -70,54 +60,69 @@ def quantize(self,
for i in range(len(parameter_group)):
for p in parameter_group[i]:
- if len(p.size()) > 1:
+ if len(p.size()) > 1 and hasattr(p, "start_bits") and p.start_bits:
param_id = id(p)
- eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0)
+ if block_eigenvalue is None:
+ eigenvalue, layer_id = None, 0
+ else:
+ eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0)
if eigenvalue is not None:
factor = 1 + math.floor(eigenvalue * 4)
p.data = self.compute_quantization(p.data, layer_id, factor)
else:
- p.data = self.compute_quantization(p.data, layer_id)
+ p.data = self.compute_quantization(p, layer_id)
def step(self):
- self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1))
+ self.qsteps += 1
+
+ def quantize_highbit(self, inputs, num_bits):
+
+ q_range = 2**num_bits
+ input_flat = inputs.reshape(self.q_groups, -1)
+ g_min = input_flat.amin(dim=-1, keepdim=True)
+ g_max = input_flat.amax(dim=-1, keepdim=True)
- def sr_quantize(self, input_flat, input_g, scale):
# Random number generator (Uniform)
- p = torch.cuda.FloatTensor(input_flat.size(),
- device=input_flat.device).uniform_()
- p = torch.split(p, p.size(0) // self.q_groups)
- add_s = torch.zeros_like(input_flat)
- add_s = torch.split(add_s, add_s.size(0) // self.q_groups)
-
- scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
- # Quantize with INT rounding
- input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)]
- # Compute the error
- error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)]
- # Stochastic Rounding
- add_s = [
- a_s.masked_fill_(pg < err_g,
- 1 / s) for (a_s,
- pg,
- err_g,
- s) in zip(add_s,
- p,
- error,
- scale)
- ]
- add_s = [
- a_s * (g > 0).float() - a_s * (g < 0).float() for a_s,
- g in zip(add_s,
- input_flat)
- ]
- input_flat = [((q + a_s) * s).clamp(-(q_range >> 1),
- (q_range >> 1) - 1) / s for q,
- a_s,
- s in zip(input_flat,
- add_s,
- scale)]
- return input_flat
+ if self.q_rounding == 'nearest':
+ p = 0.
+ else:
+ p = input_flat.new(input_flat.shape).uniform_(-0.5, 0.5)
+
+ if self.q_type == 'symmetric':
+ scale = 2 * torch.max(torch.abs(g_min), torch.abs(g_max)) / q_range
+ zero_point = 0.
+ input_flat = (input_flat / scale + p).round().clamp(
+ -(q_range >> 1),
+ (q_range >> 1) - 1) * scale
+ elif self.q_type == 'asymmetric':
+ scale = (g_max - g_min) / q_range
+ zero_point = (g_min / scale).round() * scale
+ input_flat = ((input_flat - zero_point) / scale + p).round().clamp(
+ 0,
+ (q_range - 1)) * scale + zero_point
+ output = input_flat.reshape(inputs.shape).contiguous()
+ return output
+
+ def quantize_tenary(self, inputs):
+ input_flat = inputs.reshape(self.q_groups, -1)
+ n = input_flat.shape[1]
+ m = input_flat.norm(p=1, dim=1).div(n)
+ thres = (0.7 * m).view(-1, 1) #.expand_as(input_flat)
+ pos = (input_flat > thres).type(inputs.type())
+ neg = (input_flat < -thres).type(inputs.type())
+ mask = (input_flat.abs() > thres).type(inputs.type())
+ alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
+ output = alpha * pos - alpha * neg
+ output = output.reshape(inputs.shape).contiguous()
+ return output
+
+ def quantize_binary(self, inputs):
+ input_flat = inputs.reshape(self.q_groups, -1)
+ n = input_flat.shape[1]
+ m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
+ output = input_flat.sign().mul(m)
+ output = output.reshape(inputs.shape).contiguous()
+ return output
def mixed_fp16_quantize(self, input, input_q, index):
if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1):
@@ -131,90 +136,49 @@ def compute_quantization(self, input, index=0, factor=1):
# when reducing 1 bit at each period, we increase the period
# to go slowly toward the target quantization bits
# the period and starting bit can be configured
- if self.q_offset > 0:
- if self.qsteps >= self.q_offset:
- self.q_offset = 0
- self.qsteps = 0
- else:
- return input
- if self.q_start_bits[index] != self.q_target_bits:
- if self.qsteps >= self.q_period[index]:
+ if input.start_bits != input.target_bits:
+ if self.qsteps >= input.q_period:
self.quantize_real_ratio = 1.0
- if self.q_eigenvalue:
- self.q_period[index] <<= 1
- self.q_period[index] *= factor
- self.q_start_bits[index] -= 1
- else:
- for i in range(len(self.q_start_bits)):
- self.q_start_bits[i] -= 1
- self.q_period[i] <<= 1
+ input.q_period <<= 1
+ input.q_period *= factor
+ input.start_bits -= 1
if self.q_verbose:
logger.info(
- f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}'
+ f'Quantization settings: current bit-precision = {input.start_bits}, step = {self.qsteps}, quantization period = {input.q_period}, index = {index}'
)
- assert (self.q_start_bits[index] >= self.q_target_bits), \
+ assert (input.start_bits >= input.target_bits), \
'Quantization bit is lower than target precision bits!'
- # quantize the weights base on the selected bits and the value-range
- if not self.use_quantizer_kernel:
- q_range = 2**self.q_start_bits[index]
- input_flat = input.view(-1)
- input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups)
- if self.q_type == 0: #symmetric
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index])
- else:
- scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g]
- if self.q_rounding == 0: # Nearest value rounding
- input_flat = [(g * s).round().clamp(-(q_range >> 1),
- (q_range >> 1) - 1) / s for g,
- s in zip(input_g,
- scale)]
- else: # Stochastic Rounding
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- sr=True)
- else:
- input_flat = self.sr_quantize(input_flat, input_g)
- else: #asymmetric
- if self.q_rounding == 0:
- if self.use_quantizer_kernel:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- asym=True)
- else:
- scale = [(g.max() - g.min()) / q_range for g in input_g]
- input_flat = [
- ((g - g.min()) / s).round().clamp(0,
- (q_range - 1)) * s + g.min()
- for g,
- s in zip(input_g,
- scale)
- ]
- else:
- input_q = ds_quantizer(input.clone(),
- self.q_groups,
- self.q_start_bits[index],
- asym=True)
-
- if self.use_quantizer_kernel or (self.q_type and self.q_rounding):
- return self.mixed_fp16_quantize(input, input_q, index)
+ if self.use_quantizer_kernel:
+ if input.start_bits <= 2:
+ raise ValueError(
+ 'Quantization bit is too low, please do it without quantization kernel!'
+ )
+ input_q = ds_quantizer(
+ input.data.clone(),
+ self.q_groups,
+ input.start_bits,
+ asym=False if self.q_type == 'symmetric' else True,
+ sr=False if self.q_rounding == 'nearest_neighbor' else True)
else:
- if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits -
- 1):
- input_flat = [(self.quantize_real_ratio * g) +
- ((1 - self.quantize_real_ratio) * g_q) for g,
- g_q in zip(input_g,
- input_flat)]
- input_q = torch.cat(input_flat)
- input_q = input_q.reshape(input.size())
- return input_q
+ if input.start_bits >= 3:
+ input_flat = self.quantize_highbit(input.data, input.start_bits)
+ elif input.start_bits == 2:
+ assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
+ assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
+ input_flat = self.quantize_tenary(input.data)
+ elif input.start_bits == 1:
+ assert self.q_type == 'symmetric', 'Quantization type is not symmetric!'
+ assert self.q_rounding == 'nearest', 'Quantization rounding is not nearest_neighbor!'
+ input_flat = self.quantize_binary(input.data)
+ if self.use_quantizer_kernel:
+ return self.mixed_fp16_quantize(input.data, input_q, index)
+ else:
+ if self.q_mixed_fp16 and input.start_bits >= input.target_bits - 1:
+ input_flat = self.quantize_real_ratio * input.data + \
+ (1 - self.quantize_real_ratio) * input_flat
+ return input_flat
def update_fp16_ratio(self):
if self.q_mixed_fp16:
diff --git a/docs/_config.yml b/docs/_config.yml
index 456b16ff1d16..fff37da306d5 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -48,6 +48,7 @@ collections:
- mixture-of-experts.md
- mixture-of-experts-nlg.md
- mixture-of-experts-inference.md
+ - model-compression.md
- monitor.md
- one-cycle.md
- onebit-adam.md
diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml
index babcb8da2283..4c9dcf22e19e 100755
--- a/docs/_data/navigation.yml
+++ b/docs/_data/navigation.yml
@@ -63,6 +63,8 @@ lnav:
url: /docs/config-json/#sparse-attention
- title: 'Monitoring'
url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv
+ - title: 'Model Compression'
+ url: /docs/config-json/#compression
- title: 'Tutorials'
url: /tutorials/
children:
@@ -98,6 +100,8 @@ lnav:
url: /tutorials/mixture-of-experts-nlg/
- title: 'MoE Inference'
url: /tutorials/mixture-of-experts-inference/
+ - title: 'Model Compression'
+ url: /tutorials/model-compression/
- title: 'Mixture-of-Quantization'
url: /tutorials/MoQ-tutorial/
- title: 'Monitoring'
diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md
index 412858437606..bc0fd47663f4 100755
--- a/docs/_pages/config-json.md
+++ b/docs/_pages/config-json.md
@@ -1045,3 +1045,349 @@ Example of **csv_monitor** configuration:
"job_name": "train_bert"
}
```
+### Compression
+**Note:** **Compression** has seven different components, including layer reduction, weight quantization, activation quantization, sparse pruning, row pruning, head pruning, and channel pruning. We explain them one by one with simple json examples. Read more about how to use the DeepSpeed Compression library in our [tutorial](/tutorials/model-compression/).
+
+#### Layer Reduction
+**Note:** Layer reduction works much better when using knowledage distillation (learn more in our [tutorial](/tutorials/model-compression/)):
+
+```json
+"compression_training": {
+ "layer_reduction": {
+ "enabled": true,
+ "keep_number_layer": 5,
+ "module_name_prefix": "bert.encoder.layer",
+ "teacher_layer": [
+ 2,
+ 4,
+ 6,
+ 8,
+ 10
+ ],
+ "other_module_name": [
+ "bert.pooler",
+ "bert.embeddings",
+ "classifier"
+ ]
+ }
+ }
+```
+
+**layer_reduction**: [dictionary]
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable layer reduction or not. | `false` |
+| **keep_number_layer**: [list] | The number of layer in the model to be kept. | N/A |
+| **module_name_prefix**: [str] | The (uniform) name prefix of the model's modules of which the associated weight parameters are to be reinitialized. | N/A |
+| **teacher_layer**: [list] | The layer of the weight parameters are to be reinitialized. The length of the list equals to 'keep_number_layer'. | N/A |
+| **other_module_name**: [list] | The name of modules of which the associated weight parameters are to be reinitialized. It is an complemenatory or alternative of module_name_prefix. For instance, "other_module_name": ["bert.encoder.layer.2","bert.encoder.layer.4"] equals to "module_name_prefix":"bert.encoder.layer" and "teacher_layer": [2,4]. | N/A |
+
+#### Weight Quantization
+```json
+ "compression_training": {
+ "weight_quantization": {
+ "shared_parameters":{
+ "enabled": true,
+ "quantizer_kernel": false,
+ "schedule_offset": 0,
+ "quantize_groups": 1,
+ "quantize_verbose": false,
+ "quantization_type": "symmetric",
+ "rounding": "nearest",
+ "quantize_weight_in_forward": false,
+ "fp16_mixed_quantize":{
+ "enabled": false,
+ "quantize_change_ratio": 0.001
+ }
+ },
+ "different_groups":{
+ "wq1": {
+ "params": {
+ "start_bits": 8,
+ "target_bits": 8,
+ "quantization_period": 50
+ },
+ "modules": [
+ "attention.self",
+ "intermediate"
+ ]
+ },
+ "wq2": {
+ "params": {
+ "start_bits": 4,
+ "target_bits": 4,
+ "quantization_period": 50
+ },
+ "modules": [
+ "attention.output"
+ ]
+ }
+ }
+ }
+ }
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all weight quantization groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable weight quantization or not. | `false` |
+| **quantizer_kernel**: [boolean] | Use DeepSpeed quantization kernel for >=4 bit quantization. This can only be enabled when using DeepSpeed FP16 optimizer. | `false` |
+| **schedule_offset**: [integer] | Enable weight quantization after scheduled steps (can be treated as warmup steps). | `0` |
+| **quantize_groups**: [integer] | Split the weight matrix into different number of groups, and each of them has its own scaling factor. | `1` |
+| **quantize_verbose**: [boolean] | Print the quantization related logs. | `false` |
+| **quantization_type**: [string] | Choose the quantization algorithm, symmetric or asymmetric. | `"symmetric"` |
+| **rounding**: [string] | Rounding algorithm associated with quantization, nearest or stochastic. | `"nearest"` |
+| **quantize_weight_in_forward**: [boolean] | Quantize weight in optimizer or forward step, must set to be true for FP32 optimizer training. | `false` |
+| **fp16_mixed_quantize**: [dictionary] | Using the value mixed by FP16 value and the quantized value. | N/A |
+| **enabled**: [boolean] | Whether fp16 mixed quantization is enabled. | `false` |
+| **quantize_change_ratio**: [float] | Initial quantize value ratio, will gradually increase to 1. | `0.001` |
+
+**different_groups**: [dictionary]
+
+Different quantization sets, this is used for different quantization parameters. In this example, we give two different sets. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **start_bits**: [integer] | Quantization starting bits, will gradaully reduce to target bits. | `8` |
+| **target_bits**: [integer] | Quantization target bits, need to be <= start_bits. | `8` |
+| **quantization_period**: [integer] | For every n steps, the quantization bits will be reduce by 1. | `1` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All Linear and CONV2D layers"` |
+
+#### Activation Quantization
+```json
+"compression_training": {
+ "activation_quantization": {
+ "shared_parameters":{
+ "enabled": true,
+ "quantization_type": "asymmetric",
+ "range_calibration": "dynamic",
+ "schedule_offset": 50
+ },
+ "different_groups":{
+ "aq1": {
+ "params": {
+ "bits": 8
+ },
+ "modules": [
+ "attention.output"
+ ]
+ }
+ }
+ }
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all activation quantization groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable activation quantization or not. | `false` |
+| **quantization_type**: [string] | Choose the quantization algorithm, symmetric or asymmetric. | `"symmetric"` |
+| **range_calibration**: [string] | Using dynamic (per token or per image) or static (fixed min/max using momentum) for inference. | `"static"` |
+| **schedule_offset**: [integer] | Enable activation quantization after scheduled steps (can be treated as warmup steps). | `0` |
+
+**different_groups**: [dictionary]
+
+Different quantization sets, this is used for different quantization parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **bits**: [integer] | Number of bits used for activation target bits, need to be >= 4. | `8` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All Linear and CONV2D layers"` |
+
+#### Sparse Pruning
+```json
+"compression_training": {
+ "sparse_pruning":{
+ "shared_parameters":{
+ "enabled": true,
+ "schedule_offset": 30,
+ "method": "l1"
+ },
+ "different_groups":{
+ "sp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": [
+ "attention.self"
+ ]
+ }
+ }
+ }
+}
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all sparse pruning groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable sparse pruning or not. | `false` |
+| **schedule_offset**: [integer] | Enable sparse pruning after scheduled steps (can be treated as warmup steps). | `0` |
+| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
+
+**different_groups**: [dictionary]
+
+Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **dense_ratio**: [float] | The percentage of weights to keep after pruning. | `0.5` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All Linear and CONV2D layers"` |
+
+#### Row Pruning
+**Note:** **Row Pruning** is a feature designed for two back-to-back linear layers (e.g., Feed Forward Network in Transformers). As such, we suggested use row pruning for the first linear layer (i.e., the `intermediate.dense` layer for BERT). Reducing the row dimension of this matrix can help reducing the column of the follow-up matrix (i.e., `layer.\\w+.output.dense` layer for BERT). It should also work for other linear layers as well.
+```json
+"compression_training": {
+ "row_pruning":{
+ "shared_parameters":{
+ "enabled": true,
+ "schedule_offset": 20,
+ "method": "topk"
+ },
+ "different_groups":{
+ "rp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": [
+ "intermediate.dense"
+ ],
+ "related_modules":[
+ ["layer.\\w+.output.dense"]
+ ]
+ }
+ }
+ }
+}
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all row pruning groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable row pruning or not. | `false` |
+| **schedule_offset**: [integer] | Enable row pruning after scheduled steps (can be treated as warmup steps). | `0` |
+| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
+
+**different_groups**: [dictionary]
+
+Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **dense_ratio**: [float] | The percentage of weights to keep after pruning. | `0.5` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All Linear and CONV2D layers"` |
+| **related_modules**: [list[list]] | Related module to the row pruned module, which can be performed column pruning. | `None` |
+
+#### Head Pruning
+**Note:** **Head Pruning** is a feature designed for two attention layers (e.g., Multi Head Attention in Transformers). For now, it can only be applied to output matrix of the Transformer (i.e., `attention.output.dense` in BERT). Pruning the output matrix can lead to the pruning of Query/Key/Value matrix as well.
+```json
+"compression_training": {
+ "head_pruning":{
+ "shared_parameters":{
+ "enabled": true,
+ "schedule_offset": 10,
+ "method": "topk",
+ "num_heads": 12
+ },
+ "different_groups":{
+ "rp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": [
+ "attention.output.dense"
+ ],
+ "related_modules":[
+ ["self.query", "self.key", "self.value"]
+ ]
+ }
+ }
+ }
+}
+
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all head pruning groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable head pruning or not. | `false` |
+| **schedule_offset**: [integer] | Enable head pruning after scheduled steps (can be treated as warmup steps). | `0` |
+| **method**: [string] | Choose different pruning methods. For now, we only support topk (dynamic, learnable). | `"topk"` |
+| **num_heads**: [int] | Number of heads (must be provided by user). | N/A |
+
+**different_groups**: [dictionary]
+
+Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **dense_ratio**: [float] | The percentage of weights to keep after pruning. | `0.5` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All Linear and CONV2D layers"` |
+| **related_modules**: [list[list]] | Related module (Usually Q/K/V) to the head pruned module (i.e., the output matrix). For now, this feature only works for BERT. | `None` |
+
+#### Channel Pruning
+**Note:** **Channel Pruning** is a feature designed for two back-to-back CONV2d layers (e.g., residual connection in ResNet). As such, we suggested use channel pruning for the first CONV2d layer. Reducing the number of output channels of this layer can help reducing the number of input channels the follow-up layer. It should also work for other CONV2d layers as well.
+```json
+"compression_training": {
+"channel_pruning":{
+ "shared_parameters":{
+ "enabled": true,
+ "schedule_offset": 0,
+ "method": "topk"
+ },
+ "different_groups":{
+ "cp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": [
+ "layer....conv1"
+ ],
+ "related_modules": [
+ ["layer....conv2", "layer....bn1"]
+ ]
+ }
+ }
+ }
+}
+```
+
+**shared_parameters**: [dictionary]
+
+Shared parameters for all channel pruning groups.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **enabled**: [boolean] | Enable channel pruning or not. | `false` |
+| **schedule_offset**: [integer] | Enable channel pruning after scheduled steps (can be treated as warmup steps). | `0` |
+| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
+
+**different_groups**: [dictionary]
+
+Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+
+| Fields | Value | Default |
+| ----- | ----- | ----- |
+| **params**: [dictionary] | | |
+| **dense_ratio**: [float] | The percentage of weights to keep after pruning. | `0.5` |
+| **modules**: [list] | Scope of weight parameters associated to the params setting. | `"All CONV2D layers"` |
+| **related_modules**: [list[list]] | Related module to the channel pruned module. | `None` |
diff --git a/docs/_tutorials/model-compression.md b/docs/_tutorials/model-compression.md
new file mode 100644
index 000000000000..f06fd2c23e3b
--- /dev/null
+++ b/docs/_tutorials/model-compression.md
@@ -0,0 +1,446 @@
+---
+title: "DeepSpeed Model Compression Library"
+tags: model-compression
+---
+
+**What is DeepSpeed Compression:** DeepSpeed Compression is a library purposely built to make it easy to compress models for researchers and practitioners while delivering faster speed, smaller model size, and significantly reduced compression cost.
+
+**Why use DeepSpeed Compression:** DeepSpeed Compression offers novel state-of-the-art compression techniques to achieve faster model compression with better model quality and lower compression cost. DeepSpeed Compression also takes an end-to-end approach to improve the computation efficiency of compressed models via a highly optimized inference engine. Furthermore, our library has multiple built-in state-of-the-art compression methods. It supports the synergistic composition of these methods and the system optimizations, offering the best of both worlds while allowing a seamless and easy-to-use pipeline for efficient DL model inference. We highly recommend you also to read our blog to learn more about (at a high level) why we build DeepSpeed Compression and what benefits it provides to users.
+
+
+**How to use DeepSpeed Compression:** The first section General Tutorial will describe the compression methods supported by the library. The following sections will describe our research work on how to compose different compression methods to perform [zero-cost quantization (ZeroQuant)](#2-tutorial-for-zeroquant-efficient-and-affordable-post-training-quantization) and [extreme compression (XTC)](#3-tutorial-for-xtc-simple-yet-effective-compression-pipeline-for-extreme-compression). Unless otherwise stated, experiment results listed below are based on NVIDIA A100 GPU, and we observe slightly different result numbers when using different GPU hardwares.
+
+## 1. General Tutorial
+To use DeepSpeed Compression library, you need to install DeepSpeed >= 0.7.0 following the [installation guide](/tutorials/advanced-install/). Currently the DeepSpeed Compression includes seven compression methods: layer reduction via knowledge distillation, weight quantization, activation quantization, sparse pruning, row pruning, head pruning, and channel pruning. In the following subsections, we will describe what these methods are, when to use them, and how to use them via our library.
+
+### 1.1 Layer Reduction
+**What is layer reduction**
+
+Neural networks are constructed from input layer, output layer and hidden layer. For example, the BERT-base language model consists of embedding layer (input layer), classification layer (output layer) and 12 hidden layers. Layer reduction means reducing the number of hidden layers while keeping the width of the network intact (i.e., it does not reduce the dimension of the hidden layer). This method can linearly reduce the inference latency of hidden layers regardless of the hardware and/or scenarios.
+
+**When to use layer reduction**
+
+If the model is very deep, you may consider using this method. It works much better when applying knowledge distillation. Layer reduction can be applied in both the pre-training and fine-tuning stages. The former generates a distilled task-agnostic model, while the latter generates a task-specific distilled model. In our XTC work ([paper](https://arxiv.org/abs/2206.01859), [tutorial](#3-tutorial-for-xtc-simple-yet-effective-compression-pipeline-for-extreme-compression)), we also discuss when to apply layer reduction.
+
+**How to use layer reduction**
+
+Layer reduction can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#layer-reduction)). Users have the freedom to select any depth by `keep_number_layer` and any subset of the network layers by `teacher_layer`. In addition, users also can choose whether to reinitialize the input/output layers from the given model (teacher model) by `other_module_name`.
+
+To apply layer reduction for task-specific compression, we provide an example on how to do so for BERT fine-tuning. Layer reduction is about resetting the depth of network architecture and reinitialization of weight parameters, which happens before the training process. The example includes the following changes to the client code (`model_compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)):
+
+(1) When initial the model, the number of layers in the model config should be the same as `keep_number_layer` in DeepSpeed config JSON file. For Hugging Face BERT example, set `config.num_hidden_layers = ds_config["compression_training"]["layer_reduction"]["keep_number_layer"]`.
+
+(2) Then we need to re-initialize the model based on the DeepSpeed JSON configurations using the function `init_compression` imported from `deepspeed.compression.compress`.
+
+(3) During training, if KD is not used, nothing needs to be done. Otherwise, one needs to consider applying KD with the `teacher_layer` JSON configuration when calculating the difference between teacher’s and student’s output.
+
+One can run our layer reduction example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/layer_reduction.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 18 | Time: 12m 38s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8340295466123281/0.8339096826688365
+```
+
+
+
+
+
+
+
+### 1.2 Weight Quantization
+**What is weight quantization**
+
+Weight quantization maps the full precision weight (FP32/FP16) to the low bit ones, like INT8 and INT4. Quoted from [this Coursera lecture](https://www.coursera.org/lecture/machine-learning-modeling-pipelines-in-production/benefits-and-process-of-quantization-WAjyJ): “Quantization involves transforming a model into an equivalent representation that uses parameters and computations at a lower precision. This improves the model's execution performance and efficiency, but it can often result in lower model accuracy”.
+
+**When to use weight quantization**
+
+From one-side, again quoted from [this Coursera lecture](https://www.coursera.org/lecture/machine-learning-modeling-pipelines-in-production/benefits-and-process-of-quantization-WAjyJ): “Mobile and embedded devices have limited computational resources, so it's important to keep your application resource efficient. Depending on the task, you will need to make a trade-off between model accuracy and model complexity. If your task requires high accuracy, then you may need a large and complex model. For tasks that require less precision, it's better to use a smaller, less complex model.”. On the other hand, recent server accelerators, like GPU, support low-precision arithmetic. Therefore, combining weight quantization with activation quantization (introduced in later section) can offer better efficiency as well.
+
+**How to use weight quantization**
+
+Weight quantization can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#weight-quantization)). The key configurations we would like to point out are:
+
+(1)`quantize_groups`, a group-wise weight matrix quantization: a weight matrix W is partitioned into multiple groups, and each group is quantized separately. See more details in [this paper](https://ojs.aaai.org/index.php/AAAI/article/view/6409).
+
+(2)`quantize_weight_in_forward` must be set to true for FP32 optimizer training and false for FP16.
+
+(3)`wq1`/`wq2`, users can expand more groups such as `wq3`, `wq4`, etc.
+
+(4)`start_bit` and `target_bit`, to simplify the first experiment we suggest to set them the same such that we apply quantization to the target bit once the iteration reaches `schedule_offset`.
+
+There are two changes to the client code (`model_compression/bert/run_glue_no_trainer.py` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)):
+
+(1) After initialization of the model, apply `init_compression` function to the model with DeepSpeed JSON configurations.
+
+(2) After training, apply `redundancy_clean` function to save the quantized weight.
+
+One can run our weight quantization example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/quant_weight.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 09 | Time: 27m 10s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8414671421293938/0.8422497965825875
+```
+
+### 1.3 Activation Quantization
+**What is activation quantization**
+
+Activation means the input to each layer. Activation quantization maps the input from full/half precision to low precision. See more in [this blog](https://medium.com/@joel_34050/quantization-in-deep-learning-478417eab72b).
+
+**When to use activation quantization**
+
+It can improve computation efficiency similar to [weight quantization](#12-weight-quantization).
+
+**How to use activation quantization**
+
+Activation quantization can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#activation-quantization)). Some of the components are same as weight quantization, such as `schedule_offset` and `quantization_type`. The key configurations we would like to point out are:
+
+(1)`range_calibration`, user has option to set dynamic or static. When using “dynamic”, the activation quantization groups will be automatically set to be token-wise (for Transformer-based models) and image-wise (for CNN-based models). See more in [our ZeroQuant paper](https://arxiv.org/abs/2206.01861) and the code (`deepspeed/compression/basic_layer.py` in [DeepSpeed](https://github.com/microsoft/DeepSpeed)).
+
+(2)`aq1`/`aq2`, users can expand more groups such as `aq3`, `aq4`, etc.
+
+The client code change is the same as [weight quantization](#12-weight-quantization).
+
+One can run our activation quantization example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/quant_activation.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 02 | Time: 28m 50s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8375955170657158/0.8422497965825875
+```
+
+### 1.4 Pruning
+**What is pruning**
+
+Pruning aims to reduce the number of parameters and operations involved in generating a prediction by removing network connections. With pruning, you can lower the overall parameter count in the network (see more in [this Coursera lecture](https://www.coursera.org/lecture/machine-learning-modeling-pipelines-in-production/pruning-uNSOG)). We can divide the pruning strategy into two types: structured and unstructured pruning (see more in [this paper](https://arxiv.org/abs/1506.02626)).
+
+
+| **Method** | **Type** |
+| --------------------- | ------------ |
+| [Sparse pruning](#141-sparse-pruning) | Unstructured |
+| [Row pruning](#142-row-pruning) | Structured |
+| [Head pruning](#143-head-pruning) | Structured |
+| [Channel pruning](#144-channel-pruning) | Structured |
+
+#### 1.4.1 Sparse Pruning
+**What is sparse pruning**
+
+Sparse pruning means we set some of the elements in each weight matrix with zero values. There is no structure pattern in the zero values. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626).
+
+**When to use sparse pruning**
+
+If your model is significantly over-parameterized, you may consider using sparse pruning. However, to see the real benefit of hardware computation efficiency, the density ratio (percentage of weights to keep after pruning) must be considerably low.
+
+**How to use sparse pruning**
+
+Sparse pruning can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#sparse-pruning)). The key configurations we would like to point out are:
+
+(1)`schedule_offset`, we empirically find that when using `method: topk`, it’s better to set the `schedule_offset` to a large value such as 10% of the total training steps.
+
+(2)`method`, we support L1 norm and topk methods. Users are welcome to contribute more methods.
+
+(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc.
+
+(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet.
+
+The client code change is the same as [weight quantization](#12-weight-quantization).
+
+One can run our sparse pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_sparse.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 02 | Time: 26m 14s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8416709118695873/0.8447925142392189
+```
+
+#### 1.4.2 Row Pruning
+**What is row pruning**
+
+Row pruning sets all the elements in certain rows of the weight matrix with zero values. If a row is pruned, all elements in that row are set to zero.
+
+**When to use row pruning**
+
+Row pruning can be beneficial to hardware speedup, much better than sparse pruning (but may result in larger accuracy loss compared to sparse pruning). It is a feature designed for two back-to-back linear layers (e.g., Feed Forward Network in Transformers). As such, we suggested using row pruning for the first linear layer (i.e., the `intermediate.dense` layer for BERT). Reducing the row dimension of this matrix can help to reduce the column of the follow-up matrix (i.e., `layer.\\w+.output.dense` layer for BERT). Row pruning would also work for other kinds of linear layers.
+
+**How to use row pruning**
+
+Row pruning can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#row-pruning)). The key configurations we would like to point out are:
+
+(1)`method`, only `topk` method is supported currently. Users are welcome to contribute more methods.
+
+(2)`rp1`, users can expand more groups such as `rp2`, `rp3`, etc.
+
+(3)`related_modules`, as mentioned in “when to use row pruning”, if we do row pruning, the follow-up matrix will be affected. Thus, one needs to know the connection between the modules.
+
+The client code change is the same as [weight quantization](#12-weight-quantization).
+
+One can run our row pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_row.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 02 | Time: 27m 43s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8440142638818136/0.8425549227013832
+```
+
+#### 1.4.3 Head Pruning
+**What is head pruning**
+
+Head pruning is designed specifically for networks with multi-head attention, such as transformer-based models (see more in [this blog](https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853)). For example, the BERT-base (BERT-large) model has 12 heads (24 heads).
+
+**When to use head pruning**
+
+Head pruning is beneficial to hardware speedup. Moreover, as stated in [this blog](https://towardsdatascience.com/head-pruning-in-transformer-models-ec222ca9ece7): “Surprising observations are made in the [paper](https://arxiv.org/abs/1905.09418), that even after training models normally (with all heads), many heads can be removed at a test time and it will not significantly affect the BLEU score, in fact, some cases removing few heads led to improving BLEU scores.”.
+
+NOTE: Head pruning is a feature designed for the attention layers (e.g., Multi Head Attention in Transformers). For now, it can only be applied to output matrix of the Transformer (i.e., `attention.output.dense` in BERT). Pruning the output matrix can lead to the pruning of Query/Key/Value matrix as well.
+
+**How to use head pruning**
+
+Head pruning can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#head-pruning)). The key configurations we would like to point out are:
+
+(1)`num_heads`: users need to provide the correct number of heads for their models.
+
+(2)`modules`: the module `attention.output.dense` is made specific for Hugging Face BERT model. Currently, we only support this case when Query/Key/Values are separated matrices and followed by `attention.output.dense`. We are happy to assist and welcome contributions on variants of attention models.
+
+(3)`related_modules`: as mentioned in “when to use head pruning”, pruning the attention output matrix can lead to pruning QKV matrices as well. Thus, the input here is [“self.query”, “self.key”, “self.value”].
+
+The client code change is the same as [weight quantization](#12-weight-quantization).
+
+One can run our head pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/pruning_head.sh
+```
+
+And the final result is:
+
+```shell
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8397350993377484/0.8377746135069162
+```
+
+#### 1.4.4 Channel Pruning
+**What is channel pruning**
+
+Channel pruning is made specifically for convolutional layers and computer vision. According to wikipedia.org, “The color data of an image is stored in three arrays of values, known as channels.”. For example, an image with three channels passing through ResNet-18 produces 64 channels after the first layer.
+
+**When to use channel pruning**
+
+Channel pruning is a feature designed for two back-to-back CONV2d layers (e.g., residual connection in ResNet). As such, we suggest using channel pruning for the first CONV2d layer. Reducing the number of output channels of this layer can help reduce the number of input channels of the next layer. Channel pruning would also work for other kinds of CONV2d layers.
+
+**How to use channel pruning**
+
+Channel pruning can be enabled and configured using the DeepSpeed config JSON file ([configuration details](/docs/config-json/#channel-pruning)).
+
+One can run our channel pruning example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+pip install torch torchvision
+DeepSpeedExamples/model_compression/cifar$ bash run_compress.sh
+```
+
+And the final result is:
+
+```shell
+after_clean
+epoch 10 testing_correct: 0.7664
+```
+
+Note that the above result is when not using batch-norm (BN) in the “ResNet” model. If you use BN for the model and apply channel pruning, the validation after cleaning the model will be different from the model before cleaning. We suggest users to further finetune the model after applying `redundancy_clean` for such cases.
+
+## 2. Tutorial for ZeroQuant: efficient and affordable post-training quantization
+In this section, we introduce how to apply DS-Compression to perform cost-free INT8 quantization and lightweight INT4/INT8 mixed-precision quantization. For more details, please refer to [our paper](https://arxiv.org/abs/2206.01861).
+
+**What is ZeroQuant**
+
+ZeroQuant is an efficient Post Training Quantization method that includes (1) a fine-grained hardware-friendly quantization scheme for both weight and activations, which can significantly reduce the quantization error; (2) a novel affordable layer-by-layer knowledge distillation algorithm (LKD) even without the access to the original training data; (3) a highly-optimized quantization system backend support to remove the quantization/dequantization overhead. By these techniques, ZeroQuant is able to (1) quantize models to INT8 without any cost and (2) quantize models to INT4/INT8 mixed-precision quantization with minimal resource requirements (e.g., 31s for BERT-base quantization).
+
+**When to use ZeroQuant**
+
+When you want to quantize the transformer-based model to INT8 or INT4/INT8 format, it is always a good idea to try ZeroQuant first, especially when the model is very resource-hungry (GPU and/or time) to do quantization aware training and/or when the original training data is not accessible.
+
+**How to use ZeroQuant**
+
+One can run our BERT example in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/bert$ bash bash_script/ZeroQuant/zero_quant.sh
+```
+
+And the final result is:
+
+```shell
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8427916454406521/0.8453010577705452
+```
+
+One can run our GPT example by:
+
+```shell
+DeepSpeedExamples/model_compression/gpt2$ pip install -r requirements.txt
+DeepSpeedExamples/model_compression/gpt2$ bash bash_script/run_zero_quant.sh
+```
+
+And the final result is:
+
+```shell
+Before converting the module COVN1D to linear and init_compression: 19.371443732303174
+Before cleaning, Epoch at 0 with Perplexity: 19.47031304212775
+After cleaning with Perplexity: 19.47031304212775
+```
+
+NOTE: right now, we only support zero cost quantization. Stay tuned for the code release on layer-by-layer knowledge distillation proposed in the ZeroQuant paper.
+
+## 3. Tutorial for XTC: simple yet effective compression pipeline for extreme compression
+In this section, we introduce how to apply DeepSpeed Compression library to perform the light-weight layer reduction and ultra-low bit precision (binary/ternary) quantization. In particularly, we will guide you on implementing the [XTC methods](https://arxiv.org/abs/2206.01859), namely:
+
+(1) Obtaining a 1-bit or 2-bit BERT-base (12-layer) with 8-bit activation quantization.
+
+(2) Reducing the 12-layer Bert-base to a 5-layer one and then obtaining its 1-bit or 2-bit counterparts.
+
+**What is XTC**
+
+XTC (short for eXTreme Compression) is our new simple yet efficient method that compresses a model to its limit with lightweight layer reduction and robust binarization. XTC reduces the model size by 32x with almost no loss in the average score on the GLUE tasks via simple yet effective binarization technique. By combining extreme quantization and lightweight layer reduction, we can further improve the binarized model, achieving 50x model size reduction while keeping 97% of the accuracy.
+For more details, see how we derive our method in [our paper](https://arxiv.org/abs/2206.01859) where we perform a systematic study on the impacts of various techniques currently used for extreme compression.
+
+**When to use XTC**
+
+If you want to significantly compress your models while retaining competitive performance, XTC could be a desirable choice. It is a simple and hyper-parameter tuning friendly method.
+
+**How to use XTC**
+
+**Installation:** Examples of XTC extreme compression for BERT models are at `model_compression/bert/bash_script/XTC` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples). You will need to install the requirements by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ pip install -r requirements.txt
+```
+
+**Implementation of XTC methods:**
+To accommodate users who do not have a fine-tuned model or task-specific model for compression, with the arg `--model_name_or_path yoshitomo-matsubara/bert-base-uncased-${TASK_NAME}` our python script `run_glue_no_trainer.py` automatically downloads the models from Hugging Face. Users can also use their own models with better accuracy as the teacher and the student model initialization.
+
+### 3.1 One-bit or Two-bit BERT-base (12-layer) with 8-bit activation quantization
+For the configurations, see `model_compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples). In our paper, we used FP32 (`"fp16": {"enabled": false}`) to perform training, while directly applying 8-bit quantization (`"bits": 8`) to the activations and 1-bit quantization (`"start_bits": 1, "target_bits": 1`) to the attention (query, key, val) and feedforward weight matrices (`"modules": ["attention.self", "intermediate", "output.dense"]`) at the beginning of the training (`"schedule_offset": 0`). In addition, we also apply 1-bit quantization to `word_embeddings` as weight quantization.
+
+One can run this example by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/quant_1bit.sh
+```
+
+And the final result is:
+
+```shell
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8293428425878757/0.8396053702196908
+```
+
+The other important feature we would like to mention is the `quantize_groups` inside `weight_quantization`, which is set to be 1 here to match our XTC paper's FP32 training setup. We find that under FP16 training, smaller number of quantization group (e.g., 1 or 2) could lead to unstable training. Thus, we recommend using larger number of groups (e.g., 64) under FP16. `model_compression/bert/config/ds_config_W1A8_Qgroup64_fp16.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the FP16 example configurations, where `"fp16": {"enabled": true}` and `"weight_quantization": {"shared_parameters": {"quantize_weight_in_forward": false}}` are different from FP32 case.
+
+With this config, we quantize the existing fined-tuned models downloaded from Hugging Face. For 2-bit weight quantization, user needs to update the ds_config JSON file. To give a sense of the compression performance of downloaded models compared to our paper, we collect the results (1/2-bit BERT on MNLI and QQP with 18 training epochs) in table below. The difference between this tutorial and paper is because they use different checkpoints. Data augmentation introduces in [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT) will help significantly for smaller tasks (such as mrpc, rte, sst-b and cola). See more details in [our paper](https://arxiv.org/abs/2206.01859).
+
+{: .align-center}
+
+### 3.2 Compressing the 12-layer BERT-base to 1-bit or 2-bit 6/5-layer BERT
+
+This section consists of two parts: (a) we first perform a light-weight layer reduction, and (b) based on the model in (a), we perform 1-bit or 2-bit quantization.
+
+**3.2.1 Light-weight Layer Reduction**
+
+`model_compression/bert/config/XTC/ds_config_layer_reduction_fp16.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the example configuration for reducing the 12-layer BERT-base to a 6-layer one. The student’s layers are initialized from i-layer of the teacher with i= [1, 3 ,5 ,7 ,9 ,11] (note that the layer starts from 0), which is called `Skip-BERT_5` in our XTC paper. In addition, student’s modules including embedding, pooler and classifier are also initialized from teacher. For 5-layer layer reduction, one needs to change the configs in `ds_config_layer_reduction_fp16.json` to `"keep_number_layer": 5`, `"teacher_layer": [2, 4 ,6, 8, 10]`(like in `model_compression/bert/config/ds_config_TEMPLATE.json`).
+
+One can run this example by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/layer_reduction.sh
+```
+
+And the final result is:
+
+```shell
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8377992868059093/0.8365541090317331
+```
+
+Notably, when using one-stage knowledge distillation (`--distill_method one_stage`), the difference between the outputs of teacher and student models (att_loss and rep_loss) also need to be consistent with the initialization. See the function `_kd_function` under `forward_loss` in `model_compression/bert/util.py`.
+
+For mnli/qqp, we set `--num_train_epochs 36`, `--learning_rate 5e-5`, and with the JSON config above. The results are given below (we also include the fp16 training results). Using fp32 clearly results in more stable performance than fp16, although fp16 can speed up the training time.
+
+{: .align-center}
+
+**3.2.2 One-bit or Two-bit quantization for 6-layer (5-layer) BERT**
+
+Given the above layer-reduced models ready, we now continue to compress the model with 1/2-bit quantization. `model_compression/bert/config/XTC/ds_config_layer_reduction_W1Q8_fp32.json` in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) is the example configuration where we set the layer reduction to be true on top of `model_compression/bert/config/XTC/ds_config_W1A8_Qgroup1_fp32.json`. In addition to the configuration, we need to update the path for the student model using `--pretrained_dir_student` in the script `model_compression/bert/bash_script/XTC/layer_reduction_1bit.sh`. User can train with a different teacher model by adding `--pretrained_dir_teacher`.
+
+One can run this example by:
+
+```shell
+DeepSpeedExamples/model_compression/bert$ bash bash_script/XTC/layer_reduction_1bit.sh
+```
+
+And the final result is:
+
+```shell
+Epoch: 18 | Time: 18m 11s
+Clean the best model, and the accuracy of the clean model is acc/mm-acc:0.8140601120733572/0.8199755899104963
+```
+
+With the command above, one can now obtain the results of 1-bit 6-layer model. Now we list more results for 2-/1-bit 6/5-layer models in the following table. Note that the checkpoints we used for the compression below are from the above table in section 3.2.1.
+
+{: .align-center}
+
+{: .align-center}
diff --git a/docs/assets/images/xtc-1.png b/docs/assets/images/xtc-1.png
new file mode 100644
index 000000000000..31fde8f3e623
Binary files /dev/null and b/docs/assets/images/xtc-1.png differ
diff --git a/docs/assets/images/xtc-2.png b/docs/assets/images/xtc-2.png
new file mode 100644
index 000000000000..27f57ca33857
Binary files /dev/null and b/docs/assets/images/xtc-2.png differ
diff --git a/docs/assets/images/xtc-3.png b/docs/assets/images/xtc-3.png
new file mode 100644
index 000000000000..2bb9d8813584
Binary files /dev/null and b/docs/assets/images/xtc-3.png differ
diff --git a/docs/assets/images/xtc-4.png b/docs/assets/images/xtc-4.png
new file mode 100644
index 000000000000..d4946e811a4a
Binary files /dev/null and b/docs/assets/images/xtc-4.png differ
diff --git a/tests/unit/test_compression.py b/tests/unit/test_compression.py
new file mode 100755
index 000000000000..f00aafaca1ba
--- /dev/null
+++ b/tests/unit/test_compression.py
@@ -0,0 +1,262 @@
+from zlib import compressobj
+import torch
+import pytest
+import random
+import numpy as np
+from .megatron_model import get_gpt2_model
+from deepspeed.compression.compress import init_compression
+from .modeling import BertConfig
+from .modelingpreln import BertEncoder as BertEncoderPreln
+from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
+from deepspeed.compression.helper import convert_conv1d_to_linear
+
+TORCH_MAJOR = int(torch.__version__.split('.')[0])
+TORCH_MINOR = int(torch.__version__.split('.')[1])
+pytestmark = pytest.mark.skipif(
+ TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 5),
+ reason='Megatron-LM package requires Pytorch version 1.5 or above')
+
+
+def reset_random(seed=1234):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def create_bert_model():
+ hidden_size = 384
+ num_layers = 2
+ heads = 12
+ dropout_ratio = 0.1
+ bert_config = BertConfig(vocab_size_or_config_json_file=119547,
+ hidden_size=hidden_size,
+ num_hidden_layers=num_layers,
+ num_attention_heads=heads,
+ intermediate_size=hidden_size * 4,
+ hidden_act="gelu",
+ hidden_dropout_prob=dropout_ratio,
+ attention_probs_dropout_prob=dropout_ratio,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.2)
+
+ weights = []
+ biases = []
+
+ for i in range(4):
+ weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size)))
+
+ weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+ weights.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4, hidden_size)))
+ weights.append(torch.nn.Parameter(torch.Tensor(hidden_size, hidden_size * 4)))
+ weights.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+
+ biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+ for i in range(4):
+ biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+ biases.append(torch.nn.Parameter(torch.Tensor(hidden_size * 4)))
+ biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+ biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
+
+ return BertEncoderPreln(bert_config, weights, biases)
+
+
+class Conv1D(torch.nn.Module):
+ """
+ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+ Basically works like a linear layer but the weights are transposed.
+ Args:
+ nf (`int`): The number of output features.
+ nx (`int`): The number of input features.
+ """
+ def __init__(self, nf, nx):
+ super().__init__()
+ self.nf = nf
+ w = torch.empty(nx, nf)
+ self.weight = torch.nn.Parameter(w)
+ self.bias = torch.nn.Parameter(torch.zeros(nf))
+
+ def forward(self, x):
+ size_out = x.size()[:-1] + (self.nf, )
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+ x = x.view(size_out)
+ return x
+
+
+def create_conv1d_model():
+ nf = 128
+ nx = 128
+
+ return torch.nn.ModuleList([Conv1D(nf, nx) for i in range(4)])
+
+
+class TestCompression:
+ def setup_method(self, method):
+ reset_random()
+
+ def get_ds_config(self):
+ ds_config_dict = {
+ "train_micro_batch_size_per_gpu": 1,
+ "optimizer": {
+ "type": "Lamb",
+ "params": {
+ "lr": 0.00015
+ }
+ },
+ "fp16": {
+ "enabled": True
+ },
+ "compression_training": {
+ "weight_quantization": {
+ "shared_parameters": {
+ "enabled": True,
+ "quantizer_kernel": False,
+ "schedule_offset": 50,
+ "quantize_groups": 1,
+ "quantize_verbose": False,
+ "quantization_type": "asymmetric",
+ "rounding": "nearest",
+ "fp16_mixed_quantize": {
+ "enabled": False,
+ "quantize_change_ratio": 0.001
+ }
+ },
+ "different_groups": {
+ "wq1": {
+ "params": {
+ "start_bits": 12,
+ "target_bits": 8,
+ "quantization_period": 50
+ },
+ "modules": ["attention.self",
+ "intermediate"]
+ },
+ "wq2": {
+ "params": {
+ "start_bits": 12,
+ "target_bits": 4,
+ "quantization_period": 50
+ },
+ "modules": ["attention.output"]
+ }
+ }
+ },
+ "activation_quantization": {
+ "shared_parameters": {
+ "enabled": True,
+ "quantization_type": "asymmetric",
+ "range_calibration": "dynamic",
+ "schedule_offset": 50
+ },
+ "different_groups": {
+ "aq1": {
+ "params": {
+ "bits": 8
+ },
+ "modules": ["attention.output"]
+ }
+ }
+ },
+ "sparse_pruning": {
+ "shared_parameters": {
+ "enabled": True,
+ "schedule_offset": 30,
+ "method": "l1"
+ },
+ "different_groups": {
+ "sp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": ["attention.self"]
+ }
+ }
+ },
+ "row_pruning": {
+ "shared_parameters": {
+ "enabled": True,
+ "schedule_offset": 20,
+ "method": "topk"
+ },
+ "different_groups": {
+ "rp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": ["intermediate.dense"],
+ "related_modules": [["layer.\\w+.output.dense"]]
+ }
+ }
+ },
+ "head_pruning": {
+ "shared_parameters": {
+ "enabled": True,
+ "schedule_offset": 10,
+ "method": "topk",
+ "num_heads": 12
+ },
+ "different_groups": {
+ "rp1": {
+ "params": {
+ "dense_ratio": 0.5
+ },
+ "modules": ["attention.output.dense"],
+ "related_modules": [["self.query",
+ "self.key",
+ "self.value"]]
+ }
+ }
+ }
+ }
+ }
+
+ return ds_config_dict
+
+ def test_linear_layer_compress(self, tmpdir):
+ model = create_bert_model()
+ compressed_model = init_compression(model, self.get_ds_config())
+
+ assert isinstance(compressed_model.layer[0].attention.self.query,
+ LinearLayer_Compress)
+ assert isinstance(compressed_model.layer[0].attention.self.key,
+ LinearLayer_Compress)
+ assert isinstance(compressed_model.layer[0].attention.self.value,
+ LinearLayer_Compress)
+
+ def test_mpu_compress(self, tmpdir):
+ from megatron import mpu
+ args_defaults = {
+ 'num_layers': 2,
+ 'hidden_size': 128,
+ 'num_attention_heads': 8,
+ 'max_position_embeddings': 128,
+ }
+
+ model = get_gpt2_model(args_defaults)
+ compressed_model = init_compression(model, self.get_ds_config(), mpu=mpu)
+
+ assert isinstance(
+ compressed_model.module.language_model.transformer.layers[0].attention.
+ query_key_value,
+ ColumnParallelLinear_Compress)
+ assert isinstance(
+ compressed_model.module.language_model.transformer.layers[0].attention.dense,
+ RowParallelLinear_Compress)
+ assert isinstance(
+ compressed_model.module.language_model.transformer.layers[0].mlp.
+ dense_h_to_4h,
+ ColumnParallelLinear_Compress)
+ assert isinstance(
+ compressed_model.module.language_model.transformer.layers[0].mlp.
+ dense_4h_to_h,
+ RowParallelLinear_Compress)
+
+ def test_conv1d_convertion(self, tmpdir):
+ model = create_conv1d_model()
+ compressed_model = convert_conv1d_to_linear(model, Conv1D)
+
+ assert isinstance(compressed_model[0], torch.nn.Linear)
+ assert isinstance(compressed_model[1], torch.nn.Linear)
+ assert isinstance(compressed_model[2], torch.nn.Linear)
+ assert isinstance(compressed_model[3], torch.nn.Linear)