From 9010f276420d953245fbb829e4717583b181beee Mon Sep 17 00:00:00 2001 From: Harshit Jain Date: Sun, 25 Jan 2026 07:59:24 +0530 Subject: [PATCH] fix: optimize logo fetching and resolve mcp import blockers --- litellm/experimental_mcp_client/client.py | 6 +- .../mcp_server/mcp_server_manager.py | 31 +++-- .../proxy/_experimental/mcp_server/server.py | 6 +- .../mcp_management_endpoints.py | 18 ++- litellm/proxy/proxy_server.py | 111 ++++++++++-------- tests/proxy_unit_tests/test_get_image.py | 89 ++++++++++++++ 6 files changed, 202 insertions(+), 59 deletions(-) create mode 100644 tests/proxy_unit_tests/test_get_image.py diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index 5ad2dd5485..6e377f02ab 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -10,7 +10,11 @@ from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client -from mcp.client.streamable_http import streamable_http_client + +try: + from mcp.client.streamable_http import streamable_http_client # type: ignore +except ImportError: + streamable_http_client = None from mcp.types import CallToolRequestParams as MCPCallToolRequestParams from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import ( diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index e0217cd9e0..7fceec005e 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -63,7 +63,20 @@ MCPOAuthMetadata, MCPServer, ) -from mcp.shared.tool_name_validation import SEP_986_URL, validate_tool_name + +try: + from mcp.shared.tool_name_validation import SEP_986_URL, validate_tool_name # type: ignore +except ImportError: + SEP_986_URL = "https://github.com/modelcontextprotocol/protocol/blob/main/proposals/0001-tool-name-validation.md" + + def validate_tool_name(name: str): + from pydantic import BaseModel + + class MockResult(BaseModel): + is_valid: bool = True + warnings: list = [] + + return MockResult() # Probe includes characters on both sides of the separator to mimic real prefixed tool names. @@ -90,7 +103,9 @@ def _warn(field_name: str, value: Optional[str]) -> None: if result.is_valid: return - warning_text = "; ".join(result.warnings) if result.warnings else "Validation failed" + warning_text = ( + "; ".join(result.warnings) if result.warnings else "Validation failed" + ) verbose_logger.warning( "MCP server '%s' has invalid %s '%s': %s", server_id, @@ -103,7 +118,6 @@ def _warn(field_name: str, value: Optional[str]) -> None: _warn("server_name", server_name) - def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]: """ Deserialize optional JSON mappings stored in the database. @@ -391,10 +405,13 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s # Note: `extra_headers` on MCPServer is a List[str] of header names to forward # from the client request (not available in this OpenAPI tool generation step). # `static_headers` is a dict of concrete headers to always send. - headers = merge_mcp_headers( - extra_headers=headers, - static_headers=server.static_headers, - ) or {} + headers = ( + merge_mcp_headers( + extra_headers=headers, + static_headers=server.static_headers, + ) + or {} + ) verbose_logger.debug( f"Using headers for OpenAPI tools (excluding sensitive values): " diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 03652ae155..bd7870a0fa 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -74,7 +74,11 @@ AuthContextMiddleware, auth_context_var, ) - from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + + try: + from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + except ImportError: + StreamableHTTPSessionManager = None # type: ignore from mcp.types import ( CallToolResult, EmbeddedResource, diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index a15f47d13b..83d7f3fde4 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -56,7 +56,19 @@ MCP_AVAILABLE = False if MCP_AVAILABLE: - from mcp.shared.tool_name_validation import validate_tool_name + try: + from mcp.shared.tool_name_validation import validate_tool_name # type: ignore + except ImportError: + + def validate_tool_name(name: str): + from pydantic import BaseModel + + class MockResult(BaseModel): + is_valid: bool = True + warnings: list = [] + + return MockResult() + from litellm.proxy._experimental.mcp_server.db import ( create_mcp_server, delete_mcp_server, @@ -122,9 +134,7 @@ def _validate_mcp_server_name_fields(payload: Any) -> None: ) if validation_result.warnings: error_messages_text = ( - error_messages_text - + "\n" - + "\n".join(validation_result.warnings) + error_messages_text + "\n" + "\n".join(validation_result.warnings) ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f74a89054c..18dcff31a6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -550,9 +550,9 @@ def generate_feedback_box(): server_root_path = get_server_root_path() _license_check = LicenseCheck() premium_user: bool = _license_check.is_premium() -premium_user_data: Optional["EnterpriseLicenseData"] = ( - _license_check.airgapped_license_data -) +premium_user_data: Optional[ + "EnterpriseLicenseData" +] = _license_check.airgapped_license_data global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) @@ -1209,9 +1209,9 @@ async def root_redirect(): config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional["ClientSession"] = ( - None # Global shared session for connection reuse -) +shared_aiohttp_session: Optional[ + "ClientSession" +] = None # Global shared session for connection reuse user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1219,9 +1219,9 @@ async def root_redirect(): dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[RedisCache] = ( - None # redis cache used for tracking spend, tpm/rpm limits -) +redis_usage_cache: Optional[ + RedisCache +] = None # redis cache used for tracking spend, tpm/rpm limits polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache user_custom_auth = None @@ -1560,9 +1560,9 @@ async def _update_team_cache(): _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[LiteLLM_TeamTable] = ( - await user_api_key_cache.async_get_cache(key=_id) - ) + existing_spend_obj: Optional[ + LiteLLM_TeamTable + ] = await user_api_key_cache.async_get_cache(key=_id) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2856,6 +2856,7 @@ async def _init_policy_engine( from litellm.proxy.policy_engine.init_policies import init_policies from litellm.proxy.policy_engine.policy_validator import PolicyValidator + if config is None: verbose_proxy_logger.debug("Policy engine: config is None, skipping") return @@ -2867,7 +2868,9 @@ async def _init_policy_engine( policy_attachments_config = config.get("policy_attachments", None) - verbose_proxy_logger.info(f"Policy engine: found {len(policies_config)} policies in config") + verbose_proxy_logger.info( + f"Policy engine: found {len(policies_config)} policies in config" + ) # Initialize policies await init_policies( @@ -4009,10 +4012,10 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient): ) try: - guardrails_in_db: List[Guardrail] = ( - await GuardrailRegistry.get_all_guardrails_from_db( - prisma_client=prisma_client - ) + guardrails_in_db: List[ + Guardrail + ] = await GuardrailRegistry.get_all_guardrails_from_db( + prisma_client=prisma_client ) verbose_proxy_logger.debug( "guardrails from the DB %s", str(guardrails_in_db) @@ -4046,7 +4049,9 @@ async def _init_policies_in_db(self, prisma_client: PrismaClient): await policy_registry.sync_policies_from_db(prisma_client=prisma_client) # Sync attachments from DB to in-memory registry - await attachment_registry.sync_attachments_from_db(prisma_client=prisma_client) + await attachment_registry.sync_attachments_from_db( + prisma_client=prisma_client + ) verbose_proxy_logger.debug( "Successfully synced policies and attachments from DB" @@ -4369,9 +4374,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = ( - api_version # set this for azure - litellm can read this from the env - ) + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -5217,7 +5222,9 @@ async def model_list( # Include model access groups if requested if include_model_access_groups: - proxy_model_list = list(set(proxy_model_list + list(model_access_groups.keys()))) + proxy_model_list = list( + set(proxy_model_list + list(model_access_groups.keys())) + ) # Get complete model list including wildcard routes if requested from litellm.proxy.auth.model_checks import get_complete_model_list @@ -7674,12 +7681,12 @@ def _enrich_model_info_with_litellm_data( """ Enrich a model dictionary with litellm model info (pricing, context window, etc.) and remove sensitive information. - + Args: model: Model dictionary to enrich debug: Whether to include debug information like openai_client llm_router: Optional router instance for debug info - + Returns: Enriched model dictionary with sensitive info removed """ @@ -7689,9 +7696,7 @@ def _enrich_model_info_with_litellm_data( _openai_client = "None" if llm_router is not None: _openai_client = ( - llm_router._get_client( - deployment=model, kwargs={}, client_type="async" - ) + llm_router._get_client(deployment=model, kwargs={}, client_type="async") or "None" ) else: @@ -8140,7 +8145,9 @@ async def model_info_v2( # This must happen before teamId filtering so that direct_access and access_via_team_ids are populated for i, _model in enumerate(all_models): all_models[i] = _enrich_model_info_with_litellm_data( - model=_model, debug=debug if debug is not None else False, llm_router=llm_router + model=_model, + debug=debug if debug is not None else False, + llm_router=llm_router, ) # Apply teamId filter if provided @@ -9623,7 +9630,7 @@ def get_logo_url(): @app.get("/get_image", include_in_schema=False) -def get_image(): +async def get_image(): """Get logo to show on admin UI""" # get current_dir @@ -9642,25 +9649,37 @@ def get_image(): if is_non_root and not os.path.exists(default_logo): default_logo = default_site_logo + cache_dir = assets_dir if is_non_root else current_dir + cache_path = os.path.join(cache_dir, "cached_logo.jpg") + + # [OPTIMIZATION] Check if the cached image exists first + if os.path.exists(cache_path): + return FileResponse(cache_path, media_type="image/jpeg") + logo_path = os.getenv("UI_LOGO_PATH", default_logo) verbose_proxy_logger.debug("Reading logo from path: %s", logo_path) # Check if the logo path is an HTTP/HTTPS URL if logo_path.startswith(("http://", "https://")): - # Download the image and cache it - client = HTTPHandler() - response = client.get(logo_path) - if response.status_code == 200: - # Save the image to a local file - cache_dir = assets_dir if is_non_root else current_dir - cache_path = os.path.join(cache_dir, "cached_logo.jpg") - with open(cache_path, "wb") as f: - f.write(response.content) - - # Return the cached image as a FileResponse - return FileResponse(cache_path, media_type="image/jpeg") - else: - # Handle the case when the image cannot be downloaded + try: + # Download the image and cache it + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + async_client = AsyncHTTPHandler(timeout=5.0) + response = await async_client.get(logo_path) + if response.status_code == 200: + # Save the image to a local file + with open(cache_path, "wb") as f: + f.write(response.content) + + # Return the cached image as a FileResponse + return FileResponse(cache_path, media_type="image/jpeg") + else: + # Handle the case when the image cannot be downloaded + return FileResponse(default_logo, media_type="image/jpeg") + except Exception as e: + # Handle any exceptions during the download (e.g., timeout, connection error) + verbose_proxy_logger.debug(f"Error downloading logo from {logo_path}: {e}") return FileResponse(default_logo, media_type="image/jpeg") else: # Return the local image file if the logo path is not an HTTP/HTTPS URL @@ -10278,9 +10297,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[idx].field_description = ( - sub_field_info.description - ) + nested_fields[ + idx + ].field_description = sub_field_info.description idx += 1 _stored_in_db = None diff --git a/tests/proxy_unit_tests/test_get_image.py b/tests/proxy_unit_tests/test_get_image.py new file mode 100644 index 0000000000..ad8c267275 --- /dev/null +++ b/tests/proxy_unit_tests/test_get_image.py @@ -0,0 +1,89 @@ +import os +import sys +from unittest import mock + +# Standard path insertion +sys.path.insert(0, os.path.abspath("../..")) + +import pytest +import httpx +from litellm.proxy.proxy_server import app + + +@pytest.mark.asyncio +async def test_get_image_error_handling(): + """ + Test that get_image handles network errors gracefully and doesn't hang. + """ + # Set an unreachable URL + os.environ["UI_LOGO_PATH"] = "http://invalid-url-12345.com/logo.jpg" + + # Clear cache + parent_dir = os.path.dirname( + os.path.dirname( + app.__file__ + if hasattr(app, "__file__") + else "litellm/proxy/proxy_server.py" + ) + ) + cache_path = os.path.join(parent_dir, "proxy", "cached_logo.jpg") + if os.path.exists(cache_path): + os.remove(cache_path) + + # Mock AsyncHTTPHandler to simulate a timeout or connection error + with mock.patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.get" + ) as mock_get: + mock_get.side_effect = httpx.ConnectError("Network is unreachable") + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver" + ) as ac: + response = await ac.get("/get_image") + + assert response.status_code == 200 + assert response.headers["content-type"] == "image/jpeg" + + +@pytest.mark.asyncio +async def test_get_image_cache_logic(): + """ + Test that once cached, get_image doesn't hit the network. + """ + os.environ["UI_LOGO_PATH"] = "http://example.com/logo.jpg" + + # Clear cache + parent_dir = os.path.dirname( + os.path.dirname( + app.__file__ + if hasattr(app, "__file__") + else "litellm/proxy/proxy_server.py" + ) + ) + cache_path = os.path.join(parent_dir, "proxy", "cached_logo.jpg") + if os.path.exists(cache_path): + os.remove(cache_path) + + # Mock response + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with mock.patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.get" + ) as mock_get: + mock_get.return_value = mock_response + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver" + ) as ac: + # First call - should hit download logic + response1 = await ac.get("/get_image") + assert response1.status_code == 200 + assert mock_get.call_count == 1 + + # Second call - should hit cache + response2 = await ac.get("/get_image") + assert response2.status_code == 200 + # If cache works, mock_get shouldn't be called again + assert mock_get.call_count == 1