Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,17 @@ def search_cross_encoder(
assert (
len(queries_in_pair) == len(corpus_in_pair) == len(instructions_in_pair)
)
corpus_in_pair = corpus_to_str(list(corpus_in_pair))

if hasattr(self.model, "model") and isinstance(
self.model.model, CrossEncoder
):
# can't take instructions, so add them here
queries_in_pair = [
f"{q} {i}".strip()
for i, q in zip(instructions_in_pair, queries_in_pair)
]
if instructions_in_pair[0] is not None:
queries_in_pair = [
f"{q} {i}".strip()
for i, q in zip(instructions_in_pair, queries_in_pair)
]
scores = self.model.predict(list(zip(queries_in_pair, corpus_in_pair))) # type: ignore
else:
# may use the instructions in a unique way, so give them also
Expand Down
32 changes: 25 additions & 7 deletions mteb/models/rerankers_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
batch_size: int = 4,
fp_options: bool = None,
silent: bool = False,
**kwargs,
):
self.model_name_or_path = model_name_or_path
self.batch_size = batch_size
Expand All @@ -34,7 +35,7 @@ def __init__(
self.fp_options = torch.float32
elif self.fp_options == "bfloat16":
self.fp_options = torch.bfloat16
print(f"Using fp_options of {self.fp_options}")
logger.info(f"Using fp_options of {self.fp_options}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.silent = silent
self.first_print = True # for debugging
Expand Down Expand Up @@ -70,7 +71,12 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs
if instructions is not None and instructions[0] is not None:
assert len(instructions) == len(queries)
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]
Expand Down Expand Up @@ -112,7 +118,13 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]

Expand Down Expand Up @@ -152,7 +164,13 @@ def __init__(
)

def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]

Expand All @@ -179,7 +197,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=MonoBERTReranker,
model_name_or_path="castorini/monobert-large-msmarco",
fp_options="float1616",
fp_options="float16",
),
name="castorini/monobert-large-msmarco",
languages=["eng_Latn"],
Expand All @@ -194,7 +212,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=JinaReranker,
model_name_or_path="jinaai/jina-reranker-v2-base-multilingual",
fp_options="float1616",
fp_options="float16",
),
name="jinaai/jina-reranker-v2-base-multilingual",
languages=["eng_Latn"],
Expand All @@ -208,7 +226,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=BGEReranker,
model_name_or_path="BAAI/bge-reranker-v2-m3",
fp_options="float1616",
fp_options="float16",
),
name="BAAI/bge-reranker-v2-m3",
languages=[
Expand Down
15 changes: 13 additions & 2 deletions mteb/models/rerankers_monot5_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ def get_prediction_tokens(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]
Expand Down Expand Up @@ -194,7 +199,13 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
# logger.info(f"Adding instructions to LLAMA queries")
queries = [
Expand Down
Loading