diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4d971e8ce42..09512ac5fd1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1115,6 +1115,8 @@ def swagger_monkey_patch(*args, **kwargs): redis_usage_cache: Optional[RedisCache] = ( None # redis cache used for tracking spend, tpm/rpm limits ) +polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False +polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -2317,6 +2319,15 @@ async def load_config( # noqa: PLR0915 # this is set in the cache branch # see usage here: https://docs.litellm.ai/docs/proxy/caching pass + elif key == "responses": + # Initialize global polling via cache settings + global polling_via_cache_enabled, polling_cache_ttl + background_mode = value.get("background_mode", {}) + polling_via_cache_enabled = background_mode.get("polling_via_cache", False) + polling_cache_ttl = background_mode.get("ttl", 3600) + verbose_proxy_logger.debug( + f"{blue_color_code} Initialized polling via cache: enabled={polling_via_cache_enabled}, ttl={polling_cache_ttl}{reset_color_code}" + ) elif key == "default_team_settings": for idx, team_setting in enumerate( value diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index 26d10c1ac47..01e70298ded 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -1,8 +1,12 @@ -from fastapi import APIRouter, Depends, Request, Response +import asyncio +from fastapi import APIRouter, Depends, HTTPException, Request, Response + +from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.types.responses.main import DeleteResponseResult router = APIRouter() @@ -30,7 +34,12 @@ async def responses_api( """ Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses + Supports background mode with polling_via_cache for partial response retrieval. + When background=true and polling_via_cache is enabled, returns a polling_id immediately + and streams the response in the background, updating Redis cache. + ```bash + # Normal request curl -X POST http://localhost:4000/v1/responses \ -H "Content-Type: application/json" \ -H "Authorization: Bearer sk-1234" \ @@ -38,14 +47,27 @@ async def responses_api( "model": "gpt-4o", "input": "Tell me about AI" }' + + # Background request with polling + curl -X POST http://localhost:4000/v1/responses \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "model": "gpt-4o", + "input": "Tell me about AI", + "background": true + }' ``` """ from litellm.proxy.proxy_server import ( _read_request_body, general_settings, llm_router, + polling_cache_ttl, + polling_via_cache_enabled, proxy_config, proxy_logging_obj, + redis_usage_cache, select_data_generator, user_api_base, user_max_tokens, @@ -56,6 +78,74 @@ async def responses_api( ) data = await _read_request_body(request=request) + + # Check if polling via cache should be used for this request + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + should_use_polling = should_use_polling_for_request( + background_mode=data.get("background", False), + polling_via_cache_enabled=polling_via_cache_enabled, + redis_cache=redis_usage_cache, + model=data.get("model", ""), + llm_router=llm_router, + ) + + # If polling is enabled, use polling mode + if should_use_polling: + from litellm.proxy.response_polling.polling_handler import ( + ResponsePollingHandler, + ) + from litellm.proxy.response_polling.background_streaming import ( + background_streaming_task, + ) + + verbose_proxy_logger.info( + f"Starting background response with polling for model={data.get('model')}" + ) + + # Initialize polling handler with configured TTL (from global config) + polling_handler = ResponsePollingHandler( + redis_cache=redis_usage_cache, + ttl=polling_cache_ttl # Global var set at startup + ) + + # Generate polling ID + polling_id = ResponsePollingHandler.generate_polling_id() + + # Create initial state in Redis + initial_state = await polling_handler.create_initial_state( + polling_id=polling_id, + request_data=data, + ) + + # Start background task to stream and update cache + asyncio.create_task( + background_streaming_task( + polling_id=polling_id, + data=data.copy(), + polling_handler=polling_handler, + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + general_settings=general_settings, + llm_router=llm_router, + proxy_config=proxy_config, + proxy_logging_obj=proxy_logging_obj, + select_data_generator=select_data_generator, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + ) + + # Return OpenAI Response object format (initial state) + # https://platform.openai.com/docs/api-reference/responses/object + return initial_state + + # Normal response flow processor = ProxyBaseLLMRequestProcessing(data=data) try: return await processor.base_process_llm_request( @@ -109,9 +199,18 @@ async def get_response( """ Get a response by ID. + Supports both: + - Polling IDs (litellm_poll_*): Returns cumulative cached content from background responses + - Provider response IDs: Passes through to provider API + Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/get ```bash + # Get polling response + curl -X GET http://localhost:4000/v1/responses/litellm_poll_abc123 \ + -H "Authorization: Bearer sk-1234" + + # Get provider response curl -X GET http://localhost:4000/v1/responses/resp_abc123 \ -H "Authorization: Bearer sk-1234" ``` @@ -122,6 +221,7 @@ async def get_response( llm_router, proxy_config, proxy_logging_obj, + redis_usage_cache, select_data_generator, user_api_base, user_max_tokens, @@ -130,7 +230,33 @@ async def get_response( user_temperature, version, ) - + from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + # Check if this is a polling ID + if ResponsePollingHandler.is_polling_id(response_id): + # Handle polling response + if not redis_usage_cache: + raise HTTPException( + status_code=500, + detail="Redis cache not configured. Polling requires Redis." + ) + + polling_handler = ResponsePollingHandler(redis_cache=redis_usage_cache) + + # Get current state from cache + state = await polling_handler.get_state(response_id) + + if not state: + raise HTTPException( + status_code=404, + detail=f"Polling response {response_id} not found or expired" + ) + + # Return the whole state directly (OpenAI Response object format) + # https://platform.openai.com/docs/api-reference/responses/object + return state + + # Normal provider response flow data = await _read_request_body(request=request) data["response_id"] = response_id processor = ProxyBaseLLMRequestProcessing(data=data) @@ -186,6 +312,10 @@ async def delete_response( """ Delete a response by ID. + Supports both: + - Polling IDs (litellm_poll_*): Deletes from Redis cache + - Provider response IDs: Passes through to provider API + Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/delete ```bash @@ -199,6 +329,7 @@ async def delete_response( llm_router, proxy_config, proxy_logging_obj, + redis_usage_cache, select_data_generator, user_api_base, user_max_tokens, @@ -207,7 +338,44 @@ async def delete_response( user_temperature, version, ) - + from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + # Check if this is a polling ID + if ResponsePollingHandler.is_polling_id(response_id): + # Handle polling response deletion + if not redis_usage_cache: + raise HTTPException( + status_code=500, + detail="Redis cache not configured." + ) + + polling_handler = ResponsePollingHandler(redis_cache=redis_usage_cache) + + # Get state to verify access + state = await polling_handler.get_state(response_id) + + if not state: + raise HTTPException( + status_code=404, + detail=f"Polling response {response_id} not found" + ) + + # Delete from cache + success = await polling_handler.delete_polling(response_id) + + if success: + return DeleteResponseResult( + id=response_id, + object="response", + deleted=True + ) + else: + raise HTTPException( + status_code=500, + detail="Failed to delete polling response" + ) + + # Normal provider response flow data = await _read_request_body(request=request) data["response_id"] = response_id processor = ProxyBaseLLMRequestProcessing(data=data) @@ -331,9 +499,18 @@ async def cancel_response( """ Cancel a response by ID. + Supports both: + - Polling IDs (litellm_poll_*): Cancels background response and updates status in Redis + - Provider response IDs: Passes through to provider API + Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/cancel ```bash + # Cancel polling response + curl -X POST http://localhost:4000/v1/responses/litellm_poll_abc123/cancel \ + -H "Authorization: Bearer sk-1234" + + # Cancel provider response curl -X POST http://localhost:4000/v1/responses/resp_abc123/cancel \ -H "Authorization: Bearer sk-1234" ``` @@ -344,6 +521,7 @@ async def cancel_response( llm_router, proxy_config, proxy_logging_obj, + redis_usage_cache, select_data_generator, user_api_base, user_max_tokens, @@ -352,7 +530,44 @@ async def cancel_response( user_temperature, version, ) - + from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + # Check if this is a polling ID + if ResponsePollingHandler.is_polling_id(response_id): + # Handle polling response cancellation + if not redis_usage_cache: + raise HTTPException( + status_code=500, + detail="Redis cache not configured." + ) + + polling_handler = ResponsePollingHandler(redis_cache=redis_usage_cache) + + # Get current state to verify it exists + state = await polling_handler.get_state(response_id) + + if not state: + raise HTTPException( + status_code=404, + detail=f"Polling response {response_id} not found" + ) + + # Cancel the polling response (sets status to "cancelled") + success = await polling_handler.cancel_polling(response_id) + + if success: + # Fetch the updated state with cancelled status + updated_state = await polling_handler.get_state(response_id) + + # Return the whole state directly (now with status="cancelled") + return updated_state + else: + raise HTTPException( + status_code=500, + detail="Failed to cancel polling response" + ) + + # Normal provider response flow data = await _read_request_body(request=request) data["response_id"] = response_id processor = ProxyBaseLLMRequestProcessing(data=data) diff --git a/litellm/proxy/response_polling/__init__.py b/litellm/proxy/response_polling/__init__.py new file mode 100644 index 00000000000..b500354c373 --- /dev/null +++ b/litellm/proxy/response_polling/__init__.py @@ -0,0 +1,16 @@ +""" +Response Polling Module for Background Responses with Cache +""" +from litellm.proxy.response_polling.background_streaming import ( + background_streaming_task, +) +from litellm.proxy.response_polling.polling_handler import ( + ResponsePollingHandler, + should_use_polling_for_request, +) + +__all__ = [ + "ResponsePollingHandler", + "background_streaming_task", + "should_use_polling_for_request", +] diff --git a/litellm/proxy/response_polling/background_streaming.py b/litellm/proxy/response_polling/background_streaming.py new file mode 100644 index 00000000000..b0dcb69a82e --- /dev/null +++ b/litellm/proxy/response_polling/background_streaming.py @@ -0,0 +1,306 @@ +""" +Background Streaming Task for Polling Via Cache Feature + +Handles streaming responses from LLM providers and updates Redis cache +with partial results for polling. + +Follows OpenAI Response Streaming format: +https://platform.openai.com/docs/api-reference/responses-streaming +""" +import asyncio +import json + +from fastapi import Request, Response + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + +async def background_streaming_task( # noqa: PLR0915 + polling_id: str, + data: dict, + polling_handler: ResponsePollingHandler, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth, + general_settings: dict, + llm_router, + proxy_config, + proxy_logging_obj, + select_data_generator, + user_model, + user_temperature, + user_request_timeout, + user_max_tokens, + user_api_base, + version, +): + """ + Background task to stream response and update cache + + Follows OpenAI Response Streaming format: + https://platform.openai.com/docs/api-reference/responses-streaming + + Processes streaming events and builds Response object: + https://platform.openai.com/docs/api-reference/responses/object + """ + + try: + verbose_proxy_logger.info(f"Starting background streaming for {polling_id}") + + # Update status to in_progress (OpenAI format) + await polling_handler.update_state( + polling_id=polling_id, + status="in_progress", + ) + + # Force streaming mode and remove background flag + data["stream"] = True + data.pop("background", None) + + # Create processor + processor = ProxyBaseLLMRequestProcessing(data=data) + + # Make streaming request + response = await processor.base_process_llm_request( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + route_type="aresponses", + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + general_settings=general_settings, + proxy_config=proxy_config, + select_data_generator=select_data_generator, + model=None, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + + # Process streaming response following OpenAI events format + # https://platform.openai.com/docs/api-reference/responses-streaming + output_items = {} # Track output items by ID + accumulated_text = {} # Track accumulated text deltas by (item_id, content_index) + + # ResponsesAPIResponse fields to extract from response.completed + usage_data = None + reasoning_data = None + tool_choice_data = None + tools_data = None + model_data = None + instructions_data = None + temperature_data = None + top_p_data = None + max_output_tokens_data = None + previous_response_id_data = None + text_data = None + truncation_data = None + parallel_tool_calls_data = None + user_data = None + store_data = None + incomplete_details_data = None + + state_dirty = False # Track if state needs to be synced + last_update_time = asyncio.get_event_loop().time() + UPDATE_INTERVAL = 0.150 # 150ms batching interval + + async def flush_state_if_needed(force: bool = False) -> None: + """Flush accumulated state to Redis if interval elapsed or forced""" + nonlocal state_dirty, last_update_time + + current_time = asyncio.get_event_loop().time() + if state_dirty and (force or (current_time - last_update_time) >= UPDATE_INTERVAL): + # Convert output_items dict to list for update + output_list = list(output_items.values()) + await polling_handler.update_state( + polling_id=polling_id, + output=output_list, + ) + state_dirty = False + last_update_time = current_time + + # Handle StreamingResponse + if hasattr(response, 'body_iterator'): + async for chunk in response.body_iterator: + # Parse chunk + if isinstance(chunk, bytes): + chunk = chunk.decode('utf-8') + + if isinstance(chunk, str) and chunk.startswith("data: "): + chunk_data = chunk[6:].strip() + if chunk_data == "[DONE]": + break + + try: + event = json.loads(chunk_data) + event_type = event.get("type", "") + + # Process different event types based on OpenAI streaming spec + if event_type == "response.output_item.added": + # New output item added + item = event.get("item", {}) + item_id = item.get("id") + if item_id: + output_items[item_id] = item + state_dirty = True + + elif event_type == "response.content_part.added": + # Content part added to an output item + item_id = event.get("item_id") + content_part = event.get("part", {}) + + if item_id and item_id in output_items: + # Update the output item with new content + if "content" not in output_items[item_id]: + output_items[item_id]["content"] = [] + output_items[item_id]["content"].append(content_part) + state_dirty = True + + elif event_type == "response.output_text.delta": + # Text delta - accumulate text content + # https://platform.openai.com/docs/api-reference/responses-streaming/response-text-delta + item_id = event.get("item_id") + content_index = event.get("content_index", 0) + delta = event.get("delta", "") + + if item_id and item_id in output_items: + # Accumulate text delta + key = (item_id, content_index) + if key not in accumulated_text: + accumulated_text[key] = "" + accumulated_text[key] += delta + + # Update the content in output_items + if "content" in output_items[item_id]: + content_list = output_items[item_id]["content"] + if content_index < len(content_list): + # Update existing content part with accumulated text + if isinstance(content_list[content_index], dict): + content_list[content_index]["text"] = accumulated_text[key] + state_dirty = True + + elif event_type == "response.content_part.done": + # Content part completed + item_id = event.get("item_id") + content_part = event.get("part", {}) + content_index = event.get("content_index", 0) + + if item_id and item_id in output_items: + # Update with final content from event + if "content" in output_items[item_id]: + content_list = output_items[item_id]["content"] + if content_index < len(content_list): + content_list[content_index] = content_part + state_dirty = True + + elif event_type == "response.output_item.done": + # Output item completed - use final item data + item = event.get("item", {}) + item_id = item.get("id") + if item_id: + output_items[item_id] = item + state_dirty = True + + elif event_type == "response.in_progress": + # Response is now in progress + # https://platform.openai.com/docs/api-reference/responses-streaming/response-in-progress + await polling_handler.update_state( + polling_id=polling_id, + status="in_progress", + ) + + elif event_type == "response.completed": + # Response completed - extract all ResponsesAPIResponse fields + # https://platform.openai.com/docs/api-reference/responses-streaming/response-completed + response_data = event.get("response", {}) + + # Core response fields + usage_data = response_data.get("usage") + reasoning_data = response_data.get("reasoning") + tool_choice_data = response_data.get("tool_choice") + tools_data = response_data.get("tools") + + # Additional ResponsesAPIResponse fields + model_data = response_data.get("model") + instructions_data = response_data.get("instructions") + temperature_data = response_data.get("temperature") + top_p_data = response_data.get("top_p") + max_output_tokens_data = response_data.get("max_output_tokens") + previous_response_id_data = response_data.get("previous_response_id") + text_data = response_data.get("text") + truncation_data = response_data.get("truncation") + parallel_tool_calls_data = response_data.get("parallel_tool_calls") + user_data = response_data.get("user") + store_data = response_data.get("store") + incomplete_details_data = response_data.get("incomplete_details") + + # Also update output from final response if available + if "output" in response_data: + final_output = response_data.get("output", []) + for item in final_output: + item_id = item.get("id") + if item_id: + output_items[item_id] = item + state_dirty = True + + # Flush state to Redis if interval elapsed + await flush_state_if_needed() + + except json.JSONDecodeError as e: + verbose_proxy_logger.warning( + f"Failed to parse streaming chunk: {e}" + ) + pass + + # Final flush to ensure all accumulated state is saved + await flush_state_if_needed(force=True) + + # Mark as completed with all ResponsesAPIResponse fields + await polling_handler.update_state( + polling_id=polling_id, + status="completed", + usage=usage_data, + reasoning=reasoning_data, + tool_choice=tool_choice_data, + tools=tools_data, + model=model_data, + instructions=instructions_data, + temperature=temperature_data, + top_p=top_p_data, + max_output_tokens=max_output_tokens_data, + previous_response_id=previous_response_id_data, + text=text_data, + truncation=truncation_data, + parallel_tool_calls=parallel_tool_calls_data, + user=user_data, + store=store_data, + incomplete_details=incomplete_details_data, + ) + + verbose_proxy_logger.info( + f"Completed background streaming for {polling_id}, output_items={len(output_items)}" + ) + + except Exception as e: + verbose_proxy_logger.error( + f"Error in background streaming task for {polling_id}: {str(e)}" + ) + import traceback + verbose_proxy_logger.error(traceback.format_exc()) + + await polling_handler.update_state( + polling_id=polling_id, + status="failed", + error={ + "type": "internal_error", + "message": str(e), + "code": "background_streaming_error" + }, + ) + diff --git a/litellm/proxy/response_polling/polling_handler.py b/litellm/proxy/response_polling/polling_handler.py new file mode 100644 index 00000000000..121b128f06d --- /dev/null +++ b/litellm/proxy/response_polling/polling_handler.py @@ -0,0 +1,323 @@ +""" +Response Polling Handler for Background Responses with Cache +""" +import json +from typing import Any, Dict, Optional +from datetime import datetime, timezone + +from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid4 +from litellm.caching.redis_cache import RedisCache +from litellm.types.llms.openai import ResponsesAPIResponse, ResponsesAPIStatus + + +class ResponsePollingHandler: + """Handles polling-based responses with Redis cache""" + + CACHE_KEY_PREFIX = "litellm:polling:response:" + POLLING_ID_PREFIX = "litellm_poll_" # Clear prefix to identify polling IDs + + def __init__(self, redis_cache: Optional[RedisCache] = None, ttl: int = 3600): + self.redis_cache = redis_cache + self.ttl = ttl # Time-to-live for cache entries (default: 1 hour) + + @classmethod + def generate_polling_id(cls) -> str: + """Generate a unique UUID for polling with clear prefix""" + return f"{cls.POLLING_ID_PREFIX}{uuid4()}" + + @classmethod + def is_polling_id(cls, response_id: str) -> bool: + """Check if a response_id is a polling ID""" + return response_id.startswith(cls.POLLING_ID_PREFIX) + + @classmethod + def get_cache_key(cls, polling_id: str) -> str: + """Get Redis cache key for a polling ID""" + return f"{cls.CACHE_KEY_PREFIX}{polling_id}" + + async def create_initial_state( + self, + polling_id: str, + request_data: Dict[str, Any], + ) -> ResponsesAPIResponse: + """ + Create initial state in Redis for a polling request + + Uses OpenAI ResponsesAPIResponse object: + https://platform.openai.com/docs/api-reference/responses/object + + Args: + polling_id: Unique identifier for this polling request + request_data: Original request data + + Returns: + ResponsesAPIResponse object following OpenAI spec + """ + created_timestamp = int(datetime.now(timezone.utc).timestamp()) + + # Create OpenAI-compliant response object + response = ResponsesAPIResponse( + id=polling_id, + object="response", + status="queued", # OpenAI native status + created_at=created_timestamp, + output=[], + metadata=request_data.get("metadata", {}), + usage=None, + ) + + cache_key = self.get_cache_key(polling_id) + + if self.redis_cache: + # Store ResponsesAPIResponse directly in Redis + await self.redis_cache.async_set_cache( + key=cache_key, + value=response.model_dump_json(), # Pydantic v2 method + ttl=self.ttl, + ) + verbose_proxy_logger.debug( + f"Created initial polling state for {polling_id} with TTL={self.ttl}s" + ) + + return response + + async def update_state( + self, + polling_id: str, + status: Optional[ResponsesAPIStatus] = None, + usage: Optional[Dict] = None, + error: Optional[Dict] = None, + incomplete_details: Optional[Dict] = None, + reasoning: Optional[Dict] = None, + tool_choice: Optional[Any] = None, + tools: Optional[list] = None, + output: Optional[list] = None, + # Additional ResponsesAPIResponse fields + model: Optional[str] = None, + instructions: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_output_tokens: Optional[int] = None, + previous_response_id: Optional[str] = None, + text: Optional[Dict] = None, + truncation: Optional[str] = None, + parallel_tool_calls: Optional[bool] = None, + user: Optional[str] = None, + store: Optional[bool] = None, + ) -> None: + """ + Update the polling state in Redis + + Uses OpenAI Response object format with native status types: + https://platform.openai.com/docs/api-reference/responses/object + + Args: + polling_id: Unique identifier for this polling request + status: OpenAI ResponsesAPIStatus value + usage: Usage information + error: Error dict (automatically sets status to "failed") + incomplete_details: Details for incomplete responses + reasoning: Reasoning configuration from response.completed + tool_choice: Tool choice configuration from response.completed + tools: Tools list from response.completed + output: Full output list to replace current output + model: Model identifier + instructions: System instructions + temperature: Sampling temperature + top_p: Nucleus sampling parameter + max_output_tokens: Maximum output tokens + previous_response_id: ID of previous response in conversation + text: Text configuration + truncation: Truncation setting + parallel_tool_calls: Whether parallel tool calls are enabled + user: User identifier + store: Whether to store the response + """ + if not self.redis_cache: + return + + cache_key = self.get_cache_key(polling_id) + + # Get current state + cached_state = await self.redis_cache.async_get_cache(cache_key) + if not cached_state: + verbose_proxy_logger.warning( + f"No cached state found for polling_id: {polling_id}" + ) + return + + # Parse existing ResponsesAPIResponse from cache + state = json.loads(cached_state) + + # Update status (using OpenAI native status values) + if status: + state["status"] = status + + # Replace full output list if provided + if output is not None: + state["output"] = output + + # Update usage + if usage: + state["usage"] = usage + + # Handle error (sets status to OpenAI's "failed") + if error: + state["status"] = "failed" + state["error"] = error # Use OpenAI's 'error' field + + # Handle incomplete details + if incomplete_details: + state["incomplete_details"] = incomplete_details + + # Update reasoning, tool_choice, tools from response.completed + if reasoning is not None: + state["reasoning"] = reasoning + if tool_choice is not None: + state["tool_choice"] = tool_choice + if tools is not None: + state["tools"] = tools + + # Update additional ResponsesAPIResponse fields + if model is not None: + state["model"] = model + if instructions is not None: + state["instructions"] = instructions + if temperature is not None: + state["temperature"] = temperature + if top_p is not None: + state["top_p"] = top_p + if max_output_tokens is not None: + state["max_output_tokens"] = max_output_tokens + if previous_response_id is not None: + state["previous_response_id"] = previous_response_id + if text is not None: + state["text"] = text + if truncation is not None: + state["truncation"] = truncation + if parallel_tool_calls is not None: + state["parallel_tool_calls"] = parallel_tool_calls + if user is not None: + state["user"] = user + if store is not None: + state["store"] = store + + # Update cache with configured TTL + await self.redis_cache.async_set_cache( + key=cache_key, + value=json.dumps(state), + ttl=self.ttl, + ) + + output_count = len(state.get("output", [])) + verbose_proxy_logger.debug( + f"Updated polling state for {polling_id}: status={state['status']}, output_items={output_count}" + ) + + async def get_state(self, polling_id: str) -> Optional[Dict[str, Any]]: + """Get current polling state from Redis""" + if not self.redis_cache: + return None + + cache_key = self.get_cache_key(polling_id) + cached_state = await self.redis_cache.async_get_cache(cache_key) + + if cached_state: + return json.loads(cached_state) + + return None + + async def cancel_polling(self, polling_id: str) -> bool: + """ + Cancel a polling request + + Following OpenAI Response object format for cancelled status + """ + await self.update_state( + polling_id=polling_id, + status="cancelled", + ) + return True + + async def delete_polling(self, polling_id: str) -> bool: + """Delete a polling request from cache""" + if not self.redis_cache: + return False + + cache_key = self.get_cache_key(polling_id) + # Redis client's delete method + if hasattr(self.redis_cache, 'redis_async_client'): + async_client = self.redis_cache.init_async_client() + await async_client.delete(cache_key) + return True + + return False + + +def should_use_polling_for_request( + background_mode: bool, + polling_via_cache_enabled, # Can be False, "all", or List[str] + redis_cache, # RedisCache or None + model: str, + llm_router, # Router instance or None +) -> bool: + """ + Determine if polling via cache should be used for a request. + + Args: + background_mode: Whether background=true was set in the request + polling_via_cache_enabled: Config value - False, "all", or list of providers + redis_cache: Redis cache instance (required for polling) + model: Model name from the request (e.g., "gpt-5" or "openai/gpt-4o") + llm_router: LiteLLM router instance for looking up model deployments + + Returns: + True if polling should be used, False otherwise + """ + # All conditions must be met + if not (background_mode and polling_via_cache_enabled and redis_cache): + return False + + # "all" enables polling for all providers + if polling_via_cache_enabled == "all": + return True + + # Check if provider is in the enabled list + if isinstance(polling_via_cache_enabled, list): + # First, try to get provider from model string format "provider/model" + if "/" in model: + provider = model.split("/")[0] + if provider in polling_via_cache_enabled: + return True + # Otherwise, check ALL deployments for this model_name in router + elif llm_router is not None: + try: + # Get all deployment indices for this model name + indices = llm_router.model_name_to_deployment_indices.get(model, []) + for idx in indices: + deployment_dict = llm_router.model_list[idx] + litellm_params = deployment_dict.get("litellm_params", {}) + + # Check custom_llm_provider first + dep_provider = litellm_params.get("custom_llm_provider") + + # Then try to extract from model (e.g., "openai/gpt-5") + if not dep_provider: + dep_model = litellm_params.get("model", "") + if "/" in dep_model: + dep_provider = dep_model.split("/")[0] + + # If ANY deployment's provider matches, enable polling + if dep_provider and dep_provider in polling_via_cache_enabled: + verbose_proxy_logger.debug( + f"Polling enabled for model={model}, provider={dep_provider}" + ) + return True + except Exception as e: + verbose_proxy_logger.debug( + f"Could not resolve provider for model {model}: {e}" + ) + + return False + diff --git a/tests/proxy_unit_tests/test_response_polling_handler.py b/tests/proxy_unit_tests/test_response_polling_handler.py new file mode 100644 index 00000000000..5d9b83969f7 --- /dev/null +++ b/tests/proxy_unit_tests/test_response_polling_handler.py @@ -0,0 +1,1256 @@ +""" +Unit tests for ResponsePollingHandler + +Tests core functionality including: +1. Polling ID generation and detection +2. Initial state creation (queued status) +3. State updates with batched output +4. Status transitions (queued -> in_progress -> completed) +5. Response completion with reasoning, tools, tool_choice +6. Error handling and cancellation +7. Cache key generation + +These tests ensure the polling handler correctly manages response state +following the OpenAI Response API format. +""" + +import json +import os +import sys +from datetime import datetime, timezone +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + +class TestResponsePollingHandler: + """Test cases for ResponsePollingHandler""" + + # ==================== Polling ID Tests ==================== + + def test_generate_polling_id_has_correct_prefix(self): + """Test that generated polling IDs have the correct prefix""" + polling_id = ResponsePollingHandler.generate_polling_id() + + assert polling_id.startswith("litellm_poll_") + assert len(polling_id) > len("litellm_poll_") # Has UUID after prefix + + def test_generate_polling_id_is_unique(self): + """Test that each generated polling ID is unique""" + ids = [ResponsePollingHandler.generate_polling_id() for _ in range(100)] + + assert len(ids) == len(set(ids)) # All unique + + def test_is_polling_id_returns_true_for_polling_ids(self): + """Test that is_polling_id correctly identifies polling IDs""" + polling_id = ResponsePollingHandler.generate_polling_id() + + assert ResponsePollingHandler.is_polling_id(polling_id) is True + + def test_is_polling_id_returns_false_for_provider_ids(self): + """Test that is_polling_id returns False for provider response IDs""" + # OpenAI format + assert ResponsePollingHandler.is_polling_id("resp_abc123") is False + # Anthropic format + assert ResponsePollingHandler.is_polling_id("msg_01XFDUDYJgAACzvnptvVoYEL") is False + # Generic UUID + assert ResponsePollingHandler.is_polling_id("550e8400-e29b-41d4-a716-446655440000") is False + + def test_get_cache_key_format(self): + """Test that cache keys have the correct format""" + polling_id = "litellm_poll_abc123" + cache_key = ResponsePollingHandler.get_cache_key(polling_id) + + assert cache_key == "litellm:polling:response:litellm_poll_abc123" + + # ==================== Initial State Tests ==================== + + @pytest.mark.asyncio + async def test_create_initial_state_returns_queued_status(self): + """Test that create_initial_state returns response with queued status""" + mock_redis = AsyncMock() + handler = ResponsePollingHandler(redis_cache=mock_redis, ttl=3600) + + polling_id = "litellm_poll_test123" + request_data = { + "model": "gpt-4o", + "input": "Hello", + "metadata": {"test": "value"} + } + + response = await handler.create_initial_state( + polling_id=polling_id, + request_data=request_data, + ) + + assert response.id == polling_id + assert response.object == "response" + assert response.status == "queued" + assert response.output == [] + assert response.usage is None + assert response.metadata == {"test": "value"} + + @pytest.mark.asyncio + async def test_create_initial_state_stores_in_redis(self): + """Test that create_initial_state stores state in Redis with correct TTL""" + mock_redis = AsyncMock() + handler = ResponsePollingHandler(redis_cache=mock_redis, ttl=7200) + + polling_id = "litellm_poll_test123" + request_data = {"model": "gpt-4o", "input": "Hello"} + + await handler.create_initial_state( + polling_id=polling_id, + request_data=request_data, + ) + + # Verify Redis was called with correct parameters + mock_redis.async_set_cache.assert_called_once() + call_args = mock_redis.async_set_cache.call_args + + assert call_args.kwargs["key"] == "litellm:polling:response:litellm_poll_test123" + assert call_args.kwargs["ttl"] == 7200 + + # Verify the stored value is valid JSON + stored_value = call_args.kwargs["value"] + parsed = json.loads(stored_value) + assert parsed["id"] == polling_id + assert parsed["status"] == "queued" + + @pytest.mark.asyncio + async def test_create_initial_state_sets_created_at_timestamp(self): + """Test that create_initial_state sets a valid created_at timestamp""" + mock_redis = AsyncMock() + handler = ResponsePollingHandler(redis_cache=mock_redis) + + before_time = int(datetime.now(timezone.utc).timestamp()) + + response = await handler.create_initial_state( + polling_id="litellm_poll_test", + request_data={}, + ) + + after_time = int(datetime.now(timezone.utc).timestamp()) + + assert before_time <= response.created_at <= after_time + + # ==================== State Update Tests ==================== + + @pytest.mark.asyncio + async def test_update_state_changes_status_to_in_progress(self): + """Test that update_state can change status to in_progress""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "queued", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis, ttl=3600) + + await handler.update_state( + polling_id="litellm_poll_test", + status="in_progress", + ) + + # Verify the update was saved + mock_redis.async_set_cache.assert_called_once() + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["status"] == "in_progress" + + @pytest.mark.asyncio + async def test_update_state_replaces_full_output_list(self): + """Test that update_state replaces the full output list""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [{"id": "old_item", "type": "message"}], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis, ttl=3600) + + new_output = [ + {"id": "item_1", "type": "message", "content": [{"type": "text", "text": "Hello"}]}, + {"id": "item_2", "type": "message", "content": [{"type": "text", "text": "World"}]}, + ] + + await handler.update_state( + polling_id="litellm_poll_test", + output=new_output, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert len(stored["output"]) == 2 + assert stored["output"][0]["id"] == "item_1" + assert stored["output"][1]["id"] == "item_2" + + @pytest.mark.asyncio + async def test_update_state_with_usage(self): + """Test that update_state correctly stores usage data""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + usage_data = { + "input_tokens": 10, + "output_tokens": 50, + "total_tokens": 60 + } + + await handler.update_state( + polling_id="litellm_poll_test", + status="completed", + usage=usage_data, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["status"] == "completed" + assert stored["usage"] == usage_data + + @pytest.mark.asyncio + async def test_update_state_with_reasoning_tools_tool_choice(self): + """Test that update_state stores reasoning, tools, and tool_choice from response.completed""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + reasoning_data = {"effort": "medium", "summary": "Step by step analysis"} + tool_choice_data = {"type": "function", "function": {"name": "get_weather"}} + tools_data = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] + + await handler.update_state( + polling_id="litellm_poll_test", + status="completed", + reasoning=reasoning_data, + tool_choice=tool_choice_data, + tools=tools_data, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["reasoning"] == reasoning_data + assert stored["tool_choice"] == tool_choice_data + assert stored["tools"] == tools_data + + @pytest.mark.asyncio + async def test_update_state_with_all_responses_api_fields(self): + """Test that update_state stores all ResponsesAPIResponse fields from response.completed""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + # All ResponsesAPIResponse fields that can be updated + await handler.update_state( + polling_id="litellm_poll_test", + status="completed", + usage={"input_tokens": 10, "output_tokens": 50, "total_tokens": 60}, + reasoning={"effort": "medium"}, + tool_choice={"type": "auto"}, + tools=[{"type": "function", "function": {"name": "test"}}], + model="gpt-4o", + instructions="You are a helpful assistant", + temperature=0.7, + top_p=0.9, + max_output_tokens=1000, + previous_response_id="resp_prev_123", + text={"format": {"type": "text"}}, + truncation="auto", + parallel_tool_calls=True, + user="user_123", + store=True, + incomplete_details={"reason": "max_output_tokens"}, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + # Verify all fields are stored correctly + assert stored["status"] == "completed" + assert stored["usage"] == {"input_tokens": 10, "output_tokens": 50, "total_tokens": 60} + assert stored["reasoning"] == {"effort": "medium"} + assert stored["tool_choice"] == {"type": "auto"} + assert stored["tools"] == [{"type": "function", "function": {"name": "test"}}] + assert stored["model"] == "gpt-4o" + assert stored["instructions"] == "You are a helpful assistant" + assert stored["temperature"] == 0.7 + assert stored["top_p"] == 0.9 + assert stored["max_output_tokens"] == 1000 + assert stored["previous_response_id"] == "resp_prev_123" + assert stored["text"] == {"format": {"type": "text"}} + assert stored["truncation"] == "auto" + assert stored["parallel_tool_calls"] is True + assert stored["user"] == "user_123" + assert stored["store"] is True + assert stored["incomplete_details"] == {"reason": "max_output_tokens"} + + @pytest.mark.asyncio + async def test_update_state_preserves_existing_fields(self): + """Test that update_state preserves fields not being updated""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [{"id": "item_1", "type": "message"}], + "created_at": 1234567890, + "model": "gpt-4o", + "temperature": 0.5, + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + # Only update status + await handler.update_state( + polling_id="litellm_poll_test", + status="completed", + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + # Verify existing fields are preserved + assert stored["status"] == "completed" + assert stored["model"] == "gpt-4o" + assert stored["temperature"] == 0.5 + assert stored["output"] == [{"id": "item_1", "type": "message"}] + + @pytest.mark.asyncio + async def test_update_state_with_error_sets_failed_status(self): + """Test that providing an error automatically sets status to failed""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + error_data = { + "type": "internal_error", + "message": "Something went wrong", + "code": "server_error" + } + + await handler.update_state( + polling_id="litellm_poll_test", + error=error_data, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["status"] == "failed" + assert stored["error"] == error_data + + @pytest.mark.asyncio + async def test_update_state_with_incomplete_details(self): + """Test that update_state stores incomplete_details""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + incomplete_details = { + "reason": "max_output_tokens" + } + + await handler.update_state( + polling_id="litellm_poll_test", + status="incomplete", + incomplete_details=incomplete_details, + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["status"] == "incomplete" + assert stored["incomplete_details"] == incomplete_details + + @pytest.mark.asyncio + async def test_update_state_does_nothing_without_redis(self): + """Test that update_state gracefully handles no Redis cache""" + handler = ResponsePollingHandler(redis_cache=None) + + # Should not raise an exception + await handler.update_state( + polling_id="litellm_poll_test", + status="in_progress", + ) + + @pytest.mark.asyncio + async def test_update_state_handles_missing_cached_state(self): + """Test that update_state handles case when cached state doesn't exist""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = None # Cache miss + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + # Should not raise an exception + await handler.update_state( + polling_id="litellm_poll_test", + status="in_progress", + ) + + # Should not try to set cache if nothing was found + mock_redis.async_set_cache.assert_not_called() + + # ==================== Get State Tests ==================== + + @pytest.mark.asyncio + async def test_get_state_returns_cached_state(self): + """Test that get_state returns the cached state""" + mock_redis = AsyncMock() + cached_state = { + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [{"id": "item_1", "type": "message"}], + "created_at": 1234567890, + "usage": {"input_tokens": 10, "output_tokens": 20} + } + mock_redis.async_get_cache.return_value = json.dumps(cached_state) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + result = await handler.get_state("litellm_poll_test") + + assert result == cached_state + + @pytest.mark.asyncio + async def test_get_state_returns_none_for_missing_state(self): + """Test that get_state returns None when state doesn't exist""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = None + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + result = await handler.get_state("litellm_poll_nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_get_state_returns_none_without_redis(self): + """Test that get_state returns None when Redis is not configured""" + handler = ResponsePollingHandler(redis_cache=None) + + result = await handler.get_state("litellm_poll_test") + + assert result is None + + # ==================== Cancel Polling Tests ==================== + + @pytest.mark.asyncio + async def test_cancel_polling_updates_status_to_cancelled(self): + """Test that cancel_polling sets status to cancelled""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + result = await handler.cancel_polling("litellm_poll_test") + + assert result is True + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + assert stored["status"] == "cancelled" + + # ==================== Delete Polling Tests ==================== + + @pytest.mark.asyncio + async def test_delete_polling_removes_from_cache(self): + """Test that delete_polling removes the entry from Redis""" + mock_redis = AsyncMock() + mock_async_client = AsyncMock() + mock_redis.redis_async_client = True # hasattr check + # init_async_client is a sync method that returns an async client + mock_redis.init_async_client = Mock(return_value=mock_async_client) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + result = await handler.delete_polling("litellm_poll_test") + + assert result is True + mock_async_client.delete.assert_called_once_with( + "litellm:polling:response:litellm_poll_test" + ) + + @pytest.mark.asyncio + async def test_delete_polling_returns_false_without_redis(self): + """Test that delete_polling returns False when Redis is not configured""" + handler = ResponsePollingHandler(redis_cache=None) + + result = await handler.delete_polling("litellm_poll_test") + + assert result is False + + # ==================== TTL Tests ==================== + + def test_default_ttl_is_one_hour(self): + """Test that default TTL is 3600 seconds (1 hour)""" + handler = ResponsePollingHandler(redis_cache=None) + + assert handler.ttl == 3600 + + def test_custom_ttl_is_respected(self): + """Test that custom TTL is stored correctly""" + handler = ResponsePollingHandler(redis_cache=None, ttl=7200) + + assert handler.ttl == 7200 + + @pytest.mark.asyncio + async def test_update_state_uses_configured_ttl(self): + """Test that update_state uses the configured TTL""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "queued", + "output": [], + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis, ttl=1800) + + await handler.update_state( + polling_id="litellm_poll_test", + status="in_progress", + ) + + call_args = mock_redis.async_set_cache.call_args + assert call_args.kwargs["ttl"] == 1800 + + +class TestStreamingEventProcessing: + """ + Test cases for streaming event processing logic. + + These tests verify the expected behavior when processing different + OpenAI streaming event types. + """ + + def test_accumulated_text_structure(self): + """Test the structure used for accumulating text deltas""" + accumulated_text = {} + + # Simulate accumulating deltas for (item_id, content_index) + key = ("item_123", 0) + accumulated_text[key] = "" + accumulated_text[key] += "Hello " + accumulated_text[key] += "World" + + assert accumulated_text[key] == "Hello World" + assert ("item_123", 0) in accumulated_text + assert ("item_123", 1) not in accumulated_text + + def test_output_items_tracking_structure(self): + """Test the structure used for tracking output items by ID""" + output_items = {} + + # Simulate adding output items + item1 = {"id": "item_1", "type": "message", "content": []} + item2 = {"id": "item_2", "type": "function_call", "name": "get_weather"} + + output_items[item1["id"]] = item1 + output_items[item2["id"]] = item2 + + assert len(output_items) == 2 + assert output_items["item_1"]["type"] == "message" + assert output_items["item_2"]["type"] == "function_call" + + def test_150ms_batch_interval_constant(self): + """Test that the batch interval is 150ms""" + UPDATE_INTERVAL = 0.150 # 150ms + + assert UPDATE_INTERVAL == 0.150 + assert UPDATE_INTERVAL * 1000 == 150 # 150 milliseconds + + +class TestBackgroundStreamingModule: + """Test cases for background_streaming module imports and structure""" + + def test_background_streaming_task_can_be_imported(self): + """Test that background_streaming_task can be imported from the module""" + from litellm.proxy.response_polling.background_streaming import ( + background_streaming_task, + ) + + assert background_streaming_task is not None + assert callable(background_streaming_task) + + def test_module_exports_from_init(self): + """Test that the module exports are available from __init__""" + from litellm.proxy.response_polling import ( + ResponsePollingHandler, + background_streaming_task, + ) + + assert ResponsePollingHandler is not None + assert background_streaming_task is not None + + def test_background_streaming_task_is_async(self): + """Test that background_streaming_task is an async function""" + import asyncio + from litellm.proxy.response_polling.background_streaming import ( + background_streaming_task, + ) + + assert asyncio.iscoroutinefunction(background_streaming_task) + + +class TestProviderResolutionForPolling: + """ + Test cases for provider resolution logic used to determine + if polling_via_cache should be enabled for a given model. + + This tests the logic in endpoints.py that resolves model names + to their providers using the router's deployment configuration. + """ + + def test_provider_from_model_string_with_slash(self): + """Test extracting provider from 'provider/model' format""" + model = "openai/gpt-4o" + + # Direct extraction when model has slash + if "/" in model: + provider = model.split("/")[0] + else: + provider = None + + assert provider == "openai" + + def test_provider_from_model_string_without_slash(self): + """Test that model without slash doesn't extract provider directly""" + model = "gpt-5" + + # No slash means we can't extract provider directly + if "/" in model: + provider = model.split("/")[0] + else: + provider = None + + assert provider is None + + def test_provider_resolution_from_router_single_deployment(self): + """Test resolving provider from router with single deployment""" + # Simulate router's model_name_to_deployment_indices + model_name_to_deployment_indices = { + "gpt-5": [0], # Single deployment at index 0 + } + model_list = [ + { + "model_name": "gpt-5", + "litellm_params": { + "model": "openai/gpt-5", + "api_key": "sk-test", + } + } + ] + + model = "gpt-5" + polling_via_cache_enabled = ["openai"] + should_use_polling = False + + # Simulate the resolution logic + indices = model_name_to_deployment_indices.get(model, []) + for idx in indices: + deployment_dict = model_list[idx] + litellm_params = deployment_dict.get("litellm_params", {}) + + dep_provider = litellm_params.get("custom_llm_provider") + if not dep_provider: + dep_model = litellm_params.get("model", "") + if "/" in dep_model: + dep_provider = dep_model.split("/")[0] + + if dep_provider and dep_provider in polling_via_cache_enabled: + should_use_polling = True + break + + assert should_use_polling is True + + def test_provider_resolution_from_router_multiple_deployments_match(self): + """Test resolving provider when multiple deployments exist and one matches""" + model_name_to_deployment_indices = { + "gpt-4o": [0, 1], # Two deployments + } + model_list = [ + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "openai/gpt-4o", + } + }, + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "azure/gpt-4o-deployment", + } + } + ] + + model = "gpt-4o" + polling_via_cache_enabled = ["openai"] # Only openai in list + should_use_polling = False + + indices = model_name_to_deployment_indices.get(model, []) + for idx in indices: + deployment_dict = model_list[idx] + litellm_params = deployment_dict.get("litellm_params", {}) + + dep_provider = litellm_params.get("custom_llm_provider") + if not dep_provider: + dep_model = litellm_params.get("model", "") + if "/" in dep_model: + dep_provider = dep_model.split("/")[0] + + if dep_provider and dep_provider in polling_via_cache_enabled: + should_use_polling = True + break + + # Should be True because first deployment is openai + assert should_use_polling is True + + def test_provider_resolution_from_router_no_match(self): + """Test that polling is disabled when no deployment provider matches""" + model_name_to_deployment_indices = { + "claude-3": [0], + } + model_list = [ + { + "model_name": "claude-3", + "litellm_params": { + "model": "anthropic/claude-3-sonnet", + } + } + ] + + model = "claude-3" + polling_via_cache_enabled = ["openai", "bedrock"] # anthropic not in list + should_use_polling = False + + indices = model_name_to_deployment_indices.get(model, []) + for idx in indices: + deployment_dict = model_list[idx] + litellm_params = deployment_dict.get("litellm_params", {}) + + dep_provider = litellm_params.get("custom_llm_provider") + if not dep_provider: + dep_model = litellm_params.get("model", "") + if "/" in dep_model: + dep_provider = dep_model.split("/")[0] + + if dep_provider and dep_provider in polling_via_cache_enabled: + should_use_polling = True + break + + assert should_use_polling is False + + def test_provider_resolution_with_custom_llm_provider(self): + """Test that custom_llm_provider takes precedence over model string""" + model_name_to_deployment_indices = { + "my-model": [0], + } + model_list = [ + { + "model_name": "my-model", + "litellm_params": { + "model": "some-custom-model", + "custom_llm_provider": "openai", # Explicit provider + } + } + ] + + model = "my-model" + polling_via_cache_enabled = ["openai"] + should_use_polling = False + + indices = model_name_to_deployment_indices.get(model, []) + for idx in indices: + deployment_dict = model_list[idx] + litellm_params = deployment_dict.get("litellm_params", {}) + + # custom_llm_provider should be checked first + dep_provider = litellm_params.get("custom_llm_provider") + if not dep_provider: + dep_model = litellm_params.get("model", "") + if "/" in dep_model: + dep_provider = dep_model.split("/")[0] + + if dep_provider and dep_provider in polling_via_cache_enabled: + should_use_polling = True + break + + assert should_use_polling is True + + def test_provider_resolution_model_not_in_router(self): + """Test that unknown model doesn't enable polling""" + model_name_to_deployment_indices = { + "gpt-5": [0], + } + model_list = [ + { + "model_name": "gpt-5", + "litellm_params": {"model": "openai/gpt-5"} + } + ] + + model = "unknown-model" # Not in router + polling_via_cache_enabled = ["openai"] + should_use_polling = False + + indices = model_name_to_deployment_indices.get(model, []) # Empty list + for idx in indices: + # This loop won't execute + pass + + assert should_use_polling is False + assert len(indices) == 0 + + +class TestPollingConditionChecks: + """ + Test cases for the conditions that determine whether polling should be enabled. + Tests the should_use_polling_for_request function. + """ + + def test_polling_enabled_when_all_conditions_met(self): + """Test polling is enabled when background=true, polling_via_cache="all", and redis is available""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled="all", + redis_cache=Mock(), + model="gpt-4o", + llm_router=None, + ) + + assert result is True + + def test_polling_disabled_when_background_false(self): + """Test polling is disabled when background=false""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=False, + polling_via_cache_enabled="all", + redis_cache=Mock(), + model="gpt-4o", + llm_router=None, + ) + + assert result is False + + def test_polling_disabled_when_config_false(self): + """Test polling is disabled when polling_via_cache is False""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled=False, + redis_cache=Mock(), + model="gpt-4o", + llm_router=None, + ) + + assert result is False + + def test_polling_disabled_when_redis_not_configured(self): + """Test polling is disabled when Redis is not configured""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled="all", + redis_cache=None, + model="gpt-4o", + llm_router=None, + ) + + assert result is False + + def test_polling_enabled_with_provider_list_match(self): + """Test polling is enabled when provider list matches""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled=["openai", "anthropic"], + redis_cache=Mock(), + model="openai/gpt-4o", + llm_router=None, + ) + + assert result is True + + def test_polling_disabled_with_provider_list_no_match(self): + """Test polling is disabled when provider not in list""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled=["openai"], + redis_cache=Mock(), + model="anthropic/claude-3", + llm_router=None, + ) + + assert result is False + + def test_polling_with_router_lookup(self): + """Test polling uses router to resolve model name to provider""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + # Create mock router + mock_router = Mock() + mock_router.model_name_to_deployment_indices = {"gpt-5": [0]} + mock_router.model_list = [ + { + "model_name": "gpt-5", + "litellm_params": {"model": "openai/gpt-5"} + } + ] + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled=["openai"], + redis_cache=Mock(), + model="gpt-5", # No slash, needs router lookup + llm_router=mock_router, + ) + + assert result is True + + def test_polling_with_router_lookup_no_match(self): + """Test polling returns False when router lookup finds non-matching provider""" + from litellm.proxy.response_polling.polling_handler import should_use_polling_for_request + + mock_router = Mock() + mock_router.model_name_to_deployment_indices = {"claude-3": [0]} + mock_router.model_list = [ + { + "model_name": "claude-3", + "litellm_params": {"model": "anthropic/claude-3-sonnet"} + } + ] + + result = should_use_polling_for_request( + background_mode=True, + polling_via_cache_enabled=["openai"], + redis_cache=Mock(), + model="claude-3", + llm_router=mock_router, + ) + + assert result is False + + +class TestStreamingEventParsing: + """ + Test cases for parsing OpenAI streaming events in the background task. + Tests the event handling logic in background_streaming.py. + """ + + def test_parse_response_output_item_added_event(self): + """Test parsing response.output_item.added event""" + event = { + "type": "response.output_item.added", + "item": { + "id": "item_123", + "type": "message", + "role": "assistant", + "content": [] + } + } + + output_items = {} + event_type = event.get("type", "") + + if event_type == "response.output_item.added": + item = event.get("item", {}) + item_id = item.get("id") + if item_id: + output_items[item_id] = item + + assert "item_123" in output_items + assert output_items["item_123"]["type"] == "message" + + def test_parse_response_output_text_delta_event(self): + """Test parsing response.output_text.delta event and accumulating text""" + output_items = { + "item_123": { + "id": "item_123", + "type": "message", + "content": [{"type": "text", "text": ""}] + } + } + accumulated_text = {} + + # Simulate receiving multiple delta events + delta_events = [ + {"type": "response.output_text.delta", "item_id": "item_123", "content_index": 0, "delta": "Hello "}, + {"type": "response.output_text.delta", "item_id": "item_123", "content_index": 0, "delta": "World!"}, + ] + + for event in delta_events: + event_type = event.get("type", "") + if event_type == "response.output_text.delta": + item_id = event.get("item_id") + content_index = event.get("content_index", 0) + delta = event.get("delta", "") + + if item_id and item_id in output_items: + key = (item_id, content_index) + if key not in accumulated_text: + accumulated_text[key] = "" + accumulated_text[key] += delta + + # Update content + if "content" in output_items[item_id]: + content_list = output_items[item_id]["content"] + if content_index < len(content_list): + if isinstance(content_list[content_index], dict): + content_list[content_index]["text"] = accumulated_text[key] + + assert accumulated_text[("item_123", 0)] == "Hello World!" + assert output_items["item_123"]["content"][0]["text"] == "Hello World!" + + def test_parse_response_completed_event(self): + """Test parsing response.completed event extracts all fields""" + event = { + "type": "response.completed", + "response": { + "id": "resp_123", + "status": "completed", + "usage": {"input_tokens": 10, "output_tokens": 50}, + "reasoning": {"effort": "medium"}, + "tool_choice": {"type": "auto"}, + "tools": [{"type": "function", "function": {"name": "test"}}], + "model": "gpt-4o", + "output": [{"id": "item_1", "type": "message"}] + } + } + + event_type = event.get("type", "") + usage_data = None + reasoning_data = None + tool_choice_data = None + tools_data = None + model_data = None + + if event_type == "response.completed": + response_data = event.get("response", {}) + usage_data = response_data.get("usage") + reasoning_data = response_data.get("reasoning") + tool_choice_data = response_data.get("tool_choice") + tools_data = response_data.get("tools") + model_data = response_data.get("model") + + assert usage_data == {"input_tokens": 10, "output_tokens": 50} + assert reasoning_data == {"effort": "medium"} + assert tool_choice_data == {"type": "auto"} + assert tools_data == [{"type": "function", "function": {"name": "test"}}] + assert model_data == "gpt-4o" + + def test_parse_done_marker(self): + """Test that [DONE] marker is detected correctly""" + chunks = [ + "data: {\"type\": \"response.in_progress\"}", + "data: {\"type\": \"response.completed\"}", + "data: [DONE]", + ] + + done_received = False + for chunk in chunks: + if chunk.startswith("data: "): + chunk_data = chunk[6:].strip() + if chunk_data == "[DONE]": + done_received = True + break + + assert done_received is True + + def test_parse_sse_format(self): + """Test parsing Server-Sent Events format""" + raw_chunk = b"data: {\"type\": \"response.output_item.added\", \"item\": {\"id\": \"123\"}}" + + # Decode bytes to string + if isinstance(raw_chunk, bytes): + chunk = raw_chunk.decode('utf-8') + else: + chunk = raw_chunk + + # Extract JSON from SSE format + if isinstance(chunk, str) and chunk.startswith("data: "): + chunk_data = chunk[6:].strip() + + import json + event = json.loads(chunk_data) + + assert event["type"] == "response.output_item.added" + assert event["item"]["id"] == "123" + + def test_content_part_added_event(self): + """Test parsing response.content_part.added event""" + output_items = { + "item_123": { + "id": "item_123", + "type": "message", + } + } + + event = { + "type": "response.content_part.added", + "item_id": "item_123", + "part": {"type": "text", "text": ""} + } + + event_type = event.get("type", "") + if event_type == "response.content_part.added": + item_id = event.get("item_id") + content_part = event.get("part", {}) + + if item_id and item_id in output_items: + if "content" not in output_items[item_id]: + output_items[item_id]["content"] = [] + output_items[item_id]["content"].append(content_part) + + assert "content" in output_items["item_123"] + assert len(output_items["item_123"]["content"]) == 1 + assert output_items["item_123"]["content"][0]["type"] == "text" + + +class TestEdgeCases: + """Test edge cases and error scenarios""" + + def test_empty_model_string(self): + """Test handling of empty model string""" + model = "" + polling_via_cache_enabled = ["openai"] + + should_use_polling = False + if "/" in model: + provider = model.split("/")[0] + if provider in polling_via_cache_enabled: + should_use_polling = True + + assert should_use_polling is False + + def test_model_with_multiple_slashes(self): + """Test handling model with multiple slashes (e.g., bedrock ARN)""" + model = "bedrock/arn:aws:bedrock:us-east-1:123456:model/my-model" + polling_via_cache_enabled = ["bedrock"] + + # Only split on first slash + if "/" in model: + provider = model.split("/")[0] + else: + provider = None + + assert provider == "bedrock" + assert provider in polling_via_cache_enabled + + def test_polling_id_detection_edge_cases(self): + """Test polling ID detection with edge cases""" + # Empty string + assert ResponsePollingHandler.is_polling_id("") is False + + # Just prefix without UUID + assert ResponsePollingHandler.is_polling_id("litellm_poll_") is True + + # Similar but different prefix + assert ResponsePollingHandler.is_polling_id("litellm_polling_abc") is False + + # Case sensitivity + assert ResponsePollingHandler.is_polling_id("LITELLM_POLL_abc") is False + + @pytest.mark.asyncio + async def test_create_initial_state_with_empty_metadata(self): + """Test create_initial_state handles missing metadata gracefully""" + mock_redis = AsyncMock() + handler = ResponsePollingHandler(redis_cache=mock_redis) + + response = await handler.create_initial_state( + polling_id="litellm_poll_test", + request_data={"model": "gpt-4o"}, # No metadata field + ) + + assert response.metadata == {} + + @pytest.mark.asyncio + async def test_update_state_with_none_output_clears_output(self): + """Test that output=[] explicitly sets empty output""" + mock_redis = AsyncMock() + mock_redis.async_get_cache.return_value = json.dumps({ + "id": "litellm_poll_test", + "object": "response", + "status": "in_progress", + "output": [{"id": "item_1"}], # Has existing output + "created_at": 1234567890 + }) + + handler = ResponsePollingHandler(redis_cache=mock_redis) + + await handler.update_state( + polling_id="litellm_poll_test", + output=[], # Explicitly set empty + ) + + call_args = mock_redis.async_set_cache.call_args + stored = json.loads(call_args.kwargs["value"]) + + assert stored["output"] == []