diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index e713e6ca87a..c2d2299fb25 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -263,6 +263,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) + available_on_public_internet Boolean @default(false) } // Generate Tokens for Proxy diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 0da47634a94..4704549e716 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1263,6 +1263,36 @@ "supports_vision": true, "tool_use_system_prompt_tokens": 346 }, + "au.anthropic.claude-opus-4-6-v1:0": { + "cache_creation_input_token_cost": 6.875e-06, + "cache_creation_input_token_cost_above_200k_tokens": 1.375e-05, + "cache_read_input_token_cost": 5.5e-07, + "cache_read_input_token_cost_above_200k_tokens": 1.1e-06, + "input_cost_per_token": 5.5e-06, + "input_cost_per_token_above_200k_tokens": 1.1e-05, + "litellm_provider": "bedrock_converse", + "max_input_tokens": 200000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 2.75e-05, + "output_cost_per_token_above_200k_tokens": 4.125e-05, + "search_context_cost_per_query": { + "search_context_size_high": 0.01, + "search_context_size_low": 0.01, + "search_context_size_medium": 0.01 + }, + "supports_assistant_prefill": false, + "supports_computer_use": true, + "supports_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_tool_choice": true, + "supports_vision": true, + "tool_use_system_prompt_tokens": 346 + }, "anthropic.claude-sonnet-4-20250514-v1:0": { "cache_creation_input_token_cost": 3.75e-06, "cache_read_input_token_cost": 3e-07, diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 901a9b3c076..e7174f943b2 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -72,7 +72,6 @@ from mcp.shared.tool_name_validation import ( SEP_986_URL, ) - from mcp.shared.tool_name_validation import SEP_986_URL except ImportError: from pydantic import BaseModel @@ -82,7 +81,7 @@ class _ToolNameValidationResult(BaseModel): is_valid: bool = True warnings: list = [] - def validate_tool_name(name: str) -> _ToolNameValidationResult: + def validate_tool_name(name: str) -> _ToolNameValidationResult: # type: ignore[misc] return _ToolNameValidationResult() @@ -2304,28 +2303,24 @@ def get_mcp_server_by_name( non-public servers are hidden from external IPs. """ registry = self.get_registry() - # Pass 1: Match by alias (highest priority) for server in registry.values(): if server.alias == server_name: if not self._is_server_accessible_from_ip(server, client_ip): return None return server - # Pass 2: Match by server_name for server in registry.values(): if server.server_name == server_name: if not self._is_server_accessible_from_ip(server, client_ip): return None return server - # Pass 3: Match by name (lowest priority) for server in registry.values(): if server.name == server_name: if not self._is_server_accessible_from_ip(server, client_ip): return None return server - return None def get_filtered_registry( @@ -2576,6 +2571,7 @@ def _build_mcp_server_table(self, server: MCPServer) -> LiteLLM_MCPServerTable: token_url=server.token_url, registration_url=server.registration_url, allow_all_keys=server.allow_all_keys, + available_on_public_internet=server.available_on_public_internet, ) async def get_all_mcp_servers_unfiltered(self) -> List[LiteLLM_MCPServerTable]: diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index e357311680c..890c4ae8fb2 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -303,7 +303,7 @@ async def mcp_server_tool_call( host_token = getattr(host_ctx.meta, 'progressToken', None) if host_token and hasattr(host_ctx, 'session') and host_ctx.session: host_session = host_ctx.session - + async def forward_progress(progress: float, total: float | None): """Forward progress notifications from external MCP to Host""" try: @@ -315,7 +315,7 @@ async def forward_progress(progress: float, total: float | None): verbose_logger.debug(f"Forwarded progress {progress}/{total} to Host") except Exception as e: verbose_logger.error(f"Failed to forward progress to Host: {e}") - + host_progress_callback = forward_progress verbose_logger.debug(f"Host progressToken captured: {host_token[:8]}...") except Exception as e: @@ -725,7 +725,6 @@ def filter_tools_by_allowed_tools( def _get_client_ip_from_context() -> Optional[str]: """ Extract client_ip from auth context. - Returns None if context not set (caller should handle this as "no IP filtering"). """ try: @@ -748,7 +747,6 @@ async def _get_allowed_mcp_servers( mcp_servers: Optional list of server names to filter to. client_ip: Client IP for IP-based access control. If None, falls back to auth context. Pass explicitly from request handlers for safety. - Note: If client_ip is None and auth context is not set, IP filtering is skipped. This is intentional for internal callers but may indicate a bug if called from a request handler without proper context setup. @@ -1781,7 +1779,7 @@ async def _handle_managed_mcp_tool( oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, litellm_logging_obj: Optional[Any] = None, - host_progress_callback: Optional[Callable] = None, + host_progress_callback: Optional[Callable] = None, ) -> CallToolResult: """Handle tool execution for managed server tools""" # Import here to avoid circular import diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 501cfacaf91..6e4131c3ced 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1071,6 +1071,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False + available_on_public_internet: bool = False @model_validator(mode="before") @classmethod @@ -1132,6 +1133,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False + available_on_public_internet: bool = False @model_validator(mode="before") @classmethod @@ -1185,6 +1187,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False + available_on_public_internet: bool = False class MakeMCPServersPublicRequest(LiteLLMPydanticObjectBase): @@ -2096,6 +2099,14 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): None, description="Maximum retention period for spend logs (e.g., '7d' for 7 days). Logs older than this will be deleted.", ) + mcp_internal_ip_ranges: Optional[List[str]] = Field( + None, + description="Custom CIDR ranges that define internal/private networks for MCP access control. When set, only these ranges are treated as internal. Defaults to RFC 1918 private ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8).", + ) + mcp_trusted_proxy_ranges: Optional[List[str]] = Field( + None, + description="CIDR ranges of trusted reverse proxies. When set, X-Forwarded-For headers are only trusted from these IPs.", + ) class ConfigYAML(LiteLLMPydanticObjectBase): diff --git a/litellm/proxy/auth/ip_address_utils.py b/litellm/proxy/auth/ip_address_utils.py index c8e936598e5..651d6785333 100644 --- a/litellm/proxy/auth/ip_address_utils.py +++ b/litellm/proxy/auth/ip_address_utils.py @@ -49,7 +49,6 @@ def parse_trusted_proxy_networks( ) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: """ Parse trusted proxy CIDR ranges for XFF validation. - Returns empty list if not configured (XFF will not be trusted). """ if not configured_ranges: @@ -121,12 +120,23 @@ def get_mcp_client_ip( Args: request: FastAPI request object - general_settings: Optional settings dict. If not provided, uses cached reference. + general_settings: Optional settings dict. If not provided, imports from proxy_server. """ - from litellm.proxy.proxy_server import general_settings + if general_settings is None: + try: + from litellm.proxy.proxy_server import ( + general_settings as proxy_general_settings, + ) + general_settings = proxy_general_settings + except ImportError: + general_settings = {} + + # Handle case where general_settings is still None after import + if general_settings is None: + general_settings = {} use_xff = general_settings.get("use_x_forwarded_for", False) - + # If XFF is enabled, validate the request comes from a trusted proxy if use_xff and "x-forwarded-for" in request.headers: trusted_ranges = general_settings.get("mcp_trusted_proxy_ranges") @@ -142,5 +152,4 @@ def get_mcp_client_ip( "XFF header from untrusted IP %s, ignoring", direct_ip ) return direct_ip - return _get_request_ip_address(request, use_x_forwarded_for=use_xff) diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 5e59668fc0c..90c2f7fdf62 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -69,7 +69,7 @@ class _ToolNameValidationResult(BaseModel): is_valid: bool = True warnings: list = [] - def validate_tool_name(name: str) -> _ToolNameValidationResult: + def validate_tool_name(name: str) -> _ToolNameValidationResult: # type: ignore[misc] return _ToolNameValidationResult() from litellm.proxy._experimental.mcp_server.db import ( @@ -333,6 +333,7 @@ def _build_temporary_mcp_server_record( token_url=payload.token_url, registration_url=payload.registration_url, allow_all_keys=payload.allow_all_keys, + available_on_public_internet=payload.available_on_public_internet, ) def get_prisma_client_or_throw(message: str): @@ -421,6 +422,18 @@ async def get_mcp_access_groups( access_groups_list = sorted(list(access_groups)) return {"access_groups": access_groups_list} + @router.get( + "/network/client-ip", + tags=["mcp"], + dependencies=[Depends(user_api_key_auth)], + description="Returns the caller's IP address as seen by the proxy.", + ) + async def get_client_ip(request: Request): + from litellm.proxy.auth.ip_address_utils import IPAddressUtils + + client_ip = IPAddressUtils.get_mcp_client_ip(request) + return {"ip": client_ip} + @router.get( "/registry.json", tags=["mcp"], diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0a643da2833..cf9e00512b1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -10904,6 +10904,8 @@ async def get_config_list( "pass_through_endpoints": {"type": "PydanticModel"}, "store_prompts_in_spend_logs": {"type": "Boolean"}, "maximum_spend_logs_retention_period": {"type": "String"}, + "mcp_internal_ip_ranges": {"type": "List"}, + "mcp_trusted_proxy_ranges": {"type": "List"}, } return_val = [] @@ -10973,11 +10975,15 @@ async def get_config_list( elif field_name in general_settings: _stored_in_db = False + _field_value = general_settings.get(field_name, None) + if _field_value is None and field_name in db_general_settings_dict: + _field_value = db_general_settings_dict[field_name] + _response_obj = ConfigList( field_name=field_name, field_type=allowed_args[field_name]["type"], field_description=field_info.description or "", - field_value=general_settings.get(field_name, None), + field_value=_field_value, stored_in_db=_stored_in_db, field_default_value=field_info.default, nested_fields=nested_fields, diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index a6a573836b5..efced6f8749 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -261,6 +261,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) + available_on_public_internet Boolean @default(false) } // Generate Tokens for Proxy diff --git a/schema.prisma b/schema.prisma index 240e0dfea48..95fff722882 100644 --- a/schema.prisma +++ b/schema.prisma @@ -263,6 +263,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) + available_on_public_internet Boolean @default(false) } // Generate Tokens for Proxy diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 5785ff750fd..ea823df1fb2 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -627,7 +627,10 @@ def test_generate_stable_server_id(): @pytest.mark.asyncio async def test_list_tools_rest_api_server_not_found(): """Test the list_tools REST API when server is not found""" - from litellm.proxy._experimental.mcp_server.rest_endpoints import list_tool_rest_api + from litellm.proxy._experimental.mcp_server.rest_endpoints import ( + list_tool_rest_api, + global_mcp_server_manager, + ) from fastapi import Query from litellm.proxy._types import UserAPIKeyAuth @@ -641,21 +644,38 @@ async def test_list_tools_rest_api_server_not_found(): ), ) - # Mock request + # Mock request with proper client attribute (internal IP for no filtering) mock_request = MagicMock() mock_request.headers = {} + mock_request.client = MagicMock() + mock_request.client.host = "127.0.0.1" # Internal IP to bypass IP filtering - # Test with non-existent server ID - response = await list_tool_rest_api( - request=mock_request, - server_id="non_existent_server_id", - user_api_key_dict=mock_user_auth, - ) + # Mock the global_mcp_server_manager to allow the server ID but return None for the server + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints.global_mcp_server_manager" + ) as mock_manager: + # Allow the server ID in permissions + mock_manager.get_allowed_mcp_servers = AsyncMock( + return_value=["non_existent_server_id"] + ) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) + # Return None when trying to get the server (server doesn't exist) + mock_manager.get_mcp_server_by_id = MagicMock(return_value=None) + + # Test with non-existent server ID + response = await list_tool_rest_api( + request=mock_request, + server_id="non_existent_server_id", + user_api_key_dict=mock_user_auth, + ) - assert isinstance(response, dict) - assert response["tools"] == [] - assert response["error"] == "server_not_found" - assert "Server with id non_existent_server_id not found" in response["message"] + assert isinstance(response, dict) + assert response["tools"] == [] + assert response["error"] == "server_not_found" + assert "Server with id non_existent_server_id not found" in response["message"] @pytest.mark.asyncio @@ -663,90 +683,75 @@ async def test_list_tools_rest_api_success(): """Test the list_tools REST API successful case""" from litellm.proxy._experimental.mcp_server.rest_endpoints import ( list_tool_rest_api, - global_mcp_server_manager, + ) + from litellm.proxy._experimental.mcp_server.server import ( + ListMCPToolsRestAPIResponseObject, ) from fastapi import Query from litellm.proxy._types import UserAPIKeyAuth - # Store original registry to restore after test - original_registry = global_mcp_server_manager.get_registry().copy() - original_tool_mapping = ( - global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.copy() - ) - try: - # Clear existing registry - global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear() - global_mcp_server_manager.registry.clear() - global_mcp_server_manager.config_mcp_servers.clear() - - # Mock successful tools - mock_tools = [ - MCPTool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object"}, - ) - ] - - # Create mock client - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tools) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) + # Mock successful tools + mock_tools = [ + ListMCPToolsRestAPIResponseObject( + name="test_tool", + description="A test tool", + inputSchema={"type": "object"}, + mcp_info={"server_name": "test_server"}, + ) + ] - def mock_client_constructor(*args, **kwargs): - return mock_client + # Create a mock server + mock_server = MagicMock() + mock_server.server_id = "test-server-123" + mock_server.alias = "test_server" + mock_server.name = "test_server" + mock_server.mcp_info = {"server_name": "test_server"} - with patch( - "litellm.proxy._experimental.mcp_server.mcp_server_manager.MCPClient", - mock_client_constructor, - ): - # Load server config into global manager - await global_mcp_server_manager.load_servers_from_config( - { - "test_server": { - "url": "https://test-server.com/mcp", - "transport": MCPTransport.http, - } - } - ) + # Mock UserAPIKeyAuth + mock_user_auth = UserAPIKeyAuth( + api_key="test", + user_id="test", + object_permission=LiteLLM_ObjectPermissionTable( + object_permission_id="dummy", + mcp_servers=["test-server-123"], + ), + ) - # Mock UserAPIKeyAuth - mock_user_auth = UserAPIKeyAuth( - api_key="test", - user_id="test", - object_permission=LiteLLM_ObjectPermissionTable( - object_permission_id="dummy", - mcp_servers=list( - global_mcp_server_manager.get_all_mcp_server_ids() - ), - ), - ) + # Mock request with proper client attribute (internal IP for no filtering) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.client = MagicMock() + mock_request.client.host = "127.0.0.1" # Internal IP to bypass IP filtering - # Get the server ID - server_id = list(global_mcp_server_manager.get_registry().keys())[0] + # Mock the global_mcp_server_manager + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints.global_mcp_server_manager" + ) as mock_manager: + mock_manager.get_allowed_mcp_servers = AsyncMock( + return_value=["test-server-123"] + ) + mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) - # Mock request - mock_request = MagicMock() - mock_request.headers = {} + # Mock the _get_tools_for_single_server function + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints._get_tools_for_single_server" + ) as mock_get_tools: + mock_get_tools.return_value = mock_tools # Test successful case response = await list_tool_rest_api( request=mock_request, - server_id=server_id, + server_id="test-server-123", user_api_key_dict=mock_user_auth, ) assert isinstance(response, dict) assert len(response["tools"]) == 1 assert response["tools"][0].name == "test_tool" - finally: - # Restore original state - global_mcp_server_manager.registry = {} - global_mcp_server_manager.config_mcp_servers = original_registry - global_mcp_server_manager.tool_name_to_mcp_server_name_mapping = ( - original_tool_mapping - ) @pytest.mark.asyncio @@ -808,6 +813,10 @@ def mock_get_server_by_id(server_id): ) mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2 mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) with patch( "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", @@ -843,6 +852,10 @@ async def mock_get_tools_side_effect( mock_manager_2._get_tools_from_server = AsyncMock( side_effect=mock_get_tools_side_effect ) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager_2.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) with patch( "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", @@ -867,6 +880,10 @@ async def mock_get_tools_side_effect( ) mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else (mock_server_2 if server_id == "server2_id" else mock_server_3) mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) with patch( "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", @@ -1795,6 +1812,10 @@ async def test_list_tool_rest_api_with_server_specific_auth(): mock_server.mcp_info = {"server_name": "zapier"} mock_manager.get_mcp_server_by_id.return_value = mock_server + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) mock_user_api_key_dict = UserAPIKeyAuth( api_key="test", @@ -1885,6 +1906,10 @@ async def test_list_tool_rest_api_with_default_auth(): mock_server.mcp_info = {"server_name": "unknown_server"} mock_manager.get_mcp_server_by_id.return_value = mock_server + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) mock_user_api_key_dict = UserAPIKeyAuth( api_key="test", @@ -1991,6 +2016,10 @@ async def test_list_tool_rest_api_all_servers_with_auth(): server_id ) ) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) mock_user_api_key_dict = UserAPIKeyAuth( api_key="test", @@ -2120,6 +2149,10 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-123"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) # Mock the _get_tools_from_server method to return all tools mock_manager._get_tools_from_server = AsyncMock(return_value=mock_tools) @@ -2230,6 +2263,10 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-456"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) # Mock the _get_tools_from_server method to return all tools mock_manager._get_tools_from_server = AsyncMock(return_value=mock_tools) @@ -2326,6 +2363,10 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-000"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) + # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip = MagicMock( + side_effect=lambda server_ids, client_ip: server_ids + ) # Mock the _get_tools_from_server method to return all tools mock_manager._get_tools_from_server = AsyncMock(return_value=mock_tools) diff --git a/ui/litellm-dashboard/src/components/mcp_tools/MCPNetworkSettings.tsx b/ui/litellm-dashboard/src/components/mcp_tools/MCPNetworkSettings.tsx new file mode 100644 index 00000000000..46fbbbcf2b0 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/MCPNetworkSettings.tsx @@ -0,0 +1,152 @@ +import React, { useState, useEffect } from "react"; +import { Select, Button, Card, Typography, Spin, Tag } from "antd"; +import { SaveOutlined, PlusOutlined } from "@ant-design/icons"; +import { getGeneralSettingsCall, updateConfigFieldSetting, deleteConfigFieldSetting, fetchMCPClientIp } from "../networking"; + +const { Text } = Typography; + +interface MCPNetworkSettingsProps { + accessToken: string | null; +} + +/** + * Given an IP like "203.0.113.45", return "203.0.113.0/24". + */ +function ipToSlash24(ip: string): string { + const parts = ip.split("."); + if (parts.length !== 4) return ip + "/32"; + return `${parts[0]}.${parts[1]}.${parts[2]}.0/24`; +} + +const MCPNetworkSettings: React.FC = ({ accessToken }) => { + const [loading, setLoading] = useState(true); + const [saving, setSaving] = useState(false); + const [privateRanges, setPrivateRanges] = useState([]); + const [currentIp, setCurrentIp] = useState(null); + + useEffect(() => { + loadSettings(); + detectCurrentIp(); + }, [accessToken]); + + const loadSettings = async () => { + if (!accessToken) return; + setLoading(true); + try { + const settings = await getGeneralSettingsCall(accessToken); + for (const field of settings) { + if (field.field_name === "mcp_internal_ip_ranges" && field.field_value) { + setPrivateRanges(field.field_value); + } + } + } catch (error) { + console.error("Failed to load MCP network settings:", error); + } finally { + setLoading(false); + } + }; + + const detectCurrentIp = async () => { + if (!accessToken) return; + const ip = await fetchMCPClientIp(accessToken); + if (ip) { + setCurrentIp(ip); + } + }; + + const handleSave = async () => { + if (!accessToken) return; + setSaving(true); + try { + if (privateRanges.length > 0) { + await updateConfigFieldSetting(accessToken, "mcp_internal_ip_ranges", privateRanges); + } else { + await deleteConfigFieldSetting(accessToken, "mcp_internal_ip_ranges"); + } + } catch (error) { + console.error("Failed to save MCP network settings:", error); + } finally { + setSaving(false); + } + }; + + const addSuggestedRange = (range: string) => { + if (!privateRanges.includes(range)) { + setPrivateRanges([...privateRanges, range]); + } + }; + + if (loading) { + return ( +
+ +
+ ); + } + + const suggestedRange = currentIp ? ipToSlash24(currentIp) : null; + + return ( +
+
+ Private IP Ranges +

+ Define which IP ranges are part of your private network. Callers from these IPs can see all MCP servers. Callers from any other IP can only see servers marked "Available on Public Internet". +

+
+ + + {currentIp && ( +
+ + Your current IP: {currentIp} + + {suggestedRange && !privateRanges.includes(suggestedRange) && ( +
+ Suggested range: + } + onClick={() => addSuggestedRange(suggestedRange)} + > + {suggestedRange} + +
+ )} +
+ )} + +
+ Your Private Network Ranges +
+