Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 129 additions & 30 deletions mteb/models/cohere_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

from functools import partial
import logging
import time
from functools import partial, wraps
from typing import Any, Literal, get_args

import numpy as np
import torch
import tqdm
from tqdm.auto import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.wrapper import Wrapper

logger = logging.getLogger(__name__)

supported_languages = [
"afr-Latn",
"amh-Ethi",
Expand Down Expand Up @@ -132,6 +136,76 @@
"binary",
]

# Cohere API limits
COHERE_MAX_BATCH_SIZE = 96 # Maximum number of texts per API call
COHERE_MAX_TOKENS_PER_BATCH = 128_000 # Maximum total tokens per API call


def retry_with_rate_limit(
max_retries: int = 5,
max_rpm: int = 300,
initial_delay: float = 1.0,
):
"""Combined retry and rate limiting decorator.

This decorator handles both proactive rate limiting (spacing requests)
and reactive retry with exponential backoff for API errors.

The decorator will use instance attributes (self.max_retries, self.max_rpm)
if they exist, otherwise falls back to the decorator parameters.

Args:
max_retries: Default maximum number of retry attempts (default: 5)
max_rpm: Default maximum requests per minute for rate limiting (default: 300)
initial_delay: Initial delay in seconds for exponential backoff (default: 1.0)
"""
previous_call_ts: float | None = None

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
import cohere

nonlocal previous_call_ts

request_interval = 60.0 / max_rpm

# Rate limiting: wait before making request if needed
current_time = time.time()
if (
previous_call_ts is not None
and current_time - previous_call_ts < request_interval
):
time.sleep(request_interval - (current_time - previous_call_ts))

# Retry logic with exponential backoff
for attempt in range(max_retries):
try:
result = func(self, *args, **kwargs)
previous_call_ts = time.time()
return result
except cohere.errors.TooManyRequestsError as e:
if attempt == max_retries - 1:
raise
# For rate limits, wait longer (30s minimum to respect API limits)
delay = max(30, initial_delay * (2**attempt))
logger.warning(
f"Cohere rate limit (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay}s..."
)
time.sleep(delay)
except Exception as e:
if attempt == max_retries - 1:
raise
delay = initial_delay * (2**attempt)
logger.warning(
f"Cohere API error (attempt {attempt + 1}/{max_retries}): {e}. Retrying in {delay}s..."
)
time.sleep(delay)

return wrapper

return decorator


# Implementation follows https://github.com/KennethEnevoldsen/scandinavian-embedding-benchmark/blob/main/src/seb/registered_models/cohere_models.py
class CohereTextEmbeddingModel(Wrapper):
Expand All @@ -144,52 +218,74 @@ def __init__(
output_dimension: int | None = None,
**kwargs,
) -> None:
import cohere # type: ignore

self.model_name = model_name
self.sep = sep
self.model_prompts = self.validate_task_to_prompt_name(model_prompts)
assert embedding_type in get_args(EMBEDDING_TYPE)
self.embedding_type = embedding_type
self.output_dimension = output_dimension

self._client = cohere.Client()

@retry_with_rate_limit(max_retries=5, max_rpm=300)
def _embed_func(self, **kwargs):
"""Call Cohere embed API with retry and rate limiting."""
return self._client.embed(**kwargs)

def _embed(
self,
sentences: list[str],
cohere_task_type: str,
show_progress_bar: bool = False,
retries: int = 5,
) -> torch.Tensor:
import cohere # type: ignore
all_embeddings = []
index = 0

max_batch_size = 256
pbar = tqdm(
total=len(sentences),
desc="Encoding sentences",
disable=not show_progress_bar,
)

batches = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
]
while index < len(sentences):
# Build batch respecting both count and token limits
batch, batch_tokens = [], 0
while (
index < len(sentences)
and len(batch) < COHERE_MAX_BATCH_SIZE
and batch_tokens < COHERE_MAX_TOKENS_PER_BATCH
):
# Count tokens for current sentence
n_tokens = len(
self._client.tokenize(
text=sentences[index], model=self.model_name
).tokens
)

client = cohere.Client()
# Check if adding this sentence would exceed token limit
if (
batch_tokens + n_tokens > COHERE_MAX_TOKENS_PER_BATCH
and len(batch) > 0
):
break

all_embeddings = []
batch_tokens += n_tokens
batch.append(sentences[index])
index += 1

