Skip to content

Commit 31ff4d9

Browse files
authored
bayesian optimization tool for mixed precision quantization (#694)
* bayesian optimization tool for mixed precision quantization * refactor * code refactor * integer type parameters * improve multi-process * fix a bug in symmetric quant * refactor BO optimize for model accuracy * add BO for inference speed optimization * rename BO for inference speed * refactor code * add utils * add some TODOs * renamed BO scripts * renamed to BO_acc_throughput * add TODOs
1 parent 86c7b0d commit 31ff4d9

File tree

8 files changed

+950
-5
lines changed

8 files changed

+950
-5
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
import sys
2+
3+
import torch
4+
import torch.nn as nn
5+
from torchao.quantization import quantize_
6+
import random
7+
8+
from naive_intNwo import intN_weight_only
9+
10+
import copy
11+
from lm_eval.evaluator import evaluate
12+
from lm_eval.models.huggingface import HFLM
13+
from lm_eval.tasks import get_task_dict
14+
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
from ax.service.ax_client import AxClient, ObjectiveProperties
17+
import torch.multiprocessing as mp
18+
from ax.modelbridge.cross_validation import cross_validate
19+
from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config
20+
21+
# return evaluation results to complete BO trials
22+
def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config):
23+
return {
24+
"cal_PPL": (cal_wikitext_ppl(model, tokenizer, num_PPL_eval_samples), 0.0),
25+
"model_size": (cal_model_size(model, fqn_to_config), 0.0),
26+
}
27+
28+
# TODO: make it into a yaml or json file to enable users specify their custom model formats
29+
def define_parameter_list():
30+
31+
# define the search space for all layers
32+
parameters_list = []
33+
34+
for i in range(0, 3):
35+
parameters_list.append(
36+
{
37+
"name": f"bitwidth.{i}.",
38+
"type": "fixed",
39+
"value_type": "int",
40+
"value": 5,
41+
"is_ordered": True,
42+
"sort_values": True,
43+
}
44+
)
45+
46+
parameters_list.append(
47+
{
48+
"name": f"groupsize.{i}.",
49+
"type": "fixed",
50+
"value_type": "int",
51+
"value": 32,
52+
"is_ordered": True,
53+
"sort_values": True,
54+
}
55+
)
56+
57+
for i in range(3, 30):
58+
parameters_list.append(
59+
{
60+
"name": f"bitwidth.{i}.",
61+
"type": "choice",
62+
"value_type": "int",
63+
"values": [2,3,4,5,6,8],
64+
"is_ordered": True,
65+
"sort_values": True,
66+
}
67+
)
68+
69+
parameters_list.append(
70+
{
71+
"name": f"groupsize.{i}.",
72+
"type": "choice",
73+
"value_type": "int",
74+
"values": [32, 64, 128, 256],
75+
"is_ordered": True,
76+
"sort_values": True,
77+
}
78+
)
79+
80+
for i in range(30, 32):
81+
parameters_list.append(
82+
{
83+
"name": f"bitwidth.{i}.",
84+
"type": "fixed",
85+
"value_type": "int",
86+
"value": 5,
87+
"is_ordered": True,
88+
"sort_values": True,
89+
}
90+
)
91+
parameters_list.append(
92+
{
93+
"name": f"groupsize.{i}.",
94+
"type": "fixed",
95+
"value_type": "int",
96+
"value": 32,
97+
"is_ordered": True,
98+
"sort_values": True,
99+
}
100+
)
101+
102+
return parameters_list
103+
104+
# add initial search points based on the sensitivity score
105+
# TODO: automate the initial samples by better leverage the sensitivity scores
106+
def get_initial_samples(num_BO_initial_samples=50):
107+
initial_points_set = []
108+
109+
# auto sample the bit choices with random choice probability positive correlated to FIT score
110+
for _ in range(num_BO_initial_samples):
111+
initial_points = {}
112+
for i in range(0, 3):
113+
initial_points["bitwidth." + str(i) + "."] = 5
114+
initial_points["groupsize." + str(i) + "."] = 32
115+
116+
for i in range(3, 18):
117+
if i in [5,6,7,10,11,12,16]:
118+
initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [20, 80])[0]
119+
initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [30, 70])[0]
120+
else:
121+
initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [30, 70])[0]
122+
initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [40, 60])[0]
123+
124+
for i in range(18, 30):
125+
if i in [22,23,24]:
126+
initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [20, 55, 20, 5])[0]
127+
initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [30, 40, 25, 5])[0]
128+
else:
129+
initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [30, 55, 10, 5])[0]
130+
initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [40, 40, 15, 5])[0]
131+
132+
for i in range(30, 32):
133+
initial_points["bitwidth." + str(i) + "."] = 5
134+
initial_points["groupsize." + str(i) + "."] = 32
135+
136+
initial_points_set.append(initial_points)
137+
return initial_points_set
138+
139+
'''
140+
This function will run BO trials sequentially on a single GPU.
141+
Each time the BO gets one new trial, evaluates the trial on the GPU and return the evaluation results to update the BO.
142+
One trial, one BO update.
143+
TODO: refactor the sequential BO and parallel BO into a single function
144+
'''
145+
def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, output_file):
146+
147+
parameters_list = define_parameter_list()
148+
initial_points_set = get_initial_samples(num_BO_initial_samples)
149+
150+
#initialize ax_client
151+
constraint="model_size <= "+str(model_size_constraint)
152+
ax_client = AxClient()
153+
ax_client.create_experiment(
154+
parameters = parameters_list,
155+
name = "test_quantize_BO",
156+
objectives = {"cal_PPL": ObjectiveProperties(minimize=True)},
157+
choose_generation_strategy_kwargs = {
158+
"num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy
159+
},
160+
outcome_constraints = [constraint],
161+
)
162+
163+
history=[]
164+
trial_id = 0
165+
166+
# add initial points into the BO trials
167+
for i in range(num_BO_initial_samples):
168+
169+
ax_client.attach_trial(parameters=initial_points_set[i])
170+
171+
m, tokenizer = load_model(checkpoint, device)
172+
quantize_by_fqn_to_config(m, device, initial_points_set[i])
173+
174+
eval_results = eval(m, tokenizer, num_PPL_eval_samples, initial_points_set[i])
175+
176+
print("------------")
177+
print(trial_id, initial_points_set[i], eval_results)
178+
179+
history.append((eval_results, initial_points_set[i]))
180+
ax_client.complete_trial(
181+
trial_index = trial_id,
182+
raw_data = eval_results,
183+
)
184+
trial_id += 1
185+
del m
186+
torch.cuda.empty_cache()
187+
188+
189+
# run new BO trials
190+
for k_ in range(num_trials):
191+
parameters, trial_idx = ax_client.get_next_trial()
192+
193+
m, tokenizer = load_model(checkpoint, device)
194+
195+
quantize_by_fqn_to_config(m, device, parameters)
196+
197+
eval_results = eval(m, tokenizer, num_PPL_eval_samples, parameters)
198+
199+
print("------------")
200+
print(trial_idx, parameters, eval_results)
201+
history.append((eval_results, parameters))
202+
203+
ax_client.complete_trial(
204+
trial_index=trial_idx,
205+
raw_data=eval_results,
206+
)
207+
208+
del m
209+
torch.cuda.empty_cache()
210+
211+
212+
print("------Finish BO------")
213+
for h in history:
214+
print(h)
215+
write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"])
216+
217+
print("------Best config------")
218+
best_parameters, values = ax_client.get_best_parameters()
219+
print(best_parameters, values)
220+
221+
# Worker function to perform BO trials on a specific GPU
222+
def eval_in_parallel(gpu_id, checkpoint, num_PPL_eval_samples, config, return_dict, proc_id, trial_id):
223+
224+
model, tokenizer = load_model(checkpoint, f'cuda:{gpu_id}')
225+
parameters_list = define_parameter_list()
226+
227+
print(f"Process {proc_id} on GPU {gpu_id} starts!")
228+
229+
quantize_by_fqn_to_config(model=model, device=f'cuda:{gpu_id}', fqn_to_config=dict(config))
230+
231+
eval_results = eval(model, tokenizer, num_PPL_eval_samples, config)
232+
233+
return_dict[proc_id] = (trial_id, config, eval_results)
234+
235+
del model
236+
torch.cuda.empty_cache()
237+
238+
'''
239+
This function will run BO trials in parallel on multiple GPUs.
240+
Each time the BO gets multiple new trials, evaluates the trials on the GPUs and return the evaluation results to update the BO.
241+
Multiple trials, one BO update.
242+
'''
243+
def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, gpu_list):
244+
245+
parameters_list = define_parameter_list()
246+
initial_points_set = get_initial_samples(num_BO_initial_samples)
247+
248+
#initialize ax_client
249+
constraint="model_size <= "+str(model_size_constraint)
250+
ax_client = AxClient()
251+
ax_client.create_experiment(
252+
parameters = parameters_list,
253+
name = "test_quantize_BO",
254+
objectives = {"cal_PPL": ObjectiveProperties(minimize=True)},
255+
choose_generation_strategy_kwargs = {
256+
"num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy
257+
},
258+
outcome_constraints=[constraint],
259+
)
260+
261+
gpu_list = [int(i) for i in gpu_list.split(",")]
262+
263+
history=[]
264+
trial_id = 0
265+
266+
# Set the multiprocessing start method to 'spawn'
267+
mp.set_start_method("spawn", force=True)
268+
269+
# add initial points into the BO trials
270+
for id in range(num_BO_initial_samples//len(gpu_list)):
271+
processes = []
272+
manager = mp.Manager()
273+
return_dict = manager.dict()
274+
275+
# Start the worker processes
276+
for i, gpu_id in enumerate(gpu_list):
277+
ax_client.attach_trial(parameters=dict(initial_points_set[id*len(gpu_list)+i]))
278+
p = mp.Process(target=eval_in_parallel, args=(gpu_id, checkpoint, num_PPL_eval_samples, initial_points_set[id*len(gpu_list)+i], return_dict, i, trial_id))
279+
trial_id += 1
280+
p.start()
281+
processes.append(p)
282+
283+
# Wait for all processes to finish
284+
for p in processes:
285+
p.join()
286+
287+
# Print the results after all processes have finished
288+
print(return_dict)
289+
for i in range(len(gpu_list)):
290+
current_trial_id, config, eval_results = return_dict[i]
291+
history.append((eval_results, config))
292+
ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,)
293+
294+
# run new BO trials
295+
for id in range(num_trials//len(gpu_list)):
296+
processes = []
297+
manager = mp.Manager()
298+
return_dict = manager.dict()
299+
300+
# Start the worker processes
301+
for i, gpu_id in enumerate(gpu_list):
302+
parameters, trial_idx = ax_client.get_next_trial()
303+
parameter_tuple = []
304+
for k, v in parameters.items():
305+
parameter_tuple.append((k, v))
306+
p = mp.Process(target = eval_in_parallel, args = (gpu_id, checkpoint, num_PPL_eval_samples, parameter_tuple, return_dict, i, trial_idx))
307+
p.start()
308+
processes.append(p)
309+
310+
# Wait for all processes to finish
311+
for p in processes:
312+
p.join()
313+
314+
# Print the results after all processes have finished
315+
print(return_dict)
316+
for i in range(len(gpu_list)):
317+
current_trial_id, config, eval_results = return_dict[i]
318+
history.append((eval_results, config))
319+
ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,)
320+
321+
print("------Finish BO------")
322+
for h in history:
323+
print(h)
324+
write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"])
325+
326+
print("------Best config------")
327+
best_parameters, values = ax_client.get_best_parameters()
328+
print(best_parameters, values)
329+
330+
331+
if __name__ == '__main__':
332+
333+
import argparse
334+
parser = argparse.ArgumentParser(description='Bayesian optimization for mixed-precision quantization to optimize accuracy under model size constraint.')
335+
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
336+
parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model')
337+
parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl')
338+
parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores')
339+
parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO')
340+
parser.add_argument('--model_size_constraint', type=float, default=6.0, help='The model size (GB) constraint for BO')
341+
parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3")
342+
parser.add_argument('--output_path', type=str, default="BO_acc_modelsize_output.csv", help="The file path to save the BO search trials")
343+
args = parser.parse_args()
344+
345+
if args.gpu_list != "":
346+
run_sequential_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, output_path=args.output_path)
347+
else:
348+
run_parallel_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, gpu_list=args.gpu_list, output_path=args.output_path)

0 commit comments

Comments
 (0)