Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion litellm/experimental_mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client

try:
from mcp.client.streamable_http import streamable_http_client # type: ignore
except ImportError:
streamable_http_client = None
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import (
Expand Down
31 changes: 24 additions & 7 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,20 @@
MCPOAuthMetadata,
MCPServer,
)
from mcp.shared.tool_name_validation import SEP_986_URL, validate_tool_name

try:
from mcp.shared.tool_name_validation import SEP_986_URL, validate_tool_name # type: ignore
except ImportError:
SEP_986_URL = "https://github.com/modelcontextprotocol/protocol/blob/main/proposals/0001-tool-name-validation.md"

def validate_tool_name(name: str):
from pydantic import BaseModel

class MockResult(BaseModel):
is_valid: bool = True
warnings: list = []

return MockResult()


# Probe includes characters on both sides of the separator to mimic real prefixed tool names.
Expand All @@ -90,7 +103,9 @@ def _warn(field_name: str, value: Optional[str]) -> None:
if result.is_valid:
return

warning_text = "; ".join(result.warnings) if result.warnings else "Validation failed"
warning_text = (
"; ".join(result.warnings) if result.warnings else "Validation failed"
)
verbose_logger.warning(
"MCP server '%s' has invalid %s '%s': %s",
server_id,
Expand All @@ -103,7 +118,6 @@ def _warn(field_name: str, value: Optional[str]) -> None:
_warn("server_name", server_name)



def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]:
"""
Deserialize optional JSON mappings stored in the database.
Expand Down Expand Up @@ -391,10 +405,13 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s
# Note: `extra_headers` on MCPServer is a List[str] of header names to forward
# from the client request (not available in this OpenAPI tool generation step).
# `static_headers` is a dict of concrete headers to always send.
headers = merge_mcp_headers(
extra_headers=headers,
static_headers=server.static_headers,
) or {}
headers = (
merge_mcp_headers(
extra_headers=headers,
static_headers=server.static_headers,
)
or {}
)

verbose_logger.debug(
f"Using headers for OpenAPI tools (excluding sensitive values): "
Expand Down
6 changes: 5 additions & 1 deletion litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@
AuthContextMiddleware,
auth_context_var,
)
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager

try:
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
except ImportError:
StreamableHTTPSessionManager = None # type: ignore
from mcp.types import (
CallToolResult,
EmbeddedResource,
Expand Down
18 changes: 14 additions & 4 deletions litellm/proxy/management_endpoints/mcp_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,19 @@
MCP_AVAILABLE = False

if MCP_AVAILABLE:
from mcp.shared.tool_name_validation import validate_tool_name
try:
from mcp.shared.tool_name_validation import validate_tool_name # type: ignore
except ImportError:

def validate_tool_name(name: str):
from pydantic import BaseModel

class MockResult(BaseModel):
is_valid: bool = True
warnings: list = []

return MockResult()

from litellm.proxy._experimental.mcp_server.db import (
create_mcp_server,
delete_mcp_server,
Expand Down Expand Up @@ -122,9 +134,7 @@ def _validate_mcp_server_name_fields(payload: Any) -> None:
)
if validation_result.warnings:
error_messages_text = (
error_messages_text
+ "\n"
+ "\n".join(validation_result.warnings)
error_messages_text + "\n" + "\n".join(validation_result.warnings)
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
111 changes: 65 additions & 46 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,9 @@ def generate_feedback_box():
server_root_path = get_server_root_path()
_license_check = LicenseCheck()
premium_user: bool = _license_check.is_premium()
premium_user_data: Optional["EnterpriseLicenseData"] = (
_license_check.airgapped_license_data
)
premium_user_data: Optional[
"EnterpriseLicenseData"
] = _license_check.airgapped_license_data
global_max_parallel_request_retries_env: Optional[str] = os.getenv(
"LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES"
)
Expand Down Expand Up @@ -1209,19 +1209,19 @@ async def root_redirect():
config_agents: Optional[List[AgentConfig]] = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
shared_aiohttp_session: Optional["ClientSession"] = (
None # Global shared session for connection reuse
)
shared_aiohttp_session: Optional[
"ClientSession"
] = None # Global shared session for connection reuse
user_api_key_cache = DualCache(
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
)
model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
redis_usage_cache: Optional[
RedisCache
] = None # redis cache used for tracking spend, tpm/rpm limits
polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False
polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache
user_custom_auth = None
Expand Down Expand Up @@ -1560,9 +1560,9 @@ async def _update_team_cache():
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
existing_spend_obj: Optional[
LiteLLM_TeamTable
] = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None:
# do nothing if team not in api key cache
return
Expand Down Expand Up @@ -2856,6 +2856,7 @@ async def _init_policy_engine(

from litellm.proxy.policy_engine.init_policies import init_policies
from litellm.proxy.policy_engine.policy_validator import PolicyValidator

if config is None:
verbose_proxy_logger.debug("Policy engine: config is None, skipping")
return
Expand All @@ -2867,7 +2868,9 @@ async def _init_policy_engine(

policy_attachments_config = config.get("policy_attachments", None)

verbose_proxy_logger.info(f"Policy engine: found {len(policies_config)} policies in config")
verbose_proxy_logger.info(
f"Policy engine: found {len(policies_config)} policies in config"
)

# Initialize policies
await init_policies(
Expand Down Expand Up @@ -4009,10 +4012,10 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient):
)

try:
guardrails_in_db: List[Guardrail] = (
await GuardrailRegistry.get_all_guardrails_from_db(
prisma_client=prisma_client
)
guardrails_in_db: List[
Guardrail
] = await GuardrailRegistry.get_all_guardrails_from_db(
prisma_client=prisma_client
)
verbose_proxy_logger.debug(
"guardrails from the DB %s", str(guardrails_in_db)
Expand Down Expand Up @@ -4046,7 +4049,9 @@ async def _init_policies_in_db(self, prisma_client: PrismaClient):
await policy_registry.sync_policies_from_db(prisma_client=prisma_client)

# Sync attachments from DB to in-memory registry
await attachment_registry.sync_attachments_from_db(prisma_client=prisma_client)
await attachment_registry.sync_attachments_from_db(
prisma_client=prisma_client
)

verbose_proxy_logger.debug(
"Successfully synced policies and attachments from DB"
Expand Down Expand Up @@ -4369,9 +4374,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
os.environ["AZURE_API_VERSION"] = (
api_version # set this for azure - litellm can read this from the env
)
os.environ[
"AZURE_API_VERSION"
] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param
Expand Down Expand Up @@ -5217,7 +5222,9 @@ async def model_list(

# Include model access groups if requested
if include_model_access_groups:
proxy_model_list = list(set(proxy_model_list + list(model_access_groups.keys())))
proxy_model_list = list(
set(proxy_model_list + list(model_access_groups.keys()))
)

# Get complete model list including wildcard routes if requested
from litellm.proxy.auth.model_checks import get_complete_model_list
Expand Down Expand Up @@ -7674,12 +7681,12 @@ def _enrich_model_info_with_litellm_data(
"""
Enrich a model dictionary with litellm model info (pricing, context window, etc.)
and remove sensitive information.

