Skip to content

Commit

Permalink
Cross Validation for Pseudo Relevance Feedback (#131)
Browse files Browse the repository at this point in the history
* Add SimpleSearch Pyserini End

* Add SimpleSEarch pyserini end

* Add SimpleSEarch pyserini end

* Add SimpleSEarch pyserini end

* add simplesearch

* add simplesearcher

* move main function to

* move main function to

* move main function to __main__.py

* modify main function

* modify main function

* add rm3 and qld

* add rm3 and qld

* add rm3 and qld

* add cross validation

* add cross validation

* add cross validation

* add cross validation

* add cross validation

* add cross validation

* add cross validation

* add cross validation
  • Loading branch information
yuki617 authored May 24, 2020
1 parent 47d75fa commit 770dfdc
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 0 deletions.
87 changes: 87 additions & 0 deletions scripts/classifier_prf/cross-validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pandas as pd
import json
import os
import argparse


def read_topics_alpha_map(anserini_root, collection, run_file, classifier, rm3):
res_paths = []
for num in range(0, 11):
alpha = str(num / 10)
if rm3:
file_path = f'{run_file}/{collection}/{collection}_{classifier}_A' + alpha + '_bm25+rm3.txt'
else:
file_path = f'{run_file}/{collection}/{collection}_{classifier}_A' + alpha + '_bm25.txt'
res_filename = f'{run_file}/cv/{collection}/scores_{collection}_{classifier}_A' + alpha + '_bm25.txt'
res_paths.append(res_filename)
cmd = f'{anserini_root}/eval/trec_eval.9.0.4/trec_eval -q -m map -m P.30 {anserini_root}/src/main/resources/topics-and-qrels/qrels.core18.txt {file_path} > {res_filename}'
res = os.system(cmd)
if (res == 0):
print(file_path + ' run successfully!')
print('save result in ' + res_filename)
return res_paths


def load_in_res(res_paths):
df = pd.read_csv(res_paths[0], sep='\s+', header=None, names=['Type', 'topicid', '0.0'], dtype={'0.0': float})
df.set_index('topicid', inplace=True)
for num in range(1, 11):
dataset = pd.read_csv(res_paths[num], sep='\s+', header=None, names=['Type', 'topicid', 'score'],
dtype={'topicid': str, 'score': float})
df[str(num / 10)] = dataset.score.values
df = df[df['Type'] == 'map'][:-1]
df = df.drop(['Type'], axis=1)
return df


def generate_run_file(folders,df,collection, run_file, classifier, rm3, output_path):
highest_alpha_lst = []
write_lst = []
with open(output_path, 'w') as target_file:
for folder in folders:
train_topicids = [str(topic) for folder_i in folders for topic in folder_i if folder_i != folder]
train_df = df.loc[train_topicids, :]
train_df.loc['Mean', :] = train_df.mean(axis=0)
highest_alpha = train_df.iloc[-1, :].idxmax(axis=1)
highest_alpha_lst.append(highest_alpha)
for topic in folder:
if rm3:
alpha_run_file = f'{run_file}/{collection}/{collection}_{classifier}_A' + highest_alpha + '_bm25.txt'
else:
alpha_run_file = f'{run_file}/{collection}/{collection}_{classifier}_A' + highest_alpha + '_bm25+rm3.txt'
with open(alpha_run_file) as fp:
Lines = fp.readlines()
for line in Lines:
if line.startswith(str(topic)):
write_lst.append(line)
write_lst.sort(key=lambda x: (x.split(" ")[0],int(x.split(" ")[3])))
target_file.write("".join(write_lst))
print(highest_alpha_lst)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Get Best alpha score for corresponding topics')
parser.add_argument('--anserini', metavar='path', required=True,
help='the path to anserini root')
parser.add_argument('--pyserini', metavar='path', required=True,
help='a path to the folder json file')
parser.add_argument('--collection', metavar='collectionsname', required=True,
help='one of the collectionname in robust04,robust05, core17,core18')
parser.add_argument('--run_file', metavar='path', required=True,
help='the path to run files root')
parser.add_argument('--output', metavar='path', required=True,
help='the path to the output file')
parser.add_argument('--classifier', metavar='name', required=True,
help='one of three classifers lr or svm or lr+svm')
parser.add_argument('-rm3', action='store_true',
help='use rm3 ranker')
args = parser.parse_args()
res_paths = read_topics_alpha_map(args.anserini, args.collection, args.run_file, args.classifier, args.rm3)
clean_df = load_in_res(res_paths)
folders_path = os.path.join(
args.pyserini, f'scripts/classifier_prf/folds/{args.collection}.json')
with open(folders_path) as f:
folders = json.load(f)
generate_run_file(folders,clean_df,args.collection, args.run_file, args.classifier, args.rm3, args.output)
print("Successfully generated a trained runfile in " + args.output)
1 change: 1 addition & 0 deletions scripts/classifier_prf/folds/core17.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[426, 408, 354, 399, 310, 363, 393, 442, 433, 690], [445, 379, 355, 321, 422, 394, 350, 330, 345, 620], [356, 435, 400, 375, 439, 614, 404, 341, 427, 378], [325, 336, 372, 344, 419, 436, 353, 626, 423, 367], [362, 646, 677, 307, 416, 443, 414, 397, 389, 347]]
1 change: 1 addition & 0 deletions scripts/classifier_prf/folds/core18.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[805, 439, 341, 816, 690, 815, 422, 414, 442, 821], [808, 810, 367, 427, 362, 375, 626, 426, 812, 809], [820, 823, 807, 321, 433, 802, 347, 806, 400, 378], [408, 822, 813, 824, 819, 814, 818, 811, 646, 350], [336, 803, 393, 825, 801, 804, 363, 817, 445, 397]]
1 change: 1 addition & 0 deletions scripts/classifier_prf/folds/robust04.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[681, 357, 677, 692, 321, 349, 400, 362, 368, 670, 405, 696, 381, 449, 444, 340, 442, 643, 370, 700, 693, 432, 384, 395, 618, 419, 308, 418, 688, 327, 329, 649, 352, 678, 422, 650, 441, 375, 661, 379, 412, 636, 393, 450, 605, 684, 666, 623, 640, 676], [331, 617, 401, 690, 302, 631, 695, 635, 404, 372, 446, 658, 376, 664, 639, 383, 627, 655, 334, 668, 355, 657, 628, 332, 652, 359, 409, 685, 633, 392, 430, 396, 338, 323, 410, 389, 673, 346, 638, 653, 378, 683, 398, 385, 311, 629, 390, 620, 320, 336], [391, 667, 309, 609, 414, 682, 343, 614, 680, 669, 686, 621, 607, 361, 438, 672, 354, 602, 440, 344, 687, 645, 317, 436, 367, 691, 644, 420, 397, 421, 601, 314, 324, 399, 624, 350, 303, 433, 671, 611, 402, 351, 406, 679, 603, 619, 337, 358, 431, 646], [613, 648, 305, 659, 307, 634, 694, 345, 380, 437, 665, 428, 675, 637, 386, 426, 373, 326, 333, 377, 306, 388, 423, 622, 647, 663, 610, 651, 411, 415, 356, 360, 699, 413, 301, 318, 310, 312, 382, 348, 353, 632, 369, 689, 427, 654, 365, 339, 316, 697], [322, 616, 374, 641, 304, 330, 319, 342, 660, 366, 341, 625, 371, 604, 626, 387, 313, 425, 335, 363, 407, 698, 447, 612, 608, 674, 662, 408, 435, 615, 315, 445, 434, 656, 606, 443, 642, 630, 328, 424, 417, 347, 364, 448, 325, 439, 403, 416, 429, 394]]
1 change: 1 addition & 0 deletions scripts/classifier_prf/folds/robust05.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[[622, 426, 393, 658, 639, 347, 363, 378, 638, 383], [404, 362, 689, 427, 397, 416, 651, 336, 625, 435], [648, 448, 389, 394, 310, 372, 307, 443, 344, 330], [436, 439, 408, 375, 401, 374, 322, 419, 341, 367], [650, 433, 399, 353, 345, 354, 314, 303, 409, 325]]

0 comments on commit 770dfdc

Please sign in to comment.