Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
- Pydantic v2 for data validation
- Async/await patterns throughout
- Type hints required for all public APIs
- **Avoid imports within methods** — place all imports at the top of the file (module-level). Inline imports inside functions/methods make dependencies harder to trace and hurt readability. The only exception is avoiding circular imports where absolutely necessary.

### Testing Strategy
- Unit tests in `tests/test_litellm/`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
oauth2_headers: Optional[Dict[str, str]] = None,
mcp_protocol_version: Optional[str] = None,
raw_headers: Optional[Dict[str, str]] = None,
client_ip: Optional[str] = None,
):
self.user_api_key_auth = user_api_key_auth
self.mcp_auth_header = mcp_auth_header
Expand All @@ -35,3 +36,4 @@ def __init__(
self.mcp_protocol_version = mcp_protocol_version
self.oauth2_headers = oauth2_headers
self.raw_headers = raw_headers
self.client_ip = client_ip
28 changes: 22 additions & 6 deletions litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.types.mcp_server.mcp_server_manager import MCPServer
from litellm.proxy.utils import get_server_root_path
from litellm.types.mcp_server.mcp_server_manager import MCPServer

router = APIRouter(
tags=["mcp"],
Expand Down Expand Up @@ -300,7 +301,10 @@ async def authorize(
)

lookup_name = mcp_server_name or client_id
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await authorize_with_server(
Expand Down Expand Up @@ -342,7 +346,10 @@ async def token_endpoint(
)

lookup_name = mcp_server_name or client_id
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await exchange_token_with_server(
Expand Down Expand Up @@ -425,7 +432,10 @@ def _build_oauth_protected_resource_response(
request_base_url = get_request_base_url(request)
mcp_server: Optional[MCPServer] = None
if mcp_server_name:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)

# Build resource URL based on the pattern
if mcp_server_name:
Expand Down Expand Up @@ -538,7 +548,10 @@ def _build_oauth_authorization_server_response(

mcp_server: Optional[MCPServer] = None
if mcp_server_name:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)

return {
"issuer": request_base_url, # point to your proxy
Expand Down Expand Up @@ -629,7 +642,10 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non
if not mcp_server_name:
return dummy_return

mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
if mcp_server is None:
return dummy_return
return await register_client_with_server(
Expand Down
136 changes: 124 additions & 12 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
MCPTransportType,
UserAPIKeyAuth,
)
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
from litellm.proxy.utils import ProxyLogging
from litellm.types.llms.custom_http import httpxSpecialProvider
Expand All @@ -66,20 +67,23 @@

try:
from mcp.shared.tool_name_validation import (
validate_tool_name, # type: ignore[reportAssignmentType]
validate_tool_name, # pyright: ignore[reportAssignmentType]
)
from mcp.shared.tool_name_validation import (
SEP_986_URL,
)
from mcp.shared.tool_name_validation import SEP_986_URL
except ImportError:
from pydantic import BaseModel

SEP_986_URL = "https://github.com/modelcontextprotocol/protocol/blob/main/proposals/0001-tool-name-validation.md"

class ToolNameValidationResult(BaseModel):
class _ToolNameValidationResult(BaseModel):
is_valid: bool = True
warnings: list = []

def validate_tool_name(name: str) -> ToolNameValidationResult: # type: ignore[misc]
return ToolNameValidationResult()
def validate_tool_name(name: str) -> _ToolNameValidationResult:
return _ToolNameValidationResult()


# Probe includes characters on both sides of the separator to mimic real prefixed tool names.
Expand Down Expand Up @@ -325,6 +329,9 @@ async def load_servers_from_config(
access_groups=server_config.get("access_groups", None),
static_headers=server_config.get("static_headers", None),
allow_all_keys=bool(server_config.get("allow_all_keys", False)),
available_on_public_internet=bool(
server_config.get("available_on_public_internet", False)
),
)
self.config_mcp_servers[server_id] = new_server

Expand Down Expand Up @@ -623,6 +630,9 @@ async def build_mcp_server_from_table(
allowed_tools=getattr(mcp_server, "allowed_tools", None),
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
allow_all_keys=mcp_server.allow_all_keys,
available_on_public_internet=bool(
getattr(mcp_server, "available_on_public_internet", False)
),
updated_at=getattr(mcp_server, "updated_at", None),
)
return new_server
Expand Down Expand Up @@ -697,6 +707,23 @@ async def get_allowed_mcp_servers(
verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}.")
return allow_all_server_ids

def filter_server_ids_by_ip(
self, server_ids: List[str], client_ip: Optional[str]
) -> List[str]:
"""
Filter server IDs by client IP — external callers only see public servers.

Returns server_ids unchanged when client_ip is None (no filtering).
"""
if client_ip is None:
return server_ids
return [
sid
for sid in server_ids
if (s := self.get_mcp_server_by_id(sid)) is not None
and self._is_server_accessible_from_ip(s, client_ip)
]

async def get_tools_for_server(self, server_id: str) -> List[MCPTool]:
"""
Get the tools for a given server
Expand Down Expand Up @@ -2202,6 +2229,42 @@ def get_mcp_servers_from_ids(self, server_ids: List[str]) -> List[MCPServer]:
servers.append(server)
return servers

def _get_general_settings(self) -> Dict[str, Any]:
"""Get general_settings, importing lazily to avoid circular imports."""
try:
from litellm.proxy.proxy_server import (
general_settings as proxy_general_settings,
)
return proxy_general_settings
except ImportError:
# Fallback if proxy_server not available
return {}

def _is_server_accessible_from_ip(
self, server: MCPServer, client_ip: Optional[str]
) -> bool:
"""
Check if a server is accessible from the given client IP.

- If client_ip is None, no IP filtering is applied (internal callers).
- If the server has available_on_public_internet=True, it's always accessible.
- Otherwise, only internal/private IPs can access it.
"""
if client_ip is None:
return True
if server.available_on_public_internet:
return True
# Check backwards compat: litellm.public_mcp_servers
public_ids = set(litellm.public_mcp_servers or [])
if server.server_id in public_ids:
return True
# Non-public server: only accessible from internal IPs
general_settings = self._get_general_settings()
internal_networks = IPAddressUtils.parse_internal_networks(
general_settings.get("mcp_internal_ip_ranges")
)
return IPAddressUtils.is_internal_ip(client_ip, internal_networks)

def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the server id
Expand All @@ -2214,27 +2277,76 @@ def get_mcp_server_by_id(self, server_id: str) -> Optional[MCPServer]:

def get_public_mcp_servers(self) -> List[MCPServer]:
"""
Get the public MCP servers
Get the public MCP servers (available_on_public_internet=True flag on server).
Also includes servers from litellm.public_mcp_servers for backwards compat.
"""
servers: List[MCPServer] = []
if litellm.public_mcp_servers is None:
return servers
for server_id in litellm.public_mcp_servers:
server = self.get_mcp_server_by_id(server_id)
if server:
public_ids = set(litellm.public_mcp_servers or [])
for server in self.get_registry().values():
if server.available_on_public_internet or server.server_id in public_ids:
servers.append(server)
return servers

def get_mcp_server_by_name(self, server_name: str) -> Optional[MCPServer]:
def get_mcp_server_by_name(
self, server_name: str, client_ip: Optional[str] = None
) -> Optional[MCPServer]:
"""
Comment on lines +2290 to 2293
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server alias lookup regression

get_mcp_server_by_name() only matches server.server_name == server_name and ignores alias even though servers can be addressed via alias elsewhere. After this PR, multiple request paths (e.g., OAuth endpoints and dynamic MCP routes) rely on get_mcp_server_by_name(..., client_ip=...), so any client using an alias will now get a 404 even if the server exists and is allowed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you need me to change / edit @greptile ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add alias matching to get_mcp_server_by_name(). Change:

Suggested change
def get_mcp_server_by_name(
self, server_name: str, client_ip: Optional[str] = None
) -> Optional[MCPServer]:
"""
def get_mcp_server_by_name(
self, server_name: str, client_ip: Optional[str] = None
) -> Optional[MCPServer]:
"""
Get an MCP server by name or alias, with optional IP filtering.
"""
registry = self.get_registry()
for server in registry.values():
if server.server_name == server_name or server_name in (server.alias or []):
if client_ip and not IPAddressUtils.is_private_ip(client_ip):
if not server.available_on_public_internet:
return None
return server
return None

This preserves the existing alias lookup behavior while adding IP filtering.

Get the MCP Server from the server name
Get the MCP Server from the server name.

Uses priority-based matching to avoid collisions:
1. First pass: exact alias match (highest priority)
2. Second pass: exact server_name match
3. Third pass: exact name match (lowest priority)

Args:
server_name: The server name to look up.
client_ip: Optional client IP for access control. When provided,
non-public servers are hidden from external IPs.
"""
registry = self.get_registry()

# Pass 1: Match by alias (highest priority)
for server in registry.values():
if server.alias == server_name:
if not self._is_server_accessible_from_ip(server, client_ip):
return None
return server

# Pass 2: Match by server_name
for server in registry.values():
if server.server_name == server_name:
if not self._is_server_accessible_from_ip(server, client_ip):
return None
return server

# Pass 3: Match by name (lowest priority)
for server in registry.values():
if server.name == server_name:
if not self._is_server_accessible_from_ip(server, client_ip):
return None
return server

return None

def get_filtered_registry(
self, client_ip: Optional[str] = None
) -> Dict[str, MCPServer]:
"""
Get registry filtered by client IP access control.

Args:
client_ip: Optional client IP. When provided, non-public servers
are hidden from external IPs. When None, returns all servers.
"""
registry = self.get_registry()
if client_ip is None:
return registry
return {
k: v
for k, v in registry.items()
if self._is_server_accessible_from_ip(v, client_ip)
}

def _generate_stable_server_id(
self,
server_name: str,
Expand Down
Loading
Loading