Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mteb/evaluation/evaluators/Image/Any2AnyMultiChoiceEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __getitem__(self, idx):
image = image
if image.mode != "RGB":
image = image.convert("RGB")
image = self.transform(image)
if self.transform is not None:
image = self.transform(image)
return image


Expand Down Expand Up @@ -105,6 +106,7 @@ def search(
qrels: Dataset,
top_k: int,
score_function: str,
task_name: str | None = None,
return_sorted: bool = False,
**kwargs,
) -> dict[str, dict[str, float]]:
Expand All @@ -122,7 +124,9 @@ def search(
if q_modality == "text":
query_texts = queries["text"]
query_embeddings = self.model.get_text_embeddings(
texts=query_texts, batch_size=self.encode_kwargs["batch_size"]
texts=query_texts,
task_name=task_name,
batch_size=self.encode_kwargs["batch_size"],
)
else:
queries_dataset = ImageDataset(
Expand All @@ -139,13 +143,15 @@ def search(
query_embeddings = self.model.get_image_embeddings(
images=query_image_dataloader,
batch_size=self.encode_kwargs["batch_size"],
task_name=task_name,
)
elif q_modality == "image,text":
query_texts = queries["text"]
query_embeddings = self.model.get_fused_embeddings(
texts=query_texts,
images=query_image_dataloader,
batch_size=self.encode_kwargs["batch_size"],
task_name=task_name,
)
else:
raise ValueError(f"Unsupported modality: {q_modality}")
Expand Down Expand Up @@ -189,13 +195,15 @@ def search(
sub_corpus_embeddings = self.model.get_image_embeddings(
images=corpus_image_dataloader,
batch_size=self.encode_kwargs["batch_size"],
task_name=task_name,
)
elif corpus_modality == "image,text":
corpus_texts = chunk["text"]
sub_corpus_embeddings = self.model.get_fused_embeddings(
texts=corpus_texts,
images=corpus_image_dataloader,
batch_size=self.encode_kwargs["batch_size"],
task_name=task_name,
)
else:
raise ValueError(f"Unsupported modality: {corpus_modality}")
Expand Down Expand Up @@ -301,7 +309,7 @@ def __call__(
qrels,
self.top_k,
self.score_function,
prompt_name=self.task_name, # type: ignore
task_name=self.task_name, # type: ignore
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,15 @@ def __call__(
images = [img for images in images_list for img in images]
texts = [txt for texts in texts_list for txt in texts]
images_emb = F.normalize(
model.get_image_embeddings(images, batch_size=len(images)),
model.get_image_embeddings(
images, batch_size=len(images), task_name=self.task_name
),
dim=-1,
).view(len(batch), num_images_per_sample, -1)
texts_emb = F.normalize(
model.get_text_embeddings(texts, batch_size=len(texts)),
model.get_text_embeddings(
texts, batch_size=len(texts), task_name=self.task_name
),
dim=-1,
).view(len(batch), num_texts_per_sample, -1)
for i in range(len(batch)):
Expand Down
28 changes: 27 additions & 1 deletion tests/test_benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MockRerankingTask,
MockRetrievalTask,
)
from .task_grid import MOCK_TASK_TEST_GRID
from .task_grid import MOCK_MIEB_TASK_GRID, MOCK_TASK_TEST_GRID

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -175,6 +175,32 @@ def encode(self, sentences, prompt_name: str | None = None, **kwargs):
)


@pytest.mark.parametrize("task_name", MOCK_TASK_TEST_GRID + MOCK_MIEB_TASK_GRID)
def test_task_name_passed_encoder(task_name: mteb.AbsTask, tmp_path: Path):
"""Test that all tasks correctly pass down the task_name to the encoder."""
_task_name = (
task_name.metadata.name if isinstance(task_name, mteb.AbsTask) else task_name
)

class MockEncoderWithInstructions(mteb.Encoder):
def encode(self, sentences, task_name: str | None = None, **kwargs):
assert task_name == _task_name
return np.zeros((len(sentences), 10))

if isinstance(task_name, mteb.AbsTask):
tasks = [task_name]
else:
tasks = mteb.get_tasks(tasks=[task_name])

eval = mteb.MTEB(tasks=tasks)

eval.run(
MockEncoderWithInstructions(),
output_folder=tmp_path.as_posix(),
overwrite_results=True,
)


@pytest.mark.parametrize("model", [MockNumpyEncoder()])
def test_run_using_benchmark(model: mteb.Encoder, tmp_path: Path):
"""Test that a benchmark object can be run using the MTEB class."""
Expand Down