Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions mteb/models/colpali_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
Expand Down Expand Up @@ -39,7 +40,10 @@ def __init__(

# Load model
self.mdl = model_class.from_pretrained(
model_name, revision=revision, device_map=self.device, **kwargs
model_name,
device_map=self.device,
adapter_kwargs={"revision": revision},
**kwargs,
)
self.mdl.eval()

Expand Down Expand Up @@ -68,14 +72,13 @@ def get_image_embeddings(
iterator = DataLoader(images, batch_size=batch_size)

with torch.no_grad():
for batch in iterator:
for batch in tqdm(iterator):
# batch may be list of tensors or PIL
imgs = [
F.to_pil_image(b.to("cpu")) if not isinstance(b, Image.Image) else b
for b in batch
]
inputs = self.processor.process_images(imgs)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
inputs = self.processor.process_images(imgs).to(self.device)
outs = self.encode_input(inputs)
all_embeds.extend(outs.cpu().to(torch.float32))

Expand All @@ -92,10 +95,14 @@ def get_text_embeddings(
):
all_embeds = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = self.processor.process_queries(batch)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
for i in tqdm(range(0, len(texts), batch_size)):
batch = [
self.processor.query_prefix
+ t
+ self.processor.query_augmentation_token * 10
for t in texts[i : i + batch_size]
]
inputs = self.processor.process_texts(batch).to(self.device)
outs = self.encode_input(inputs)
all_embeds.extend(outs.cpu().to(torch.float32))

Expand All @@ -120,11 +127,11 @@ def get_fused_embeddings(
)

def calculate_probs(self, text_embeddings, image_embeddings):
scores = self.similarity(text_embeddings, image_embeddings)
return (scores * 100).softmax(dim=-1)
scores = self.similarity(text_embeddings, image_embeddings).T
return scores.softmax(dim=-1)

def similarity(self, a, b):
return self.processor.score_multi_vector(a, b)
return self.processor.score(a, b)


class ColPaliWrapper(ColPaliEngineWrapper):
Expand Down Expand Up @@ -164,6 +171,7 @@ def __init__(
loader=partial(
ColPaliWrapper,
model_name="vidore/colpali-v1.1",
revision="a0f15e3bcf97110e7ac1bb4be4bcd30eeb31992a",
torch_dtype=torch.float16,
),
name="vidore/colpali-v1.1",
Expand All @@ -190,6 +198,7 @@ def __init__(
loader=partial(
ColPaliWrapper,
model_name="vidore/colpali-v1.2",
revision="6b89bc63c16809af4d111bfe412e2ac6bc3c9451",
torch_dtype=torch.float16,
),
name="vidore/colpali-v1.2",
Expand All @@ -216,6 +225,7 @@ def __init__(
loader=partial(
ColPaliWrapper,
model_name="vidore/colpali-v1.3",
revision="1b5c8929330df1a66de441a9b5409a878f0de5b0",
torch_dtype=torch.float16,
),
name="vidore/colpali-v1.3",
Expand Down
37 changes: 6 additions & 31 deletions mteb/models/colqwen_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
loader=partial(
ColQwen2Wrapper,
model_name="vidore/colqwen2-v1.0",
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
Expand Down Expand Up @@ -98,14 +99,15 @@ def __init__(
loader=partial(
ColQwen2_5Wrapper,
model_name="vidore/colqwen2.5-v0.2",
revision="6f6fcdfd1a114dfe365f529701b33d66b9349014",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
),
name="vidore/colqwen2.5-v0.2",
languages=["eng-Latn"],
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
revision="6f6fcdfd1a114dfe365f529701b33d66b9349014",
release_date="2025-01-31",
modalities=["image", "text"],
n_parameters=3_000_000_000,
Expand All @@ -123,35 +125,6 @@ def __init__(
training_datasets=COLPALI_TRAINING_DATA,
)

colnomic_7b = ModelMeta(
loader=partial(
ColQwen2_5Wrapper,
model_name="nomic-ai/colnomic-embed-multimodal-7b",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
),
name="nomic-ai/colnomic-embed-multimodal-7b",
languages=["eng-Latn"],
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
release_date="2025-03-31",
modalities=["image", "text"],
n_parameters=7_000_000_000,
memory_usage_mb=14400,
max_tokens=128000,
embed_dim=128,
license="apache-2.0",
open_weights=True,
public_training_code="https://github.com/nomic-ai/colpali",
public_training_data="https://huggingface.co/datasets/vidore/colpali_train_set",
framework=["ColPali"],
reference="https://huggingface.co/nomic-ai/colnomic-embed-multimodal-7b",
similarity_fn_name="max_sim",
use_instructions=True,
training_datasets=COLPALI_TRAINING_DATA,
)

COLNOMIC_TRAINING_DATA = {"VDRMultilingual": ["Train"], **COLPALI_TRAINING_DATA}
COLNOMIC_LANGUAGES = [
"deu-Latn", # German
Expand All @@ -165,14 +138,15 @@ def __init__(
loader=partial(
ColQwen2_5Wrapper,
model_name="nomic-ai/colnomic-embed-multimodal-3b",
revision="86627b4a9b0cade577851a70afa469084f9863a4",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
),
name="nomic-ai/colnomic-embed-multimodal-3b",
languages=COLNOMIC_LANGUAGES,
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
revision="86627b4a9b0cade577851a70afa469084f9863a4",
release_date="2025-03-31",
modalities=["image", "text"],
n_parameters=3_000_000_000,
Expand All @@ -194,6 +168,7 @@ def __init__(
loader=partial(
ColQwen2_5Wrapper,
model_name="nomic-ai/colnomic-embed-multimodal-7b",
revision="09dbc9502b66605d5be56d2226019b49c9fd3293",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
Expand Down
6 changes: 4 additions & 2 deletions mteb/models/colsmol_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def __init__(
loader=partial(
ColSmolWrapper,
model_name="vidore/colSmol-256M",
revision="a59110fdf114638b8018e6c9a018907e12f14855",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
),
name="vidore/colSmol-256M",
languages=["eng-Latn"],
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
revision="a59110fdf114638b8018e6c9a018907e12f14855",
release_date="2025-01-22",
modalities=["image", "text"],
n_parameters=256_000_000,
Expand All @@ -73,14 +74,15 @@ def __init__(
loader=partial(
ColSmolWrapper,
model_name="vidore/colSmol-500M",
revision="1aa9325cba7ed2b3b9b97ede4d55026322504902",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
),
name="vidore/colSmol-500M",
languages=["eng-Latn"],
revision="530094e83a40ca4edcb5c9e5ddfa61a4b5ea0d2f",
revision="1aa9325cba7ed2b3b9b97ede4d55026322504902",
release_date="2025-01-22",
modalities=["image", "text"],
n_parameters=500_000_000,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ llm2vec = ["llm2vec>=0.2.3,<0.3.0"]
timm = ["timm>=1.0.15,<1.1.0"]
open_clip_torch = ["open_clip_torch==2.31.0"]
ark = ["volcengine-python-sdk[ark]==3.0.2", "tiktoken>=0.8.0"]
colpali_engine = ["colpali_engine>=0.3.10"]
colpali_engine = ["colpali_engine>=0.3.12"]

[tool.coverage.report]

Expand Down