-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pseudo Relevance Feedback Interpolation on pysearch (#123)
Initial implementation. Remaining issues to be addressed - punting for now.
- Loading branch information
Showing
5 changed files
with
288 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,93 @@ | ||
import re | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import re | ||
from pyserini.search.pysearch import get_topics, SimpleSearcher | ||
from pyserini.search.reranker import ClassifierType, PseudoRelevanceClassifierReranker | ||
|
||
from pyserini.search.pysearch import get_topics,SimpleSearcher | ||
|
||
parser = argparse.ArgumentParser(description='Create a input schema') | ||
parser.add_argument('-index', metavar='path', required=True, | ||
help='the path to workspace') | ||
help='the path to workspace') | ||
parser.add_argument('-topics', metavar='topicsname', required=True, | ||
help='topicsname') | ||
parser.add_argument('-output', metavar='path', required=True, | ||
help='path to the output file') | ||
help='topicsname') | ||
parser.add_argument('-output', metavar='path', | ||
help='path to the output file') | ||
parser.add_argument('-bm25', action='store_true', default=True, | ||
help='use bm25 ranker') | ||
parser.add_argument('-rm3', action='store_true', | ||
help='take rm3 ranker') | ||
help='use rm3 ranker') | ||
parser.add_argument('-qld', action='store_true', | ||
help='take qld ranker') | ||
help='use qld ranker') | ||
parser.add_argument('-prf', type=ClassifierType, nargs='+', | ||
help='use pseudo relevance feedback ranker') | ||
parser.add_argument('-r', type=int, default=10, | ||
help='number of positive labels in pseudo relevance feedback') | ||
parser.add_argument('-n', type=int, default=100, | ||
help='number of negative labels in pseudo relevance feedback') | ||
parser.add_argument('-alpha', type=float, default=0.5, | ||
help='alpha value for interpolation in pseudo relevance feedback') | ||
args = parser.parse_args() | ||
|
||
searcher = SimpleSearcher(args.index) | ||
topics_dic = get_topics(args.topics) | ||
search_rankers = ['bm25'] | ||
if args.rm3: | ||
search_rankers.append('rm3') | ||
searcher.set_rm3() | ||
if args.qld: | ||
search_rankers.append('qld') | ||
searcher.set_qld() | ||
if topics_dic != {}: | ||
target_file = open(args.output, 'w') | ||
for key, value in sorted(topics_dic.items()): | ||
search = value.get('title') | ||
|
||
if topics_dic == {}: | ||
print('Topic Not Found') | ||
exit() | ||
|
||
output_path = args.output | ||
if output_path is None: | ||
clf_rankers = [] | ||
for t in args.prf: | ||
if t == ClassifierType.LR: | ||
clf_rankers.append('lr') | ||
elif t == ClassifierType.SVM: | ||
clf_rankers.append('svm') | ||
|
||
tokens = [args.topics, '+'.join(clf_rankers), | ||
f'A{args.alpha}', '+'.join(search_rankers)] | ||
output_path = '_'.join(tokens) + ".txt" | ||
|
||
need_classifier = args.prf and len(args.prf) > 0 and args.alpha > 0 | ||
if need_classifier is True: | ||
ranker = PseudoRelevanceClassifierReranker( | ||
args.index, args.prf, r=args.r, n=args.n, alpha=args.alpha) | ||
|
||
print('Output ->', output_path) | ||
with open(output_path, 'w') as target_file: | ||
for index, topic in enumerate(sorted(topics_dic.keys())): | ||
print(f'Topic {topic}: {index + 1}/{len(topics_dic)}') | ||
search = topics_dic[topic].get('title') | ||
hits = searcher.search(search, 1000) | ||
for i in range(0, len(hits)): | ||
target_file.write(f'{key} Q0 {hits[i].docid.strip()} {i + 1} {hits[i].score:.6f} Anserini\n') | ||
target_file.close() | ||
else: | ||
print('Topic Not Found') | ||
doc_ids = [hit.docid.strip() for hit in hits] | ||
scores = [hit.score for hit in hits] | ||
|
||
if need_classifier and len(hits) > (args.r + args.n): | ||
scores, doc_ids = ranker.rerank(doc_ids, scores) | ||
|
||
tag = output_path[:-4] if args.output is None else 'anserini' | ||
for i, (doc_id, score) in enumerate(zip(doc_ids, scores)): | ||
target_file.write( | ||
f'{topic} Q0 {doc_id} {i + 1} {score:.6f} {tag}\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import enum | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.svm import SVC | ||
from typing import List | ||
from ..vectorizer import TfidfVectorizer | ||
import uuid | ||
import os | ||
|
||
|
||
class ClassifierType(enum.Enum): | ||
LR = 'lr' | ||
SVM = 'svm' | ||
|
||
|
||
class FusionMethod(enum.Enum): | ||
AVG = 'avg' | ||
|
||
|
||
class PseudoRelevanceClassifierReranker: | ||
def __init__(self, lucene_index: str, clf_type: List[ClassifierType], r=10, n=100, alpha=0.5): | ||
self.r = r | ||
self.n = n | ||
self.alpha = alpha | ||
self.clf_type = clf_type | ||
|
||
if len(clf_type) > 2: | ||
raise Exception('Re-ranker takes at most two classifiers') | ||
|
||
self.vectorizer = TfidfVectorizer(lucene_index, min_df=5) | ||
|
||
def _set_classifier(self, clf_type: ClassifierType): | ||
if clf_type == ClassifierType.LR: | ||
self.clf = LogisticRegression(random_state=42) | ||
elif clf_type == ClassifierType.SVM: | ||
self.clf = SVC(kernel='linear', probability=True, random_state=42) | ||
else: | ||
raise Exception("Invalid classifier type") | ||
|
||
def _get_prf_vectors(self, doc_ids: List[str]): | ||
train_docs = doc_ids[:self.r] + doc_ids[-self.n:] | ||
train_labels = [1] * self.r + [0] * self.n | ||
|
||
train_vecs = self.vectorizer.get_vectors(train_docs) | ||
test_vecs = self.vectorizer.get_vectors(doc_ids) | ||
|
||
return train_vecs, train_labels, test_vecs | ||
|
||
def _rerank_with_classifier(self, doc_ids: List[str], search_scores: List[float]): | ||
train_vecs, train_labels, test_vecs = self._get_prf_vectors( | ||
doc_ids) | ||
|
||
# classification | ||
self.clf.fit(train_vecs, train_labels) | ||
pred = self.clf.predict_proba(test_vecs) | ||
classifier_scores = self._normalize([p[1] for p in pred]) | ||
search_scores = self._normalize(search_scores) | ||
|
||
# interpolation | ||
interpolated_scores = [a * self.alpha + b * (1-self.alpha) | ||
for a, b in zip(classifier_scores, search_scores)] | ||
|
||
return self._sort_dual_list(interpolated_scores, doc_ids) | ||
|
||
def rerank(self, doc_ids: List[str], search_scores: List[float]): | ||
# one classifier | ||
if len(self.clf_type) == 1: | ||
self._set_classifier(self.clf_type[0]) | ||
return self._rerank_with_classifier(doc_ids, search_scores) | ||
|
||
# two classifier with FusionMethod.AVG | ||
doc_score_dict = {} | ||
for i in range(2): | ||
self._set_classifier(self.clf_type[i]) | ||
i_scores, i_doc_ids = self._rerank_with_classifier( | ||
doc_ids, search_scores) | ||
|
||
for score, doc_id in zip(i_scores, i_doc_ids): | ||
if doc_id not in doc_score_dict: | ||
doc_score_dict[doc_id] = set() | ||
doc_score_dict[doc_id].add(score) | ||
|
||
r_scores, r_doc_ids = [], [] | ||
for doc_id, score in doc_score_dict.items(): | ||
avg = sum(score) / len(score) | ||
r_doc_ids.append(doc_id) | ||
r_scores.append(avg) | ||
|
||
return r_scores, r_doc_ids | ||
|
||
def _normalize(self, scores: List[float]): | ||
low = min(scores) | ||
high = max(scores) | ||
width = high - low | ||
|
||
return [(s-low)/width for s in scores] | ||
|
||
# sort both list in decreasing order by using the list1 to compare | ||
def _sort_dual_list(self, list1, list2): | ||
zipped_lists = zip(list1, list2) | ||
sorted_pairs = sorted(zipped_lists) | ||
|
||
tuples = zip(*sorted_pairs) | ||
list1, list2 = [list(tuple) for tuple in tuples] | ||
|
||
list1.reverse() | ||
list2.reverse() | ||
return list1, list2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
pyjnius==1.2.1 | ||
numpy==1.16.4 | ||
scipy==1.4.1 | ||
scikit-learn==0.19.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import argparse | ||
from typing import Dict, List, Set | ||
|
||
|
||
def get_topics(path: str) -> Set[str]: | ||
topics = set() | ||
with open(path, 'r') as f: | ||
for line in f: | ||
topic = line.strip().split(' ')[0] | ||
topics.add(topic) | ||
|
||
return topics | ||
|
||
|
||
def get_doc_id_dict(topic: str, path: str, res=None) -> Dict[str, List[float]]: | ||
if res is None: | ||
res = {} | ||
|
||
with open(path, 'r') as f: | ||
for line in f: | ||
tokens = line.strip().split(' ') | ||
line_topic = tokens[0] | ||
if topic != line_topic: | ||
continue | ||
|
||
doc_id = tokens[2] | ||
score = float(tokens[4]) | ||
if doc_id not in res: | ||
res[doc_id] = [] | ||
|
||
res[doc_id].append(score) | ||
|
||
return res | ||
|
||
|
||
# sort both list in decreasing order by using the list1 to compare | ||
def sort_dual_list(list1, list2): | ||
zipped_lists = zip(list1, list2) | ||
sorted_pairs = sorted(zipped_lists) | ||
|
||
tuples = zip(*sorted_pairs) | ||
list1, list2 = [list(tuple) for tuple in tuples] | ||
|
||
list1.reverse() | ||
list2.reverse() | ||
return list1, list2 | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Perform ensemble average on two Qrun files') | ||
parser.add_argument('-run1', type=str, required=True, | ||
help='the path to the first qrun file') | ||
parser.add_argument('-run2', type=str, required=True, | ||
help='the path to the second qrun file') | ||
parser.add_argument('-out', type=str, required=True, | ||
help='the path to the final output file') | ||
args = parser.parse_args() | ||
|
||
# get topics | ||
topics = get_topics(args.run1) | ||
if topics != get_topics(args.run2): | ||
print('Topics mismatch') | ||
exit() | ||
|
||
topics = sorted(list(topics)) | ||
with open(args.out, 'w+') as f: | ||
for index, topic in enumerate(topics): | ||
print(f'Topic {topic}: {index+1}/{len(topics)}') | ||
doc_id_score_dict = get_doc_id_dict(topic, args.run1) | ||
doc_id_score_dict = get_doc_id_dict( | ||
topic, args.run2, doc_id_score_dict) | ||
|
||
doc_ids, scores = [], [] | ||
for doc_id, score in doc_id_score_dict.items(): | ||
if len(score) != 2: | ||
print(f"[Topic {topic}][id {doc_id}] should have exactly 2 scores") | ||
avg = sum(score) / len(score) | ||
scores.append(avg) | ||
doc_ids.append(doc_id) | ||
|
||
scores, doc_ids = sort_dual_list(scores, doc_ids) | ||
for i, (score, doc_id) in enumerate(zip(scores, doc_ids)): | ||
f.write(f"{topic} Q0 {doc_id} {i+1} {score:.6f} ensemble_avg\n") |