Skip to content

Commit

Permalink
update hybrid for mrtydi (castorini#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
MXueguang committed Nov 5, 2021
1 parent 211a737 commit 6c14c50
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
8 changes: 5 additions & 3 deletions pyserini/hsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
def define_fusion_args(parser):
parser.add_argument('--alpha', type=float, metavar='num', required=False, default=0.1,
help="alpha for hybrid search")
parser.add_argument('--hits', type=int, required=False, default=10, help='number of hits from dense and sparse')
parser.add_argument('--normalization', action='store_true', required=False, help='hybrid score with normalization')
parser.add_argument('--weight-on-dense', action='store_true', required=False, help='weight on dense part')


def parse_args(parser, commands):
Expand Down Expand Up @@ -160,16 +162,16 @@ def parse_args(parser, commands):
batch_topic_ids = list()
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
if args.run.batch_size <= 1 and args.run.threads <= 1:
hits = hsearcher.search(text, args.run.hits, args.fusion.alpha, args.fusion.normalization)
hits = hsearcher.search(text, args.fusion.hits, args.run.hits, args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense)
results = [(topic_id, hits)]
else:
batch_topic_ids.append(str(topic_id))
batch_topics.append(text)
if (index + 1) % args.run.batch_size == 0 or \
index == len(topics.keys()) - 1:
results = hsearcher.batch_search(
batch_topics, batch_topic_ids, args.run.hits, args.run.threads,
args.fusion.alpha, args.fusion.normalization)
batch_topics, batch_topic_ids, args.fusion.hits, args.run.hits, args.run.threads,
args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
batch_topic_ids.clear()
batch_topics.clear()
Expand Down
22 changes: 11 additions & 11 deletions pyserini/hsearch/_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ def __init__(self, dense_searcher, sparse_searcher):
self.dense_searcher = dense_searcher
self.sparse_searcher = sparse_searcher

def search(self, query: str, k: int = 10, alpha: float = 0.1, normalization: bool = False) -> List[DenseSearchResult]:
dense_hits = self.dense_searcher.search(query, k)
sparse_hits = self.sparse_searcher.search(query, k)
return self._hybrid_results(dense_hits, sparse_hits, alpha, k, normalization)
def search(self, query: str, k0: int = 10, k: int = 10, alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) -> List[DenseSearchResult]:
dense_hits = self.dense_searcher.search(query, k0)
sparse_hits = self.sparse_searcher.search(query, k0)
return self._hybrid_results(dense_hits, sparse_hits, alpha, k, normalization, weight_on_dense)

def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, threads: int = 1,
alpha: float = 0.1, normalization: bool = False) \
def batch_search(self, queries: List[str], q_ids: List[str], k0: int = 10, k: int = 10, threads: int = 1,
alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) \
-> Dict[str, List[DenseSearchResult]]:
dense_result = self.dense_searcher.batch_search(queries, q_ids, k, threads)
sparse_result = self.sparse_searcher.batch_search(queries, q_ids, k, threads)
dense_result = self.dense_searcher.batch_search(queries, q_ids, k0, threads)
sparse_result = self.sparse_searcher.batch_search(queries, q_ids, k0, threads)
hybrid_result = {
key: self._hybrid_results(dense_result[key], sparse_result[key], alpha, k, normalization)
key: self._hybrid_results(dense_result[key], sparse_result[key], alpha, k, normalization, weight_on_dense)
for key in dense_result
}
return hybrid_result

@staticmethod
def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False):
def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False, weight_on_dense=False):
dense_hits = {hit.docid: hit.score for hit in dense_results}
sparse_hits = {hit.docid: hit.score for hit in sparse_results}
hybrid_result = []
Expand All @@ -76,6 +76,6 @@ def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False
/ (max_sparse_score - min_sparse_score)
dense_score = (dense_score - (min_dense_score + max_dense_score) / 2) \
/ (max_dense_score - min_dense_score)
score = alpha * sparse_score + dense_score
score = alpha * sparse_score + dense_score if not weight_on_dense else sparse_score + alpha * dense_score
hybrid_result.append(DenseSearchResult(doc, score))
return sorted(hybrid_result, key=lambda x: x.score, reverse=True)[:k]
2 changes: 1 addition & 1 deletion pyserini/search/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import Dict, List, Optional, Union
import numpy as np
from ._base import Document
from pyserini.index import IndexReader
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap, JString
from pyserini.util import download_prebuilt_index
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, CachedDataQueryEncoder
Expand Down Expand Up @@ -230,6 +229,7 @@ def _init_query_encoder_from_str(query_encoder):

@staticmethod
def _compute_idf(index_path):
from pyserini.index import IndexReader
index_reader = IndexReader(index_path)
tokens = []
dfs = []
Expand Down

0 comments on commit 6c14c50

Please sign in to comment.