Skip to content

Commit

Permalink
Pseudo Relevance Feedback Interpolation on pysearch (#123)
Browse files Browse the repository at this point in the history
Initial implementation. Remaining issues to be addressed - punting for now.
  • Loading branch information
x65han authored May 22, 2020
1 parent 481d323 commit 9487191
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 20 deletions.
94 changes: 77 additions & 17 deletions pyserini/search/__main__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,93 @@
import re
# -*- coding: utf-8 -*-
#
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import re
from pyserini.search.pysearch import get_topics, SimpleSearcher
from pyserini.search.reranker import ClassifierType, PseudoRelevanceClassifierReranker

from pyserini.search.pysearch import get_topics,SimpleSearcher

parser = argparse.ArgumentParser(description='Create a input schema')
parser.add_argument('-index', metavar='path', required=True,
help='the path to workspace')
help='the path to workspace')
parser.add_argument('-topics', metavar='topicsname', required=True,
help='topicsname')
parser.add_argument('-output', metavar='path', required=True,
help='path to the output file')
help='topicsname')
parser.add_argument('-output', metavar='path',
help='path to the output file')
parser.add_argument('-bm25', action='store_true', default=True,
help='use bm25 ranker')
parser.add_argument('-rm3', action='store_true',
help='take rm3 ranker')
help='use rm3 ranker')
parser.add_argument('-qld', action='store_true',
help='take qld ranker')
help='use qld ranker')
parser.add_argument('-prf', type=ClassifierType, nargs='+',
help='use pseudo relevance feedback ranker')
parser.add_argument('-r', type=int, default=10,
help='number of positive labels in pseudo relevance feedback')
parser.add_argument('-n', type=int, default=100,
help='number of negative labels in pseudo relevance feedback')
parser.add_argument('-alpha', type=float, default=0.5,
help='alpha value for interpolation in pseudo relevance feedback')
args = parser.parse_args()

searcher = SimpleSearcher(args.index)
topics_dic = get_topics(args.topics)
search_rankers = ['bm25']
if args.rm3:
search_rankers.append('rm3')
searcher.set_rm3()
if args.qld:
search_rankers.append('qld')
searcher.set_qld()
if topics_dic != {}:
target_file = open(args.output, 'w')
for key, value in sorted(topics_dic.items()):
search = value.get('title')

if topics_dic == {}:
print('Topic Not Found')
exit()

output_path = args.output
if output_path is None:
clf_rankers = []
for t in args.prf:
if t == ClassifierType.LR:
clf_rankers.append('lr')
elif t == ClassifierType.SVM:
clf_rankers.append('svm')

tokens = [args.topics, '+'.join(clf_rankers),
f'A{args.alpha}', '+'.join(search_rankers)]
output_path = '_'.join(tokens) + ".txt"

need_classifier = args.prf and len(args.prf) > 0 and args.alpha > 0
if need_classifier is True:
ranker = PseudoRelevanceClassifierReranker(
args.index, args.prf, r=args.r, n=args.n, alpha=args.alpha)

print('Output ->', output_path)
with open(output_path, 'w') as target_file:
for index, topic in enumerate(sorted(topics_dic.keys())):
print(f'Topic {topic}: {index + 1}/{len(topics_dic)}')
search = topics_dic[topic].get('title')
hits = searcher.search(search, 1000)
for i in range(0, len(hits)):
target_file.write(f'{key} Q0 {hits[i].docid.strip()} {i + 1} {hits[i].score:.6f} Anserini\n')
target_file.close()
else:
print('Topic Not Found')
doc_ids = [hit.docid.strip() for hit in hits]
scores = [hit.score for hit in hits]

if need_classifier and len(hits) > (args.r + args.n):
scores, doc_ids = ranker.rerank(doc_ids, scores)

tag = output_path[:-4] if args.output is None else 'anserini'
for i, (doc_id, score) in enumerate(zip(doc_ids, scores)):
target_file.write(
f'{topic} Q0 {doc_id} {i + 1} {score:.6f} {tag}\n')
123 changes: 123 additions & 0 deletions pyserini/search/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -*- coding: utf-8 -*-
#
# Pyserini: Python interface to the Anserini IR toolkit built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from typing import List
from ..vectorizer import TfidfVectorizer
import uuid
import os


class ClassifierType(enum.Enum):
LR = 'lr'
SVM = 'svm'


class FusionMethod(enum.Enum):
AVG = 'avg'


class PseudoRelevanceClassifierReranker:
def __init__(self, lucene_index: str, clf_type: List[ClassifierType], r=10, n=100, alpha=0.5):
self.r = r
self.n = n
self.alpha = alpha
self.clf_type = clf_type

if len(clf_type) > 2:
raise Exception('Re-ranker takes at most two classifiers')

self.vectorizer = TfidfVectorizer(lucene_index, min_df=5)

def _set_classifier(self, clf_type: ClassifierType):
if clf_type == ClassifierType.LR:
self.clf = LogisticRegression(random_state=42)
elif clf_type == ClassifierType.SVM:
self.clf = SVC(kernel='linear', probability=True, random_state=42)
else:
raise Exception("Invalid classifier type")

def _get_prf_vectors(self, doc_ids: List[str]):
train_docs = doc_ids[:self.r] + doc_ids[-self.n:]
train_labels = [1] * self.r + [0] * self.n

train_vecs = self.vectorizer.get_vectors(train_docs)
test_vecs = self.vectorizer.get_vectors(doc_ids)

