diff --git a/mteb/evaluation/evaluators/Image/Any2AnyMultiChoiceEvaluator.py b/mteb/evaluation/evaluators/Image/Any2AnyMultiChoiceEvaluator.py index 5fdbb112f3..c69f0153a2 100644 --- a/mteb/evaluation/evaluators/Image/Any2AnyMultiChoiceEvaluator.py +++ b/mteb/evaluation/evaluators/Image/Any2AnyMultiChoiceEvaluator.py @@ -57,7 +57,8 @@ def __getitem__(self, idx): image = image if image.mode != "RGB": image = image.convert("RGB") - image = self.transform(image) + if self.transform is not None: + image = self.transform(image) return image @@ -105,6 +106,7 @@ def search( qrels: Dataset, top_k: int, score_function: str, + task_name: str | None = None, return_sorted: bool = False, **kwargs, ) -> dict[str, dict[str, float]]: @@ -122,7 +124,9 @@ def search( if q_modality == "text": query_texts = queries["text"] query_embeddings = self.model.get_text_embeddings( - texts=query_texts, batch_size=self.encode_kwargs["batch_size"] + texts=query_texts, + task_name=task_name, + batch_size=self.encode_kwargs["batch_size"], ) else: queries_dataset = ImageDataset( @@ -139,6 +143,7 @@ def search( query_embeddings = self.model.get_image_embeddings( images=query_image_dataloader, batch_size=self.encode_kwargs["batch_size"], + task_name=task_name, ) elif q_modality == "image,text": query_texts = queries["text"] @@ -146,6 +151,7 @@ def search( texts=query_texts, images=query_image_dataloader, batch_size=self.encode_kwargs["batch_size"], + task_name=task_name, ) else: raise ValueError(f"Unsupported modality: {q_modality}") @@ -189,6 +195,7 @@ def search( sub_corpus_embeddings = self.model.get_image_embeddings( images=corpus_image_dataloader, batch_size=self.encode_kwargs["batch_size"], + task_name=task_name, ) elif corpus_modality == "image,text": corpus_texts = chunk["text"] @@ -196,6 +203,7 @@ def search( texts=corpus_texts, images=corpus_image_dataloader, batch_size=self.encode_kwargs["batch_size"], + task_name=task_name, ) else: raise ValueError(f"Unsupported modality: {corpus_modality}") @@ -301,7 +309,7 @@ def __call__( qrels, self.top_k, self.score_function, - prompt_name=self.task_name, # type: ignore + task_name=self.task_name, # type: ignore ) @staticmethod diff --git a/mteb/evaluation/evaluators/Image/ImageTextPairClassificationEvaluator.py b/mteb/evaluation/evaluators/Image/ImageTextPairClassificationEvaluator.py index 7e3d84bb87..f3188f7753 100644 --- a/mteb/evaluation/evaluators/Image/ImageTextPairClassificationEvaluator.py +++ b/mteb/evaluation/evaluators/Image/ImageTextPairClassificationEvaluator.py @@ -134,11 +134,15 @@ def __call__( images = [img for images in images_list for img in images] texts = [txt for texts in texts_list for txt in texts] images_emb = F.normalize( - model.get_image_embeddings(images, batch_size=len(images)), + model.get_image_embeddings( + images, batch_size=len(images), task_name=self.task_name + ), dim=-1, ).view(len(batch), num_images_per_sample, -1) texts_emb = F.normalize( - model.get_text_embeddings(texts, batch_size=len(texts)), + model.get_text_embeddings( + texts, batch_size=len(texts), task_name=self.task_name + ), dim=-1, ).view(len(batch), num_texts_per_sample, -1) for i in range(len(batch)): diff --git a/tests/test_benchmark/test_benchmark.py b/tests/test_benchmark/test_benchmark.py index d7357664fe..826610657f 100644 --- a/tests/test_benchmark/test_benchmark.py +++ b/tests/test_benchmark/test_benchmark.py @@ -35,7 +35,7 @@ MockRerankingTask, MockRetrievalTask, ) -from .task_grid import MOCK_TASK_TEST_GRID +from .task_grid import MOCK_MIEB_TASK_GRID, MOCK_TASK_TEST_GRID logging.basicConfig(level=logging.INFO) @@ -175,6 +175,32 @@ def encode(self, sentences, prompt_name: str | None = None, **kwargs): ) +@pytest.mark.parametrize("task_name", MOCK_TASK_TEST_GRID + MOCK_MIEB_TASK_GRID) +def test_task_name_passed_encoder(task_name: mteb.AbsTask, tmp_path: Path): + """Test that all tasks correctly pass down the task_name to the encoder.""" + _task_name = ( + task_name.metadata.name if isinstance(task_name, mteb.AbsTask) else task_name + ) + + class MockEncoderWithInstructions(mteb.Encoder): + def encode(self, sentences, task_name: str | None = None, **kwargs): + assert task_name == _task_name + return np.zeros((len(sentences), 10)) + + if isinstance(task_name, mteb.AbsTask): + tasks = [task_name] + else: + tasks = mteb.get_tasks(tasks=[task_name]) + + eval = mteb.MTEB(tasks=tasks) + + eval.run( + MockEncoderWithInstructions(), + output_folder=tmp_path.as_posix(), + overwrite_results=True, + ) + + @pytest.mark.parametrize("model", [MockNumpyEncoder()]) def test_run_using_benchmark(model: mteb.Encoder, tmp_path: Path): """Test that a benchmark object can be run using the MTEB class."""