From afdc0b9b6f601e056cc0a18eecd0debe5a8b7bab Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 10 Jun 2024 19:26:10 -0700 Subject: [PATCH] Improving hf_eval.py Summary: made it so you can quantize on cpu rather than cuda. Added options to change batch_size and max_length and added -q for quantize Test Plan: python hf_eval.py --limit 8 -q int8wo --batch_size 8 --max_length 20 --compile python hf_eval.py --limit 8 -q int8wo --batch_size 8 --max_length 200 --compile Reviewers: Subscribers: Tasks: Tags: --- scripts/hf_eval.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index ab1a8adb17..5d2d1bb8f2 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -16,10 +16,10 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.fx_graph_cache = True -def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile): +def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile, batch_size, max_length): tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cuda", dtype=precision) + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) if compile: torch.compile(model, mode="max-autotune", fullgraph=True) @@ -29,21 +29,25 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c elif quantization == "int8wo": change_linear_weights_to_int8_woqtensors(model) elif quantization == "int4wo": - change_linear_weights_to_int4_woqtensors(model) + # note cannot quantize this model on cpu and run it on cuda at this time + change_linear_weights_to_int4_woqtensors(model.to(device=device)) elif quantization == "autoquant": - model = autoquant(model) + model = autoquant(model.to(device=device)) with torch.no_grad(): result = evaluate( - HFLM(pretrained=model, tokenizer=tokenizer), + HFLM( + pretrained=model.to(device), + tokenizer=tokenizer, + batch_size=batch_size, + max_length=max_length), get_task_dict(task_list), - limit = limit + limit = limit, ) for task, res in result["results"].items(): print(f"{task}: {res}") - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run HF Model Evaluation') @@ -52,8 +56,10 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') + parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') args = parser.parse_args() - run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile) + run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)