Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/src/server/api_routes/knowledge_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from ..services.search.rag_service import RAGService
from ..services.storage import DocumentStorageService
from ..utils import get_supabase_client
from ..services.embeddings.embedding_exceptions import (
EmbeddingAuthenticationError,
)
from ..utils.document_processing import extract_text_from_document

# Get logger for this module
Expand Down Expand Up @@ -750,6 +753,9 @@ async def perform_rag_query(request: RagQueryRequest):
)
except HTTPException:
raise
except EmbeddingAuthenticationError as e:
safe_logfire_error(f"Authentication error in RAG query: {str(e)}")
raise HTTPException(status_code=401, detail={"error": "Invalid API key"})
except Exception as e:
safe_logfire_error(
f"RAG query failed | error={str(e)} | query={request.query[:50]} | source={request.source}"
Expand Down Expand Up @@ -784,6 +790,9 @@ async def search_code_examples(request: RagQueryRequest):
)
except HTTPException:
raise
except EmbeddingAuthenticationError as e:
safe_logfire_error(f"Authentication error in code examples search: {str(e)}")
raise HTTPException(status_code=401, detail={"error": "Invalid API key"})
except Exception as e:
safe_logfire_error(
f"Code examples search failed | error={str(e)} | query={request.query[:50]} | source={request.source}"
Expand Down
16 changes: 16 additions & 0 deletions python/src/server/services/embeddings/embedding_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ class EmbeddingAsyncContextError(EmbeddingError):
pass


class EmbeddingAuthenticationError(EmbeddingError):
"""
Raised when API authentication fails (invalid API key, expired key, etc).

This is a CRITICAL error that should stop the entire process
as continuing would be pointless without valid authentication.
"""

def __init__(self, message: str, api_key_prefix: str | None = None, **kwargs):
super().__init__(message, **kwargs)
masked = f"{api_key_prefix[:4]}…" if api_key_prefix else None
self.api_key_prefix = masked
if masked:
self.metadata["api_key_prefix"] = masked


class EmbeddingAPIError(EmbeddingError):
"""
Raised for general API failures (network, invalid response, etc).
Expand Down
15 changes: 15 additions & 0 deletions python/src/server/services/embeddings/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..threading_service import get_threading_service
from .embedding_exceptions import (
EmbeddingAPIError,
EmbeddingAuthenticationError,
EmbeddingError,
EmbeddingQuotaExhaustedError,
EmbeddingRateLimitError,
Expand Down Expand Up @@ -107,6 +108,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float
"No embeddings returned from batch creation", text_preview=text
)
return result.embeddings[0]
except EmbeddingAuthenticationError:
# Let auth errors bubble so the HTTP layer can return 401
raise
except EmbeddingError:
# Re-raise our custom exceptions
raise
Expand Down Expand Up @@ -226,6 +230,11 @@ async def create_embeddings_batch(

break # Success, exit retry loop

except openai.AuthenticationError as e:
# Invalid API key - critical error, stop everything
search_logger.error("Authentication failed: Invalid API key", exc_info=True)
raise EmbeddingAuthenticationError("Invalid API key") from e

except openai.RateLimitError as e:
error_message = str(e)
if "insufficient_quota" in error_message:
Expand Down Expand Up @@ -268,6 +277,9 @@ async def create_embeddings_batch(
else:
raise # Will be caught by outer try

except EmbeddingAuthenticationError:
# Auth errors must bubble up immediately for HTTP 401
raise
except Exception as e:
# This batch failed - track failures but continue with next batch
search_logger.error(f"Batch {batch_index} failed: {e}", exc_info=True)
Expand Down Expand Up @@ -318,6 +330,9 @@ async def create_embeddings_batch(

return result

except EmbeddingAuthenticationError:
# Auth errors must bubble up immediately for HTTP 401
raise
except Exception as e:
# Catastrophic failure - return what we have
span.set_attribute("catastrophic_failure", True)
Expand Down
28 changes: 19 additions & 9 deletions python/src/server/services/search/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
from typing import Any

from ..embeddings.embedding_exceptions import EmbeddingAuthenticationError
from ...config.logfire_config import get_logger, safe_span
from ...utils import get_supabase_client
from ..embeddings.embedding_service import create_embedding
Expand Down Expand Up @@ -112,8 +113,8 @@ async def search_documents(
hybrid_enabled=use_hybrid_search,
) as span:
try:
# Create embedding for the query
query_embedding = await create_embedding(query)
# Create embedding for the query using single-vector API
query_embedding = await create_embedding(text=query)

if not query_embedding:
logger.error("Failed to create embedding for query")
Expand All @@ -140,6 +141,9 @@ async def search_documents(
span.set_attribute("results_found", len(results))
return results

except EmbeddingAuthenticationError:
# Let auth failures bubble to API layer -> 401
raise
except Exception as e:
logger.error(f"Document search failed: {e}")
span.set_attribute("error", str(e))
Expand Down Expand Up @@ -202,7 +206,6 @@ async def perform_rag_query(

# Check which strategies are enabled
use_hybrid_search = self.get_bool_setting("USE_HYBRID_SEARCH", False)
use_reranking = self.get_bool_setting("USE_RERANKING", False)

# Step 1 & 2: Get results (with hybrid search if enabled)
results = await self.search_documents(
Expand Down Expand Up @@ -230,9 +233,9 @@ async def perform_rag_query(
logger.warning(f"Failed to format result {i}: {format_error}")
continue

# Step 3: Apply reranking if we have a strategy or if enabled
# Step 3: Apply reranking if we have a strategy
reranking_applied = False
if self.reranking_strategy and formatted_results:
if self.reranking_strategy is not None and formatted_results:
try:
formatted_results = await self.reranking_strategy.rerank_results(
query, formatted_results, content_key="content"
Expand Down Expand Up @@ -262,6 +265,9 @@ async def perform_rag_query(
logger.info(f"RAG query completed - {len(formatted_results)} results found")
return True, response_data

except EmbeddingAuthenticationError:
# Let 401 bubble to the API layer
raise
except Exception as e:
logger.error(f"RAG query failed: {e}")
span.set_attribute("error", str(e))
Expand Down Expand Up @@ -311,7 +317,6 @@ async def search_code_examples_service(

# Check which strategies are enabled
use_hybrid_search = self.get_bool_setting("USE_HYBRID_SEARCH", False)
use_reranking = self.get_bool_setting("USE_RERANKING", False)

# Prepare filter
filter_metadata = {"source": source_id} if source_id and source_id.strip() else None
Expand All @@ -334,11 +339,13 @@ async def search_code_examples_service(
)

# Apply reranking if we have a strategy
if self.reranking_strategy and results:
reranking_applied = False
if self.reranking_strategy is not None and results:
try:
results = await self.reranking_strategy.rerank_results(
query, results, content_key="content"
)
reranking_applied = True
except Exception as e:
logger.warning(f"Code reranking failed: {e}")

Expand All @@ -362,17 +369,20 @@ async def search_code_examples_service(
"query": query,
"source_filter": source_id,
"search_mode": "hybrid" if use_hybrid_search else "vector",
"reranking_applied": self.reranking_strategy is not None,
"reranking_applied": reranking_applied,
"results": formatted_results,
"count": len(formatted_results),
}

span.set_attribute("results_found", len(formatted_results))
span.set_attribute("hybrid_used", use_hybrid_search)
span.set_attribute("reranking_used", use_reranking)
span.set_attribute("reranking_used", reranking_applied)

return True, response_data

except EmbeddingAuthenticationError:
# Let 401 bubble to the API layer
raise
except Exception as e:
logger.error(f"Code example search failed: {e}")
span.set_attribute("error", str(e))
Expand Down