-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathquantize.py
569 lines (474 loc) · 20.6 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
import fnmatch
from typing import Union
import torch
from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter
from .layers import (FP8Linear, FP8RowLinear, Fp8RowwiseAttention,
Fp8RowwiseGatedMLP, Fp8RowwiseMLP, Fp8RowwiseRmsNorm,
Int8SmoothQuantLinear, Int8SmoothQuantRowLinear,
QServeAttention, QServeGatedMLP, QServeMLP, QServeRmsNorm,
SmoothQuantAttention, SmoothQuantGatedMLP,
SmoothQuantLayerNorm, SmoothQuantMLP, SmoothQuantRmsNorm,
WeightOnlyGroupwiseQuantColumnLinear,
WeightOnlyGroupwiseQuantRowLinear,
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding,
WeightOnlyQuantRowLinear)
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode
def quantize_layers(
model,
quant_config: QuantConfig,
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules or [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
# MOE module will be quantize when initialization.
# We need to re-initialize a FP version of MOE module.
if isinstance(module, MixtureOfExperts):
init_params = get_init_params(module, MixtureOfExperts)
init_params["quant_mode"] = QuantMode(0)
original_layer = MixtureOfExperts(**init_params)
if parent is not None:
setattr(parent, module_name, original_layer)
else:
model = original_layer
break
if not is_excluded:
quant_cls = None
for cls in quant_map:
if isinstance(module, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(module, quant_cls)
if "bias" in init_params:
init_params["bias"] = init_params["bias"] is not None
if isinstance(module, ColumnLinear):
init_params[
"out_features"] = module.out_features * module.tp_size
elif isinstance(module, RowLinear):
init_params["in_features"] = module.in_features * module.tp_size
if preprocess_init_params is not None:
preprocess_init_params(init_params, name, module)
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyQuantColumnLinear,
RowLinear: WeightOnlyQuantRowLinear,
Embedding: WeightOnlyQuantEmbedding,
}
def preprocess_init_params(init_params, name, module):
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(module, ColumnLinear):
module_name = name.rsplit('.', 1)[-1]
init_params["transb"] = module_name == "lm_head"
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def weight_only_groupwise_quantize(model,
quant_config: QuantConfig,
model_config=None):
assert quant_config.quant_mode.is_weight_only()
try:
model_cfg = model.config
except AttributeError:
model_cfg = model_config
quant_map = {
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
RowLinear: WeightOnlyGroupwiseQuantRowLinear,
}
def preprocess_init_params(init_params, name, module):
init_params["group_size"] = quant_config.group_size
init_params["pre_quant_scale"] = quant_config.pre_quant_scale
init_params["zero"] = quant_config.has_zero_point
init_params[
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
init_params[
"use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
init_params["tp_rank"] = model_cfg.mapping.tp_rank
model = quantize_layers(
model,
quant_config,
quant_map,
preprocess_init_params,
)
return model
def smooth_quantize_ootb(
model,
quant_config: QuantConfig,
):
quant_map = {
ColumnLinear: Int8SmoothQuantLinear,
RowLinear: Int8SmoothQuantRowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def smooth_quantize_plugin(model, quant_mode):
quant_map = {
RmsNorm: SmoothQuantRmsNorm,
LayerNorm: SmoothQuantLayerNorm,
GatedMLP: SmoothQuantGatedMLP,
MLP: SmoothQuantMLP,
Attention: SmoothQuantAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
def smooth_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_act_and_weight_quant()
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
return smooth_quantize_plugin(model, quant_config.quant_mode)
else:
return smooth_quantize_ootb(model, quant_config)
def fp8_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_qdq()
quant_map = {
ColumnLinear: FP8Linear,
RowLinear: FP8RowLinear,
}
model = quantize_layers(
model,
quant_config,
quant_map,
)
return model
def fp8_rowwise_quantize(model, quant_config: QuantConfig):
assert quant_config.quant_mode.has_fp8_rowwise()
quant_cls_map = {
RmsNorm: Fp8RowwiseRmsNorm,
GatedMLP: Fp8RowwiseGatedMLP,
MLP: Fp8RowwiseMLP,
Attention: Fp8RowwiseAttention,
}
if quant_config.exclude_modules is None:
exclude_modules = ['*ln_f', '*ln_embed']
else:
exclude_modules = quant_config.exclude_modules
def extract_layer_idx(name):
ss = name.split('.')
for s in ss:
if s.isdigit():
return int(s)
return None
# Meta's LLaMA 3.1 recipe:
# (1) Skip quantization for the first and last Transformer layers
# (2) Skip quantization for the Attention layers
if quant_config.use_meta_recipe:
exclude_modules.extend(['*input_layernorm', '*attention'])
for name, layer, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
if quant_config.use_meta_recipe:
local_layer_idx = extract_layer_idx(name)
mapping = model.config.mapping
layers_range = mapping.pp_layers(model.config.num_hidden_layers)
if mapping.is_first_pp_rank() and local_layer_idx == 0:
continue
if mapping.is_last_pp_rank(
) and local_layer_idx == len(layers_range) - 1:
continue
quant_cls = None
for cls in quant_cls_map:
if isinstance(layer, cls):
quant_cls = quant_cls_map[cls]
break
if quant_cls is None:
continue
is_excluded = False
for exclude_module in exclude_modules:
if fnmatch.fnmatchcase(name, exclude_module):
is_excluded = True
break
if is_excluded:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_config.quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
if parent is not None:
setattr(parent, module_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_config.quant_mode)
return model
# TODO: These functions should be moved to ModelOpt.
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor,
group_size: int) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int8
linear_weight = linear_weight.div(
s1_scales.reshape(out_features, 1).to(linear_weight.device))
linear_weight = linear_weight.round()
# assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
assert (linear_weight.min() >= -128 and linear_weight.max()
<= 127), "Stage 1: Quantized weight out of range"
# Step 2: Quantize the weights to int4
linear_weight = linear_weight.reshape(out_features,
in_features // group_size, group_size)
s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
s2_scales = s2_scales.reshape(out_features, in_features // group_size,
1).to(torch.float16).to(linear_weight.device)
linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
assert (linear_weight.min() >= 0 and linear_weight.max()
<= 15), "Stage 2: Quantized weight out of range"
qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
return qweight
def qserve_quantize_weight_per_channel(
linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor) -> torch.CharTensor:
out_features = linear_weight.shape[0]
in_features = linear_weight.shape[1]
# Step 1: Quantize the weights to int4
s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)
qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
assert (qweight.min() >= 0
and qweight.max() <= 15), "Quantized weight out of range"
return qweight.reshape(out_features, in_features).to(torch.int8)
# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s2_scales: torch.FloatTensor,
s2_szeros: torch.FloatTensor, group_size):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
s1_scales = s1_scales.reshape(out_features).to(torch.float16)
s2_szeros = s2_szeros.reshape(out_features,
in_features // group_size).to(torch.int8)
s2_scales = s2_scales.reshape(out_features,
in_features // group_size).to(torch.int8)
# Step 3: Pack the quantized weights to real quantized weights
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous().to(torch.int8))
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features,
in_features // 2)
outputs.append(W_unpack_repacked)
# for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
# ---- Pack the scales ---- #
outputs.append(s1_scales.reshape(out_features))
s2_scales = (s2_scales.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
32)
s2_scales = (s2_scales.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_scales = s2_scales.reshape(in_features // group_size,
out_features).contiguous()
outputs.append(s2_scales)
# ---- Pack the zeros ---- #
s2_szeros = (s2_szeros.reshape(out_features, in_features //
group_size).transpose(0, 1).contiguous())
s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
32)
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features // 32, 4,
8).transpose(-2, -1).contiguous())
s2_szeros = (s2_szeros.reshape(in_features // group_size,
out_features).contiguous())
# (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
# We convert the s2_zeros -> -s2_zeros * s2_scales
s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor
s2_szeros = s2_szeros.to(torch.int8)
outputs.append(s2_szeros)
return outputs
def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
s1_scales: torch.FloatTensor,
s1_szeros: torch.FloatTensor):
out_features = qweight.shape[0]
in_features = qweight.shape[1]
outputs = []
# ---- Repack the weight ---- #
assert qweight.dtype == torch.int8
# pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
W_unpack_reorder = (qweight.reshape(
out_features // 32,
2,
2,
8,
in_features // 32,
2,
4,
4,
).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
4).contiguous())
# B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
# [16, 0, 17, 1, ...]
W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
0]
W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
in_features // 32, 32, 16)
W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
2).contiguous()
outputs.append(W_unpack_repacked)
# ---- Pack the scales and zeros ---- #
s1_scales = s1_scales.reshape(out_features).contiguous()
outputs.append(s1_scales.half())
s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
outputs.append(s1_szeros)
return outputs
# TODO: Duplicates smooth_quantize and quantize_layers
def qserve_quantize(model, quant_config: QuantConfig):
quant_mode = quant_config.quant_mode
assert quant_config.quant_mode.is_qserve_w4a8()
quant_map = {
RmsNorm: QServeRmsNorm,
LayerNorm: QServeRmsNorm,
GatedMLP: QServeGatedMLP,
MLP: QServeMLP,
Attention: QServeAttention,
}
for name, layer, parent in model.named_modules_with_parent():
layer_name = name.rsplit('.', 1)[-1]
if layer_name in ['ln_f', 'ln_embed']:
continue
quant_cls = None
for cls in quant_map:
if isinstance(layer, cls):
quant_cls = quant_map[cls]
break
if quant_cls is None:
continue
init_params = get_init_params(layer, quant_cls)
init_params["quant_mode"] = quant_mode
if isinstance(layer, Attention):
init_params[
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size
quant_layer = quant_cls(**init_params)
if parent is not None:
setattr(parent, layer_name, quant_layer)
else:
model = quant_layer
setattr(model, 'quant_mode', quant_mode)
return model
# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
for name, module in model.named_modules():
if isinstance(module,
(Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
module.kv_cache_scaling_factor = Parameter(shape=(1, ),
dtype='float32')
return model
def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):
quant_mode = quant_config.layer_quant_mode
for name, module, parent in model.named_modules_with_parent():
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
if name in quant_mode.keys():
layer_quant_mode = quant_mode[name]
else:
continue
else:
layer_quant_mode = quant_mode
if layer_quant_mode == QuantMode(0):
continue
layer_quant_cfg = quant_config.get_quant_cfg(name)
if layer_quant_mode.has_fp8_qdq():
module = fp8_quantize(module, layer_quant_cfg)
elif layer_quant_mode.has_fp8_rowwise():
module = fp8_rowwise_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_qserve_w4a8():
model = qserve_quantize(model, quant_config)
elif layer_quant_mode.has_act_and_weight_quant():
module = smooth_quantize(module, layer_quant_cfg)
elif layer_quant_mode.is_weight_only():
if layer_quant_mode.has_per_group_scaling():
module = weight_only_groupwise_quantize(module, layer_quant_cfg,
model.config)
else:
module = weight_only_quantize(module, layer_quant_cfg,
model.config)
if parent is not None: # for per layer
module_name = name.rsplit('.', 1)[-1]
setattr(parent, module_name, module)
else: # for all layer
model = module
break
if quant_config.quant_mode.has_kv_cache_quant():
model = kv_cache_quantize(model)
setattr(model, 'quant_mode', quant_config.quant_mode)
return model