Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class CohereAITruncate(str, Enum):
# supported image formats
SUPPORTED_IMAGE_FORMATS = {"png", "jpeg", "jpg", "webp", "gif"}

# Maximum batch size for Cohere API
MAX_EMBED_BATCH_SIZE = 96


# Assuming BaseEmbedding is a Pydantic model and handles its own initializations
class CohereEmbedding(MultiModalEmbedding):
Expand Down Expand Up @@ -171,6 +174,8 @@ def __init__(
'search_document', 'classification', or 'clustering'.
model_name (str): The name of the model to be used for generating embeddings. The class ensures that
this model is supported and that the input type provided is compatible with the model.
embed_batch_size (int): The batch size for embedding generation. Maximum allowed value is 96 (MAX_EMBED_BATCH_SIZE)
due to Cohere API limitations. Defaults to DEFAULT_EMBED_BATCH_SIZE.

"""
# Validate model_name and input_type
Expand All @@ -189,6 +194,12 @@ def __init__(
if truncate not in VALID_TRUNCATE_OPTIONS:
raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")

# Validate embed_batch_size
if embed_batch_size > MAX_EMBED_BATCH_SIZE:
raise ValueError(
f"embed_batch_size {embed_batch_size} exceeds the maximum allowed value of {MAX_EMBED_BATCH_SIZE} for Cohere API"
)

super().__init__(
api_key=api_key or cohere_api_key,
model_name=model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev = [

[project]
name = "llama-index-embeddings-cohere"
version = "0.5.0"
version = "0.5.1"
description = "llama-index embeddings cohere integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ def test_sync_embedding():
emb.get_query_embedding("I love Cohere!")


@pytest.mark.skipif(
os.environ.get("CO_API_KEY") is None, reason="Cohere API key required"
)
def test_batch_size_validation():
"""Test that batch size validation works correctly."""
# Test batch size exceeding the limit
with pytest.raises(ValueError) as exc_info:
CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=97)
assert "exceeds the maximum allowed value of 96" in str(exc_info.value)

# Test batch size at the limit (should not raise)
emb = CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=96)
assert emb.embed_batch_size == 96

# Test batch size below the limit (should not raise)
emb = CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=50)
assert emb.embed_batch_size == 50


@pytest.mark.skipif(
os.environ.get("CO_API_KEY") is None, reason="Cohere API key required"
)
Expand Down
Loading