diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index 1e719ed373..f23cea85ca 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -3,13 +3,16 @@ import hashlib from pathlib import Path from typing import Any, TypeAlias, TypedDict +import uuid from crewai.rag.config.types import RagConfigType from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient from crewai.rag.factory import create_client +from crewai.rag.qdrant.config import QdrantConfig from crewai.rag.types import BaseRecord, SearchResult from pydantic import PrivateAttr +from qdrant_client.models import VectorParams from typing_extensions import Unpack from crewai_tools.rag.data_types import DataType @@ -52,7 +55,11 @@ def model_post_init(self, __context: Any) -> None: self._client = create_client(self.config) else: self._client = get_rag_client() - self._client.get_or_create_collection(collection_name=self.collection_name) + collection_params: dict[str, Any] = {"collection_name": self.collection_name} + if isinstance(self.config, QdrantConfig) and self.config.vectors_config: + if isinstance(self.config.vectors_config, VectorParams): + collection_params["vectors_config"] = self.config.vectors_config + self._client.get_or_create_collection(**collection_params) def query( self, @@ -76,6 +83,8 @@ def query( if similarity_threshold is not None else self.similarity_threshold ) + if self._client is None: + raise ValueError("Client is not initialized") results: list[SearchResult] = self._client.search( collection_name=self.collection_name, @@ -201,9 +210,10 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: if isinstance(arg, dict): file_metadata.update(arg.get("metadata", {})) - chunk_id = hashlib.sha256( + chunk_hash = hashlib.sha256( f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode() ).hexdigest() + chunk_id = str(uuid.UUID(chunk_hash[:32])) documents.append( { @@ -251,9 +261,10 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: if isinstance(arg, dict): chunk_metadata.update(arg.get("metadata", {})) - chunk_id = hashlib.sha256( + chunk_hash = hashlib.sha256( f"{loader_result.doc_id}_{i}_{chunk}".encode() ).hexdigest() + chunk_id = str(uuid.UUID(chunk_hash[:32])) documents.append( { @@ -264,6 +275,8 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None: ) if documents: + if self._client is None: + raise ValueError("Client is not initialized") self._client.add_documents( collection_name=self.collection_name, documents=documents ) diff --git a/lib/crewai-tools/src/crewai_tools/rag/core.py b/lib/crewai-tools/src/crewai_tools/rag/core.py index 9c731c223b..9b4b64d036 100644 --- a/lib/crewai-tools/src/crewai_tools/rag/core.py +++ b/lib/crewai-tools/src/crewai_tools/rag/core.py @@ -4,12 +4,12 @@ from uuid import uuid4 import chromadb -import litellm from pydantic import BaseModel, Field, PrivateAttr from crewai_tools.rag.base_loader import BaseLoader from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.data_types import DataType +from crewai_tools.rag.embedding_service import EmbeddingService from crewai_tools.rag.misc import compute_sha256 from crewai_tools.rag.source_content import SourceContent from crewai_tools.tools.rag.rag_tool import Adapter @@ -18,31 +18,6 @@ logger = logging.getLogger(__name__) -class EmbeddingService: - def __init__(self, model: str = "text-embedding-3-small", **kwargs): - self.model = model - self.kwargs = kwargs - - def embed_text(self, text: str) -> list[float]: - try: - response = litellm.embedding(model=self.model, input=[text], **self.kwargs) - return response.data[0]["embedding"] - except Exception as e: - logger.error(f"Error generating embedding: {e}") - raise - - def embed_batch(self, texts: list[str]) -> list[list[float]]: - if not texts: - return [] - - try: - response = litellm.embedding(model=self.model, input=texts, **self.kwargs) - return [data["embedding"] for data in response.data] - except Exception as e: - logger.error(f"Error generating batch embeddings: {e}") - raise - - class Document(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) content: str @@ -54,6 +29,7 @@ class Document(BaseModel): class RAG(Adapter): collection_name: str = "crewai_knowledge_base" persist_directory: str | None = None + embedding_provider: str = "openai" embedding_model: str = "text-embedding-3-large" summarize: bool = False top_k: int = 5 @@ -79,7 +55,9 @@ def model_post_init(self, __context: Any) -> None: ) self._embedding_service = EmbeddingService( - model=self.embedding_model, **self.embedding_config + provider=self.embedding_provider, + model=self.embedding_model, + **self.embedding_config, ) except Exception as e: logger.error(f"Failed to initialize ChromaDB: {e}") @@ -181,7 +159,7 @@ def add( except Exception as e: logger.error(f"Failed to add documents to ChromaDB: {e}") - def query(self, question: str, where: dict[str, Any] | None = None) -> str: + def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore try: question_embedding = self._embedding_service.embed_text(question) diff --git a/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py b/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py new file mode 100644 index 0000000000..9dd146668c --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/rag/embedding_service.py @@ -0,0 +1,508 @@ +""" +Enhanced embedding service that leverages CrewAI's existing embedding providers. +This replaces the litellm-based EmbeddingService with a more flexible architecture. +""" + +import logging +import os +from typing import Any + +from pydantic import BaseModel, Field + + +logger = logging.getLogger(__name__) + + +class EmbeddingConfig(BaseModel): + """Configuration for embedding providers.""" + + provider: str = Field(description="Embedding provider name") + model: str = Field(description="Model name to use") + api_key: str | None = Field(default=None, description="API key for the provider") + timeout: float | None = Field( + default=30.0, description="Request timeout in seconds" + ) + max_retries: int = Field(default=3, description="Maximum number of retries") + batch_size: int = Field( + default=100, description="Batch size for processing multiple texts" + ) + extra_config: dict[str, Any] = Field( + default_factory=dict, description="Additional provider-specific configuration" + ) + + +class EmbeddingService: + """ + Enhanced embedding service that uses CrewAI's existing embedding providers. + + Supports multiple providers: + - openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.) + - voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.) + - cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.) + - google-generativeai: Google Gemini embeddings (models/embedding-001, etc.) + - google-vertex: Google Vertex embeddings (models/embedding-001, etc.) + - huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.) + - jina: Jina embeddings (jina-embeddings-v2-base-en, etc.) + - ollama: Ollama embeddings (nomic-embed-text, etc.) + - openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.) + - roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.) + - voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.) + - watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.) + - custom: Custom embeddings (embedding_callable, etc.) + - sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.) + - text2vec: Text2Vec embeddings (text2vec-base-en, etc.) + - openclip: OpenClip embeddings (openclip-large-v2, etc.) + - instructor: Instructor embeddings (hkunlp/instructor-large, etc.) + - onnx: ONNX embeddings (onnx-large-v2, etc.) + """ + + def __init__( + self, + provider: str = "openai", + model: str = "text-embedding-3-small", + api_key: str | None = None, + **kwargs: Any, + ): + """ + Initialize the embedding service. + + Args: + provider: The embedding provider to use + model: The model name + api_key: API key (if not provided, will look for environment variables) + **kwargs: Additional configuration options + """ + self.config = EmbeddingConfig( + provider=provider, + model=model, + api_key=api_key or self._get_default_api_key(provider), + **kwargs, + ) + + self._embedding_function = None + self._initialize_embedding_function() + + def _get_default_api_key(self, provider: str) -> str | None: + """Get default API key from environment variables.""" + env_key_map = { + "azure": "AZURE_OPENAI_API_KEY", + "amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE + "cohere": "COHERE_API_KEY", + "google-generativeai": "GOOGLE_API_KEY", + "google-vertex": "GOOGLE_APPLICATION_CREDENTIALS", + "huggingface": "HUGGINGFACE_API_KEY", + "jina": "JINA_API_KEY", + "ollama": None, # Ollama typically runs locally without API key + "openai": "OPENAI_API_KEY", + "roboflow": "ROBOFLOW_API_KEY", + "voyageai": "VOYAGE_API_KEY", + "watsonx": "WATSONX_API_KEY", + } + + env_key = env_key_map.get(provider) + if env_key: + return os.getenv(env_key) + return None + + def _initialize_embedding_function(self): + """Initialize the embedding function using CrewAI's factory.""" + try: + from crewai.rag.embeddings.factory import build_embedder + + # Build the configuration for CrewAI's factory + config = self._build_provider_config() + + # Create the embedding function + self._embedding_function = build_embedder(config) + + logger.info( + f"Initialized {self.config.provider} embedding service with model " + f"{self.config.model}" + ) + + except ImportError as e: + raise ImportError( + f"CrewAI embedding providers not available. " + f"Make sure crewai is installed: {e}" + ) from e + except Exception as e: + logger.error(f"Failed to initialize embedding function: {e}") + raise RuntimeError( + f"Failed to initialize {self.config.provider} embedding service: {e}" + ) from e + + def _build_provider_config(self) -> dict[str, Any]: + """Build configuration dictionary for CrewAI's embedding factory.""" + base_config = {"provider": self.config.provider, "config": {}} + + # Provider-specific configuration mapping + if self.config.provider == "openai": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "azure": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "voyageai": + base_config["config"] = { + "api_key": self.config.api_key, + "model": self.config.model, + "max_retries": self.config.max_retries, + "timeout": self.config.timeout, + **self.config.extra_config, + } + elif self.config.provider == "cohere": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider in ["google-generativeai", "google-vertex"]: + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "amazon-bedrock": + base_config["config"] = { + "aws_access_key_id": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "huggingface": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "jina": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "ollama": + base_config["config"] = { + "model": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "sentence-transformer": + base_config["config"] = { + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "instructor": + base_config["config"] = { + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "onnx": + base_config["config"] = { + **self.config.extra_config, + } + elif self.config.provider == "roboflow": + base_config["config"] = { + "api_key": self.config.api_key, + **self.config.extra_config, + } + elif self.config.provider == "openclip": + base_config["config"] = { + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "text2vec": + base_config["config"] = { + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "watsonx": + base_config["config"] = { + "api_key": self.config.api_key, + "model_name": self.config.model, + **self.config.extra_config, + } + elif self.config.provider == "custom": + # Custom provider requires embedding_callable in extra_config + base_config["config"] = { + **self.config.extra_config, + } + else: + # Generic configuration for any unlisted providers + base_config["config"] = { + "api_key": self.config.api_key, + "model": self.config.model, + **self.config.extra_config, + } + + return base_config + + def embed_text(self, text: str) -> list[float]: + """ + Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + List of floats representing the embedding + + Raises: + RuntimeError: If embedding generation fails + """ + if not text or not text.strip(): + logger.warning("Empty text provided for embedding") + return [] + + try: + # Use ChromaDB's embedding function interface + embeddings = self._embedding_function([text]) # type: ignore + return embeddings[0] if embeddings else [] + + except Exception as e: + logger.error(f"Error generating embedding for text: {e}") + raise RuntimeError(f"Failed to generate embedding: {e}") from e + + def embed_batch(self, texts: list[str]) -> list[list[float]]: + """ + Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + + Returns: + List of embedding vectors + + Raises: + RuntimeError: If embedding generation fails + """ + if not texts: + return [] + + # Filter out empty texts + valid_texts = [text for text in texts if text and text.strip()] + if not valid_texts: + logger.warning("No valid texts provided for batch embedding") + return [] + + try: + # Process in batches to avoid API limits + all_embeddings = [] + + for i in range(0, len(valid_texts), self.config.batch_size): + batch = valid_texts[i : i + self.config.batch_size] + batch_embeddings = self._embedding_function(batch) # type: ignore + all_embeddings.extend(batch_embeddings) + + return all_embeddings + + except Exception as e: + logger.error(f"Error generating batch embeddings: {e}") + raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e + + def get_embedding_dimension(self) -> int | None: + """ + Get the dimension of embeddings produced by this service. + + Returns: + Embedding dimension or None if unknown + """ + # Try to get dimension by generating a test embedding + try: + test_embedding = self.embed_text("test") + return len(test_embedding) if test_embedding else None + except Exception: + logger.warning("Could not determine embedding dimension") + return None + + def validate_connection(self) -> bool: + """ + Validate that the embedding service is working correctly. + + Returns: + True if the service is working, False otherwise + """ + try: + test_embedding = self.embed_text("test connection") + return len(test_embedding) > 0 + except Exception as e: + logger.error(f"Connection validation failed: {e}") + return False + + def get_service_info(self) -> dict[str, Any]: + """ + Get information about the current embedding service. + + Returns: + Dictionary with service information + """ + return { + "provider": self.config.provider, + "model": self.config.model, + "embedding_dimension": self.get_embedding_dimension(), + "batch_size": self.config.batch_size, + "is_connected": self.validate_connection(), + } + + @classmethod + def list_supported_providers(cls) -> list[str]: + """ + List all supported embedding providers. + + Returns: + List of supported provider names + """ + return [ + "azure", + "amazon-bedrock", + "cohere", + "custom", + "google-generativeai", + "google-vertex", + "huggingface", + "instructor", + "jina", + "ollama", + "onnx", + "openai", + "openclip", + "roboflow", + "sentence-transformer", + "text2vec", + "voyageai", + "watsonx", + ] + + @classmethod + def create_openai_service( + cls, + model: str = "text-embedding-3-small", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create an OpenAI embedding service.""" + return cls(provider="openai", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_voyage_service( + cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any + ) -> "EmbeddingService": + """Create a Voyage AI embedding service.""" + return cls(provider="voyageai", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_cohere_service( + cls, + model: str = "embed-english-v3.0", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Cohere embedding service.""" + return cls(provider="cohere", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_gemini_service( + cls, + model: str = "models/embedding-001", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Google Gemini embedding service.""" + return cls( + provider="google-generativeai", model=model, api_key=api_key, **kwargs + ) + + @classmethod + def create_azure_service( + cls, + model: str = "text-embedding-ada-002", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create an Azure OpenAI embedding service.""" + return cls(provider="azure", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_bedrock_service( + cls, + model: str = "amazon.titan-embed-text-v1", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create an Amazon Bedrock embedding service.""" + return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_huggingface_service( + cls, + model: str = "sentence-transformers/all-MiniLM-L6-v2", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Hugging Face embedding service.""" + return cls(provider="huggingface", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_sentence_transformer_service( + cls, + model: str = "all-MiniLM-L6-v2", + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Sentence Transformers embedding service (local).""" + return cls(provider="sentence-transformer", model=model, **kwargs) + + @classmethod + def create_ollama_service( + cls, + model: str = "nomic-embed-text", + **kwargs: Any, + ) -> "EmbeddingService": + """Create an Ollama embedding service (local).""" + return cls(provider="ollama", model=model, **kwargs) + + @classmethod + def create_jina_service( + cls, + model: str = "jina-embeddings-v2-base-en", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Jina AI embedding service.""" + return cls(provider="jina", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_instructor_service( + cls, + model: str = "hkunlp/instructor-large", + **kwargs: Any, + ) -> "EmbeddingService": + """Create an Instructor embedding service.""" + return cls(provider="instructor", model=model, **kwargs) + + @classmethod + def create_watsonx_service( + cls, + model: str = "ibm/slate-125m-english-rtrvr", + api_key: str | None = None, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a Watson X embedding service.""" + return cls(provider="watsonx", model=model, api_key=api_key, **kwargs) + + @classmethod + def create_custom_service( + cls, + embedding_callable: Any, + **kwargs: Any, + ) -> "EmbeddingService": + """Create a custom embedding service with your own embedding function.""" + return cls( + provider="custom", + model="custom", + extra_config={"embedding_callable": embedding_callable}, + **kwargs, + ) diff --git a/lib/crewai-tools/tests/rag/test_embedding_service.py b/lib/crewai-tools/tests/rag/test_embedding_service.py new file mode 100644 index 0000000000..c6c74fdf12 --- /dev/null +++ b/lib/crewai-tools/tests/rag/test_embedding_service.py @@ -0,0 +1,342 @@ +""" +Tests for the enhanced embedding service. +""" + +import os +import pytest +from unittest.mock import Mock, patch + +from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig + + +class TestEmbeddingConfig: + """Test the EmbeddingConfig model.""" + + def test_default_config(self): + """Test default configuration values.""" + config = EmbeddingConfig(provider="openai", model="text-embedding-3-small") + + assert config.provider == "openai" + assert config.model == "text-embedding-3-small" + assert config.api_key is None + assert config.timeout == 30.0 + assert config.max_retries == 3 + assert config.batch_size == 100 + assert config.extra_config == {} + + def test_custom_config(self): + """Test custom configuration values.""" + config = EmbeddingConfig( + provider="voyageai", + model="voyage-2", + api_key="test-key", + timeout=60.0, + max_retries=5, + batch_size=50, + extra_config={"input_type": "document"} + ) + + assert config.provider == "voyageai" + assert config.model == "voyage-2" + assert config.api_key == "test-key" + assert config.timeout == 60.0 + assert config.max_retries == 5 + assert config.batch_size == 50 + assert config.extra_config == {"input_type": "document"} + + +class TestEmbeddingService: + """Test the EmbeddingService class.""" + + def test_list_supported_providers(self): + """Test listing supported providers.""" + providers = EmbeddingService.list_supported_providers() + expected_providers = [ + "openai", "azure", "voyageai", "cohere", "google-generativeai", + "amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer", + "instructor", "watsonx", "custom" + ] + + assert isinstance(providers, list) + assert len(providers) >= 15 # Should have at least 15 providers + assert all(provider in providers for provider in expected_providers) + + def test_get_default_api_key(self): + """Test getting default API keys from environment.""" + service = EmbeddingService.__new__(EmbeddingService) # Create without __init__ + + # Test with environment variable set + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}): + api_key = service._get_default_api_key("openai") + assert api_key == "test-openai-key" + + # Test with no environment variable + with patch.dict(os.environ, {}, clear=True): + api_key = service._get_default_api_key("openai") + assert api_key is None + + # Test unknown provider + api_key = service._get_default_api_key("unknown-provider") + assert api_key is None + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_initialization_success(self, mock_build_embedder): + """Test successful initialization.""" + # Mock the embedding function + mock_embedding_function = Mock() + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService( + provider="openai", + model="text-embedding-3-small", + api_key="test-key" + ) + + assert service.config.provider == "openai" + assert service.config.model == "text-embedding-3-small" + assert service.config.api_key == "test-key" + assert service._embedding_function == mock_embedding_function + + # Verify build_embedder was called with correct config + mock_build_embedder.assert_called_once() + call_args = mock_build_embedder.call_args[0][0] + assert call_args["provider"] == "openai" + assert call_args["config"]["api_key"] == "test-key" + assert call_args["config"]["model_name"] == "text-embedding-3-small" + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_initialization_import_error(self, mock_build_embedder): + """Test initialization with import error.""" + mock_build_embedder.side_effect = ImportError("CrewAI not installed") + + with pytest.raises(ImportError, match="CrewAI embedding providers not available"): + EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_embed_text_success(self, mock_build_embedder): + """Test successful text embedding.""" + # Mock the embedding function + mock_embedding_function = Mock() + mock_embedding_function.return_value = [[0.1, 0.2, 0.3]] + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + result = service.embed_text("test text") + + assert result == [0.1, 0.2, 0.3] + mock_embedding_function.assert_called_once_with(["test text"]) + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_embed_text_empty_input(self, mock_build_embedder): + """Test embedding empty text.""" + mock_embedding_function = Mock() + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + result = service.embed_text("") + assert result == [] + + result = service.embed_text(" ") + assert result == [] + + # Embedding function should not be called for empty text + mock_embedding_function.assert_not_called() + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_embed_batch_success(self, mock_build_embedder): + """Test successful batch embedding.""" + # Mock the embedding function + mock_embedding_function = Mock() + mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + texts = ["text1", "text2", "text3"] + result = service.embed_batch(texts) + + assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + mock_embedding_function.assert_called_once_with(texts) + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_embed_batch_empty_input(self, mock_build_embedder): + """Test batch embedding with empty input.""" + mock_embedding_function = Mock() + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + # Empty list + result = service.embed_batch([]) + assert result == [] + + # List with empty strings + result = service.embed_batch(["", " ", ""]) + assert result == [] + + # Embedding function should not be called for empty input + mock_embedding_function.assert_not_called() + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_validate_connection(self, mock_build_embedder): + """Test connection validation.""" + # Mock successful embedding + mock_embedding_function = Mock() + mock_embedding_function.return_value = [[0.1, 0.2, 0.3]] + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + assert service.validate_connection() is True + + # Mock failed embedding + mock_embedding_function.side_effect = Exception("Connection failed") + assert service.validate_connection() is False + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_get_service_info(self, mock_build_embedder): + """Test getting service information.""" + # Mock the embedding function + mock_embedding_function = Mock() + mock_embedding_function.return_value = [[0.1, 0.2, 0.3]] + mock_build_embedder.return_value = mock_embedding_function + + service = EmbeddingService(provider="openai", model="test-model", api_key="test-key") + + info = service.get_service_info() + + assert info["provider"] == "openai" + assert info["model"] == "test-model" + assert info["embedding_dimension"] == 3 + assert info["batch_size"] == 100 + assert info["is_connected"] is True + + def test_create_openai_service(self): + """Test OpenAI service creation.""" + with patch('crewai.rag.embeddings.factory.build_embedder'): + service = EmbeddingService.create_openai_service( + model="text-embedding-3-large", + api_key="test-key" + ) + + assert service.config.provider == "openai" + assert service.config.model == "text-embedding-3-large" + assert service.config.api_key == "test-key" + + def test_create_voyage_service(self): + """Test Voyage AI service creation.""" + with patch('crewai.rag.embeddings.factory.build_embedder'): + service = EmbeddingService.create_voyage_service( + model="voyage-large-2", + api_key="test-key" + ) + + assert service.config.provider == "voyageai" + assert service.config.model == "voyage-large-2" + assert service.config.api_key == "test-key" + + def test_create_cohere_service(self): + """Test Cohere service creation.""" + with patch('crewai.rag.embeddings.factory.build_embedder'): + service = EmbeddingService.create_cohere_service( + model="embed-multilingual-v3.0", + api_key="test-key" + ) + + assert service.config.provider == "cohere" + assert service.config.model == "embed-multilingual-v3.0" + assert service.config.api_key == "test-key" + + def test_create_gemini_service(self): + """Test Gemini service creation.""" + with patch('crewai.rag.embeddings.factory.build_embedder'): + service = EmbeddingService.create_gemini_service( + model="models/embedding-001", + api_key="test-key" + ) + + assert service.config.provider == "google-generativeai" + assert service.config.model == "models/embedding-001" + assert service.config.api_key == "test-key" + + +class TestProviderConfigurations: + """Test provider-specific configurations.""" + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_openai_config(self, mock_build_embedder): + """Test OpenAI configuration mapping.""" + mock_build_embedder.return_value = Mock() + + service = EmbeddingService( + provider="openai", + model="text-embedding-3-small", + api_key="test-key", + extra_config={"dimensions": 1024} + ) + + # Check the configuration passed to build_embedder + call_args = mock_build_embedder.call_args[0][0] + assert call_args["provider"] == "openai" + assert call_args["config"]["api_key"] == "test-key" + assert call_args["config"]["model_name"] == "text-embedding-3-small" + assert call_args["config"]["dimensions"] == 1024 + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_voyageai_config(self, mock_build_embedder): + """Test Voyage AI configuration mapping.""" + mock_build_embedder.return_value = Mock() + + service = EmbeddingService( + provider="voyageai", + model="voyage-2", + api_key="test-key", + timeout=60.0, + max_retries=5, + extra_config={"input_type": "document"} + ) + + # Check the configuration passed to build_embedder + call_args = mock_build_embedder.call_args[0][0] + assert call_args["provider"] == "voyageai" + assert call_args["config"]["api_key"] == "test-key" + assert call_args["config"]["model"] == "voyage-2" + assert call_args["config"]["timeout"] == 60.0 + assert call_args["config"]["max_retries"] == 5 + assert call_args["config"]["input_type"] == "document" + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_cohere_config(self, mock_build_embedder): + """Test Cohere configuration mapping.""" + mock_build_embedder.return_value = Mock() + + service = EmbeddingService( + provider="cohere", + model="embed-english-v3.0", + api_key="test-key" + ) + + # Check the configuration passed to build_embedder + call_args = mock_build_embedder.call_args[0][0] + assert call_args["provider"] == "cohere" + assert call_args["config"]["api_key"] == "test-key" + assert call_args["config"]["model_name"] == "embed-english-v3.0" + + @patch('crewai.rag.embeddings.factory.build_embedder') + def test_gemini_config(self, mock_build_embedder): + """Test Gemini configuration mapping.""" + mock_build_embedder.return_value = Mock() + + service = EmbeddingService( + provider="google-generativeai", + model="models/embedding-001", + api_key="test-key" + ) + + # Check the configuration passed to build_embedder + call_args = mock_build_embedder.call_args[0][0] + assert call_args["provider"] == "google-generativeai" + assert call_args["config"]["api_key"] == "test-key" + assert call_args["config"]["model_name"] == "models/embedding-001" diff --git a/lib/crewai/src/crewai/rag/qdrant/config.py b/lib/crewai/src/crewai/rag/qdrant/config.py index 316708b807..0926c33850 100644 --- a/lib/crewai/src/crewai/rag/qdrant/config.py +++ b/lib/crewai/src/crewai/rag/qdrant/config.py @@ -4,6 +4,7 @@ from typing import Literal, cast from pydantic.dataclasses import dataclass as pyd_dataclass +from qdrant_client.models import VectorParams from crewai.rag.config.base import BaseRagConfig from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH @@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig): embedding_function: QdrantEmbeddingFunctionWrapper = field( default_factory=_default_embedding_function ) + vectors_config: VectorParams | None = field(default=None) diff --git a/lib/crewai/src/crewai/rag/qdrant/utils.py b/lib/crewai/src/crewai/rag/qdrant/utils.py index 01afd31efd..a535fa9a40 100644 --- a/lib/crewai/src/crewai/rag/qdrant/utils.py +++ b/lib/crewai/src/crewai/rag/qdrant/utils.py @@ -4,8 +4,8 @@ from typing import TypeGuard from uuid import uuid4 -from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found] from qdrant_client import ( + AsyncQdrantClient, # type: ignore[import-not-found] QdrantClient as SyncQdrantClient, # type: ignore[import-not-found] ) from qdrant_client.models import ( # type: ignore[import-not-found]