|
| 1 | +import sys |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torchao.quantization import quantize_ |
| 6 | +import random |
| 7 | + |
| 8 | +from naive_intNwo import intN_weight_only |
| 9 | + |
| 10 | +import copy |
| 11 | +from lm_eval.evaluator import evaluate |
| 12 | +from lm_eval.models.huggingface import HFLM |
| 13 | +from lm_eval.tasks import get_task_dict |
| 14 | + |
| 15 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 16 | +from ax.service.ax_client import AxClient, ObjectiveProperties |
| 17 | +import torch.multiprocessing as mp |
| 18 | +from ax.modelbridge.cross_validation import cross_validate |
| 19 | +from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config |
| 20 | + |
| 21 | +# return evaluation results to complete BO trials |
| 22 | +def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config): |
| 23 | + return { |
| 24 | + "cal_PPL": (cal_wikitext_ppl(model, tokenizer, num_PPL_eval_samples), 0.0), |
| 25 | + "model_size": (cal_model_size(model, fqn_to_config), 0.0), |
| 26 | + } |
| 27 | + |
| 28 | +# TODO: make it into a yaml or json file to enable users specify their custom model formats |
| 29 | +def define_parameter_list(): |
| 30 | + |
| 31 | + # define the search space for all layers |
| 32 | + parameters_list = [] |
| 33 | + |
| 34 | + for i in range(0, 3): |
| 35 | + parameters_list.append( |
| 36 | + { |
| 37 | + "name": f"bitwidth.{i}.", |
| 38 | + "type": "fixed", |
| 39 | + "value_type": "int", |
| 40 | + "value": 5, |
| 41 | + "is_ordered": True, |
| 42 | + "sort_values": True, |
| 43 | + } |
| 44 | + ) |
| 45 | + |
| 46 | + parameters_list.append( |
| 47 | + { |
| 48 | + "name": f"groupsize.{i}.", |
| 49 | + "type": "fixed", |
| 50 | + "value_type": "int", |
| 51 | + "value": 32, |
| 52 | + "is_ordered": True, |
| 53 | + "sort_values": True, |
| 54 | + } |
| 55 | + ) |
| 56 | + |
| 57 | + for i in range(3, 30): |
| 58 | + parameters_list.append( |
| 59 | + { |
| 60 | + "name": f"bitwidth.{i}.", |
| 61 | + "type": "choice", |
| 62 | + "value_type": "int", |
| 63 | + "values": [2,3,4,5,6,8], |
| 64 | + "is_ordered": True, |
| 65 | + "sort_values": True, |
| 66 | + } |
| 67 | + ) |
| 68 | + |
| 69 | + parameters_list.append( |
| 70 | + { |
| 71 | + "name": f"groupsize.{i}.", |
| 72 | + "type": "choice", |
| 73 | + "value_type": "int", |
| 74 | + "values": [32, 64, 128, 256], |
| 75 | + "is_ordered": True, |
| 76 | + "sort_values": True, |
| 77 | + } |
| 78 | + ) |
| 79 | + |
| 80 | + for i in range(30, 32): |
| 81 | + parameters_list.append( |
| 82 | + { |
| 83 | + "name": f"bitwidth.{i}.", |
| 84 | + "type": "fixed", |
| 85 | + "value_type": "int", |
| 86 | + "value": 5, |
| 87 | + "is_ordered": True, |
| 88 | + "sort_values": True, |
| 89 | + } |
| 90 | + ) |
| 91 | + parameters_list.append( |
| 92 | + { |
| 93 | + "name": f"groupsize.{i}.", |
| 94 | + "type": "fixed", |
| 95 | + "value_type": "int", |
| 96 | + "value": 32, |
| 97 | + "is_ordered": True, |
| 98 | + "sort_values": True, |
| 99 | + } |
| 100 | + ) |
| 101 | + |
| 102 | + return parameters_list |
| 103 | + |
| 104 | +# add initial search points based on the sensitivity score |
| 105 | +# TODO: automate the initial samples by better leverage the sensitivity scores |
| 106 | +def get_initial_samples(num_BO_initial_samples=50): |
| 107 | + initial_points_set = [] |
| 108 | + |
| 109 | + # auto sample the bit choices with random choice probability positive correlated to FIT score |
| 110 | + for _ in range(num_BO_initial_samples): |
| 111 | + initial_points = {} |
| 112 | + for i in range(0, 3): |
| 113 | + initial_points["bitwidth." + str(i) + "."] = 5 |
| 114 | + initial_points["groupsize." + str(i) + "."] = 32 |
| 115 | + |
| 116 | + for i in range(3, 18): |
| 117 | + if i in [5,6,7,10,11,12,16]: |
| 118 | + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [20, 80])[0] |
| 119 | + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [30, 70])[0] |
| 120 | + else: |
| 121 | + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [30, 70])[0] |
| 122 | + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [40, 60])[0] |
| 123 | + |
| 124 | + for i in range(18, 30): |
| 125 | + if i in [22,23,24]: |
| 126 | + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [20, 55, 20, 5])[0] |
| 127 | + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [30, 40, 25, 5])[0] |
| 128 | + else: |
| 129 | + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [30, 55, 10, 5])[0] |
| 130 | + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [40, 40, 15, 5])[0] |
| 131 | + |
| 132 | + for i in range(30, 32): |
| 133 | + initial_points["bitwidth." + str(i) + "."] = 5 |
| 134 | + initial_points["groupsize." + str(i) + "."] = 32 |
| 135 | + |
| 136 | + initial_points_set.append(initial_points) |
| 137 | + return initial_points_set |
| 138 | + |
| 139 | +''' |
| 140 | +This function will run BO trials sequentially on a single GPU. |
| 141 | +Each time the BO gets one new trial, evaluates the trial on the GPU and return the evaluation results to update the BO. |
| 142 | +One trial, one BO update. |
| 143 | +TODO: refactor the sequential BO and parallel BO into a single function |
| 144 | +''' |
| 145 | +def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, output_file): |
| 146 | + |
| 147 | + parameters_list = define_parameter_list() |
| 148 | + initial_points_set = get_initial_samples(num_BO_initial_samples) |
| 149 | + |
| 150 | + #initialize ax_client |
| 151 | + constraint="model_size <= "+str(model_size_constraint) |
| 152 | + ax_client = AxClient() |
| 153 | + ax_client.create_experiment( |
| 154 | + parameters = parameters_list, |
| 155 | + name = "test_quantize_BO", |
| 156 | + objectives = {"cal_PPL": ObjectiveProperties(minimize=True)}, |
| 157 | + choose_generation_strategy_kwargs = { |
| 158 | + "num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy |
| 159 | + }, |
| 160 | + outcome_constraints = [constraint], |
| 161 | + ) |
| 162 | + |
| 163 | + history=[] |
| 164 | + trial_id = 0 |
| 165 | + |
| 166 | + # add initial points into the BO trials |
| 167 | + for i in range(num_BO_initial_samples): |
| 168 | + |
| 169 | + ax_client.attach_trial(parameters=initial_points_set[i]) |
| 170 | + |
| 171 | + m, tokenizer = load_model(checkpoint, device) |
| 172 | + quantize_by_fqn_to_config(m, device, initial_points_set[i]) |
| 173 | + |
| 174 | + eval_results = eval(m, tokenizer, num_PPL_eval_samples, initial_points_set[i]) |
| 175 | + |
| 176 | + print("------------") |
| 177 | + print(trial_id, initial_points_set[i], eval_results) |
| 178 | + |
| 179 | + history.append((eval_results, initial_points_set[i])) |
| 180 | + ax_client.complete_trial( |
| 181 | + trial_index = trial_id, |
| 182 | + raw_data = eval_results, |
| 183 | + ) |
| 184 | + trial_id += 1 |
| 185 | + del m |
| 186 | + torch.cuda.empty_cache() |
| 187 | + |
| 188 | + |
| 189 | + # run new BO trials |
| 190 | + for k_ in range(num_trials): |
| 191 | + parameters, trial_idx = ax_client.get_next_trial() |
| 192 | + |
| 193 | + m, tokenizer = load_model(checkpoint, device) |
| 194 | + |
| 195 | + quantize_by_fqn_to_config(m, device, parameters) |
| 196 | + |
| 197 | + eval_results = eval(m, tokenizer, num_PPL_eval_samples, parameters) |
| 198 | + |
| 199 | + print("------------") |
| 200 | + print(trial_idx, parameters, eval_results) |
| 201 | + history.append((eval_results, parameters)) |
| 202 | + |
| 203 | + ax_client.complete_trial( |
| 204 | + trial_index=trial_idx, |
| 205 | + raw_data=eval_results, |
| 206 | + ) |
| 207 | + |
| 208 | + del m |
| 209 | + torch.cuda.empty_cache() |
| 210 | + |
| 211 | + |
| 212 | + print("------Finish BO------") |
| 213 | + for h in history: |
| 214 | + print(h) |
| 215 | + write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) |
| 216 | + |
| 217 | + print("------Best config------") |
| 218 | + best_parameters, values = ax_client.get_best_parameters() |
| 219 | + print(best_parameters, values) |
| 220 | + |
| 221 | +# Worker function to perform BO trials on a specific GPU |
| 222 | +def eval_in_parallel(gpu_id, checkpoint, num_PPL_eval_samples, config, return_dict, proc_id, trial_id): |
| 223 | + |
| 224 | + model, tokenizer = load_model(checkpoint, f'cuda:{gpu_id}') |
| 225 | + parameters_list = define_parameter_list() |
| 226 | + |
| 227 | + print(f"Process {proc_id} on GPU {gpu_id} starts!") |
| 228 | + |
| 229 | + quantize_by_fqn_to_config(model=model, device=f'cuda:{gpu_id}', fqn_to_config=dict(config)) |
| 230 | + |
| 231 | + eval_results = eval(model, tokenizer, num_PPL_eval_samples, config) |
| 232 | + |
| 233 | + return_dict[proc_id] = (trial_id, config, eval_results) |
| 234 | + |
| 235 | + del model |
| 236 | + torch.cuda.empty_cache() |
| 237 | + |
| 238 | +''' |
| 239 | +This function will run BO trials in parallel on multiple GPUs. |
| 240 | +Each time the BO gets multiple new trials, evaluates the trials on the GPUs and return the evaluation results to update the BO. |
| 241 | +Multiple trials, one BO update. |
| 242 | +''' |
| 243 | +def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, gpu_list): |
| 244 | + |
| 245 | + parameters_list = define_parameter_list() |
| 246 | + initial_points_set = get_initial_samples(num_BO_initial_samples) |
| 247 | + |
| 248 | + #initialize ax_client |
| 249 | + constraint="model_size <= "+str(model_size_constraint) |
| 250 | + ax_client = AxClient() |
| 251 | + ax_client.create_experiment( |
| 252 | + parameters = parameters_list, |
| 253 | + name = "test_quantize_BO", |
| 254 | + objectives = {"cal_PPL": ObjectiveProperties(minimize=True)}, |
| 255 | + choose_generation_strategy_kwargs = { |
| 256 | + "num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy |
| 257 | + }, |
| 258 | + outcome_constraints=[constraint], |
| 259 | + ) |
| 260 | + |
| 261 | + gpu_list = [int(i) for i in gpu_list.split(",")] |
| 262 | + |
| 263 | + history=[] |
| 264 | + trial_id = 0 |
| 265 | + |
| 266 | + # Set the multiprocessing start method to 'spawn' |
| 267 | + mp.set_start_method("spawn", force=True) |
| 268 | + |
| 269 | + # add initial points into the BO trials |
| 270 | + for id in range(num_BO_initial_samples//len(gpu_list)): |
| 271 | + processes = [] |
| 272 | + manager = mp.Manager() |
| 273 | + return_dict = manager.dict() |
| 274 | + |
| 275 | + # Start the worker processes |
| 276 | + for i, gpu_id in enumerate(gpu_list): |
| 277 | + ax_client.attach_trial(parameters=dict(initial_points_set[id*len(gpu_list)+i])) |
| 278 | + p = mp.Process(target=eval_in_parallel, args=(gpu_id, checkpoint, num_PPL_eval_samples, initial_points_set[id*len(gpu_list)+i], return_dict, i, trial_id)) |
| 279 | + trial_id += 1 |
| 280 | + p.start() |
| 281 | + processes.append(p) |
| 282 | + |
| 283 | + # Wait for all processes to finish |
| 284 | + for p in processes: |
| 285 | + p.join() |
| 286 | + |
| 287 | + # Print the results after all processes have finished |
| 288 | + print(return_dict) |
| 289 | + for i in range(len(gpu_list)): |
| 290 | + current_trial_id, config, eval_results = return_dict[i] |
| 291 | + history.append((eval_results, config)) |
| 292 | + ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,) |
| 293 | + |
| 294 | + # run new BO trials |
| 295 | + for id in range(num_trials//len(gpu_list)): |
| 296 | + processes = [] |
| 297 | + manager = mp.Manager() |
| 298 | + return_dict = manager.dict() |
| 299 | + |
| 300 | + # Start the worker processes |
| 301 | + for i, gpu_id in enumerate(gpu_list): |
| 302 | + parameters, trial_idx = ax_client.get_next_trial() |
| 303 | + parameter_tuple = [] |
| 304 | + for k, v in parameters.items(): |
| 305 | + parameter_tuple.append((k, v)) |
| 306 | + p = mp.Process(target = eval_in_parallel, args = (gpu_id, checkpoint, num_PPL_eval_samples, parameter_tuple, return_dict, i, trial_idx)) |
| 307 | + p.start() |
| 308 | + processes.append(p) |
| 309 | + |
| 310 | + # Wait for all processes to finish |
| 311 | + for p in processes: |
| 312 | + p.join() |
| 313 | + |
| 314 | + # Print the results after all processes have finished |
| 315 | + print(return_dict) |
| 316 | + for i in range(len(gpu_list)): |
| 317 | + current_trial_id, config, eval_results = return_dict[i] |
| 318 | + history.append((eval_results, config)) |
| 319 | + ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,) |
| 320 | + |
| 321 | + print("------Finish BO------") |
| 322 | + for h in history: |
| 323 | + print(h) |
| 324 | + write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) |
| 325 | + |
| 326 | + print("------Best config------") |
| 327 | + best_parameters, values = ax_client.get_best_parameters() |
| 328 | + print(best_parameters, values) |
| 329 | + |
| 330 | + |
| 331 | +if __name__ == '__main__': |
| 332 | + |
| 333 | + import argparse |
| 334 | + parser = argparse.ArgumentParser(description='Bayesian optimization for mixed-precision quantization to optimize accuracy under model size constraint.') |
| 335 | + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') |
| 336 | + parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') |
| 337 | + parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl') |
| 338 | + parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores') |
| 339 | + parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO') |
| 340 | + parser.add_argument('--model_size_constraint', type=float, default=6.0, help='The model size (GB) constraint for BO') |
| 341 | + parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3") |
| 342 | + parser.add_argument('--output_path', type=str, default="BO_acc_modelsize_output.csv", help="The file path to save the BO search trials") |
| 343 | + args = parser.parse_args() |
| 344 | + |
| 345 | + if args.gpu_list != "": |
| 346 | + run_sequential_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, output_path=args.output_path) |
| 347 | + else: |
| 348 | + run_parallel_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, gpu_list=args.gpu_list, output_path=args.output_path) |
0 commit comments