diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index 881c27b722..bee1b7b159 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -117,6 +117,9 @@ class CohereAITruncate(str, Enum): # supported image formats SUPPORTED_IMAGE_FORMATS = {"png", "jpeg", "jpg", "webp", "gif"} +# Maximum batch size for Cohere API +MAX_EMBED_BATCH_SIZE = 96 + # Assuming BaseEmbedding is a Pydantic model and handles its own initializations class CohereEmbedding(MultiModalEmbedding): @@ -171,6 +174,8 @@ def __init__( 'search_document', 'classification', or 'clustering'. model_name (str): The name of the model to be used for generating embeddings. The class ensures that this model is supported and that the input type provided is compatible with the model. + embed_batch_size (int): The batch size for embedding generation. Maximum allowed value is 96 (MAX_EMBED_BATCH_SIZE) + due to Cohere API limitations. Defaults to DEFAULT_EMBED_BATCH_SIZE. """ # Validate model_name and input_type @@ -189,6 +194,12 @@ def __init__( if truncate not in VALID_TRUNCATE_OPTIONS: raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") + # Validate embed_batch_size + if embed_batch_size > MAX_EMBED_BATCH_SIZE: + raise ValueError( + f"embed_batch_size {embed_batch_size} exceeds the maximum allowed value of {MAX_EMBED_BATCH_SIZE} for Cohere API" + ) + super().__init__( api_key=api_key or cohere_api_key, model_name=model_name, diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index 5f31236a4d..b1774e7a6e 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,7 +27,7 @@ dev = [ [project] name = "llama-index-embeddings-cohere" -version = "0.5.0" +version = "0.5.1" description = "llama-index embeddings cohere integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py index e277abd8fb..56f7c4bb20 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/tests/test_embeddings.py @@ -29,6 +29,25 @@ def test_sync_embedding(): emb.get_query_embedding("I love Cohere!") +@pytest.mark.skipif( + os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" +) +def test_batch_size_validation(): + """Test that batch size validation works correctly.""" + # Test batch size exceeding the limit + with pytest.raises(ValueError) as exc_info: + CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=97) + assert "exceeds the maximum allowed value of 96" in str(exc_info.value) + + # Test batch size at the limit (should not raise) + emb = CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=96) + assert emb.embed_batch_size == 96 + + # Test batch size below the limit (should not raise) + emb = CohereEmbedding(api_key=os.environ["CO_API_KEY"], embed_batch_size=50) + assert emb.embed_batch_size == 50 + + @pytest.mark.skipif( os.environ.get("CO_API_KEY") is None, reason="Cohere API key required" )