-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathexecution_eval.py
65 lines (57 loc) · 2 KB
/
execution_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
execution-based evaluation
"""
import argparse
import json
import sys
import evaluate
from datasets import load_metric
import numpy as np
from collections import defaultdict
import os
from py import test
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
def pass_at_k(result_file, unittest_file):
with open(unittest_file, 'r') as f:
unittests = json.load(f)
# select the examples which have unit test
selected_predictions = []
with open(result_file, 'r') as f:
for line in f:
pred = json.loads(line)
if pred["question_id"] in unittests:
selected_predictions.append(pred)
print(f"selected {len(selected_predictions)} examples with unit test")
# run the test
# load the metric from huggingface
code_eval_metric = evaluate.load("code_eval")
preds = []
tests = []
for prediction in selected_predictions:
suffix = ""
question_id = prediction["question_id"]
unittest = unittests[question_id]
entry_point = unittest["entry_point"]
test_func = f"\n{unittest['test']}\ncheck({entry_point})"
# wrap the generated code to a runnable function
if isinstance(prediction['clean_code'], list):
runnable_func = [f"{unittest['prompt']}{x}{suffix}" for x in prediction['clean_code']]
else:
runnable_func = [f"{unittest['prompt']}{prediction['clean_code']}{suffix}"]
preds.append(runnable_func)
tests.append(test_func)
r = code_eval_metric.compute(
predictions=preds,
references=tests,
k=[1, 5, 10, 50, 100, 150, 200],
num_workers=8,
)
print(r[0])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--result_file", type=str, default="")
args = parser.parse_args()
result_file = args.result_file
unittest_file = "data/conala/unittest_docprompting_conala.json"
assert result_file
pass_at_k(result_file, unittest_file)