Skip to content

Commit b52d85f

Browse files
authored
Improving hf_eval.py (pytorch#342)
Summary: made it so you can quantize on cpu rather than cuda. Added options to change batch_size and max_length and added -q for quantize Test Plan: python hf_eval.py --limit 8 -q int8wo --batch_size 8 --max_length 20 --compile python hf_eval.py --limit 8 -q int8wo --batch_size 8 --max_length 200 --compile Reviewers: Subscribers: Tasks: Tags:
1 parent c32d1cf commit b52d85f

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

scripts/hf_eval.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
torch._inductor.config.force_fuse_int_mm_with_mul = True
1717
torch._inductor.config.fx_graph_cache = True
1818

19-
def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile):
19+
def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile, batch_size, max_length):
2020

2121
tokenizer = AutoTokenizer.from_pretrained(repo_id)
22-
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cuda", dtype=precision)
22+
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
2323

2424
if compile:
2525
model = torch.compile(model, mode="max-autotune", fullgraph=True)
@@ -29,21 +29,25 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
2929
elif quantization == "int8wo":
3030
change_linear_weights_to_int8_woqtensors(model)
3131
elif quantization == "int4wo":
32-
change_linear_weights_to_int4_woqtensors(model)
32+
# note cannot quantize this model on cpu and run it on cuda at this time
33+
change_linear_weights_to_int4_woqtensors(model.to(device=device))
3334
elif quantization == "autoquant":
34-
model = autoquant(model)
35+
model = autoquant(model.to(device=device))
3536

3637
with torch.no_grad():
3738
result = evaluate(
38-
HFLM(pretrained=model, tokenizer=tokenizer),
39+
HFLM(
40+
pretrained=model.to(device),
41+
tokenizer=tokenizer,
42+
batch_size=batch_size,
43+
max_length=max_length),
3944
get_task_dict(task_list),
40-
limit = limit
45+
limit = limit,
4146
)
4247
for task, res in result["results"].items():
4348
print(f"{task}: {res}")
4449

4550

46-
4751
if __name__ == '__main__':
4852
import argparse
4953
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
@@ -52,8 +56,10 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
5256
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
5357
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
5458
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
55-
parser.add_argument('--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
59+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
5660
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
61+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
62+
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
5763

5864
args = parser.parse_args()
59-
run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile)
65+
run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)

0 commit comments

Comments
 (0)