Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mteb/abstasks/TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@
"Summarization",
"InstructionRetrieval",
"Speed",
"ZeroShotClassification",
"ImageTextPairClassification",
"Any2AnyMultiChoice",
"Any2AnyRetrieval",
"Any2TextMutipleChoice",
"ImageClustering",
"ImageClassification",
"ImageMultilabelClassification",
"ImageTextPairClassification",
"VisualSTS",
"ZeroShotClassification",
]

TASK_CATEGORY = Literal[
Expand Down
70 changes: 51 additions & 19 deletions mteb/models/e5_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor

Expand Down Expand Up @@ -56,9 +57,23 @@ def get_text_embeddings(self, texts: list[str], batch_size: int = 8):
all_text_embeddings.append(text_outputs.cpu())
return torch.cat(all_text_embeddings, dim=0)

def get_image_embeddings(self, images: list[Image.Image], batch_size: int = 8):
def get_image_embeddings(
self, images: list[Image.Image] | DataLoader, batch_size: int = 8
):
all_image_embeddings = []

if isinstance(images, DataLoader):
for batch_images in tqdm(images):
img_inputs = self.processor(
[self.img_prompt] * len(batch_images),
batch_images,
return_tensors="pt",
padding=True,
).to("cuda")
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_image_embeddings.append(image_outputs.cpu())
with torch.no_grad():
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]
Expand Down Expand Up @@ -95,24 +110,41 @@ def get_fused_embeddings(
all_fused_embeddings = []

if texts is not None and images is not None:
if len(texts) != len(images):
raise ValueError(
"The number of texts and images must have the same length"
)
with torch.no_grad():
for i in tqdm(range(0, len(images), batch_size)):
batch_texts = texts[i : i + batch_size]
batch_images = images[i : i + batch_size]
prompts = [
self.composed_prompt.format(text) for text in batch_texts
]
inputs = self.processor(
prompts, batch_images, return_tensors="pt", padding=True
).to("cuda")
outputs = self.model(
**inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_fused_embeddings.append(outputs.cpu())
if isinstance(images, DataLoader):
with torch.no_grad():
for index, batch_images in enumerate(tqdm(images)):
batch_texts = texts[
index * batch_size : (index + 1) * batch_size
]
prompts = [
self.composed_prompt.format(text) for text in batch_texts
]
inputs = self.processor(
prompts, batch_images, return_tensors="pt", padding=True
).to("cuda")
outputs = self.model(
**inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_fused_embeddings.append(outputs.cpu())
else:
if len(texts) != len(images):
raise ValueError(
"The number of texts and images must have the same length"
)
with torch.no_grad():
for i in tqdm(range(0, len(images), batch_size)):
batch_texts = texts[i : i + batch_size]
batch_images = images[i : i + batch_size]
prompts = [
self.composed_prompt.format(text) for text in batch_texts
]
inputs = self.processor(
prompts, batch_images, return_tensors="pt", padding=True
).to("cuda")
outputs = self.model(
**inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_fused_embeddings.append(outputs.cpu())
return torch.cat(all_fused_embeddings, dim=0)

elif texts is not None:
Expand Down
75 changes: 54 additions & 21 deletions mteb/models/vista_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from functools import partial
from typing import Any

import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

from mteb.model_meta import ModelMeta

tensor_to_image = transforms.Compose([transforms.ToPILImage()])


def vista_loader(**kwargs):
try: # a temporal fix for the dependency issues of vista models.
Expand All @@ -27,6 +32,7 @@ def __init__(
negatives_cross_device: bool = False,
temperature: float = 0.02,
from_pretrained=None,
**kwargs: Any,
):
super().__init__(
model_name_bge=model_name_bge,
Expand Down Expand Up @@ -88,15 +94,21 @@ def encode_text(self, texts):
t_reps = torch.nn.functional.normalize(t_reps, dim=-1)
return t_reps.contiguous()

def encode(self, images=None, texts=None):
def encode(self, images=None, texts=None, tensors=False):
if images is not None:
if isinstance(images, list):
images = [
self.preprocess_val(
img if isinstance(img, Image.Image) else Image.open(img)
)
for img in images
]
if not tensors:
images = [
self.preprocess_val(
img if isinstance(img, Image.Image) else Image.open(img)
)
for img in images
]
else:
images = [
self.preprocess_val(tensor_to_image(image))
for image in images
]
images = torch.stack(images)
if texts is not None:
texts = self.tokenizer(texts, return_tensors="pt", padding=True)
Expand All @@ -119,31 +131,52 @@ def get_text_embeddings(self, texts: list[str], batch_size: int = 32):
all_text_embeddings.append(batch_embeddings.cpu())
return torch.cat(all_text_embeddings, dim=0)

def get_image_embeddings(self, images: list[Image.Image], batch_size: int = 32):
def get_image_embeddings(
self, images: list[Image.Image] | DataLoader, batch_size: int = 32
):
all_image_embeddings = []
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]

if isinstance(images, DataLoader):
with torch.no_grad():
batch_embeddings = self.encode(images=batch_images)
all_image_embeddings.append(batch_embeddings.cpu())
for batch in tqdm(images):
batch_embeddings = self.encode(images=batch, tensors=True)
all_image_embeddings.append(batch_embeddings.cpu())
else:
with torch.no_grad():
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]
batch_embeddings = self.encode(images=batch_images)
all_image_embeddings.append(batch_embeddings.cpu())
return torch.cat(all_image_embeddings, dim=0)

def get_fused_embeddings(
self,
texts: list[str] = None,
images: list[Image.Image] = None,
images: list[Image.Image] | DataLoader = None,
batch_size: int = 32,
):
all_embeddings = []
assert len(texts) == len(images)
for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i : i + batch_size]
batch_images = images[i : i + batch_size]

if isinstance(images, DataLoader):
with torch.no_grad():
for index, batch_images in enumerate(tqdm(images)):
batch_texts = texts[
index * batch_size : (index + 1) * batch_size
]
batch_embeddings = self.encode(
images=batch_images, texts=batch_texts, tensors=True
)
all_embeddings.append(batch_embeddings.cpu())
else:
assert len(texts) == len(images)
with torch.no_grad():
batch_embeddings = self.encode(
images=batch_images, texts=batch_texts
)
all_embeddings.append(batch_embeddings.cpu())
for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i : i + batch_size]
batch_images = images[i : i + batch_size]
batch_embeddings = self.encode(
images=batch_images, texts=batch_texts
)
all_embeddings.append(batch_embeddings.cpu())
return torch.cat(all_embeddings, dim=0)

def calculate_probs(self, text_embeddings, image_embeddings):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ImageCoDeT2IMultiChoice(AbsTaskAny2AnyMultiChoice):
"path": "JamieSJS/imagecode-multi",
"revision": "d28adfd8b34fefa546fdf94bdc352622b2575f6c",
},
type="Retrieval",
type="Any2AnyMultiChoice",
category="t2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BLINKIT2IRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "359b66f11c25d19bc8f7108d98e660a5857f3d26",
"trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BLINKIT2TRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "4ab83c87ac5b24e3b730f86d585671493a3a423c",
"trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Image/Any2AnyRetrieval/eng/CIRRIT2IRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CIRRIT2IRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "503301cd99348035b9675883a543aa1ded0cf07c",
"trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CUB200I2I(AbsTaskAny2AnyRetrieval):
"path": "isaacchung/cub200_retrieval",
"revision": "ad08c1307b15a226bf1b64e62656a17f1f85f7ec",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
4 changes: 2 additions & 2 deletions mteb/tasks/Image/Any2AnyRetrieval/eng/FORBI2IRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class FORBI2I(AbsTaskAny2AnyRetrieval):
reference="https://github.com/pxiangwu/FORB",
dataset={
"path": "isaacchung/forb_retrieval",
"revision": "336607d5bcc853fb7f7276c2c9721d4b5b1ca8e4",
"revision": "26ab4bd972854becada339afc80f5f3ffc047e2b",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Fashion200kI2TRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "96a313715ecf67f5dfe70c4fa52406bc7bdfbeee",
# "trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Fashion200kT2IRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "1b86e2dde50e671d5c83d07a79e8b1d8c696964b",
# "trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="t2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class FashionIQIT2IRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "e6f0ec70becc413d940cd62b2cfa3b1d3a08c31a",
# "trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Flickr30kI2TRetrieval(AbsTaskAny2AnyRetrieval):
"path": "JamieSJS/flickr30k",
"revision": "a4cf34ac79215f9e2cd6a10342d84f606fc41cc3",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Flickr30kT2IRetrieval(AbsTaskAny2AnyRetrieval):
"path": "JamieSJS/flickr30k",
"revision": "a4cf34ac79215f9e2cd6a10342d84f606fc41cc3",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="t2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Image/Any2AnyRetrieval/eng/GLDv2I2TRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GLDv2I2TRetrieval(AbsTaskAny2AnyRetrieval):
"path": "JamieSJS/gld-v2-i2t",
"revision": "d8c3e53160860f76de73ed3041a8593672fe5928",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class HatefulMemesI2TRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "c9a9a6c3ef0765622a6de0af6ebb68f323ad73ba",
# "trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class HatefulMemesT2IRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "c9a9a6c3ef0765622a6de0af6ebb68f323ad73ba",
# "trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="t2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ImageCoDeT2IRetrieval(AbsTaskAny2AnyRetrieval):
"path": "JamieSJS/imagecode",
"revision": "a424cd523ffb157b69a875fb5e71c1d51be54089",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="t2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class InfoSeekIT2ITRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "78ee7f7708aac75d3afac5dcab1c9e03cb62664c",
"trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2it",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class InfoSeekIT2TRetrieval(AbsTaskAny2AnyRetrieval):
"revision": "d4f4606f7a42bbf311c2957419ef3734fe81c47f",
"trust_remote_code": True,
},
type="Retrieval",
type="Any2AnyRetrieval",
category="it2t",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Image/Any2AnyRetrieval/eng/METI2IRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class METI2IRetrieval(AbsTaskAny2AnyRetrieval):
"path": "JamieSJS/met",
"revision": "08ceaa61c0d172214abb3b8e82971d8f69d2aec0",
},
type="Retrieval",
type="Any2AnyRetrieval",
category="i2i",
eval_splits=["test"],
eval_langs=["eng-Latn"],
Expand Down
Loading