diff --git a/litellm/llms/base_llm/vector_store/transformation.py b/litellm/llms/base_llm/vector_store/transformation.py index 89f2094d5d..935fd53c19 100644 --- a/litellm/llms/base_llm/vector_store/transformation.py +++ b/litellm/llms/base_llm/vector_store/transformation.py @@ -5,8 +5,8 @@ from litellm.types.router import GenericLiteLLMParams from litellm.types.vector_stores import ( - BaseVectorStoreAuthCredentials, VECTOR_STORE_OPENAI_PARAMS, + BaseVectorStoreAuthCredentials, VectorStoreCreateOptionalRequestParams, VectorStoreCreateResponse, VectorStoreIndexEndpoints, @@ -64,6 +64,30 @@ def transform_search_vector_store_request( pass + async def atransform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + litellm_logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> Tuple[str, Dict]: + """ + Optional async version of transform_search_vector_store_request. + If not implemented, the handler will fall back to the sync version. + Providers that need to make async calls (e.g., generating embeddings) should override this. + """ + # Default implementation: call the sync version + return self.transform_search_vector_store_request( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + api_base=api_base, + litellm_logging_obj=litellm_logging_obj, + litellm_params=litellm_params, + ) + @abstractmethod def transform_search_vector_store_response( self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 642d15fe3e..1de1c40c43 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -1163,7 +1163,7 @@ def _filter_headers_for_aws_signature(self, headers: dict) -> dict: def _sign_request( self, - service_name: Literal["bedrock", "sagemaker", "bedrock-agentcore"], + service_name: Literal["bedrock", "sagemaker", "bedrock-agentcore", "s3vectors"], headers: dict, optional_params: dict, request_data: dict, diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 6a87967c3a..d2ea7e872a 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -7033,17 +7033,31 @@ async def async_vector_store_search_handler( litellm_params=dict(litellm_params), ) - ( - url, - request_body, - ) = vector_store_provider_config.transform_search_vector_store_request( - vector_store_id=vector_store_id, - query=query, - vector_store_search_optional_params=vector_store_search_optional_params, - api_base=api_base, - litellm_logging_obj=logging_obj, - litellm_params=dict(litellm_params), - ) + # Check if provider has async transform method + if hasattr(vector_store_provider_config, "atransform_search_vector_store_request"): + ( + url, + request_body, + ) = await vector_store_provider_config.atransform_search_vector_store_request( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + api_base=api_base, + litellm_logging_obj=logging_obj, + litellm_params=dict(litellm_params), + ) + else: + ( + url, + request_body, + ) = vector_store_provider_config.transform_search_vector_store_request( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + api_base=api_base, + litellm_logging_obj=logging_obj, + litellm_params=dict(litellm_params), + ) all_optional_params: Dict[str, Any] = dict(litellm_params) all_optional_params.update(vector_store_search_optional_params or {}) headers, signed_json_body = vector_store_provider_config.sign_request( diff --git a/litellm/llms/s3_vectors/__init__.py b/litellm/llms/s3_vectors/__init__.py new file mode 100644 index 0000000000..e8367949c3 --- /dev/null +++ b/litellm/llms/s3_vectors/__init__.py @@ -0,0 +1 @@ +# S3 Vectors LLM integration diff --git a/litellm/llms/s3_vectors/vector_stores/__init__.py b/litellm/llms/s3_vectors/vector_stores/__init__.py new file mode 100644 index 0000000000..ac24b4a38d --- /dev/null +++ b/litellm/llms/s3_vectors/vector_stores/__init__.py @@ -0,0 +1 @@ +# S3 Vectors vector store integration diff --git a/litellm/llms/s3_vectors/vector_stores/transformation.py b/litellm/llms/s3_vectors/vector_stores/transformation.py new file mode 100644 index 0000000000..df81a78289 --- /dev/null +++ b/litellm/llms/s3_vectors/vector_stores/transformation.py @@ -0,0 +1,254 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import httpx + +from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM +from litellm.types.router import GenericLiteLLMParams +from litellm.types.vector_stores import ( + VECTOR_STORE_OPENAI_PARAMS, + BaseVectorStoreAuthCredentials, + VectorStoreIndexEndpoints, + VectorStoreResultContent, + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchResponse, + VectorStoreSearchResult, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class S3VectorsVectorStoreConfig(BaseVectorStoreConfig, BaseAWSLLM): + """Vector store configuration for AWS S3 Vectors.""" + + def __init__(self) -> None: + BaseVectorStoreConfig.__init__(self) + BaseAWSLLM.__init__(self) + + def get_auth_credentials( + self, litellm_params: dict + ) -> BaseVectorStoreAuthCredentials: + return {} + + def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints: + return { + "read": [("POST", "/QueryVectors")], + "write": [], + } + + def get_supported_openai_params( + self, model: str + ) -> List[VECTOR_STORE_OPENAI_PARAMS]: + return ["max_num_results"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_num_results": + optional_params["maxResults"] = value + return optional_params + + def validate_environment( + self, headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + headers = headers or {} + headers.setdefault("Content-Type", "application/json") + return headers + + def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str: + aws_region_name = litellm_params.get("aws_region_name") + if not aws_region_name: + raise ValueError("aws_region_name is required for S3 Vectors") + return f"https://s3vectors.{aws_region_name}.api.aws" + + def transform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + litellm_logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> Tuple[str, Dict]: + """Sync version - generates embedding synchronously.""" + # For S3 Vectors, vector_store_id should be in format: bucket_name:index_name + # If not in that format, try to construct it from litellm_params + bucket_name: str + index_name: str + + if ":" in vector_store_id: + bucket_name, index_name = vector_store_id.split(":", 1) + else: + # Try to get bucket_name from litellm_params + bucket_name_from_params = litellm_params.get("vector_bucket_name") + if not bucket_name_from_params or not isinstance(bucket_name_from_params, str): + raise ValueError( + "vector_store_id must be in format 'bucket_name:index_name' for S3 Vectors, " + "or vector_bucket_name must be provided in litellm_params" + ) + bucket_name = bucket_name_from_params + index_name = vector_store_id + + if isinstance(query, list): + query = " ".join(query) + + # Generate embedding for the query + embedding_model = litellm_params.get("embedding_model", "text-embedding-3-small") + + import litellm as litellm_module + embedding_response = litellm_module.embedding(model=embedding_model, input=[query]) + query_embedding = embedding_response.data[0]["embedding"] + + url = f"{api_base}/QueryVectors" + + request_body: Dict[str, Any] = { + "vectorBucketName": bucket_name, + "indexName": index_name, + "queryVector": {"float32": query_embedding}, + "topK": vector_store_search_optional_params.get("max_num_results", 5), # Default to 5 + "returnDistance": True, + "returnMetadata": True, + } + + litellm_logging_obj.model_call_details["query"] = query + return url, request_body + + async def atransform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + litellm_logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> Tuple[str, Dict]: + """Async version - generates embedding asynchronously.""" + # For S3 Vectors, vector_store_id should be in format: bucket_name:index_name + # If not in that format, try to construct it from litellm_params + bucket_name: str + index_name: str + + if ":" in vector_store_id: + bucket_name, index_name = vector_store_id.split(":", 1) + else: + # Try to get bucket_name from litellm_params + bucket_name_from_params = litellm_params.get("vector_bucket_name") + if not bucket_name_from_params or not isinstance(bucket_name_from_params, str): + raise ValueError( + "vector_store_id must be in format 'bucket_name:index_name' for S3 Vectors, " + "or vector_bucket_name must be provided in litellm_params" + ) + bucket_name = bucket_name_from_params + index_name = vector_store_id + + if isinstance(query, list): + query = " ".join(query) + + # Generate embedding for the query asynchronously + embedding_model = litellm_params.get("embedding_model", "text-embedding-3-small") + + import litellm as litellm_module + embedding_response = await litellm_module.aembedding(model=embedding_model, input=[query]) + query_embedding = embedding_response.data[0]["embedding"] + + url = f"{api_base}/QueryVectors" + + request_body: Dict[str, Any] = { + "vectorBucketName": bucket_name, + "indexName": index_name, + "queryVector": {"float32": query_embedding}, + "topK": vector_store_search_optional_params.get("max_num_results", 5), # Default to 5 + "returnDistance": True, + "returnMetadata": True, + } + + litellm_logging_obj.model_call_details["query"] = query + return url, request_body + + def sign_request( + self, + headers: dict, + optional_params: Dict, + request_data: Dict, + api_base: str, + api_key: Optional[str] = None, + ) -> Tuple[dict, Optional[bytes]]: + return self._sign_request( + service_name="s3vectors", + headers=headers, + optional_params=optional_params, + request_data=request_data, + api_base=api_base, + api_key=api_key, + ) + + def transform_search_vector_store_response( + self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj + ) -> VectorStoreSearchResponse: + try: + response_data = response.json() + results: List[VectorStoreSearchResult] = [] + + for item in response_data.get("vectors", []) or []: + metadata = item.get("metadata", {}) or {} + source_text = metadata.get("source_text", "") + + if not source_text: + continue + + # Extract file information from metadata + chunk_index = metadata.get("chunk_index", "0") + file_id = f"s3-vectors-chunk-{chunk_index}" + filename = metadata.get("filename", f"document-{chunk_index}") + + # S3 Vectors returns distance, convert to similarity score (0-1) + # Lower distance = higher similarity + # We'll normalize using 1 / (1 + distance) to get a 0-1 score + distance = item.get("distance") + score = None + if distance is not None: + # Convert distance to similarity score between 0 and 1 + # For cosine distance: similarity = 1 - distance + # For euclidean: use 1 / (1 + distance) + # Assuming cosine distance here + score = max(0.0, min(1.0, 1.0 - float(distance))) + + results.append( + VectorStoreSearchResult( + score=score, + content=[VectorStoreResultContent(text=source_text, type="text")], + file_id=file_id, + filename=filename, + attributes=metadata, + ) + ) + + return VectorStoreSearchResponse( + object="vector_store.search_results.page", + search_query=litellm_logging_obj.model_call_details.get("query", ""), + data=results, + ) + except Exception as e: + raise self.get_error_class( + error_message=str(e), + status_code=response.status_code, + headers=response.headers, + ) + + # Vector store creation is not yet implemented + def transform_create_vector_store_request( + self, + vector_store_create_optional_params, + api_base: str, + ) -> Tuple[str, Dict]: + raise NotImplementedError + + def transform_create_vector_store_response(self, response: httpx.Response): + raise NotImplementedError diff --git a/litellm/proxy/_experimental/out/assets/logos/s3_vector.png b/litellm/proxy/_experimental/out/assets/logos/s3_vector.png new file mode 100644 index 0000000000..15a1a456e1 Binary files /dev/null and b/litellm/proxy/_experimental/out/assets/logos/s3_vector.png differ diff --git a/litellm/proxy/rag_endpoints/endpoints.py b/litellm/proxy/rag_endpoints/endpoints.py index 79f182817f..2c34457c02 100644 --- a/litellm/proxy/rag_endpoints/endpoints.py +++ b/litellm/proxy/rag_endpoints/endpoints.py @@ -129,6 +129,14 @@ async def _save_vector_store_to_db_from_rag_ingest( litellm_vector_store_params = ingest_options.get("litellm_vector_store_params", {}) custom_vector_store_name = litellm_vector_store_params.get("vector_store_name") custom_vector_store_description = litellm_vector_store_params.get("vector_store_description") + + # Extract provider-specific params from vector_store_config to save as litellm_params + # This ensures params like aws_region_name, embedding_model, etc. are available for search + provider_specific_params = {} + excluded_keys = {"custom_llm_provider", "vector_store_id"} + for key, value in vector_store_config.items(): + if key not in excluded_keys and value is not None: + provider_specific_params[key] = value # Build file metadata entry using helper file_entry = _build_file_metadata_entry( @@ -167,6 +175,7 @@ async def _save_vector_store_to_db_from_rag_ingest( vector_store_name=vector_store_name, vector_store_description=vector_store_description, vector_store_metadata=initial_metadata, + litellm_params=provider_specific_params if provider_specific_params else None, ) verbose_proxy_logger.info( diff --git a/litellm/rag/ingestion/base_ingestion.py b/litellm/rag/ingestion/base_ingestion.py index 20059487b4..3daa767188 100644 --- a/litellm/rag/ingestion/base_ingestion.py +++ b/litellm/rag/ingestion/base_ingestion.py @@ -24,6 +24,7 @@ get_async_httpx_client, httpxSpecialProvider, ) +from litellm.rag.ingestion.file_parsers import extract_text_from_pdf from litellm.rag.text_splitters import RecursiveCharacterTextSplitter from litellm.types.rag import RAGIngestOptions, RAGIngestResponse @@ -193,11 +194,23 @@ def chunk( if text: text_to_chunk = text elif file_content and not ocr_was_used: + # Try UTF-8 decode first try: text_to_chunk = file_content.decode("utf-8") except UnicodeDecodeError: - verbose_logger.debug("Binary file detected, skipping text chunking") - return [] + # Check if it's a PDF and try to extract text + if file_content.startswith(b"%PDF"): + verbose_logger.debug("PDF detected, attempting text extraction") + text_to_chunk = extract_text_from_pdf(file_content) + if not text_to_chunk: + verbose_logger.debug( + "PDF text extraction failed. Install 'pypdf' or 'PyPDF2' for PDF support, " + "or enable OCR with a vision model." + ) + return [] + else: + verbose_logger.debug("Binary file detected, skipping text chunking") + return [] if not text_to_chunk: return [] diff --git a/litellm/rag/ingestion/file_parsers/__init__.py b/litellm/rag/ingestion/file_parsers/__init__.py new file mode 100644 index 0000000000..5be68cdbb7 --- /dev/null +++ b/litellm/rag/ingestion/file_parsers/__init__.py @@ -0,0 +1,9 @@ +""" +File parsers for RAG ingestion. + +Provides text extraction utilities for various file formats. +""" + +from .pdf_parser import extract_text_from_pdf + +__all__ = ["extract_text_from_pdf"] diff --git a/litellm/rag/ingestion/file_parsers/pdf_parser.py b/litellm/rag/ingestion/file_parsers/pdf_parser.py new file mode 100644 index 0000000000..9a533ccf13 --- /dev/null +++ b/litellm/rag/ingestion/file_parsers/pdf_parser.py @@ -0,0 +1,70 @@ +""" +PDF text extraction utilities. + +Provides text extraction from PDF files using pypdf or PyPDF2. +""" + +from typing import Optional + +from litellm._logging import verbose_logger + + +def extract_text_from_pdf(file_content: bytes) -> Optional[str]: + """ + Extract text from PDF using pypdf if available. + + Args: + file_content: Raw PDF bytes + + Returns: + Extracted text or None if extraction fails + """ + try: + from io import BytesIO + + # Try pypdf first (most common) + try: + from pypdf import PdfReader as PypdfReader + + pdf_file = BytesIO(file_content) + reader = PypdfReader(pdf_file) + + text_parts = [] + for page in reader.pages: + text = page.extract_text() + if text: + text_parts.append(text) + + if text_parts: + extracted_text = "\n\n".join(text_parts) + verbose_logger.debug(f"Extracted {len(extracted_text)} characters from PDF using pypdf") + return extracted_text + + except ImportError: + verbose_logger.debug("pypdf not available, trying PyPDF2") + + # Fallback to PyPDF2 + try: + from PyPDF2 import PdfReader as PyPDF2Reader + + pdf_file = BytesIO(file_content) + reader = PyPDF2Reader(pdf_file) + + text_parts = [] + for page in reader.pages: + text = page.extract_text() + if text: + text_parts.append(text) + + if text_parts: + extracted_text = "\n\n".join(text_parts) + verbose_logger.debug(f"Extracted {len(extracted_text)} characters from PDF using PyPDF2") + return extracted_text + + except ImportError: + verbose_logger.debug("PyPDF2 not available, PDF extraction requires OCR or pypdf/PyPDF2 library") + + except Exception as e: + verbose_logger.debug(f"PDF text extraction failed: {e}") + + return None diff --git a/litellm/rag/ingestion/s3_vectors_ingestion.py b/litellm/rag/ingestion/s3_vectors_ingestion.py index 13964cd89a..e6c166a101 100644 --- a/litellm/rag/ingestion/s3_vectors_ingestion.py +++ b/litellm/rag/ingestion/s3_vectors_ingestion.py @@ -253,6 +253,20 @@ async def _ensure_vector_bucket_exists(self): verbose_logger.debug( f"Ensuring S3 vector bucket exists: {self.vector_bucket_name}" ) + + # Validate bucket name (AWS S3 naming rules) + if len(self.vector_bucket_name) < 3: + raise ValueError( + f"Invalid vector_bucket_name '{self.vector_bucket_name}': " + f"AWS S3 bucket names must be at least 3 characters long. " + f"Please provide a valid bucket name (e.g., 'my-vector-bucket')." + ) + if not self.vector_bucket_name.replace("-", "").replace(".", "").isalnum(): + raise ValueError( + f"Invalid vector_bucket_name '{self.vector_bucket_name}': " + f"AWS S3 bucket names can only contain lowercase letters, numbers, hyphens, and periods. " + f"Please provide a valid bucket name (e.g., 'my-vector-bucket')." + ) # Try to get bucket info using GetVectorBucket API get_url = f"https://s3vectors.{self.aws_region_name}.api.aws/GetVectorBucket" @@ -455,30 +469,43 @@ async def store( await self._ensure_config_initialized() if not embeddings or not chunks: - verbose_logger.warning("No embeddings or chunks to store") - return self.index_name, None + error_msg = ( + "No text content could be extracted from the file for embedding. " + "Possible causes:\n" + " 1. PDF files require OCR - add 'ocr' config with a vision model (e.g., 'anthropic/claude-3-5-sonnet-20241022')\n" + " 2. Binary files cannot be processed - convert to text first\n" + " 3. File is empty or contains no extractable text\n" + "For PDFs, either enable OCR or use a PDF extraction library to convert to text before ingestion." + ) + verbose_logger.error(error_msg) + raise ValueError(error_msg) # Prepare vectors for PutVectors API vectors = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + # Build metadata dict + metadata: Dict[str, str] = { + "source_text": chunk, # Non-filterable (for reference) + "chunk_index": str(i), # Filterable + } + + if filename: + metadata["filename"] = filename # Filterable + vector_obj = { "key": f"{filename}_{i}" if filename else f"chunk_{i}", "data": {"float32": embedding}, - "metadata": { - "source_text": chunk, # Non-filterable (for reference) - "chunk_index": str(i), # Filterable - }, + "metadata": metadata, } - if filename: - vector_obj["metadata"]["filename"] = filename # Filterable - vectors.append(vector_obj) # Call PutVectors API await self._put_vectors(vectors) - return self.index_name, filename + # Return vector_store_id in format bucket_name:index_name for S3 Vectors search compatibility + vector_store_id = f"{self.vector_bucket_name}:{self.index_name}" + return vector_store_id, filename async def query_vector_store( self, vector_store_id: str, query: str, top_k: int = 5 diff --git a/litellm/types/router.py b/litellm/types/router.py index 8ea7a20753..43943d9e07 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -404,6 +404,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): aws_access_key_id: Optional[str] aws_secret_access_key: Optional[str] aws_region_name: Optional[str] + ## AWS S3 VECTORS ## + vector_bucket_name: Optional[str] + index_name: Optional[str] + embedding_model: Optional[str] ## IBM WATSONX ## watsonx_region_name: Optional[str] ## CUSTOM PRICING ## diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 2d67e13e92..49c903502d 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -4,20 +4,22 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union from openai._models import BaseModel as OpenAIObject -from openai.types.audio.transcription_create_params import FileTypes as FileTypes # type: ignore +from openai.types.audio.transcription_create_params import ( + FileTypes as FileTypes, # type: ignore +) from openai.types.chat.chat_completion import ChatCompletion as ChatCompletion from openai.types.completion_usage import ( CompletionTokensDetails, CompletionUsage, PromptTokensDetails, ) +from openai.types.moderation import Categories as Categories from openai.types.moderation import ( - Categories as Categories, CategoryAppliedInputTypes as CategoryAppliedInputTypes, - CategoryScores as CategoryScores, ) +from openai.types.moderation import CategoryScores as CategoryScores +from openai.types.moderation_create_response import Moderation as Moderation from openai.types.moderation_create_response import ( - Moderation as Moderation, ModerationCreateResponse as ModerationCreateResponse, ) from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator @@ -3075,6 +3077,7 @@ class LlmProviders(str, Enum): LLAMA = "meta_llama" NSCALE = "nscale" PG_VECTOR = "pg_vector" + S3_VECTORS = "s3_vectors" HELICONE = "helicone" HYPERBOLIC = "hyperbolic" RECRAFT = "recraft" diff --git a/litellm/utils.py b/litellm/utils.py index bb95be05b5..d7fb4855a4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8452,6 +8452,12 @@ def get_provider_vector_stores_config( ) return RAGFlowVectorStoreConfig() + elif litellm.LlmProviders.S3_VECTORS == provider: + from litellm.llms.s3_vectors.vector_stores.transformation import ( + S3VectorsVectorStoreConfig, + ) + + return S3VectorsVectorStoreConfig() return None @staticmethod @@ -8699,9 +8705,9 @@ def get_provider_search_config( """ Get Search configuration for a given provider. """ + from litellm.llms.brave.search.transformation import BraveSearchConfig from litellm.llms.dataforseo.search.transformation import DataForSEOSearchConfig from litellm.llms.exa_ai.search.transformation import ExaAISearchConfig - from litellm.llms.brave.search.transformation import BraveSearchConfig from litellm.llms.firecrawl.search.transformation import FirecrawlSearchConfig from litellm.llms.google_pse.search.transformation import GooglePSESearchConfig from litellm.llms.linkup.search.transformation import LinkupSearchConfig diff --git a/requirements.txt b/requirements.txt index 36328a081a..f250fb213e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,6 +68,7 @@ jsonschema>=4.23.0,<5.0.0 # validating json schema - aligned with openapi-core + websockets==15.0.1 # for realtime API soundfile==0.12.1 # for audio file processing openapi-core==0.21.0 # for OpenAPI compliance tests +pypdf>=6.6.2 # for PDF text extraction in RAG ingestion ######################## # LITELLM ENTERPRISE DEPENDENCIES diff --git a/tests/test_litellm/llms/s3_vectors/__init__.py b/tests/test_litellm/llms/s3_vectors/__init__.py new file mode 100644 index 0000000000..d4b0c4d855 --- /dev/null +++ b/tests/test_litellm/llms/s3_vectors/__init__.py @@ -0,0 +1 @@ +# S3 Vectors tests diff --git a/tests/test_litellm/llms/s3_vectors/vector_stores/__init__.py b/tests/test_litellm/llms/s3_vectors/vector_stores/__init__.py new file mode 100644 index 0000000000..231735c1de --- /dev/null +++ b/tests/test_litellm/llms/s3_vectors/vector_stores/__init__.py @@ -0,0 +1 @@ +# S3 Vectors vector store tests diff --git a/tests/test_litellm/llms/s3_vectors/vector_stores/test_s3_vectors_transformation.py b/tests/test_litellm/llms/s3_vectors/vector_stores/test_s3_vectors_transformation.py new file mode 100644 index 0000000000..3a84da3154 --- /dev/null +++ b/tests/test_litellm/llms/s3_vectors/vector_stores/test_s3_vectors_transformation.py @@ -0,0 +1,115 @@ +from unittest.mock import MagicMock, Mock + +import httpx +import pytest + +from litellm.llms.s3_vectors.vector_stores.transformation import ( + S3VectorsVectorStoreConfig, +) +from litellm.types.vector_stores import VectorStoreSearchResponse + + +class TestS3VectorsVectorStoreConfig: + def test_init(self): + """Test that S3VectorsVectorStoreConfig initializes correctly""" + config = S3VectorsVectorStoreConfig() + assert config is not None + + def test_get_supported_openai_params(self): + """Test that supported OpenAI params are returned""" + config = S3VectorsVectorStoreConfig() + params = config.get_supported_openai_params("test-model") + assert "max_num_results" in params + + def test_get_complete_url(self): + """Test URL generation for S3 Vectors""" + config = S3VectorsVectorStoreConfig() + litellm_params = {"aws_region_name": "us-west-2"} + url = config.get_complete_url(None, litellm_params) + assert url == "https://s3vectors.us-west-2.api.aws" + + def test_get_complete_url_missing_region(self): + """Test that missing region raises error""" + config = S3VectorsVectorStoreConfig() + litellm_params = {} + with pytest.raises(ValueError, match="aws_region_name is required"): + config.get_complete_url(None, litellm_params) + + @pytest.mark.skip(reason="Requires embedding API call, tested in integration tests") + def test_transform_search_request(self): + """Test search request transformation""" + # This test requires making an actual embedding API call + # It's better tested in integration tests + pass + + def test_transform_search_request_invalid_vector_store_id(self): + """Test that invalid vector_store_id format raises error""" + config = S3VectorsVectorStoreConfig() + mock_logging_obj = Mock() + mock_logging_obj.model_call_details = {} + + with pytest.raises( + ValueError, match="vector_store_id must be in format 'bucket_name:index_name'" + ): + config.transform_search_vector_store_request( + vector_store_id="invalid-format", + query="test query", + vector_store_search_optional_params={}, + api_base="https://s3vectors.us-west-2.api.aws", + litellm_logging_obj=mock_logging_obj, + litellm_params={}, + ) + + def test_transform_search_response(self): + """Test search response transformation""" + config = S3VectorsVectorStoreConfig() + mock_logging_obj = Mock() + mock_logging_obj.model_call_details = {"query": "test query"} + + mock_response = Mock(spec=httpx.Response) + mock_response.json.return_value = { + "vectors": [ + { + "distance": 0.05, # S3 Vectors returns distance, not score + "metadata": { + "source_text": "This is test content", + "chunk_index": "0", + "filename": "test.pdf", + }, + }, + { + "distance": 0.15, + "metadata": { + "source_text": "More test content", + "chunk_index": "1", + }, + }, + ] + } + mock_response.status_code = 200 + mock_response.headers = {} + + result = config.transform_search_vector_store_response( + mock_response, mock_logging_obj + ) + + # VectorStoreSearchResponse is a TypedDict, so check structure instead of isinstance + assert result["object"] == "vector_store.search_results.page" + assert result["search_query"] == "test query" + assert len(result["data"]) == 2 + # Score should be 1 - distance (cosine similarity) + assert result["data"][0]["score"] == 0.95 # 1 - 0.05 + assert result["data"][0]["content"][0]["text"] == "This is test content" + assert result["data"][0]["filename"] == "test.pdf" + assert result["data"][1]["score"] == 0.85 # 1 - 0.15 + assert result["data"][1]["content"][0]["text"] == "More test content" + + def test_map_openai_params(self): + """Test OpenAI parameter mapping""" + config = S3VectorsVectorStoreConfig() + non_default_params = {"max_num_results": 5} + optional_params = {} + + result = config.map_openai_params(non_default_params, optional_params, False) + + assert result["maxResults"] == 5 diff --git a/tests/vector_store_tests/test_s3_vectors_vector_store.py b/tests/vector_store_tests/test_s3_vectors_vector_store.py new file mode 100644 index 0000000000..a7a1568c1c --- /dev/null +++ b/tests/vector_store_tests/test_s3_vectors_vector_store.py @@ -0,0 +1,42 @@ +from base_vector_store_test import BaseVectorStoreTest +import os +import pytest + + +class TestS3VectorsVectorStore(BaseVectorStoreTest): + @pytest.fixture(autouse=True) + def check_env_vars(self): + """Check if required environment variables are set""" + required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] + missing_vars = [var for var in required_vars if not os.getenv(var)] + if missing_vars: + pytest.skip(f"Missing required environment variables: {', '.join(missing_vars)}") + + def get_base_request_args(self) -> dict: + """ + Must return the base request args for searching. + For S3 Vectors, vector_store_id should be in format: bucket_name:index_name + """ + return { + "custom_llm_provider": "s3_vectors", + "vector_store_id": os.getenv( + "S3_VECTORS_VECTOR_STORE_ID", "test-litellm-vectors:test-index" + ), + "query": "What is machine learning?", + "aws_region_name": os.getenv("AWS_REGION_NAME", "us-west-2"), + "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), + } + + def get_base_create_vector_store_args(self) -> dict: + """ + Vector store creation is not yet implemented for S3 Vectors. + This test will be skipped. + """ + return {} + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_basic_create_vector_store(self, sync_mode): + """S3 Vectors doesn't support vector store creation via this API yet""" + pytest.skip("Vector store creation not yet implemented for S3 Vectors") diff --git a/ui/litellm-dashboard/public/assets/logos/s3_vector.png b/ui/litellm-dashboard/public/assets/logos/s3_vector.png new file mode 100644 index 0000000000..15a1a456e1 Binary files /dev/null and b/ui/litellm-dashboard/public/assets/logos/s3_vector.png differ diff --git a/ui/litellm-dashboard/src/components/navbar.tsx b/ui/litellm-dashboard/src/components/navbar.tsx index c78f355ff1..8e4fda0ba5 100644 --- a/ui/litellm-dashboard/src/components/navbar.tsx +++ b/ui/litellm-dashboard/src/components/navbar.tsx @@ -219,7 +219,7 @@ const Navbar: React.FC = ({ style={{ animationDuration: "2s" }} title="Happy Holidays!" > - 🎄 + ❄️ diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a1625b6ffb..e0dccbe6e8 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -6954,7 +6954,8 @@ export const ragIngestCall = async ( customLlmProvider: string, vectorStoreId?: string, vectorStoreName?: string, - vectorStoreDescription?: string + vectorStoreDescription?: string, + providerSpecificParams?: Record ): Promise => { try { let url = proxyBaseUrl ? `${proxyBaseUrl}/rag/ingest` : `/rag/ingest`; @@ -6967,6 +6968,7 @@ export const ragIngestCall = async ( vector_store: { custom_llm_provider: customLlmProvider, ...(vectorStoreId && { vector_store_id: vectorStoreId }), + ...(providerSpecificParams && providerSpecificParams), }, }, }; diff --git a/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.test.tsx b/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.test.tsx index db2975781b..a8ba55b91f 100644 --- a/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.test.tsx +++ b/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.test.tsx @@ -22,17 +22,51 @@ vi.mock("../vector_store_providers", () => ({ BEDROCK: "Amazon Bedrock", OPENAI: "OpenAI", AZURE_OPENAI: "Azure OpenAI", + S3Vectors: "AWS S3 Vectors", }, vectorStoreProviderMap: { BEDROCK: "bedrock", OPENAI: "openai", AZURE_OPENAI: "azure_openai", + S3Vectors: "s3_vectors", }, vectorStoreProviderLogoMap: { "Amazon Bedrock": "https://example.com/bedrock.png", "OpenAI": "https://example.com/openai.png", "Azure OpenAI": "https://example.com/azure.png", + "AWS S3 Vectors": "https://example.com/aws.png", }, + getProviderSpecificFields: vi.fn((provider: string) => { + if (provider === "s3_vectors") { + return [ + { + name: "vector_bucket_name", + label: "Vector Bucket Name", + tooltip: "S3 bucket name for vector storage", + placeholder: "my-vector-bucket", + required: true, + type: "text", + }, + { + name: "aws_region_name", + label: "AWS Region", + tooltip: "AWS region", + placeholder: "us-west-2", + required: true, + type: "text", + }, + { + name: "embedding_model", + label: "Embedding Model", + tooltip: "Embedding model to use", + placeholder: "text-embedding-3-small", + required: true, + type: "select", + }, + ]; + } + return []; + }), })); describe("CreateVectorStore", () => { @@ -43,9 +77,9 @@ describe("CreateVectorStore", () => { it("should render the component successfully", () => { render(); - expect(screen.getByText("Create Vector Store")).toBeInTheDocument(); + expect(screen.getAllByText("Create Vector Store").length).toBeGreaterThan(0); expect(screen.getByText("Step 1: Upload Documents")).toBeInTheDocument(); - expect(screen.getByText("Step 2: Select Provider")).toBeInTheDocument(); + expect(screen.getByText("Step 2: Configure Vector Store")).toBeInTheDocument(); }); it("should display upload area with correct text", () => { @@ -123,7 +157,15 @@ describe("CreateVectorStore", () => { }); await waitFor(() => { - expect(mockRagIngestCall).toHaveBeenCalledWith("test-token", expect.any(File), "bedrock", undefined); + expect(mockRagIngestCall).toHaveBeenCalledWith( + "test-token", + expect.any(File), + "bedrock", + undefined, + undefined, + undefined, + {} + ); }); }); @@ -163,4 +205,72 @@ describe("CreateVectorStore", () => { expect(screen.getByText("Vector Store Created Successfully")).toBeInTheDocument(); }); }); + + it("should display S3 Vectors provider-specific fields when selected", async () => { + render(); + + // Find and click the provider dropdown + const providerSelect = screen.getByRole("combobox"); + + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + // Wait for dropdown options to appear + await waitFor(() => { + const s3Option = screen.queryByText("AWS S3 Vectors"); + if (s3Option) { + fireEvent.click(s3Option); + } + }); + + // Check if S3-specific fields are displayed + await waitFor(() => { + expect(screen.queryByText("Vector Bucket Name")).toBeInTheDocument(); + expect(screen.queryByText("AWS Region")).toBeInTheDocument(); + expect(screen.queryByText("Embedding Model")).toBeInTheDocument(); + }); + }); + + it("should validate S3 Vectors required fields before submission", async () => { + render(); + + // Upload a file first + const file = new File(["test content"], "test.pdf", { type: "application/pdf" }); + const uploadInput = document.querySelector('input[type="file"]') as HTMLInputElement; + + await act(async () => { + if (uploadInput) { + fireEvent.change(uploadInput, { target: { files: [file] } }); + } + }); + + await waitFor(() => { + expect(screen.getByText("Uploaded Documents (1)")).toBeInTheDocument(); + }); + + // Select S3 Vectors provider + const providerSelect = screen.getByRole("combobox"); + + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + await waitFor(() => { + const s3Option = screen.queryByText("AWS S3 Vectors"); + if (s3Option) { + fireEvent.click(s3Option); + } + }); + + // Try to create without filling required fields + const createButton = screen.getByRole("button", { name: /Create Vector Store/i }); + + await act(async () => { + fireEvent.click(createButton); + }); + + // Should show validation warning (mocked message.warning would be called) + // The actual validation happens in the component + }); }); diff --git a/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.tsx b/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.tsx index 685e37ac73..97162aa132 100644 --- a/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.tsx +++ b/ui/litellm-dashboard/src/components/vector_store_management/CreateVectorStore.tsx @@ -10,8 +10,11 @@ import { VectorStoreProviders, vectorStoreProviderLogoMap, vectorStoreProviderMap, + getProviderSpecificFields, + VectorStoreFieldConfig, } from "../vector_store_providers"; import NotificationsManager from "../molecules/notifications_manager"; +import S3VectorsConfig from "./S3VectorsConfig"; const { Dragger } = Upload; @@ -28,6 +31,7 @@ const CreateVectorStore: React.FC = ({ accessToken, onSu const [vectorStoreName, setVectorStoreName] = useState(""); const [vectorStoreDescription, setVectorStoreDescription] = useState(""); const [ingestResults, setIngestResults] = useState([]); + const [providerParams, setProviderParams] = useState>({}); const uploadProps: UploadProps = { name: "file", @@ -92,6 +96,27 @@ const CreateVectorStore: React.FC = ({ accessToken, onSu return; } + // Validate provider-specific required fields + const requiredFields = getProviderSpecificFields(selectedProvider).filter((field) => field.required); + for (const field of requiredFields) { + if (!providerParams[field.name]) { + message.warning(`Please provide ${field.label}`); + return; + } + } + + // S3 Vectors specific validation + if (selectedProvider === "s3_vectors") { + if (providerParams.vector_bucket_name && providerParams.vector_bucket_name.length < 3) { + message.warning("Vector bucket name must be at least 3 characters"); + return; + } + if (providerParams.index_name && providerParams.index_name.length > 0 && providerParams.index_name.length < 3) { + message.warning("Index name must be at least 3 characters if provided"); + return; + } + } + if (!accessToken) { message.error("No access token available"); return; @@ -118,7 +143,8 @@ const CreateVectorStore: React.FC = ({ accessToken, onSu selectedProvider, vectorStoreId, // Use the same vector store ID for subsequent uploads vectorStoreName || undefined, - vectorStoreDescription || undefined + vectorStoreDescription || undefined, + providerParams ); // Store the vector store ID from the first successful ingest @@ -298,6 +324,74 @@ const CreateVectorStore: React.FC = ({ accessToken, onSu })} + + {/* S3 Vectors Configuration */} + {selectedProvider === "s3_vectors" && ( + + )} + + {/* Other Provider-specific fields */} + {selectedProvider !== "s3_vectors" && + getProviderSpecificFields(selectedProvider).map((field: VectorStoreFieldConfig) => { + if (field.type === "select") { + // For embedding model selection, we'd need to fetch available models + // For now, provide a text input as fallback + return ( + + {field.label}{" "} + + + + + } + required={field.required} + > + + setProviderParams((prev) => ({ ...prev, [field.name]: e.target.value })) + } + placeholder={field.placeholder} + size="large" + className="rounded-md" + /> + + ); + } + + return ( + + {field.label}{" "} + + + + + } + required={field.required} + > + + setProviderParams((prev) => ({ ...prev, [field.name]: e.target.value })) + } + placeholder={field.placeholder} + size="large" + className="rounded-md" + /> + + ); + })}
diff --git a/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.test.tsx b/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.test.tsx new file mode 100644 index 0000000000..beb13b5441 --- /dev/null +++ b/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.test.tsx @@ -0,0 +1,203 @@ +import { render, screen, fireEvent, waitFor, act } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import S3VectorsConfig from "./S3VectorsConfig"; +import * as fetchModels from "../playground/llm_calls/fetch_models"; + +// Mock fetchAvailableModels +vi.mock("../playground/llm_calls/fetch_models", () => ({ + fetchAvailableModels: vi.fn(), +})); + +describe("S3VectorsConfig", () => { + const mockOnParamsChange = vi.fn(); + const defaultProps = { + accessToken: "test-token", + providerParams: {}, + onParamsChange: mockOnParamsChange, + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render the component successfully", () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + render(); + + expect(screen.getByText("AWS S3 Vectors Setup")).toBeInTheDocument(); + expect(screen.getByText("Vector Bucket Name")).toBeInTheDocument(); + expect(screen.getByText("Index Name")).toBeInTheDocument(); + expect(screen.getByText("AWS Region")).toBeInTheDocument(); + expect(screen.getByText("Embedding Model")).toBeInTheDocument(); + }); + + it("should display setup instructions", () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + render(); + + expect( + screen.getByText(/AWS S3 Vectors allows you to store and query vector embeddings directly in S3/) + ).toBeInTheDocument(); + expect(screen.getByText(/Vector buckets and indexes will be automatically created/)).toBeInTheDocument(); + expect(screen.getByText(/Vector dimensions are auto-detected/)).toBeInTheDocument(); + }); + + it("should fetch embedding models on mount", async () => { + const mockModels = [ + { model_group: "text-embedding-3-small", mode: "embedding" }, + { model_group: "text-embedding-3-large", mode: "embedding" }, + { model_group: "gpt-4", mode: "chat" }, + ]; + + const fetchSpy = vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue(mockModels); + + render(); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith("test-token"); + }); + }); + + it("should filter and display only embedding models", async () => { + const mockModels = [ + { model_group: "text-embedding-3-small", mode: "embedding" }, + { model_group: "text-embedding-3-large", mode: "embedding" }, + { model_group: "gpt-4", mode: "chat" }, + { model_group: "gpt-3.5-turbo", mode: "chat" }, + ]; + + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue(mockModels); + + render(); + + // Wait for models to load + await waitFor(() => { + expect(fetchModels.fetchAvailableModels).toHaveBeenCalled(); + }); + + // The component should filter to only embedding models internally + // We can verify this by checking the component loaded successfully + expect(screen.getByText("Embedding Model")).toBeInTheDocument(); + }); + + it("should call onParamsChange when vector bucket name changes", async () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + render(); + + const bucketInput = screen.getByPlaceholderText("my-vector-bucket"); + + await act(async () => { + fireEvent.change(bucketInput, { target: { value: "test-bucket" } }); + }); + + expect(mockOnParamsChange).toHaveBeenCalledWith({ + vector_bucket_name: "test-bucket", + }); + }); + + it("should call onParamsChange when AWS region changes", async () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + render(); + + const regionInput = screen.getByPlaceholderText("us-west-2"); + + await act(async () => { + fireEvent.change(regionInput, { target: { value: "us-east-1" } }); + }); + + expect(mockOnParamsChange).toHaveBeenCalledWith({ + aws_region_name: "us-east-1", + }); + }); + + it("should call onParamsChange when embedding model is selected", async () => { + const mockModels = [ + { model_group: "text-embedding-3-small", mode: "embedding" }, + { model_group: "text-embedding-3-large", mode: "embedding" }, + ]; + + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue(mockModels); + + render(); + + await waitFor(() => { + expect(fetchModels.fetchAvailableModels).toHaveBeenCalled(); + }); + + // Find the Select component and trigger change directly + const selectElement = screen.getByRole("combobox"); + + await act(async () => { + // Simulate selecting a value by firing the change event + fireEvent.change(selectElement, { target: { value: "text-embedding-3-small" } }); + }); + + // The component should handle the selection + expect(screen.getByText("Embedding Model")).toBeInTheDocument(); + }); + + it("should preserve existing params when updating a field", async () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + const existingParams = { + vector_bucket_name: "existing-bucket", + aws_region_name: "us-west-2", + }; + + render(); + + const indexInput = screen.getByPlaceholderText("my-vector-index"); + + await act(async () => { + fireEvent.change(indexInput, { target: { value: "my-index" } }); + }); + + expect(mockOnParamsChange).toHaveBeenCalledWith({ + vector_bucket_name: "existing-bucket", + aws_region_name: "us-west-2", + index_name: "my-index", + }); + }); + + it("should display existing param values", () => { + vi.spyOn(fetchModels, "fetchAvailableModels").mockResolvedValue([]); + + const existingParams = { + vector_bucket_name: "my-bucket", + index_name: "my-index", + aws_region_name: "eu-west-1", + embedding_model: "text-embedding-3-small", + }; + + render(); + + expect(screen.getByDisplayValue("my-bucket")).toBeInTheDocument(); + expect(screen.getByDisplayValue("my-index")).toBeInTheDocument(); + expect(screen.getByDisplayValue("eu-west-1")).toBeInTheDocument(); + }); + + it("should handle model fetch error gracefully", async () => { + const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + vi.spyOn(fetchModels, "fetchAvailableModels").mockRejectedValue(new Error("Failed to fetch models")); + + render(); + + await waitFor(() => { + expect(consoleErrorSpy).toHaveBeenCalledWith("Error fetching embedding models:", expect.any(Error)); + }); + + consoleErrorSpy.mockRestore(); + }); + + it("should not fetch models if accessToken is null", () => { + const fetchSpy = vi.spyOn(fetchModels, "fetchAvailableModels"); + + render(); + + expect(fetchSpy).not.toHaveBeenCalled(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.tsx b/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.tsx new file mode 100644 index 0000000000..e54d3a163c --- /dev/null +++ b/ui/litellm-dashboard/src/components/vector_store_management/S3VectorsConfig.tsx @@ -0,0 +1,192 @@ +import React, { useState, useEffect } from "react"; +import { Alert, Form, Input, Select, Tooltip } from "antd"; +import { InfoCircleOutlined } from "@ant-design/icons"; +import { fetchAvailableModels, ModelGroup } from "../playground/llm_calls/fetch_models"; + +interface S3VectorsConfigProps { + accessToken: string | null; + providerParams: Record; + onParamsChange: (params: Record) => void; +} + +const S3VectorsConfig: React.FC = ({ + accessToken, + providerParams, + onParamsChange, +}) => { + const [embeddingModels, setEmbeddingModels] = useState([]); + const [isLoadingModels, setIsLoadingModels] = useState(false); + + useEffect(() => { + if (!accessToken) return; + + const loadModels = async () => { + setIsLoadingModels(true); + try { + const models = await fetchAvailableModels(accessToken); + // Filter for embedding models only + const embeddingOnly = models.filter((model) => model.mode === "embedding"); + setEmbeddingModels(embeddingOnly); + } catch (error) { + console.error("Error fetching embedding models:", error); + } finally { + setIsLoadingModels(false); + } + }; + + loadModels(); + }, [accessToken]); + + const handleFieldChange = (fieldName: string, value: string) => { + onParamsChange({ + ...providerParams, + [fieldName]: value, + }); + }; + + return ( + <> + {/* S3 Vectors Setup Instructions */} + +

AWS S3 Vectors allows you to store and query vector embeddings directly in S3:

+
    +
  • Vector buckets and indexes will be automatically created if they don't exist
  • +
  • Vector dimensions are auto-detected from your selected embedding model
  • +
  • Ensure your AWS credentials have permissions for S3 Vectors operations
  • +
  • + Learn more:{" "} + + AWS S3 Vectors Documentation + +
  • +
+
+ } + type="info" + showIcon + style={{ marginBottom: "16px" }} + /> + + {/* Vector Bucket Name */} + + Vector Bucket Name{" "} + + + + + } + required + validateStatus={ + providerParams.vector_bucket_name && providerParams.vector_bucket_name.length < 3 + ? "error" + : undefined + } + help={ + providerParams.vector_bucket_name && providerParams.vector_bucket_name.length < 3 + ? "Bucket name must be at least 3 characters" + : undefined + } + > + handleFieldChange("vector_bucket_name", e.target.value)} + placeholder="my-vector-bucket (min 3 chars)" + size="large" + className="rounded-md" + /> + + + {/* Index Name (Optional) */} + + Index Name{" "} + + + + + } + validateStatus={ + providerParams.index_name && providerParams.index_name.length > 0 && providerParams.index_name.length < 3 + ? "error" + : undefined + } + help={ + providerParams.index_name && providerParams.index_name.length > 0 && providerParams.index_name.length < 3 + ? "Index name must be at least 3 characters if provided" + : undefined + } + > + handleFieldChange("index_name", e.target.value)} + placeholder="my-vector-index (optional, min 3 chars)" + size="large" + className="rounded-md" + /> + + + {/* AWS Region */} + + AWS Region{" "} + + + + + } + required + > + handleFieldChange("aws_region_name", e.target.value)} + placeholder="us-west-2" + size="large" + className="rounded-md" + /> + + + {/* Embedding Model */} + + Embedding Model{" "} + + + + + } + required + > +