for batch in tqdm.tqdm(batches, leave=False, disable=not show_progress_bar):
while retries > 0: # Cohere's API is not always reliable
try:
embed_kwargs = {
"texts": batch,
"model": self.model_name,
"input_type": cohere_task_type,
"embedding_types": [self.embedding_type],
}
if self.output_dimension is not None:
embed_kwargs["output_dimension"] = self.output_dimension

response = client.embed(**embed_kwargs)
break
except Exception as e:
print(f"Retrying... {retries} retries left.")
retries -= 1
if retries == 0:
raise e
# Embed the batch with retry logic handled by client
embed_kwargs = {
"texts": batch,
"model": self.model_name,
"input_type": cohere_task_type,
"embedding_types": [self.embedding_type],
}
if self.output_dimension is not None:
embed_kwargs["output_dimension"] = self.output_dimension

response = self._embed_func(**embed_kwargs)

# Get embeddings based on requested type
if self.embedding_type == "float":
Expand All @@ -202,8 +298,11 @@ def _embed(
embeddings = response.embeddings.binary
else:
raise ValueError(f"Embedding type {self.embedding_type} not allowed")

all_embeddings.extend(torch.tensor(embeddings).numpy())
pbar.update(len(batch))

pbar.close()
embeddings_array = np.array(all_embeddings)

# Post-process embeddings based on type (similar to voyage_models.py)
Expand Down
55 changes: 48 additions & 7 deletions mteb/models/cohere_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.auto import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.cohere_models import (
COHERE_MAX_BATCH_SIZE,
COHERE_MAX_TOKENS_PER_BATCH,
retry_with_rate_limit,
)
from mteb.requires_package import requires_image_dependencies, requires_package


Expand Down Expand Up @@ -194,10 +199,16 @@ def __init__(
self.embedding_type = embedding_type
self.output_dimension = output_dimension
api_key = os.getenv("COHERE_API_KEY")

self.client = cohere.ClientV2(api_key)
self.image_format = "JPEG"
self.transform = transforms.Compose([transforms.PILToTensor()])

@retry_with_rate_limit(max_retries=5, max_rpm=300)
def _embed_func(self, **kwargs):
"""Call Cohere embed API with retry and rate limiting."""
return self.client.embed(**kwargs)

def get_text_embeddings(
self,
texts: list[str],
Expand All @@ -208,19 +219,47 @@ def get_text_embeddings(
**kwargs: Any,
):
all_text_embeddings = []
index = 0

pbar = tqdm(total=len(texts), desc="Encoding text sentences")

while index < len(texts):
# Build batch respecting both count and token limits
batch, batch_tokens = [], 0
while (
index < len(texts)
and len(batch) < COHERE_MAX_BATCH_SIZE
and batch_tokens < COHERE_MAX_TOKENS_PER_BATCH
):
# Count tokens for current sentence
n_tokens = len(
self.client.tokenize(
text=texts[index], model=self.model_name
).tokens
)

# Check if adding this sentence would exceed token limit
if (
batch_tokens + n_tokens > COHERE_MAX_TOKENS_PER_BATCH
and len(batch) > 0
):
break

batch_tokens += n_tokens
batch.append(texts[index])
index += 1

for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i : i + batch_size]
# Embed the batch with retry logic handled by client
embed_kwargs = {
"texts": batch_texts,
"texts": batch,
"model": self.model_name,
"input_type": "search_document",
"embedding_types": [self.embedding_type],
}
if self.output_dimension is not None:
embed_kwargs["output_dimension"] = self.output_dimension

response = self.client.embed(**embed_kwargs)
response = self._embed_func(**embed_kwargs)

# Get embeddings based on requested type
if self.embedding_type == "float":
Expand All @@ -236,7 +275,9 @@ def get_text_embeddings(
f"Embedding type {self.embedding_type} not allowed"
)
all_text_embeddings.append(torch.tensor(embeddings))
pbar.update(len(batch))

pbar.close()
all_text_embeddings = torch.cat(all_text_embeddings, dim=0)

# Post-process embeddings based on type
Expand Down Expand Up @@ -281,7 +322,7 @@ def get_image_embeddings(
if self.output_dimension is not None:
embed_kwargs["output_dimension"] = self.output_dimension

response = self.client.embed(**embed_kwargs)
response = self._embed_func(**embed_kwargs)

# Get embeddings based on requested type
if self.embedding_type == "float":
Expand Down Expand Up @@ -322,7 +363,7 @@ def get_image_embeddings(
if self.output_dimension is not None:
embed_kwargs["output_dimension"] = self.output_dimension

response = self.client.embed(**embed_kwargs)
response = self._embed_func(**embed_kwargs)

# Get embeddings based on requested type
if self.embedding_type == "float":
Expand Down