Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
import uuid

from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr
from qdrant_client.models import VectorParams
from typing_extensions import Unpack

from crewai_tools.rag.data_types import DataType
Expand Down Expand Up @@ -52,7 +55,11 @@ def model_post_init(self, __context: Any) -> None:
self._client = create_client(self.config)
else:
self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name)
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
if isinstance(self.config.vectors_config, VectorParams):
collection_params["vectors_config"] = self.config.vectors_config
self._client.get_or_create_collection(**collection_params)

def query(
self,
Expand All @@ -76,6 +83,8 @@ def query(
if similarity_threshold is not None
else self.similarity_threshold
)
if self._client is None:
raise ValueError("Client is not initialized")

results: list[SearchResult] = self._client.search(
collection_name=self.collection_name,
Expand Down Expand Up @@ -201,9 +210,10 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {}))

chunk_id = hashlib.sha256(
chunk_hash = hashlib.sha256(
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))

documents.append(
{
Expand Down Expand Up @@ -251,9 +261,10 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {}))

chunk_id = hashlib.sha256(
chunk_hash = hashlib.sha256(
f"{loader_result.doc_id}_{i}_{chunk}".encode()
).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))

documents.append(
{
Expand All @@ -264,6 +275,8 @@ def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
)

if documents:
if self._client is None:
raise ValueError("Client is not initialized")
self._client.add_documents(
collection_name=self.collection_name, documents=documents
)
34 changes: 6 additions & 28 deletions lib/crewai-tools/src/crewai_tools/rag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from uuid import uuid4

import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr

from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.embedding_service import EmbeddingService
from crewai_tools.rag.misc import compute_sha256
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.tools.rag.rag_tool import Adapter
Expand All @@ -18,31 +18,6 @@
logger = logging.getLogger(__name__)


class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs

def embed_text(self, text: str) -> list[float]:
try:
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
return response.data[0]["embedding"]
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise

def embed_batch(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []

try:
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
return [data["embedding"] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise


class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
content: str
Expand All @@ -54,6 +29,7 @@ class Document(BaseModel):
class RAG(Adapter):
collection_name: str = "crewai_knowledge_base"
persist_directory: str | None = None
embedding_provider: str = "openai"
embedding_model: str = "text-embedding-3-large"
summarize: bool = False
top_k: int = 5
Expand All @@ -79,7 +55,9 @@ def model_post_init(self, __context: Any) -> None:
)

self._embedding_service = EmbeddingService(
model=self.embedding_model, **self.embedding_config
provider=self.embedding_provider,
model=self.embedding_model,
**self.embedding_config,
)
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
Expand Down Expand Up @@ -181,7 +159,7 @@ def add(
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")

def query(self, question: str, where: dict[str, Any] | None = None) -> str:
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
question_embedding = self._embedding_service.embed_text(question)

Expand Down
Loading