Skip to content

Commit 0f376d1

Browse files
committed
add candidate selection script
1 parent 3508340 commit 0f376d1

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed

evaluation/candidate_selection.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import numpy as np
2+
import pandas as pd
3+
import argparse
4+
import time
5+
import os
6+
import nltk
7+
from nltk.tokenize import word_tokenize
8+
import editdistance
9+
import rouge
10+
from bleu import *
11+
from nltk.translate.bleu_score import corpus_bleu
12+
13+
def slice_11(my_list, n):
14+
composite_list = [my_list[x:x+n] for x in range(0, len(my_list),n)]
15+
return composite_list
16+
17+
18+
def bleu_scorer(ref, hyp, script='default'):
19+
refsend = []
20+
for i in range(len(ref)):
21+
refsi = []
22+
for j in range(len(ref[i])):
23+
refsi.append(ref[i][j].split())
24+
refsend.append(refsi)
25+
26+
gensend = []
27+
for i in range(len(hyp)):
28+
gensend.append(hyp[i].split())
29+
30+
if script == 'nltk':
31+
metrics = corpus_bleu(refsend, gensend)
32+
return [metrics]
33+
34+
metrics = compute_bleu(refsend, gensend)
35+
return metrics
36+
37+
rouge_eval = rouge.Rouge(metrics=['rouge-1', 'rouge-2', 'rouge-l'])
38+
39+
def select_posed_bleu(src, df_sub):
40+
poseds = []
41+
for idx in list(df_sub.index):
42+
syn = df_sub.loc[idx, 'syn_paraphrase']
43+
temp = df_sub.loc[idx, 'template']
44+
syn_tags = list(zip(*nltk.pos_tag(word_tokenize(syn))))[1]
45+
temp_tags = list(zip(*nltk.pos_tag(word_tokenize(temp))))[1]
46+
posed = editdistance.eval(syn_tags, temp_tags)
47+
poseds.append(posed)
48+
49+
min_posed = min(poseds)
50+
posed_idx = [i for i in range(len(poseds)) if poseds[i] == min_posed]
51+
max_bleu = -1
52+
final_idx = None
53+
id_start = list(df_sub.index)[0]
54+
for idx in posed_idx:
55+
syn = df_sub.loc[id_start + idx, 'syn_paraphrase']
56+
bleu = bleu_scorer([[src]], [syn])[0]
57+
if bleu > max_bleu:
58+
max_bleu = bleu
59+
final_idx = id_start + idx
60+
61+
return final_idx
62+
63+
def select_rouge(src, df_sub):
64+
max_rouge = -1
65+
max_idx = None
66+
for idx in list(df_sub.index):
67+
syn = df_sub.loc[idx, 'syn_paraphrase']
68+
rouge = rouge_eval.get_scores([syn], [src])[0]['rouge-1']['f']
69+
if rouge > max_rouge:
70+
max_rouge = rouge
71+
max_idx = idx
72+
return max_idx
73+
74+
def ranker_select_rouge(src, df_sub):
75+
max_rouge = -1
76+
max_idx = None
77+
for idx in list(df_sub.index):
78+
syn = df_sub.loc[idx, 'syn_paraphrase']
79+
rouge1 = rouge_eval.get_scores([syn], [src])[0]['rouge-1']['f']
80+
rouge2 = rouge_eval.get_scores([syn], [src])[0]['rouge-2']['f']
81+
rougel = rouge_eval.get_scores([syn], [src])[0]['rouge-l']['f']
82+
rouge_general = 0.2 * rouge1 + 0.3 * rouge2 + 0.5 * rougel
83+
if rouge_general > max_rouge:
84+
max_rouge = rouge_general
85+
max_idx = idx
86+
return max_idx
87+
88+
def select_bleu(src, df_sub):
89+
max_bleu = -1
90+
max_idx = None
91+
for idx in list(df_sub.index):
92+
syn = df_sub.loc[idx, 'syn_paraphrase']
93+
bleu = bleu_scorer([[src]], [syn])[1][0]
94+
if bleu > max_bleu:
95+
max_bleu = bleu
96+
max_idx = idx
97+
return max_idx
98+
99+
def select_maxht(df_sub):
100+
max_ht = -1
101+
max_idx = None
102+
for idx in list(df_sub.index):
103+
ht = int(df_sub.loc[idx, 'height'])
104+
if ht > max_ht:
105+
max_ht = ht
106+
max_idx = idx
107+
108+
return max_idx
109+
110+
if __name__ == "__main__":
111+
112+
parser = argparse.ArgumentParser('Convert trees file to sentence file')
113+
parser.add_argument('-mode', default = 'test', help = '')
114+
parser.add_argument('-gen_dir', help = ' ', default="./")
115+
parser.add_argument('-output_file', help="the name of the output_file")
116+
# parser.add_argument('-clean_gen_file', required = True, help = 'name of the file')
117+
# parser.add_argument('-res_file', required = True, help = 'name of the file')
118+
parser.add_argument('-crt', choices = ['posed','rouge', 'bleu', 'maxht', 'rouge-general'],
119+
default ='bleu',
120+
help = "Criteria to select best generation")
121+
parser.add_argument('-sample', type=int, default=10)
122+
parser.add_argument('-scbart_generate', help="the file scbart generated", default="output/template-based-diverse-wr.txt")
123+
parser.add_argument('-target', help="the target file", default="eval_data/template-based3-set1.target")
124+
args = parser.parse_args()
125+
126+
generate_lines = open(args.scbart_generate, "r").readlines()
127+
target_lines = open(args.target,"r").readlines()
128+
target_lines = [line.split("<sep>")[0].strip() for line in target_lines]
129+
assert len(generate_lines) == len(target_lines)
130+
df_ls = []
131+
for i in range(0, len(generate_lines)):
132+
generate = generate_lines[i]
133+
target = target_lines[i]
134+
df_ls.append({
135+
"source": target,
136+
"syn_paraphrase": generate
137+
})
138+
df = pd.DataFrame(df_ls)
139+
140+
# df = pd.read_csv(os.path.join(args.gen_dir, args.clean_gen_file))
141+
srcs_unq = []
142+
idss = []
143+
ids = []
144+
prev_src = None
145+
prev_temp = None
146+
it = 0
147+
148+
srcs_unq = [ls[0].strip("\n") for ls in slice_11(df["source"].values, args.sample)]
149+
idss = slice_11(df["source"].index, args.sample)
150+
151+
assert len(idss) == len(srcs_unq)
152+
elites = []
153+
for src, ids in zip(srcs_unq, idss):
154+
df_sub = df.loc[ids]
155+
156+
if args.crt == 'posed':
157+
final_idx = select_posed_bleu(src, df_sub)
158+
elif args.crt == 'bleu':
159+
final_idx = select_bleu(src, df_sub)
160+
elif args.crt == 'maxht':
161+
final_idx = select_maxht(df_sub)
162+
elif args.crt == 'rouge-general':
163+
final_idx = ranker_select_rouge(src, df_sub)
164+
else:
165+
final_idx = select_rouge(src, df_sub)
166+
elites.append(final_idx)
167+
168+
df_elite = df[df.index.isin(elites)]
169+
170+
assert len(df_elite) == len(srcs_unq)
171+
try:
172+
references = df_elite['reference'].values
173+
except:
174+
references = []
175+
syn_paras = df_elite['syn_paraphrase'].values
176+
sources = df_elite['source'].values
177+
178+
# para_f, source_f = open(os.path.join(args.gen_dir, 'para.txt'), "w+"), \
179+
# open(os.path.join(args.gen_dir, 'source.txt'), "w+")
180+
# para_f = open(os.path.join(args.gen_dir, 'QQPPos-para.txt'), "w+")
181+
para_f = open(os.path.join(args.gen_dir, args.output_file), "w+")
182+
for i, row in df_elite.iterrows():
183+
syn_para, source = row["syn_paraphrase"].strip("\n").strip(), row["source"].strip("\n").strip()
184+
para_f.write(syn_para + "\n")
185+
# source_f.write(source + "\n")
186+
187+
188+
189+
190+
191+
192+

0 commit comments

Comments
 (0)