Skip to content

Commit 89bc2e5

Browse files
authored
task_list to tasks (pytorch#343)
1 parent 67d2830 commit 89bc2e5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

scripts/hf_eval.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
torch._inductor.config.force_fuse_int_mm_with_mul = True
1717
torch._inductor.config.fx_graph_cache = True
1818

19-
def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile, batch_size, max_length):
19+
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
2020

2121
tokenizer = AutoTokenizer.from_pretrained(repo_id)
2222
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
@@ -41,7 +41,7 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
4141
tokenizer=tokenizer,
4242
batch_size=batch_size,
4343
max_length=max_length),
44-
get_task_dict(task_list),
44+
get_task_dict(tasks),
4545
limit = limit,
4646
)
4747
for task, res in result["results"].items():
@@ -52,7 +52,7 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
5252
import argparse
5353
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
5454
parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.')
55-
parser.add_argument('--task_list', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
55+
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
5656
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
5757
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
5858
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
@@ -62,4 +62,4 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c
6262
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
6363

6464
args = parser.parse_args()
65-
run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
65+
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)

0 commit comments

Comments
 (0)