diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index 4b2596c4d5..20a29b3ad5 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -268,7 +268,7 @@ def search_cross_encoder( for qid in queries.keys(): if self.previous_results is None: # try to use all of them - logging.logging( + logging.info( f"previous_results is None. Using all the documents to rerank: {len(corpus)}" ) q_results = {doc_id: 0.0 for doc_id in corpus.keys()} @@ -318,7 +318,9 @@ def search_cross_encoder( len(queries_in_pair) == len(corpus_in_pair) == len(instructions_in_pair) ) - if isinstance(self.model.model, CrossEncoder): + if hasattr(self.model, "model") and isinstance( + self.model.model, CrossEncoder + ): # can't take instructions, so add them here queries_in_pair = [ f"{q} {i}".strip() @@ -428,7 +430,7 @@ def encode( def is_cross_encoder_compatible(model) -> bool: - op = getattr(model.model, "predict", None) + op = getattr(model, "predict", None) return callable(op) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 5cc824fa82..13d39e4031 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -53,6 +53,9 @@ def __init__( self.model.prompts = model_prompts self.model_prompts = self.validate_task_to_prompt_name(model_prompts) + if isinstance(self.model, CrossEncoder): + self.predict = self._predict + def encode( self, sentences: Sequence[str], @@ -106,7 +109,7 @@ def encode( embeddings = embeddings.cpu().detach().float().numpy() return embeddings - def predict( + def _predict( self, sentences: Sequence[str], **kwargs: Any,