-
Notifications
You must be signed in to change notification settings - Fork 583
[mieb] add siglip, cohere multimodal & some fixes for final run #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
4b19e2d
fix dataset type error
gowitheflow-1998 f89827f
fix clustering metrics
gowitheflow-1998 219378c
add siglip & cohere
gowitheflow-1998 812a0cd
update mieb run script
gowitheflow-1998 a0eb4ec
cohere-v import
gowitheflow-1998 edddc96
fix
gowitheflow-1998 2a4d9ce
api key name
gowitheflow-1998 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import partial | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| from PIL import Image | ||
| from torch.utils.data import DataLoader | ||
| from torchvision import transforms | ||
| from tqdm import tqdm | ||
| import os | ||
| import io | ||
| import base64 | ||
| import mteb | ||
| import time | ||
| from mteb.model_meta import ModelMeta | ||
|
|
||
| api_key = os.getenv("COHERE_API_KEY") | ||
| tensor_to_image = transforms.Compose([transforms.ToPILImage()]) | ||
|
|
||
|
|
||
| def cohere_v_loader(**kwargs): | ||
| try: | ||
| import cohere | ||
| except ImportError: | ||
| raise ImportError("To use cohere models, please run `pip install cohere`.") | ||
|
|
||
| class CohereMultiModalModelWrapper: | ||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| **kwargs: Any, | ||
| ): | ||
| self.model_name = model_name | ||
| self.client = cohere.ClientV2(api_key) | ||
| self.image_format = "JPEG" | ||
| """ Wrapper for Cohere multimodal embedding model, | ||
|
|
||
| do `export COHERE_API_KEY=<Your_Cohere_API_KEY>` before running eval scripts. | ||
| Cohere currently supports 40 images/min, thus time.sleep(1.5) is applied after each image. | ||
| Remove or adjust this after Cohere API changes capacity. | ||
| """ | ||
|
|
||
| def get_text_embeddings(self, texts: list[str], batch_size: int = 32): | ||
| all_text_embeddings = [] | ||
|
|
||
| for i in tqdm(range(0, len(texts), batch_size)): | ||
| batch_texts = texts[i : i + batch_size] | ||
| response = self.client.embed( | ||
| texts=batch_texts, | ||
| model=self.model_name, | ||
| input_type="search_document", | ||
| ) | ||
| all_text_embeddings.append(torch.tensor(response.embeddings.float)) | ||
|
|
||
| all_text_embeddings = torch.cat(all_text_embeddings, dim=0) | ||
| return all_text_embeddings | ||
|
|
||
| def get_image_embeddings( | ||
| self, images: list[Image.Image] | DataLoader, batch_size: int = 32 | ||
| ): | ||
| all_image_embeddings = [] | ||
|
|
||
| if isinstance(images, DataLoader): | ||
| for batch in tqdm(images): | ||
| for image in batch: | ||
| # cohere only supports 1 image per call | ||
| buffered = io.BytesIO() | ||
| image = tensor_to_image(image) | ||
| image.save(buffered, format=self.image_format) | ||
| image_bytes = buffered.getvalue() | ||
| stringified_buffer = base64.b64encode(image_bytes).decode( | ||
| "utf-8" | ||
| ) | ||
| content_type = f"image/{self.image_format.lower()}" | ||
| image_base64 = ( | ||
| f"data:{content_type};base64,{stringified_buffer}" | ||
| ) | ||
| response = self.client.embed( | ||
| model=self.model_name, | ||
| input_type="image", | ||
| embedding_types=["float"], | ||
| images=[image_base64], | ||
| ) | ||
| all_image_embeddings.append( | ||
| torch.tensor(response.embeddings.float) | ||
| ) | ||
| time.sleep(1.5) | ||
| else: | ||
| for i in tqdm(range(0, len(images), batch_size)): | ||
| batch_images = images[i : i + batch_size] | ||
| for image in batch_images: | ||
| # cohere only supports 1 image per call | ||
| buffered = io.BytesIO() | ||
| image.save(buffered, format=self.image_format) | ||
| image_bytes = buffered.getvalue() | ||
| stringified_buffer = base64.b64encode(image_bytes).decode( | ||
| "utf-8" | ||
| ) | ||
| content_type = f"image/{self.image_format.lower()}" | ||
| image_base64 = ( | ||
| f"data:{content_type};base64,{stringified_buffer}" | ||
| ) | ||
| response = self.client.embed( | ||
| model=self.model_name, | ||
| input_type="image", | ||
| embedding_types=["float"], | ||
| images=[image_base64], | ||
| ) | ||
| all_image_embeddings.append( | ||
| torch.tensor(response.embeddings.float) | ||
| ) | ||
| time.sleep(1.5) | ||
| all_image_embeddings = torch.cat(all_image_embeddings, dim=0) | ||
| return all_image_embeddings | ||
|
|
||
| def calculate_probs(self, text_embeddings, image_embeddings): | ||
| text_embeddings = text_embeddings / text_embeddings.norm( | ||
| dim=-1, keepdim=True | ||
| ) | ||
| image_embeddings = image_embeddings / image_embeddings.norm( | ||
| dim=-1, keepdim=True | ||
| ) | ||
| logits = torch.matmul(image_embeddings, text_embeddings.T) | ||
| probs = (logits * 100).softmax(dim=-1) | ||
| return probs | ||
|
|
||
| def get_fused_embeddings( | ||
| self, | ||
| texts: list[str] = None, | ||
| images: list[Image.Image] | DataLoader = None, | ||
| fusion_mode="sum", | ||
| batch_size: int = 32, | ||
| ): | ||
| if texts is None and images is None: | ||
| raise ValueError("Either texts or images must be provided") | ||
|
|
||
| text_embeddings = None | ||
| image_embeddings = None | ||
|
|
||
| if texts is not None: | ||
| text_embeddings = self.get_text_embeddings(texts, batch_size) | ||
|
|
||
| if images is not None: | ||
| image_embeddings = self.get_image_embeddings(images, batch_size) | ||
|
|
||
| if text_embeddings is not None and image_embeddings is not None: | ||
| if len(text_embeddings) != len(image_embeddings): | ||
| raise ValueError( | ||
| "The number of texts and images must have the same length" | ||
| ) | ||
| if fusion_mode == "sum": | ||
| fused_embeddings = text_embeddings + image_embeddings | ||
| else: | ||
| # to do: add other fusion mode | ||
| raise ValueError( | ||
| f"fusion mode {fusion_mode} hasn't been implemented" | ||
| ) | ||
| return fused_embeddings | ||
| elif text_embeddings is not None: | ||
| return text_embeddings | ||
| elif image_embeddings is not None: | ||
| return image_embeddings | ||
|
|
||
| return CohereMultiModalModelWrapper(**kwargs) | ||
|
|
||
|
|
||
| cohere_mult_3 = ModelMeta( | ||
| loader=partial(cohere_v_loader, model_name="embed-multilingual-v3.0"), | ||
| name="embed-multilingual-v3.0-v", | ||
| languages=[], # Unknown, but support >100 languages | ||
| open_source=False, | ||
| revision="1", | ||
| release_date="2024-10-24", | ||
| n_parameters=None, | ||
| memory_usage=None, | ||
| max_tokens=None, | ||
| embed_dim=1024, | ||
| license=None, | ||
| similarity_fn_name="cosine", | ||
| framework=[], | ||
| ) | ||
|
|
||
| cohere_eng_3 = ModelMeta( | ||
| loader=partial(cohere_v_loader, model_name="embed-english-v3.0"), | ||
| name="embed-english-v3.0-v", | ||
| languages=["eng-Latn"], | ||
| open_source=False, | ||
| revision="1", | ||
| release_date="2024-10-24", | ||
| n_parameters=None, | ||
| memory_usage=None, | ||
| max_tokens=None, | ||
| embed_dim=1024, | ||
| license=None, | ||
| similarity_fn_name="cosine", | ||
| framework=[], | ||
| ) | ||
|
|
||
| if __name__ == "__main__": | ||
| mdl = mteb.get_model(cohere_mult_3.name, cohere_mult_3.revision) | ||
| emb = mdl.encode(["Hello, world!"]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.