diff --git a/python/src/server/services/embeddings/contextual_embedding_service.py b/python/src/server/services/embeddings/contextual_embedding_service.py index 76f3c59b31..6da657afb6 100644 --- a/python/src/server/services/embeddings/contextual_embedding_service.py +++ b/python/src/server/services/embeddings/contextual_embedding_service.py @@ -10,7 +10,7 @@ import openai from ...config.logfire_config import search_logger -from ..llm_provider_service import get_llm_client +from ..llm_provider_service import get_llm_client, prepare_llm_params from ..threading_service import get_threading_service @@ -65,6 +65,10 @@ async def generate_contextual_embedding( # Get model from provider configuration model = await _get_model_choice(provider) + # Prepare compatible parameters for the API call + params = prepare_llm_params(provider or "openai", model, + temperature=0.3, max_tokens=200) + response = await client.chat.completions.create( model=model, messages=[ @@ -74,8 +78,7 @@ async def generate_contextual_embedding( }, {"role": "user", "content": prompt}, ], - temperature=0.3, - max_tokens=200, + **params ) context = response.choices[0].message.content.strip() @@ -122,7 +125,7 @@ async def _get_model_choice(provider: str | None = None) -> str: # Handle empty model case - fallback to provider-specific defaults or explicit config if not model: search_logger.warning(f"chat_model is empty for provider {provider_name}, using fallback logic") - + if provider_name == "ollama": # Try to get OLLAMA_CHAT_MODEL specifically try: @@ -143,7 +146,7 @@ async def _get_model_choice(provider: str | None = None) -> str: else: # OpenAI or other providers model = "gpt-4o-mini" - + search_logger.debug(f"Using model from credential service: {model}") return model @@ -187,6 +190,10 @@ async def generate_contextual_embeddings_batch( batch_prompt += "For each chunk, provide a short succinct context to situate it within the overall document for improving search retrieval. Format your response as:\\nCHUNK 1: [context]\\nCHUNK 2: [context]\\netc." + # Prepare compatible parameters for the API call + params = prepare_llm_params(provider or "openai", model_choice, + temperature=0, max_tokens=100 * len(chunks)) + # Make single API call for ALL chunks response = await client.chat.completions.create( model=model_choice, @@ -197,8 +204,7 @@ async def generate_contextual_embeddings_batch( }, {"role": "user", "content": batch_prompt}, ], - temperature=0, - max_tokens=100 * len(chunks), # Limit response size + **params ) # Parse response @@ -245,4 +251,4 @@ async def generate_contextual_embeddings_batch( except Exception as e: search_logger.error(f"Error in contextual embedding batch: {e}") # Return non-contextual for all chunks - return [(chunk, False) for chunk in chunks] \ No newline at end of file + return [(chunk, False) for chunk in chunks] diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py index f04f0741ba..645f956567 100644 --- a/python/src/server/services/llm_provider_service.py +++ b/python/src/server/services/llm_provider_service.py @@ -383,3 +383,47 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "error_message": str(e), "validation_timestamp": time.time() } + + +def prepare_llm_params(provider: str, model: str, **kwargs) -> dict: + """ + Prepare LLM API parameters with automatic compatibility handling. + + Handles: + - OpenAI max_tokens → max_completion_tokens deprecation + - Reasoning model temperature exclusion (o1, gpt-5 series) + + Args: + provider: LLM provider name (openai, ollama, google) + model: Model name to check for special requirements + **kwargs: Original API parameters + + Returns: + dict: Compatible parameters ready for API call + """ + params = kwargs.copy() + + # Handle OpenAI parameter deprecation + if provider == "openai" and "max_tokens" in params: + params["max_completion_tokens"] = params.pop("max_tokens") + + # Handle reasoning model restrictions + if model and _is_reasoning_model(model): + params.pop("temperature", None) + + return params + + +def _is_reasoning_model(model: str) -> bool: + """ + Check if model is a reasoning model that doesn't support custom temperature. + + Args: + model: Model name to check + + Returns: + True if model is a reasoning model, False otherwise + """ + reasoning_patterns = ["o1", "o1-preview", "o1-mini", "gpt-5"] + model_lower = model.lower() + return any(pattern in model_lower for pattern in reasoning_patterns) diff --git a/python/src/server/services/storage/code_storage_service.py b/python/src/server/services/storage/code_storage_service.py index ece5ea1007..390fb507e9 100644 --- a/python/src/server/services/storage/code_storage_service.py +++ b/python/src/server/services/storage/code_storage_service.py @@ -507,7 +507,7 @@ def generate_code_example_summary( A dictionary with 'summary' and 'example_name' """ import asyncio - + # Run the async version in the current thread return asyncio.run(_generate_code_example_summary_async(code, context_before, context_after, language, provider)) @@ -518,8 +518,8 @@ async def _generate_code_example_summary_async( """ Async version of generate_code_example_summary using unified LLM provider service. """ - from ..llm_provider_service import get_llm_client - + from ..llm_provider_service import get_llm_client, prepare_llm_params + # Get model choice from credential service (RAG setting) model_choice = _get_model_choice() @@ -555,7 +555,11 @@ async def _generate_code_example_summary_async( search_logger.info( f"Generating summary for {hash(code) & 0xffffff:06x} using model: {model_choice}" ) - + + # Prepare compatible parameters for the API call + params = prepare_llm_params(provider or "openai", model_choice, + max_tokens=500, temperature=0.3) + response = await client.chat.completions.create( model=model_choice, messages=[ @@ -566,8 +570,7 @@ async def _generate_code_example_summary_async( {"role": "user", "content": prompt}, ], response_format={"type": "json_object"}, - max_tokens=500, - temperature=0.3, + **params ) response_content = response.choices[0].message.content.strip() @@ -848,14 +851,14 @@ async def add_code_examples_to_supabase( # Use only successful embeddings valid_embeddings = result.embeddings successful_texts = result.texts_processed - + # Get model information for tracking - from ..llm_provider_service import get_embedding_model from ..credential_service import credential_service - + from ..llm_provider_service import get_embedding_model + # Get embedding model name embedding_model_name = await get_embedding_model(provider=provider) - + # Get LLM chat model (used for code summaries and contextual embeddings if enabled) llm_chat_model = None try: @@ -908,7 +911,7 @@ async def add_code_examples_to_supabase( # Determine the correct embedding column based on dimension embedding_dim = len(embedding) if isinstance(embedding, list) else len(embedding.tolist()) embedding_column = None - + if embedding_dim == 768: embedding_column = "embedding_768" elif embedding_dim == 1024: @@ -921,7 +924,7 @@ async def add_code_examples_to_supabase( # Default to closest supported dimension search_logger.warning(f"Unsupported embedding dimension {embedding_dim}, using embedding_1536") embedding_column = "embedding_1536" - + batch_data.append({ "url": urls[idx], "chunk_number": chunk_numbers[idx],