diff --git a/CLAUDE.md b/CLAUDE.md index 23a0e97eaee..3cb67908076 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -90,6 +90,7 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components: - Pydantic v2 for data validation - Async/await patterns throughout - Type hints required for all public APIs +- **Avoid imports within methods** — place all imports at the top of the file (module-level). Inline imports inside functions/methods make dependencies harder to trace and hurt readability. The only exception is avoiding circular imports where absolutely necessary. ### Testing Strategy - Unit tests in `tests/test_litellm/` diff --git a/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py b/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py index 081d83dd1c8..75b75d3ba44 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py +++ b/litellm/proxy/_experimental/mcp_server/auth/litellm_auth_handler.py @@ -27,6 +27,7 @@ def __init__( oauth2_headers: Optional[Dict[str, str]] = None, mcp_protocol_version: Optional[str] = None, raw_headers: Optional[Dict[str, str]] = None, + client_ip: Optional[str] = None, ): self.user_api_key_auth = user_api_key_auth self.mcp_auth_header = mcp_auth_header @@ -35,3 +36,4 @@ def __init__( self.mcp_protocol_version = mcp_protocol_version self.oauth2_headers = oauth2_headers self.raw_headers = raw_headers + self.client_ip = client_ip diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 56feff548ad..8b052dd0da1 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -9,13 +9,14 @@ get_async_httpx_client, httpxSpecialProvider, ) +from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.common_utils.encrypt_decrypt_utils import ( decrypt_value_helper, encrypt_value_helper, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body -from litellm.types.mcp_server.mcp_server_manager import MCPServer from litellm.proxy.utils import get_server_root_path +from litellm.types.mcp_server.mcp_server_manager import MCPServer router = APIRouter( tags=["mcp"], @@ -300,7 +301,10 @@ async def authorize( ) lookup_name = mcp_server_name or client_id - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + lookup_name, client_ip=client_ip + ) if mcp_server is None: raise HTTPException(status_code=404, detail="MCP server not found") return await authorize_with_server( @@ -342,7 +346,10 @@ async def token_endpoint( ) lookup_name = mcp_server_name or client_id - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + lookup_name, client_ip=client_ip + ) if mcp_server is None: raise HTTPException(status_code=404, detail="MCP server not found") return await exchange_token_with_server( @@ -425,7 +432,10 @@ def _build_oauth_protected_resource_response( request_base_url = get_request_base_url(request) mcp_server: Optional[MCPServer] = None if mcp_server_name: - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + mcp_server_name, client_ip=client_ip + ) # Build resource URL based on the pattern if mcp_server_name: @@ -538,7 +548,10 @@ def _build_oauth_authorization_server_response( mcp_server: Optional[MCPServer] = None if mcp_server_name: - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + mcp_server_name, client_ip=client_ip + ) return { "issuer": request_base_url, # point to your proxy @@ -629,7 +642,10 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non if not mcp_server_name: return dummy_return - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + mcp_server_name, client_ip=client_ip + ) if mcp_server is None: return dummy_return return await register_client_with_server( diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index f5b0f152fd8..901a9b3c076 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -53,6 +53,7 @@ MCPTransportType, UserAPIKeyAuth, ) +from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper from litellm.proxy.utils import ProxyLogging from litellm.types.llms.custom_http import httpxSpecialProvider @@ -66,7 +67,10 @@ try: from mcp.shared.tool_name_validation import ( - validate_tool_name, # type: ignore[reportAssignmentType] + validate_tool_name, # pyright: ignore[reportAssignmentType] + ) + from mcp.shared.tool_name_validation import ( + SEP_986_URL, ) from mcp.shared.tool_name_validation import SEP_986_URL except ImportError: @@ -74,12 +78,12 @@ SEP_986_URL = "https://github.com/modelcontextprotocol/protocol/blob/main/proposals/0001-tool-name-validation.md" - class ToolNameValidationResult(BaseModel): + class _ToolNameValidationResult(BaseModel): is_valid: bool = True warnings: list = [] - def validate_tool_name(name: str) -> ToolNameValidationResult: # type: ignore[misc] - return ToolNameValidationResult() + def validate_tool_name(name: str) -> _ToolNameValidationResult: + return _ToolNameValidationResult() # Probe includes characters on both sides of the separator to mimic real prefixed tool names. @@ -325,6 +329,9 @@ async def load_servers_from_config( access_groups=server_config.get("access_groups", None), static_headers=server_config.get("static_headers", None), allow_all_keys=bool(server_config.get("allow_all_keys", False)), + available_on_public_internet=bool( + server_config.get("available_on_public_internet", False) + ), ) self.config_mcp_servers[server_id] = new_server @@ -623,6 +630,9 @@ async def build_mcp_server_from_table( allowed_tools=getattr(mcp_server, "allowed_tools", None), disallowed_tools=getattr(mcp_server, "disallowed_tools", None), allow_all_keys=mcp_server.allow_all_keys, + available_on_public_internet=bool( + getattr(mcp_server, "available_on_public_internet", False) + ), updated_at=getattr(mcp_server, "updated_at", None), ) return new_server @@ -697,6 +707,23 @@ async def get_allowed_mcp_servers( verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}.") return allow_all_server_ids + def filter_server_ids_by_ip( + self, server_ids: List[str], client_ip: Optional[str] + ) -> List[str]: + """ + Filter server IDs by client IP — external callers only see public servers. + + Returns server_ids unchanged when client_ip is None (no filtering). + """ + if client_ip is None: + return server_ids + return [ + sid + for sid in server_ids + if (s := self.get_mcp_server_by_id(sid)) is not None + and self._is_server_accessible_from_ip(s, client_ip) + ] + async def get_tools_for_server(self, server_id: str) -> List[MCPTool]: """ Get the tools for a given server @@ -2202,6 +2229,42 @@ def get_mcp_servers_from_ids(self, server_ids: List[str]) -> List[MCPServer]: servers.append(server) return servers + def _get_general_settings(self) -> Dict[str, Any]: + """Get general_settings, importing lazily to avoid circular imports.""" + try: + from litellm.proxy.proxy_server import ( + general_settings as proxy_general_settings, + ) + return proxy_general_settings + except ImportError: + # Fallback if proxy_server not available + return {} + + def _is_server_accessible_from_ip( + self, server: MCPServer, client_ip: Optional[str] + ) -> bool: + """ + Check if a server is accessible from the given client IP. + + - If client_ip is None, no IP filtering is applied (internal callers). + - If the server has available_on_public_internet=True, it's always accessible. + - Otherwise, only internal/private IPs can access it. + """ + if client_ip is None: + return True + if server.available_on_public_internet: + return True + # Check backwards compat: litellm.public_mcp_servers + public_ids = set(litellm.public_mcp_servers or []) + if server.server_id in public_ids: + return True + # Non-public server: only accessible from internal IPs + general_settings = self._get_general_settings() + internal_networks = IPAddressUtils.parse_internal_networks( + general_settings.get("mcp_internal_ip_ranges") + ) + return IPAddressUtils.is_internal_ip(client_ip, internal_networks) + def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]: """ Get the MCP Server from the server id @@ -2214,27 +2277,76 @@ def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]: def get_public_mcp_servers(self) -> List[MCPServer]: """ - Get the public MCP servers + Get the public MCP servers (available_on_public_internet=True flag on server). + Also includes servers from litellm.public_mcp_servers for backwards compat. """ servers: List[MCPServer] = [] - if litellm.public_mcp_servers is None: - return servers - for server_id in litellm.public_mcp_servers: - server = self.get_mcp_server_by_id(server_id) - if server: + public_ids = set(litellm.public_mcp_servers or []) + for server in self.get_registry().values(): + if server.available_on_public_internet or server.server_id in public_ids: servers.append(server) return servers - def get_mcp_server_by_name(self, server_name: str) -> Optional[MCPServer]: + def get_mcp_server_by_name( + self, server_name: str, client_ip: Optional[str] = None + ) -> Optional[MCPServer]: """ - Get the MCP Server from the server name + Get the MCP Server from the server name. + + Uses priority-based matching to avoid collisions: + 1. First pass: exact alias match (highest priority) + 2. Second pass: exact server_name match + 3. Third pass: exact name match (lowest priority) + + Args: + server_name: The server name to look up. + client_ip: Optional client IP for access control. When provided, + non-public servers are hidden from external IPs. """ registry = self.get_registry() + + # Pass 1: Match by alias (highest priority) + for server in registry.values(): + if server.alias == server_name: + if not self._is_server_accessible_from_ip(server, client_ip): + return None + return server + + # Pass 2: Match by server_name for server in registry.values(): if server.server_name == server_name: + if not self._is_server_accessible_from_ip(server, client_ip): + return None + return server + + # Pass 3: Match by name (lowest priority) + for server in registry.values(): + if server.name == server_name: + if not self._is_server_accessible_from_ip(server, client_ip): + return None return server + return None + def get_filtered_registry( + self, client_ip: Optional[str] = None + ) -> Dict[str, MCPServer]: + """ + Get registry filtered by client IP access control. + + Args: + client_ip: Optional client IP. When provided, non-public servers + are hidden from external IPs. When None, returns all servers. + """ + registry = self.get_registry() + if client_ip is None: + return registry + return { + k: v + for k, v in registry.items() + if self._is_server_accessible_from_ip(v, client_ip) + } + def _generate_stable_server_id( self, server_name: str, diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index eb47fb4f608..2cd56e0ff3f 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -10,6 +10,7 @@ ) from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.mcp import MCPAuth from litellm.types.utils import CallTypes @@ -78,6 +79,26 @@ def _create_tool_response_objects(tools, server_mcp_info): for tool in tools ] + def _extract_mcp_headers_from_request( + request: Request, + mcp_request_handler_cls, + ) -> tuple: + """ + Extract MCP auth headers from HTTP request. + + Returns: + Tuple of (mcp_auth_header, mcp_server_auth_headers, raw_headers) + """ + headers = request.headers + raw_headers = dict(headers) + mcp_auth_header = mcp_request_handler_cls._get_mcp_auth_header_from_headers( + headers + ) + mcp_server_auth_headers = ( + mcp_request_handler_cls._get_mcp_server_auth_headers_from_headers(headers) + ) + return mcp_auth_header, mcp_server_auth_headers, raw_headers + async def _get_tools_for_single_server( server, server_auth_header, @@ -173,21 +194,25 @@ async def list_tool_rest_api( auth_contexts = await build_effective_auth_contexts(user_api_key_dict) + _rest_client_ip = IPAddressUtils.get_mcp_client_ip(request) + allowed_server_ids_set = set() for auth_context in auth_contexts: servers = await global_mcp_server_manager.get_allowed_mcp_servers( - user_api_key_auth=auth_context + user_api_key_auth=auth_context, ) allowed_server_ids_set.update(servers) - allowed_server_ids = list(allowed_server_ids_set) + allowed_server_ids = global_mcp_server_manager.filter_server_ids_by_ip( + list(allowed_server_ids_set), _rest_client_ip + ) list_tools_result = [] error_message = None # If server_id is specified, only query that specific server if server_id: - if server_id not in allowed_server_ids_set: + if server_id not in allowed_server_ids: raise HTTPException( status_code=403, detail={ @@ -195,7 +220,9 @@ async def list_tool_rest_api( "message": f"The key is not allowed to access server {server_id}", }, ) - server = global_mcp_server_manager.get_mcp_server_by_id(server_id) + server = global_mcp_server_manager.get_mcp_server_by_id( + server_id + ) if server is None: return { "tools": [], @@ -339,21 +366,10 @@ async def call_tool_rest_api( ) ) - # FIX: Extract MCP auth headers from request - # The UI sends bearer token in x-mcp-auth header and server-specific headers, - # but they weren't being extracted and passed to call_mcp_tool. - # This fix ensures auth headers are properly extracted from the HTTP request - # and passed through to the MCP server for authentication. - headers = request.headers - raw_headers_from_request = dict(headers) - mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers( - headers - ) - mcp_server_auth_headers = ( - MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers) + # Extract MCP auth headers from request and add to data dict + mcp_auth_header, mcp_server_auth_headers, raw_headers_from_request = ( + _extract_mcp_headers_from_request(request, MCPRequestHandler) ) - - # Add extracted headers to data dict to pass to call_mcp_tool if mcp_auth_header: data["mcp_auth_header"] = mcp_auth_header if mcp_server_auth_headers: @@ -365,10 +381,43 @@ async def call_tool_rest_api( if "metadata" in data and "user_api_key_auth" in data["metadata"]: data["user_api_key_auth"] = data["metadata"]["user_api_key_auth"] - allowed_mcp_servers = await _resolve_allowed_mcp_servers_for_tool_call( - user_api_key_dict, server_id + # Get all auth contexts + auth_contexts = await build_effective_auth_contexts(user_api_key_dict) + + # Collect allowed server IDs from all contexts, then apply IP filtering + _rest_client_ip = IPAddressUtils.get_mcp_client_ip(request) + allowed_server_ids_set = set() + for auth_context in auth_contexts: + servers = await global_mcp_server_manager.get_allowed_mcp_servers( + user_api_key_auth=auth_context, + ) + allowed_server_ids_set.update(servers) + + allowed_server_ids_set = set( + global_mcp_server_manager.filter_server_ids_by_ip( + list(allowed_server_ids_set), _rest_client_ip + ) ) + # Check if the specified server_id is allowed + if server_id not in allowed_server_ids_set: + raise HTTPException( + status_code=403, + detail={ + "error": "access_denied", + "message": f"The key is not allowed to access server {server_id}", + }, + ) + + # Build allowed_mcp_servers list (only include allowed servers) + allowed_mcp_servers: List[MCPServer] = [] + for allowed_server_id in allowed_server_ids_set: + server = global_mcp_server_manager.get_mcp_server_by_id( + allowed_server_id + ) + if server is not None: + allowed_mcp_servers.append(server) + # Call execute_mcp_tool directly (permission checks already done) result = await execute_mcp_tool( name=tool_name, diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 79cd88227a9..e357311680c 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -5,12 +5,24 @@ import asyncio import contextlib -from datetime import datetime import traceback import uuid -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union, cast, Callable +from datetime import datetime +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + from fastapi import FastAPI, HTTPException from pydantic import AnyUrl, ConfigDict +from starlette.requests import Request as StarletteRequest from starlette.types import Receive, Scope, Send from litellm._logging import verbose_logger @@ -25,6 +37,7 @@ LITELLM_MCP_SERVER_VERSION, ) from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall @@ -213,6 +226,7 @@ async def list_tools() -> List[MCPTool]: mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( f"MCP list_tools - User API Key Auth from context: {user_api_key_auth}" @@ -276,6 +290,7 @@ async def mcp_server_tool_call( mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( @@ -387,6 +402,7 @@ async def list_prompts() -> List[Prompt]: mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( f"MCP list_prompts - User API Key Auth from context: {user_api_key_auth}" @@ -440,6 +456,7 @@ async def get_prompt( mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( @@ -467,6 +484,7 @@ async def list_resources() -> List[Resource]: mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( f"MCP list_resources - User API Key Auth from context: {user_api_key_auth}" @@ -505,6 +523,7 @@ async def list_resource_templates() -> List[ResourceTemplate]: mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() verbose_logger.debug( f"MCP list_resource_templates - User API Key Auth from context: {user_api_key_auth}" @@ -544,6 +563,7 @@ async def read_resource(url: AnyUrl) -> list[ReadResourceContents]: mcp_server_auth_headers, oauth2_headers, raw_headers, + _client_ip, ) = get_auth_context() read_resource_result = await mcp_read_resource( @@ -702,13 +722,59 @@ def filter_tools_by_allowed_tools( return tools_to_return + def _get_client_ip_from_context() -> Optional[str]: + """ + Extract client_ip from auth context. + + Returns None if context not set (caller should handle this as "no IP filtering"). + """ + try: + auth_user = auth_context_var.get() + if auth_user and isinstance(auth_user, MCPAuthenticatedUser): + return auth_user.client_ip + except Exception: + pass + return None + async def _get_allowed_mcp_servers( user_api_key_auth: Optional[UserAPIKeyAuth], mcp_servers: Optional[List[str]], + client_ip: Optional[str] = None, ) -> List[MCPServer]: - """Return allowed MCP servers for a request after applying filters.""" + """Return allowed MCP servers for a request after applying filters. + + Args: + user_api_key_auth: The authenticated user's API key info. + mcp_servers: Optional list of server names to filter to. + client_ip: Client IP for IP-based access control. If None, falls back to + auth context. Pass explicitly from request handlers for safety. + + Note: If client_ip is None and auth context is not set, IP filtering is skipped. + This is intentional for internal callers but may indicate a bug if called + from a request handler without proper context setup. + """ + # Use explicit client_ip if provided, otherwise try auth context + if client_ip is None: + client_ip = _get_client_ip_from_context() + if client_ip is None: + verbose_logger.debug( + "MCP _get_allowed_mcp_servers called without client_ip and no auth context. " + "IP filtering will be skipped. This is expected for internal calls." + ) + + allowed_mcp_server_ids = ( + await global_mcp_server_manager.get_allowed_mcp_servers( + user_api_key_auth + ) + ) allowed_mcp_server_ids = ( - await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth) + global_mcp_server_manager.filter_server_ids_by_ip( + allowed_mcp_server_ids, client_ip + ) + ) + verbose_logger.debug( + "MCP IP filter: client_ip=%s, allowed_server_ids=%s", + client_ip, allowed_mcp_server_ids, ) allowed_mcp_servers: List[MCPServer] = [] for allowed_mcp_server_id in allowed_mcp_server_ids: @@ -1891,6 +1957,10 @@ async def handle_streamable_http_mcp( oauth2_headers, raw_headers, ) = await extract_mcp_auth_context(scope, path) + + # Extract client IP for MCP access control + _client_ip = IPAddressUtils.get_mcp_client_ip(StarletteRequest(scope)) + verbose_logger.debug( f"MCP request mcp_servers (header/path): {mcp_servers}" ) @@ -1899,11 +1969,11 @@ async def handle_streamable_http_mcp( ) # https://datatracker.ietf.org/doc/html/rfc9728#name-www-authenticate-response for server_name in mcp_servers or []: - server = global_mcp_server_manager.get_mcp_server_by_name(server_name) + server = global_mcp_server_manager.get_mcp_server_by_name( + server_name, client_ip=_client_ip + ) if server and server.auth_type == MCPAuth.oauth2 and not oauth2_headers: - from starlette.requests import Request - - request = Request(scope) + request = StarletteRequest(scope) base_url = str(request.base_url).rstrip("/") authorization_uri = ( @@ -1925,6 +1995,7 @@ async def handle_streamable_http_mcp( mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + client_ip=_client_ip, ) # Ensure session managers are initialized @@ -1936,12 +2007,13 @@ async def handle_streamable_http_mcp( _strip_stale_mcp_session_header(scope, session_manager) await session_manager.handle_request(scope, receive, send) + except HTTPException: + # Re-raise HTTP exceptions to preserve status codes and details + raise except Exception as e: - raise e verbose_logger.exception(f"Error handling MCP request: {e}") - # Instead of re-raising, try to send a graceful error response + # Try to send a graceful error response for non-HTTP exceptions try: - # Send a proper HTTP error response instead of letting the exception bubble up from starlette.responses import JSONResponse from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR @@ -1969,6 +2041,10 @@ async def handle_sse_mcp(scope: Scope, receive: Receive, send: Send) -> None: oauth2_headers, raw_headers, ) = await extract_mcp_auth_context(scope, path) + + # Extract client IP for MCP access control + _sse_client_ip = IPAddressUtils.get_mcp_client_ip(StarletteRequest(scope)) + verbose_logger.debug( f"MCP request mcp_servers (header/path): {mcp_servers}" ) @@ -1982,6 +2058,7 @@ async def handle_sse_mcp(scope: Scope, receive: Receive, send: Send) -> None: mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + client_ip=_sse_client_ip, ) if not _SESSION_MANAGERS_INITIALIZED: @@ -2045,6 +2122,7 @@ def set_auth_context( mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, + client_ip: Optional[str] = None, ) -> None: """ Set the UserAPIKeyAuth in the auth context variable. @@ -2054,6 +2132,7 @@ def set_auth_context( mcp_auth_header: MCP auth header to be passed to the MCP server (deprecated) mcp_servers: Optional list of server names and access groups to filter by mcp_server_auth_headers: Optional dict of server-specific auth headers {server_alias: auth_value} + client_ip: Client IP address for MCP access control """ auth_user = MCPAuthenticatedUser( user_api_key_auth=user_api_key_auth, @@ -2062,6 +2141,7 @@ def set_auth_context( mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + client_ip=client_ip, ) auth_context_var.set(auth_user) @@ -2073,14 +2153,15 @@ def get_auth_context() -> ( Optional[Dict[str, Dict[str, str]]], Optional[Dict[str, str]], Optional[Dict[str, str]], + Optional[str], ] ): """ Get the UserAPIKeyAuth from the auth context variable. Returns: - Tuple[Optional[UserAPIKeyAuth], Optional[str], Optional[List[str]], Optional[Dict[str, str]]]: - UserAPIKeyAuth object, MCP auth header (deprecated), MCP servers (can include access groups), and server-specific auth headers + Tuple containing: UserAPIKeyAuth, MCP auth header (deprecated), + MCP servers, server-specific auth headers, OAuth2 headers, raw headers, client IP """ auth_user = auth_context_var.get() if auth_user and isinstance(auth_user, MCPAuthenticatedUser): @@ -2091,8 +2172,9 @@ def get_auth_context() -> ( auth_user.mcp_server_auth_headers, auth_user.oauth2_headers, auth_user.raw_headers, + auth_user.client_ip, ) - return None, None, None, None, None, None + return None, None, None, None, None, None, None ######################################################## ############ End of Auth Context Functions ############# diff --git a/litellm/proxy/auth/ip_address_utils.py b/litellm/proxy/auth/ip_address_utils.py new file mode 100644 index 00000000000..c8e936598e5 --- /dev/null +++ b/litellm/proxy/auth/ip_address_utils.py @@ -0,0 +1,146 @@ +""" +IP address utilities for MCP public/private access control. + +Internal callers (private IPs) see all MCP servers. +External callers (public IPs) only see servers with available_on_public_internet=True. +""" + +import ipaddress +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.auth_utils import _get_request_ip_address + + +class IPAddressUtils: + """Static utilities for IP-based MCP access control.""" + + _DEFAULT_INTERNAL_NETWORKS = [ + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ] + + @staticmethod + def parse_internal_networks( + configured_ranges: Optional[List[str]], + ) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: + """Parse configured CIDR ranges into network objects, falling back to defaults.""" + if not configured_ranges: + return IPAddressUtils._DEFAULT_INTERNAL_NETWORKS + networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = [] + for cidr in configured_ranges: + try: + networks.append(ipaddress.ip_network(cidr, strict=False)) + except ValueError: + verbose_proxy_logger.warning( + "Invalid CIDR in mcp_internal_ip_ranges: %s, skipping", cidr + ) + return networks if networks else IPAddressUtils._DEFAULT_INTERNAL_NETWORKS + + @staticmethod + def parse_trusted_proxy_networks( + configured_ranges: Optional[List[str]], + ) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: + """ + Parse trusted proxy CIDR ranges for XFF validation. + + Returns empty list if not configured (XFF will not be trusted). + """ + if not configured_ranges: + return [] + networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] = [] + for cidr in configured_ranges: + try: + networks.append(ipaddress.ip_network(cidr, strict=False)) + except ValueError: + verbose_proxy_logger.warning( + "Invalid CIDR in mcp_trusted_proxy_ranges: %s, skipping", cidr + ) + return networks + + @staticmethod + def is_trusted_proxy( + proxy_ip: Optional[str], + trusted_networks: List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]], + ) -> bool: + """Check if the direct connection IP is from a trusted proxy.""" + if not proxy_ip or not trusted_networks: + return False + try: + addr = ipaddress.ip_address(proxy_ip.strip()) + return any(addr in network for network in trusted_networks) + except ValueError: + return False + + @staticmethod + def is_internal_ip( + client_ip: Optional[str], + internal_networks: Optional[ + List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] + ] = None, + ) -> bool: + """ + Check if a client IP is from an internal/private network. + + Handles X-Forwarded-For comma chains (takes leftmost = original client). + Fails closed: empty/invalid IPs are treated as external. + """ + if not client_ip: + return False + + # X-Forwarded-For may contain comma-separated chain; leftmost is original client + if "," in client_ip: + client_ip = client_ip.split(",")[0].strip() + + networks = internal_networks or IPAddressUtils._DEFAULT_INTERNAL_NETWORKS + + try: + addr = ipaddress.ip_address(client_ip.strip()) + except ValueError: + return False + + return any(addr in network for network in networks) + + @staticmethod + def get_mcp_client_ip( + request: Request, + general_settings: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + """ + Extract client IP from a FastAPI request for MCP access control. + + Security: Only trusts X-Forwarded-For if: + 1. use_x_forwarded_for is enabled in settings + 2. The direct connection is from a trusted proxy (if mcp_trusted_proxy_ranges configured) + + Args: + request: FastAPI request object + general_settings: Optional settings dict. If not provided, uses cached reference. + """ + from litellm.proxy.proxy_server import general_settings + + use_xff = general_settings.get("use_x_forwarded_for", False) + + # If XFF is enabled, validate the request comes from a trusted proxy + if use_xff and "x-forwarded-for" in request.headers: + trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges") + if trusted_ranges: + # Validate direct connection is from trusted proxy + direct_ip = request.client.host if request.client else None + trusted_networks = IPAddressUtils.parse_trusted_proxy_networks( + trusted_ranges + ) + if not IPAddressUtils.is_trusted_proxy(direct_ip, trusted_networks): + # Untrusted source trying to set XFF - ignore XFF, use direct IP + verbose_proxy_logger.warning( + "XFF header from untrusted IP %s, ignoring", direct_ip + ) + return direct_ip + + return _get_request_ip_address(request, use_x_forwarded_for=use_xff) diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 40f4745615e..5e59668fc0c 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -59,18 +59,18 @@ if MCP_AVAILABLE: try: - from mcp.shared.tool_name_validation import ( # type: ignore - validate_tool_name, + from mcp.shared.tool_name_validation import ( + validate_tool_name, # pyright: ignore[reportAssignmentType] ) except ImportError: from pydantic import BaseModel - class ToolNameValidationResult(BaseModel): + class _ToolNameValidationResult(BaseModel): is_valid: bool = True warnings: list = [] - def validate_tool_name(name: str) -> ToolNameValidationResult: # type: ignore[misc] - return ToolNameValidationResult() + def validate_tool_name(name: str) -> _ToolNameValidationResult: + return _ToolNameValidationResult() from litellm.proxy._experimental.mcp_server.db import ( create_mcp_server, @@ -433,11 +433,21 @@ async def get_mcp_registry(request: Request): detail="MCP registry is not enabled", ) + from litellm.proxy.auth.ip_address_utils import IPAddressUtils + + client_ip = IPAddressUtils.get_mcp_client_ip(request) + + verbose_proxy_logger.debug("MCP registry request from IP=%s", client_ip) + base_url = get_request_base_url(request) registry_servers: List[Dict[str, Any]] = [] registry_servers.append({"server": _build_builtin_registry_entry(base_url)}) - registered_servers = list(global_mcp_server_manager.get_registry().values()) + # Centralized IP-based filtering: external callers only see public servers + registered_servers = list( + global_mcp_server_manager.get_filtered_registry(client_ip).values() + ) + registered_servers.sort(key=_build_mcp_registry_server_name) for server in registered_servers: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 60327d8da17..0a643da2833 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -11771,9 +11771,13 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request): from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( global_mcp_server_manager, ) + from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.types.mcp import MCPAuth - mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name) + client_ip = IPAddressUtils.get_mcp_client_ip(request) + mcp_server = global_mcp_server_manager.get_mcp_server_by_name( + mcp_server_name, client_ip=client_ip + ) if mcp_server is None: raise HTTPException( status_code=404, detail=f"MCP server '{mcp_server_name}' not found" diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 94f33ffb297..4e6c1810255 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -51,5 +51,6 @@ class MCPServer(BaseModel): env: Optional[Dict[str, str]] = None access_groups: Optional[List[str]] = None allow_all_keys: bool = False + available_on_public_internet: bool = False updated_at: Optional[datetime] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/test_litellm/proxy/auth/test_mcp_ip_filtering.py b/tests/test_litellm/proxy/auth/test_mcp_ip_filtering.py new file mode 100644 index 00000000000..50e51fbe035 --- /dev/null +++ b/tests/test_litellm/proxy/auth/test_mcp_ip_filtering.py @@ -0,0 +1,93 @@ +""" +Unit tests for MCP IP-based access control. + +Tests that internal callers see all MCP servers while +external callers only see servers with available_on_public_internet=True. +""" + +import ipaddress +from unittest.mock import patch + +from litellm.proxy.auth.ip_address_utils import IPAddressUtils +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +def _make_server(server_id, available_on_public_internet=False): + return MCPServer( + server_id=server_id, + name=server_id, + server_name=server_id, + transport="http", + available_on_public_internet=available_on_public_internet, + ) + + +def _make_manager(servers): + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + for s in servers: + manager.config_mcp_servers[s.server_id] = s + return manager + + +class TestIsInternalIp: + """Tests that IP classification works for private, public, and edge cases.""" + + def test_private_ranges_are_internal(self): + assert IPAddressUtils.is_internal_ip("127.0.0.1") is True + assert IPAddressUtils.is_internal_ip("10.0.0.1") is True + assert IPAddressUtils.is_internal_ip("172.16.0.1") is True + assert IPAddressUtils.is_internal_ip("192.168.1.1") is True + assert IPAddressUtils.is_internal_ip("::1") is True + + def test_public_ips_are_external(self): + assert IPAddressUtils.is_internal_ip("8.8.8.8") is False + assert IPAddressUtils.is_internal_ip("1.1.1.1") is False + assert IPAddressUtils.is_internal_ip("172.32.0.1") is False + + def test_xff_chain_uses_leftmost_ip(self): + assert IPAddressUtils.is_internal_ip("8.8.8.8, 10.0.0.1") is False + assert IPAddressUtils.is_internal_ip("10.0.0.1, 8.8.8.8") is True + + def test_fails_closed_on_bad_input(self): + assert IPAddressUtils.is_internal_ip("") is False + assert IPAddressUtils.is_internal_ip(None) is False + assert IPAddressUtils.is_internal_ip("not-an-ip") is False + + +class TestMCPServerIPFiltering: + """Tests that external callers only see public MCP servers.""" + + @patch("litellm.public_mcp_servers", []) + @patch("litellm.proxy.proxy_server.general_settings", {}) + def test_external_ip_only_sees_public_servers(self): + pub = _make_server("pub", available_on_public_internet=True) + priv = _make_server("priv", available_on_public_internet=False) + manager = _make_manager([pub, priv]) + + result = manager.filter_server_ids_by_ip(["pub", "priv"], client_ip="8.8.8.8") + assert result == ["pub"] + + @patch("litellm.public_mcp_servers", []) + @patch("litellm.proxy.proxy_server.general_settings", {}) + def test_internal_ip_sees_all_servers(self): + pub = _make_server("pub", available_on_public_internet=True) + priv = _make_server("priv", available_on_public_internet=False) + manager = _make_manager([pub, priv]) + + result = manager.filter_server_ids_by_ip( + ["pub", "priv"], client_ip="192.168.1.1" + ) + assert result == ["pub", "priv"] + + @patch("litellm.public_mcp_servers", []) + @patch("litellm.proxy.proxy_server.general_settings", {}) + def test_no_ip_means_no_filtering(self): + priv = _make_server("priv", available_on_public_internet=False) + manager = _make_manager([priv]) + + result = manager.filter_server_ids_by_ip(["priv"], client_ip=None) + assert result == ["priv"]