Args:
model: Model dictionary to enrich
debug: Whether to include debug information like openai_client
llm_router: Optional router instance for debug info

Returns:
Enriched model dictionary with sensitive info removed
"""
Expand All @@ -7689,9 +7696,7 @@ def _enrich_model_info_with_litellm_data(
_openai_client = "None"
if llm_router is not None:
_openai_client = (
llm_router._get_client(
deployment=model, kwargs={}, client_type="async"
)
llm_router._get_client(deployment=model, kwargs={}, client_type="async")
or "None"
)
else:
Expand Down Expand Up @@ -8140,7 +8145,9 @@ async def model_info_v2(
# This must happen before teamId filtering so that direct_access and access_via_team_ids are populated
for i, _model in enumerate(all_models):
all_models[i] = _enrich_model_info_with_litellm_data(
model=_model, debug=debug if debug is not None else False, llm_router=llm_router
model=_model,
debug=debug if debug is not None else False,
llm_router=llm_router,
)

# Apply teamId filter if provided
Expand Down Expand Up @@ -9623,7 +9630,7 @@ def get_logo_url():


@app.get("/get_image", include_in_schema=False)
def get_image():
async def get_image():
"""Get logo to show on admin UI"""

# get current_dir
Expand All @@ -9642,25 +9649,37 @@ def get_image():
if is_non_root and not os.path.exists(default_logo):
default_logo = default_site_logo

cache_dir = assets_dir if is_non_root else current_dir
cache_path = os.path.join(cache_dir, "cached_logo.jpg")

# [OPTIMIZATION] Check if the cached image exists first
if os.path.exists(cache_path):
return FileResponse(cache_path, media_type="image/jpeg")

logo_path = os.getenv("UI_LOGO_PATH", default_logo)
verbose_proxy_logger.debug("Reading logo from path: %s", logo_path)

# Check if the logo path is an HTTP/HTTPS URL
if logo_path.startswith(("http://", "https://")):
# Download the image and cache it
client = HTTPHandler()
response = client.get(logo_path)
if response.status_code == 200:
# Save the image to a local file
cache_dir = assets_dir if is_non_root else current_dir
cache_path = os.path.join(cache_dir, "cached_logo.jpg")
with open(cache_path, "wb") as f:
f.write(response.content)

# Return the cached image as a FileResponse
return FileResponse(cache_path, media_type="image/jpeg")
else:
# Handle the case when the image cannot be downloaded
try:
# Download the image and cache it
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler

async_client = AsyncHTTPHandler(timeout=5.0)
response = await async_client.get(logo_path)
if response.status_code == 200:
# Save the image to a local file
with open(cache_path, "wb") as f:
f.write(response.content)

# Return the cached image as a FileResponse
return FileResponse(cache_path, media_type="image/jpeg")
else:
# Handle the case when the image cannot be downloaded
return FileResponse(default_logo, media_type="image/jpeg")
except Exception as e:
# Handle any exceptions during the download (e.g., timeout, connection error)
verbose_proxy_logger.debug(f"Error downloading logo from {logo_path}: {e}")
return FileResponse(default_logo, media_type="image/jpeg")
else:
# Return the local image file if the logo path is not an HTTP/HTTPS URL
Expand Down Expand Up @@ -10278,9 +10297,9 @@ async def get_config_list(
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[idx].field_description = (
sub_field_info.description
)
nested_fields[
idx
].field_description = sub_field_info.description
idx += 1

_stored_in_db = None
Expand Down
Loading
Loading