-
Notifications
You must be signed in to change notification settings - Fork 212
/
auto_scale.py
479 lines (427 loc) · 15.2 KB
/
auto_scale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import gc
import torch
import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import GELUActivation
from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"]
@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
if q_group_size > 0:
weight = weight.view(-1, q_group_size)
scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
scale = scale.view(org_shape)
scale = scale.mean(0)
return scale
@torch.no_grad()
def get_act_scale(x):
return x.abs().view(-1, x.shape[-1]).mean(0)
@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(ln.weight.device)
ln.weight.div_(scales)
if hasattr(ln, "bias") and ln.bias is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
# assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device)
# fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc2.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat):
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if w_bit is not None:
def w_quantize_func(p):
return pseudo_quantize_tensor(
p,
n_bit=w_bit,
**q_config,
).detach()
else:
def w_quantize_func(p):
return p
if "use_cache" in module_kwargs:
module_kwargs.pop("use_cache")
# find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
x = x.to(next(block.parameters()).device)
with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
x_max = get_act_scale(x)
best_error = float("inf")
best_ratio = -1
best_scales = None
n_grid = 20
history = []
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func(fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
loss = (
(org_out - out).float().pow(2).mean().item()
) # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
print(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale
return (
get_op_name(module, prev_op),
tuple([get_op_name(module, m) for m in layers]),
scales,
)
scales_list = [] # return the searched scales
if isinstance(module, OPTDecoderLayer):
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn_layer_norm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attn out
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat["self_attn.out_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat["fc1"],
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat["fc2"],
)
)
elif isinstance(module, LlamaDecoderLayer):
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
scales_list.append(
_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat["self_attention.query_key_value"],
module2inspect=module,
kwargs=module_kwargs,
)
)
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat["mlp.dense_h_to_4h"],
module2inspect=module,
kwargs=module_kwargs,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat["mlp.dense_4h_to_h"],
)
)
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(
_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat["attn.Wqkv"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# attn out
scales_list.append(
_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat["attn.out_proj"],
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat["ffn.up_proj"],
module2inspect=module.ffn,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat["ffn.down_proj"],
)
)
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[
module.mlp.dense_h_to_4h,
module.self_attention.query_key_value,
],
inp=input_feat["self_attention.query_key_value"],
module2inspect=module,
kwargs=module_kwargs,
)
)
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(
_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat["self_attention.query_key_value"],
module2inspect=module,
kwargs=module_kwargs,
)
)
scales_list.append(
_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat["mlp.dense_h_to_4h"],
module2inspect=module,
kwargs=module_kwargs,
)
)
else:
raise NotImplementedError(
"Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat["mlp.dense_4h_to_h"],
)
)
elif "bigcode" in str(module.__class__).lower():
scales_list.append(
_auto_get_scale(
prev_op=module.ln_1,
layers=[module.attn.c_attn],
inp=input_feat["attn.c_attn"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.ln_2,
layers=[module.mlp.c_fc],
inp=input_feat["mlp.c_fc"],
module2inspect=module.mlp,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.c_proj],
inp=input_feat["mlp.c_proj"],
)
)
elif "neox" in str(module.__class__).lower():
scales_list.append(
_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.attention.query_key_value],
inp=input_feat["attention.query_key_value"],
module2inspect=module.attention,
kwargs=module_kwargs,
)
)
# fc1
scales_list.append(
_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat["mlp.dense_h_to_4h"],
module2inspect=module.mlp,
)
)
# fc2
scales_list.append(
_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat["mlp.dense_4h_to_h"],
)
)
else:
raise NotImplementedError(f"{type(module)} not supported yet!")
return scales_list
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()