Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 114 additions & 147 deletions mteb/models/model_implementations/seed_1_6_embedding_models_1215.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from io import BytesIO
from typing import TYPE_CHECKING, Any

import requests
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from mteb._requires_package import requires_package
from mteb.abstasks.task_metadata import TaskMetadata
Expand All @@ -26,114 +28,6 @@

logger = logging.getLogger(__name__)


def pil_to_base64(image, format="jpeg"):
if image is None:
return None
buffer = BytesIO()
image.save(buffer, format=format)
img_bytes = buffer.getvalue()
encoded_bytes = base64.b64encode(img_bytes)
return encoded_bytes.decode("utf-8")


def multimodal_embedding(image_base64=None, text_content=None):
auth_token = os.getenv("VOLCES_AUTH_TOKEN")
model_name = "doubao-embedding-vision-251215"
api_url = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"

headers = {
"Authorization": f"Bearer {auth_token}",
"x-ark-vlm1": "true",
"Content-Type": "application/json",
}

if image_base64 is not None and text_content is None:
inputs = []
for image in image_base64:
image_format = "jpeg"
image_data = f"data:image/{image_format};base64,{image}"
inputs.append({"type": "image_url", "image_url": {"url": image_data}})

payload = {"model": model_name, "input": inputs}
elif image_base64 is None and text_content is not None:
payload = {
"model": model_name,
"input": [
{"type": "text", "text": text_content},
],
}
else:
inputs = []
for image in image_base64:
image_format = "jpeg"
image_data = f"data:image/{image_format};base64,{image}"
inputs.append({"type": "image_url", "image_url": {"url": image_data}})
inputs.append({"type": "text", "text": text_content})
payload = {"model": model_name, "input": inputs}

try:
response = requests.post(url=api_url, headers=headers, json=payload, timeout=10)

response.raise_for_status()
return response.json()

except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error ({http_err.response.status_code}): {http_err}")
except requests.exceptions.JSONDecodeError:
logger.error("Error:The response is not in valid JSON format")
except requests.exceptions.Timeout:
logger.error("Error:Request timeout")
except Exception as e:
logger.error(f"Unknown error: {str(e)}")

return None


def multi_thread_encode(sentences, batch_size=1, max_workers=8):
batches = []
for idx in range(0, len(sentences), batch_size):
batches.append((idx // batch_size, sentences[idx : idx + batch_size]))

n_batches = len(batches)
results = [None] * n_batches # Pre-allocated result list
all_embeddings = [] # Final ordered embeddings

def _process_batch(batch_idx, batch_sentences):
sentence = batch_sentences[0]

retries = 5
while retries > 0:
try:
resp = multimodal_embedding(text_content=sentence)
embedding = torch.tensor(resp["data"]["embedding"])
break
except Exception as e:
time.sleep(1)
logger.warning(f"Retrying... {retries} retries left. Error: {str(e)}")
retries -= 1
if retries == 0:
raise e
return batch_idx, embedding

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_process_batch, idx, batch): idx for idx, batch in batches
}

for future in as_completed(futures):
batch_idx, embeddings = future.result()
results[batch_idx] = embeddings

for batch_embeddings in results:
all_embeddings.append(batch_embeddings)

all_embeddings = torch.stack(all_embeddings, dim=0)
all_embeddings = torch.nn.functional.normalize(all_embeddings, dim=-1)

return all_embeddings.float().cpu()


