-
Notifications
You must be signed in to change notification settings - Fork 12
/
run_codegen_experiment.py
executable file
·207 lines (164 loc) · 7.54 KB
/
run_codegen_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import json
import argparse
import copy
from collections import defaultdict
from tqdm import tqdm
from utils.helper import PreviousStateRecorder
from utils.typo_fix import typo_fix
from config import CONFIG
from codegen_completion import codegen_check_over_length, codegen_completion
from utils.sql import sql_pred_parse, sv_dict_to_string
from prompting import get_prompt, conversion, table_prompt
from retriever.code.embed_based_retriever import EmbeddingRetriever
from evaluate_metrics import evaluate
# input arguments
parser = argparse.ArgumentParser()
parser.add_argument('--train_fn', type=str, help="training data file (few-shot or full shot)", required=True) # e.g. "./data/mw21_10p_train_v3.json"
parser.add_argument('--retriever_dir', type=str, required=True, help="sentence transformer saved path") # "./retriever/expts/mw21_10p_v3_0304_400_20"
parser.add_argument('--output_dir', type=str, default="./expts/debug", help="directory to save running log and configs")
parser.add_argument('--mwz_ver', type=str, default="2.1", choices=['2.1', '2.4'], help="version of MultiWOZ")
parser.add_argument('--test_fn', type=str, default='',
help="file to evaluate on, empty means use the test set")
args = parser.parse_args()
# create the output folder
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "exp_config.json"), 'w') as f:
json.dump(vars(args), f, indent=4)
NUM_EXAMPLE=5
# set up the completion function
complete_fn = codegen_completion
check_overlen_fn = codegen_check_over_length
# read the selection pool
with open(args.train_fn) as f:
train_set = json.load(f)
# read the ontology and the test set
if args.mwz_ver == '2.1':
ontology_path = CONFIG["ontology_21"]
if args.test_fn == "":
test_set_path = "./data/mw21_100p_test.json"
else:
ontology_path = CONFIG["ontology_24"]
if args.test_fn == "":
test_set_path = "./data/mw24_100p_test.json"
# evaluate on some other file
if args.test_fn:
test_set_path = args.test_fn
with open(ontology_path) as f:
ontology = json.load(f)
with open(test_set_path) as f:
test_set = json.load(f)
# load the retriever
retriever = EmbeddingRetriever(datasets=[train_set],
model_path=args.retriever_dir,
search_index_filename=os.path.join(args.retriever_dir, "train_index.npy"),
sampling_method="pre_assigned")
def run(test_set, turn=-1, use_gold=False):
# turn and use_gold are for analysis purpose
# turn = -1 means evalute all dialogues
# turn = 0 means evaluate single-turn dialogues
# turn = 1 means evalute two-turn dialogues... etc.
# when use_gold = True, the context are gold context (for analysis purpose)
result_dict = defaultdict(list) # use to record the accuracy
selected_set = test_set
# if needed, only evaluate on particular turns (analysis purpose)
if turn >= 0:
if not use_gold:
raise ValueError("can only evaluate particular turn when using gold context")
selected_set = [d for d in test_set if len(d['dialog']['usr']) == turn + 1]
prediction_recorder = PreviousStateRecorder() # state recorder
# start experiment
all_result = []
n_total = 0
n_correct = 0
total_acc = 0
total_f1 = 0
for data_item in tqdm(selected_set):
n_total += 1
completion = ""
if use_gold:
prompt_text = get_prompt(
data_item, examples=retriever.item_to_nearest_examples(data_item, k=NUM_EXAMPLE))
else:
predicted_context = prediction_recorder.state_retrieval(data_item)
modified_item = copy.deepcopy(data_item)
modified_item['last_slot_values'] = predicted_context
examples = retriever.item_to_nearest_examples(
modified_item, k=NUM_EXAMPLE)
prompt_text = get_prompt(
data_item, examples=examples, given_context=predicted_context)
# print the retrieved examples (without the sql table)
print(prompt_text.replace(conversion(table_prompt), ""))
# record the prompt
data_item['prompt'] = prompt_text
# completion
overlen_flag = True
while overlen_flag:
prompt_text = get_prompt(
data_item, examples=examples, given_context=predicted_context)
overlen_flag = check_overlen_fn(prompt_text)
# reduce the number of examples if overlength
if overlen_flag:
print("prompt overlength")
examples = examples[1:]
completion = complete_fn(prompt_text)
completion = conversion(completion, reverse=True)
# aggregate the prediction and the history states
predicted_slot_values = {}
try:
predicted_slot_values = sql_pred_parse(completion) # a dictionary
except:
print("the output is not a valid SQL query")
data_item['not_valid'] = 1
predicted_slot_values = typo_fix(predicted_slot_values, ontology=ontology, version=args.mwz_ver)
context_slot_values = data_item['last_slot_values'] # a dictionary
# merge context and prediction
if use_gold:
all_slot_values = context_slot_values.copy()
else:
all_slot_values = prediction_recorder.state_retrieval(
data_item).copy()
for s, v in predicted_slot_values.items():
if s in all_slot_values and v == "[DELETE]":
del all_slot_values[s]
elif v != "[DELETE]":
all_slot_values[s] = v
# some slots may contain multiple values
all_slot_values = {k: v.split('|')[0] for k, v in all_slot_values.items()}
# record current turn prediction
prediction_recorder.add_state(data_item, all_slot_values)
# record the predictions
data_item['pred'] = all_slot_values
data_item['ontology_path'] = ontology_path
data_item['completion'] = completion
all_result.append(data_item)
# print the result
print(completion)
print(f"this is the {n_total - 1}th example. {data_item['ID']}_turn_{data_item['turn_id']}")
print(f"pred turn change: {sv_dict_to_string(predicted_slot_values, sep='-')}")
print(f"gold turn change: {sv_dict_to_string(data_item['turn_slot_values'], sep='-')}")
print(f"pred states: {sv_dict_to_string(all_slot_values, sep='-')}")
print(f"gold states: {sv_dict_to_string(data_item['slot_values'], sep='-')}")
this_jga, this_acc, this_f1 = evaluate(all_slot_values,data_item['slot_values'])
total_acc += this_acc
total_f1 += this_f1
if this_jga:
n_correct += 1
result_dict[data_item['turn_id']].append(1)
print("\n=====================correct!=======================")
else:
result_dict[data_item['turn_id']].append(0)
print("\n=====================wrong!=======================")
print("\n")
print(f"correct {n_correct}/{n_total} = {n_correct / n_total}")
print(f"Slot Acc {total_acc/n_total}")
print(f"Joint F1 {total_f1/n_total}")
print()
# calculate the accuracy of each turn
for k, v in result_dict.items():
print(f"accuracy of turn {k} is {sum(v)}/{len(v)} = {sum(v) / len(v)}")
return all_result
if __name__ == "__main__":
all_results = run(test_set)
with open(os.path.join(args.output_dir, "running_log.json"), 'w') as f:
json.dump(all_results, f, indent=4)