-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cross Validation for Pseudo Relevance Feedback (#131)
* Add SimpleSearch Pyserini End * Add SimpleSEarch pyserini end * Add SimpleSEarch pyserini end * Add SimpleSEarch pyserini end * add simplesearch * add simplesearcher * move main function to * move main function to * move main function to __main__.py * modify main function * modify main function * add rm3 and qld * add rm3 and qld * add rm3 and qld * add cross validation * add cross validation * add cross validation * add cross validation * add cross validation * add cross validation * add cross validation * add cross validation
- Loading branch information
Showing
5 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pandas as pd | ||
import json | ||
import os | ||
import argparse | ||
|
||
|
||
def read_topics_alpha_map(anserini_root, collection, run_file, classifier, rm3): | ||
res_paths = [] | ||
for num in range(0, 11): | ||
alpha = str(num / 10) | ||
if rm3: | ||
file_path = f'{run_file}/{collection}/{collection}_{classifier}_A' + alpha + '_bm25+rm3.txt' | ||
else: | ||
file_path = f'{run_file}/{collection}/{collection}_{classifier}_A' + alpha + '_bm25.txt' | ||
res_filename = f'{run_file}/cv/{collection}/scores_{collection}_{classifier}_A' + alpha + '_bm25.txt' | ||
res_paths.append(res_filename) | ||
cmd = f'{anserini_root}/eval/trec_eval.9.0.4/trec_eval -q -m map -m P.30 {anserini_root}/src/main/resources/topics-and-qrels/qrels.core18.txt {file_path} > {res_filename}' | ||
res = os.system(cmd) | ||
if (res == 0): | ||
print(file_path + ' run successfully!') | ||
print('save result in ' + res_filename) | ||
return res_paths | ||
|
||
|
||
def load_in_res(res_paths): | ||
df = pd.read_csv(res_paths[0], sep='\s+', header=None, names=['Type', 'topicid', '0.0'], dtype={'0.0': float}) | ||
df.set_index('topicid', inplace=True) | ||
for num in range(1, 11): | ||
dataset = pd.read_csv(res_paths[num], sep='\s+', header=None, names=['Type', 'topicid', 'score'], | ||
dtype={'topicid': str, 'score': float}) | ||
df[str(num / 10)] = dataset.score.values | ||
df = df[df['Type'] == 'map'][:-1] | ||
df = df.drop(['Type'], axis=1) | ||
return df | ||
|
||
|
||
def generate_run_file(folders,df,collection, run_file, classifier, rm3, output_path): | ||
highest_alpha_lst = [] | ||
write_lst = [] | ||
with open(output_path, 'w') as target_file: | ||
for folder in folders: | ||
train_topicids = [str(topic) for folder_i in folders for topic in folder_i if folder_i != folder] | ||
train_df = df.loc[train_topicids, :] | ||
train_df.loc['Mean', :] = train_df.mean(axis=0) | ||
highest_alpha = train_df.iloc[-1, :].idxmax(axis=1) | ||
highest_alpha_lst.append(highest_alpha) | ||
for topic in folder: | ||
if rm3: | ||
alpha_run_file = f'{run_file}/{collection}/{collection}_{classifier}_A' + highest_alpha + '_bm25.txt' | ||
else: | ||
alpha_run_file = f'{run_file}/{collection}/{collection}_{classifier}_A' + highest_alpha + '_bm25+rm3.txt' | ||
with open(alpha_run_file) as fp: | ||
Lines = fp.readlines() | ||
for line in Lines: | ||
if line.startswith(str(topic)): | ||
write_lst.append(line) | ||
write_lst.sort(key=lambda x: (x.split(" ")[0],int(x.split(" ")[3]))) | ||
target_file.write("".join(write_lst)) | ||
print(highest_alpha_lst) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Get Best alpha score for corresponding topics') | ||
parser.add_argument('--anserini', metavar='path', required=True, | ||
help='the path to anserini root') | ||
parser.add_argument('--pyserini', metavar='path', required=True, | ||
help='a path to the folder json file') | ||
parser.add_argument('--collection', metavar='collectionsname', required=True, | ||
help='one of the collectionname in robust04,robust05, core17,core18') | ||
parser.add_argument('--run_file', metavar='path', required=True, | ||
help='the path to run files root') | ||
parser.add_argument('--output', metavar='path', required=True, | ||
help='the path to the output file') | ||
parser.add_argument('--classifier', metavar='name', required=True, | ||
help='one of three classifers lr or svm or lr+svm') | ||
parser.add_argument('-rm3', action='store_true', | ||
help='use rm3 ranker') | ||
args = parser.parse_args() | ||
res_paths = read_topics_alpha_map(args.anserini, args.collection, args.run_file, args.classifier, args.rm3) | ||
clean_df = load_in_res(res_paths) | ||
folders_path = os.path.join( | ||
args.pyserini, f'scripts/classifier_prf/folds/{args.collection}.json') | ||
with open(folders_path) as f: | ||
folders = json.load(f) | ||
generate_run_file(folders,clean_df,args.collection, args.run_file, args.classifier, args.rm3, args.output) | ||
print("Successfully generated a trained runfile in " + args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[[426, 408, 354, 399, 310, 363, 393, 442, 433, 690], [445, 379, 355, 321, 422, 394, 350, 330, 345, 620], [356, 435, 400, 375, 439, 614, 404, 341, 427, 378], [325, 336, 372, 344, 419, 436, 353, 626, 423, 367], [362, 646, 677, 307, 416, 443, 414, 397, 389, 347]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[[805, 439, 341, 816, 690, 815, 422, 414, 442, 821], [808, 810, 367, 427, 362, 375, 626, 426, 812, 809], [820, 823, 807, 321, 433, 802, 347, 806, 400, 378], [408, 822, 813, 824, 819, 814, 818, 811, 646, 350], [336, 803, 393, 825, 801, 804, 363, 817, 445, 397]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[[681, 357, 677, 692, 321, 349, 400, 362, 368, 670, 405, 696, 381, 449, 444, 340, 442, 643, 370, 700, 693, 432, 384, 395, 618, 419, 308, 418, 688, 327, 329, 649, 352, 678, 422, 650, 441, 375, 661, 379, 412, 636, 393, 450, 605, 684, 666, 623, 640, 676], [331, 617, 401, 690, 302, 631, 695, 635, 404, 372, 446, 658, 376, 664, 639, 383, 627, 655, 334, 668, 355, 657, 628, 332, 652, 359, 409, 685, 633, 392, 430, 396, 338, 323, 410, 389, 673, 346, 638, 653, 378, 683, 398, 385, 311, 629, 390, 620, 320, 336], [391, 667, 309, 609, 414, 682, 343, 614, 680, 669, 686, 621, 607, 361, 438, 672, 354, 602, 440, 344, 687, 645, 317, 436, 367, 691, 644, 420, 397, 421, 601, 314, 324, 399, 624, 350, 303, 433, 671, 611, 402, 351, 406, 679, 603, 619, 337, 358, 431, 646], [613, 648, 305, 659, 307, 634, 694, 345, 380, 437, 665, 428, 675, 637, 386, 426, 373, 326, 333, 377, 306, 388, 423, 622, 647, 663, 610, 651, 411, 415, 356, 360, 699, 413, 301, 318, 310, 312, 382, 348, 353, 632, 369, 689, 427, 654, 365, 339, 316, 697], [322, 616, 374, 641, 304, 330, 319, 342, 660, 366, 341, 625, 371, 604, 626, 387, 313, 425, 335, 363, 407, 698, 447, 612, 608, 674, 662, 408, 435, 615, 315, 445, 434, 656, 606, 443, 642, 630, 328, 424, 417, 347, 364, 448, 325, 439, 403, 416, 429, 394]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[[622, 426, 393, 658, 639, 347, 363, 378, 638, 383], [404, 362, 689, 427, 397, 416, 651, 336, 625, 435], [648, 448, 389, 394, 310, 372, 307, 443, 344, 330], [436, 439, 408, 375, 401, 374, 322, 419, 341, 367], [650, 433, 399, 353, 345, 354, 314, 303, 409, 325]] |