return train_vecs, train_labels, test_vecs

def _rerank_with_classifier(self, doc_ids: List[str], search_scores: List[float]):
train_vecs, train_labels, test_vecs = self._get_prf_vectors(
doc_ids)

# classification
self.clf.fit(train_vecs, train_labels)
pred = self.clf.predict_proba(test_vecs)
classifier_scores = self._normalize([p[1] for p in pred])
search_scores = self._normalize(search_scores)

# interpolation
interpolated_scores = [a * self.alpha + b * (1-self.alpha)
for a, b in zip(classifier_scores, search_scores)]

return self._sort_dual_list(interpolated_scores, doc_ids)

def rerank(self, doc_ids: List[str], search_scores: List[float]):
# one classifier
if len(self.clf_type) == 1:
self._set_classifier(self.clf_type[0])
return self._rerank_with_classifier(doc_ids, search_scores)

# two classifier with FusionMethod.AVG
doc_score_dict = {}
for i in range(2):
self._set_classifier(self.clf_type[i])
i_scores, i_doc_ids = self._rerank_with_classifier(
doc_ids, search_scores)

for score, doc_id in zip(i_scores, i_doc_ids):
if doc_id not in doc_score_dict:
doc_score_dict[doc_id] = set()
doc_score_dict[doc_id].add(score)

r_scores, r_doc_ids = [], []
for doc_id, score in doc_score_dict.items():
avg = sum(score) / len(score)
r_doc_ids.append(doc_id)
r_scores.append(avg)

return r_scores, r_doc_ids

def _normalize(self, scores: List[float]):
low = min(scores)
high = max(scores)
width = high - low

return [(s-low)/width for s in scores]

# sort both list in decreasing order by using the list1 to compare
def _sort_dual_list(self, list1, list2):
zipped_lists = zip(list1, list2)
sorted_pairs = sorted(zipped_lists)

tuples = zip(*sorted_pairs)
list1, list2 = [list(tuple) for tuple in tuples]

list1.reverse()
list2.reverse()
return list1, list2
5 changes: 3 additions & 2 deletions pyserini/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def l2norm(self, a):

def get_vectors(self, doc_ids: List[str]):
matrix_row, matrix_col, matrix_data = [], [], []
num_docs = len(doc_ids)

for index, doc_id in enumerate(doc_ids):
if index % 1000 == 0:
if index % 1000 == 0 and num_docs > 1000:
print(f'Vectorizing: {index}/{len(doc_ids)}')

# Term Frequency
Expand All @@ -59,5 +60,5 @@ def get_vectors(self, doc_ids: List[str]):
matrix_data.append(tfidf)

vectors = csr_matrix((matrix_data, (matrix_row, matrix_col)), shape=(
len(doc_ids), self.vocabulary_size))
num_docs, self.vocabulary_size))
return self.l2norm(vectors)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pyjnius==1.2.1
numpy==1.16.4
scipy==1.4.1
scikit-learn==0.19.2
84 changes: 84 additions & 0 deletions scripts/ensemble_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
from typing import Dict, List, Set


def get_topics(path: str) -> Set[str]:
topics = set()
with open(path, 'r') as f:
for line in f:
topic = line.strip().split(' ')[0]
topics.add(topic)

return topics


def get_doc_id_dict(topic: str, path: str, res=None) -> Dict[str, List[float]]:
if res is None:
res = {}

with open(path, 'r') as f:
for line in f:
tokens = line.strip().split(' ')
line_topic = tokens[0]
if topic != line_topic:
continue

doc_id = tokens[2]
score = float(tokens[4])
if doc_id not in res:
res[doc_id] = []

res[doc_id].append(score)

return res


# sort both list in decreasing order by using the list1 to compare
def sort_dual_list(list1, list2):
zipped_lists = zip(list1, list2)
sorted_pairs = sorted(zipped_lists)

tuples = zip(*sorted_pairs)
list1, list2 = [list(tuple) for tuple in tuples]

list1.reverse()
list2.reverse()
return list1, list2


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Perform ensemble average on two Qrun files')
parser.add_argument('-run1', type=str, required=True,
help='the path to the first qrun file')
parser.add_argument('-run2', type=str, required=True,
help='the path to the second qrun file')
parser.add_argument('-out', type=str, required=True,
help='the path to the final output file')
args = parser.parse_args()

# get topics
topics = get_topics(args.run1)
if topics != get_topics(args.run2):
print('Topics mismatch')
exit()

topics = sorted(list(topics))
with open(args.out, 'w+') as f:
for index, topic in enumerate(topics):
print(f'Topic {topic}: {index+1}/{len(topics)}')
doc_id_score_dict = get_doc_id_dict(topic, args.run1)
doc_id_score_dict = get_doc_id_dict(
topic, args.run2, doc_id_score_dict)

doc_ids, scores = [], []
for doc_id, score in doc_id_score_dict.items():
if len(score) != 2:
print(f"[Topic {topic}][id {doc_id}] should have exactly 2 scores")
avg = sum(score) / len(score)
scores.append(avg)
doc_ids.append(doc_id)

scores, doc_ids = sort_dual_list(scores, doc_ids)
for i, (score, doc_id) in enumerate(zip(scores, doc_ids)):
f.write(f"{topic} Q0 {doc_id} {i+1} {score:.6f} ensemble_avg\n")

0 comments on commit 9487191

Please sign in to comment.