diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 37eeffc4d3..96d503442f 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -31,6 +31,9 @@ from ..services.search.rag_service import RAGService from ..services.storage import DocumentStorageService from ..utils import get_supabase_client +from ..services.embeddings.embedding_exceptions import ( + EmbeddingAuthenticationError, +) from ..utils.document_processing import extract_text_from_document # Get logger for this module @@ -750,6 +753,9 @@ async def perform_rag_query(request: RagQueryRequest): ) except HTTPException: raise + except EmbeddingAuthenticationError as e: + safe_logfire_error(f"Authentication error in RAG query: {str(e)}") + raise HTTPException(status_code=401, detail={"error": "Invalid API key"}) except Exception as e: safe_logfire_error( f"RAG query failed | error={str(e)} | query={request.query[:50]} | source={request.source}" @@ -784,6 +790,9 @@ async def search_code_examples(request: RagQueryRequest): ) except HTTPException: raise + except EmbeddingAuthenticationError as e: + safe_logfire_error(f"Authentication error in code examples search: {str(e)}") + raise HTTPException(status_code=401, detail={"error": "Invalid API key"}) except Exception as e: safe_logfire_error( f"Code examples search failed | error={str(e)} | query={request.query[:50]} | source={request.source}" diff --git a/python/src/server/services/embeddings/embedding_exceptions.py b/python/src/server/services/embeddings/embedding_exceptions.py index 8f153c8c7d..965d522a43 100644 --- a/python/src/server/services/embeddings/embedding_exceptions.py +++ b/python/src/server/services/embeddings/embedding_exceptions.py @@ -83,6 +83,22 @@ class EmbeddingAsyncContextError(EmbeddingError): pass +class EmbeddingAuthenticationError(EmbeddingError): + """ + Raised when API authentication fails (invalid API key, expired key, etc). + + This is a CRITICAL error that should stop the entire process + as continuing would be pointless without valid authentication. + """ + + def __init__(self, message: str, api_key_prefix: str | None = None, **kwargs): + super().__init__(message, **kwargs) + masked = f"{api_key_prefix[:4]}…" if api_key_prefix else None + self.api_key_prefix = masked + if masked: + self.metadata["api_key_prefix"] = masked + + class EmbeddingAPIError(EmbeddingError): """ Raised for general API failures (network, invalid response, etc). diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index 47b552086b..7e1d65243c 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -17,6 +17,7 @@ from ..threading_service import get_threading_service from .embedding_exceptions import ( EmbeddingAPIError, + EmbeddingAuthenticationError, EmbeddingError, EmbeddingQuotaExhaustedError, EmbeddingRateLimitError, @@ -107,6 +108,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float "No embeddings returned from batch creation", text_preview=text ) return result.embeddings[0] + except EmbeddingAuthenticationError: + # Let auth errors bubble so the HTTP layer can return 401 + raise except EmbeddingError: # Re-raise our custom exceptions raise @@ -226,6 +230,11 @@ async def create_embeddings_batch( break # Success, exit retry loop + except openai.AuthenticationError as e: + # Invalid API key - critical error, stop everything + search_logger.error("Authentication failed: Invalid API key", exc_info=True) + raise EmbeddingAuthenticationError("Invalid API key") from e + except openai.RateLimitError as e: error_message = str(e) if "insufficient_quota" in error_message: @@ -268,6 +277,9 @@ async def create_embeddings_batch( else: raise # Will be caught by outer try + except EmbeddingAuthenticationError: + # Auth errors must bubble up immediately for HTTP 401 + raise except Exception as e: # This batch failed - track failures but continue with next batch search_logger.error(f"Batch {batch_index} failed: {e}", exc_info=True) @@ -318,6 +330,9 @@ async def create_embeddings_batch( return result + except EmbeddingAuthenticationError: + # Auth errors must bubble up immediately for HTTP 401 + raise except Exception as e: # Catastrophic failure - return what we have span.set_attribute("catastrophic_failure", True) diff --git a/python/src/server/services/search/rag_service.py b/python/src/server/services/search/rag_service.py index cdc89c237f..0c0a1a1bf9 100644 --- a/python/src/server/services/search/rag_service.py +++ b/python/src/server/services/search/rag_service.py @@ -15,6 +15,7 @@ import os from typing import Any +from ..embeddings.embedding_exceptions import EmbeddingAuthenticationError from ...config.logfire_config import get_logger, safe_span from ...utils import get_supabase_client from ..embeddings.embedding_service import create_embedding @@ -112,8 +113,8 @@ async def search_documents( hybrid_enabled=use_hybrid_search, ) as span: try: - # Create embedding for the query - query_embedding = await create_embedding(query) + # Create embedding for the query using single-vector API + query_embedding = await create_embedding(text=query) if not query_embedding: logger.error("Failed to create embedding for query") @@ -140,6 +141,9 @@ async def search_documents( span.set_attribute("results_found", len(results)) return results + except EmbeddingAuthenticationError: + # Let auth failures bubble to API layer -> 401 + raise except Exception as e: logger.error(f"Document search failed: {e}") span.set_attribute("error", str(e)) @@ -202,7 +206,6 @@ async def perform_rag_query( # Check which strategies are enabled use_hybrid_search = self.get_bool_setting("USE_HYBRID_SEARCH", False) - use_reranking = self.get_bool_setting("USE_RERANKING", False) # Step 1 & 2: Get results (with hybrid search if enabled) results = await self.search_documents( @@ -230,9 +233,9 @@ async def perform_rag_query( logger.warning(f"Failed to format result {i}: {format_error}") continue - # Step 3: Apply reranking if we have a strategy or if enabled + # Step 3: Apply reranking if we have a strategy reranking_applied = False - if self.reranking_strategy and formatted_results: + if self.reranking_strategy is not None and formatted_results: try: formatted_results = await self.reranking_strategy.rerank_results( query, formatted_results, content_key="content" @@ -262,6 +265,9 @@ async def perform_rag_query( logger.info(f"RAG query completed - {len(formatted_results)} results found") return True, response_data + except EmbeddingAuthenticationError: + # Let 401 bubble to the API layer + raise except Exception as e: logger.error(f"RAG query failed: {e}") span.set_attribute("error", str(e)) @@ -311,7 +317,6 @@ async def search_code_examples_service( # Check which strategies are enabled use_hybrid_search = self.get_bool_setting("USE_HYBRID_SEARCH", False) - use_reranking = self.get_bool_setting("USE_RERANKING", False) # Prepare filter filter_metadata = {"source": source_id} if source_id and source_id.strip() else None @@ -334,11 +339,13 @@ async def search_code_examples_service( ) # Apply reranking if we have a strategy - if self.reranking_strategy and results: + reranking_applied = False + if self.reranking_strategy is not None and results: try: results = await self.reranking_strategy.rerank_results( query, results, content_key="content" ) + reranking_applied = True except Exception as e: logger.warning(f"Code reranking failed: {e}") @@ -362,17 +369,20 @@ async def search_code_examples_service( "query": query, "source_filter": source_id, "search_mode": "hybrid" if use_hybrid_search else "vector", - "reranking_applied": self.reranking_strategy is not None, + "reranking_applied": reranking_applied, "results": formatted_results, "count": len(formatted_results), } span.set_attribute("results_found", len(formatted_results)) span.set_attribute("hybrid_used", use_hybrid_search) - span.set_attribute("reranking_used", use_reranking) + span.set_attribute("reranking_used", reranking_applied) return True, response_data + except EmbeddingAuthenticationError: + # Let 401 bubble to the API layer + raise except Exception as e: logger.error(f"Code example search failed: {e}") span.set_attribute("error", str(e))