doubao_embedding_training_data = (
{
"PawsXPairClassification",
Expand Down Expand Up @@ -166,25 +60,80 @@ def __init__(
"pip install mteb[ark]",
"tiktoken",
)
import tiktoken

self._model_name = model_name
self._max_tokens = 32768
self._embed_dim = embed_dim
self._available_embed_dims = [2048, 1024]
self._encoding = tiktoken.get_encoding(tokenizer_name)

def truncate_text_tokens(self, text: str) -> str:
"""Truncate a string to have `max_tokens` according to the given encoding.
def pil_to_base64(self, image, format="jpeg"):
if image is None:
return None
buffer = BytesIO()
image.save(buffer, format=format)
img_bytes = buffer.getvalue()
encoded_bytes = base64.b64encode(img_bytes)
return encoded_bytes.decode("utf-8")

def multimodal_embedding(self, instruction, image_base64, text_content):
auth_token = os.getenv("VOLCES_AUTH_TOKEN")
model_name = "doubao-embedding-vision-251215"
api_url = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"

headers = {
"Authorization": f"Bearer {auth_token}",
"x-ark-vlm1": "true",
"Content-Type": "application/json",
}

Args:
text: The input string to be truncated.
if text_content is not None and len(text_content) > self._max_tokens:
text_content = text_content[: self._max_tokens]

if image_base64 is not None and text_content is None:
inputs = []
for image in image_base64:
image_format = "jpeg"
image_data = f"data:image/{image_format};base64,{image}"
inputs.append({"type": "image_url", "image_url": {"url": image_data}})

payload = {"model": model_name, "input": inputs}
elif image_base64 is None and text_content is not None:
payload = {
"model": model_name,
"instruction": instruction,
"input": [
{"type": "text", "text": text_content},
],
}
else:
inputs = []
for image in image_base64:
image_format = "jpeg"
image_data = f"data:image/{image_format};base64,{image}"
inputs.append({"type": "image_url", "image_url": {"url": image_data}})
inputs.append({"type": "text", "text": text_content})
payload = {"model": model_name, "input": inputs}

max_retries = 3
retry_count = 0

while retry_count < max_retries:
response = requests.post(
url=api_url, headers=headers, json=payload, timeout=30
)

Returns:
The truncated string.
"""
truncated_sentence = self._encoding.encode(text)[: self._max_tokens]
return self._encoding.decode(truncated_sentence)
if response.status_code != 200:
retry_count += 1
time.sleep(3)
continue

response_json = response.json()
return response_json

raise Exception(
f"Request failed with status code {response.status_code}. "
f"Response: {response.text}"
)

def get_fused_embeddings(
self,
Expand All @@ -204,59 +153,69 @@ def get_fused_embeddings(
if images is not None and texts is not None:
assert len(texts) == len(images)
batch_len = len(texts)
images_base64 = [pil_to_base64(image) for image in images]
images_base64 = [self.pil_to_base64(image) for image in images]
elif images is None:
batch_len = len(texts)
images_base64 = [None for _ in range(batch_len)]
elif texts is None:
batch_len = len(images)
images_base64 = [pil_to_base64(image) for image in images]
images_base64 = [self.pil_to_base64(image) for image in images]
else:
raise ValueError("images and texts cannot be None at the same time")

outputs = []
for i in range(batch_len):
def process_item(
i, prompt_type, task_name, texts, images_base64, multimodal_embedding
):
if (
prompt_type == PromptType("query") or prompt_type is None
) and task_name in TASK_NAME_TO_INSTRUCTION:
instruction = TASK_NAME_TO_INSTRUCTION[task_name]
instruction = instruction.rstrip("{}").rstrip("\n")
if texts[i] != "":
input_text = (
"Target_modality:Text.\n Instruction:"
+ instruction
+ "\n Query:{}"
).format(texts[i])
else:
input_text = (
"Target_modality:Text.\n Instruction:"
+ instruction
+ "\n Query:"
)
instruction = (
"Target_modality:Text.\n Instruction:" + instruction + "\n Query:"
)
input_text = texts[i]
else:
if texts[i] != "" and images_base64[i] is not None:
instruction = "Instruction: Compress the the text and image into one word.\n Query: {}"
input_text = instruction.format(texts[i])
instruction = "Instruction: Compress the text and image into one word.\n Query:"
input_text = texts[i]
elif texts[i] != "":
instruction = (
"Instruction: Compress the the text into one word.\n Query: {}"
"Instruction: Compress the text into one word.\n Query:"
)
input_text = instruction.format(texts[i])
input_text = texts[i]
elif images_base64[i] is not None:
instruction = (
"Instruction: Compress the the image into one word.\n Query:"
"Instruction: Compress the image into one word.\n Query:"
)
input_text = instruction
input_text = None
else:
raise ValueError("image and text are both None")

resp = multimodal_embedding(
image_base64=[images_base64[i]], text_content=input_text
instruction=instruction,
image_base64=images_base64[i],
text_content=input_text,
)
embedding = torch.tensor(resp["data"]["embedding"])
embedding = torch.reshape(embedding, (1, -1))
return embedding

outputs = []
process_partial = partial(
process_item,
prompt_type=prompt_type,
task_name=task_name,
texts=texts,
images_base64=images_base64,
multimodal_embedding=self.multimodal_embedding,
)
with ThreadPoolExecutor(max_workers=15) as executor:
futures = [executor.submit(process_partial, i) for i in range(batch_len)]
for future in tqdm(futures, total=batch_len, desc="Encoding"):
outputs.append(future.result())

outputs = torch.stack(outputs, dim=0)
outputs = torch.stack(outputs, dim=0).squeeze(1)

if self._embed_dim is not None:
outputs = outputs[:, : self._embed_dim]
Expand All @@ -273,13 +232,21 @@ def encode(
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> Array:
sentences = [text for batch in inputs for text in batch["text"]]
images = [image for batch in inputs for image in batch["image"]]
if "text" in inputs.dataset.features:
sentences = [text for batch in inputs for text in batch["text"]]
else:
sentences = None

if "image" in inputs.dataset.features:
images = [image for batch in inputs for image in batch["image"]]
else:
images = None

return self.get_fused_embeddings(
texts=sentences,
images=images,
task_name=task_metadata.name,
prompt_type=prompt_type,
**kwargs,
)

Expand Down Expand Up @@ -648,7 +615,7 @@ def encode(
n_parameters=None,
memory_usage_mb=None,
license=None,
reference="https://seed1-6-embedding-1215.github.io/",
reference="https://console.volcengine.com/ark/region:ark+cn-beijing/model/detail?Id=doubao-embedding-vision",
similarity_fn_name="cosine",
framework=["API"],
use_instructions=True,
Expand Down