Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a8a0c83
fix(common_daily_activity.py): initial commit with working mock BE en…
Jul 7, 2025
6c1b72b
feat(ui/): show mcp server activity on UI
Jul 7, 2025
2e520ed
feat(common_daily_activity.py): return activity by key
Jul 7, 2025
a6afa31
feat(ui/): show top api keys for a given model / mcp server
Jul 7, 2025
929b69b
fix(common_daily_activity.py): use known mcp server names
Jul 7, 2025
2725c46
feat(server.py): log the namespaced tool name (includes server prefix)
Jul 7, 2025
0e0efc2
feat(db_spend_update_writer.py): log by mcp_namespaced_tool_name
Jul 8, 2025
cacd89c
fix(server.py): add key/user metadata to mcp calls
Jul 8, 2025
faec6cc
refactor(common_daily_activity.py): update to return mcp activity in API
Jul 8, 2025
99cf1a2
fix(common_daily_activity.py): handle empty key
Jul 8, 2025
ec93847
fix(common_daily_activity.py): track when api key is empty
Jul 8, 2025
aabef51
test(test_spend_management_endpoints.py): update tests
Jul 8, 2025
f08a8bd
fix: fix ui linting error
Jul 8, 2025
8e5a386
fix: fix linting errors
Jul 8, 2025
986f98f
test: add missing key
Jul 8, 2025
84bb228
build(schema.prisma): add mcp tool tracking
Jul 8, 2025
5fe2d7f
fix(migration.sql): add schema migration file
Jul 8, 2025
ee967ef
feat(server.py): add request logging for mcp calls
Jul 9, 2025
c01b671
fix(new_usage.tsx): fix linting errors
Jul 9, 2025
336f072
fix: fix code qa errors
Jul 9, 2025
70e2f59
Merge branch 'main' into litellm_mcp_demo
Jul 9, 2025
374acb4
fix(activity_metrics.tsx): fix ui linting errors post-merge
Jul 9, 2025
b4e75f9
fix(types/utils.py): fix linting error
Jul 9, 2025
ff4d10e
fix(server.py): always have name
Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
-- DropIndex
DROP INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key";

-- DropIndex
DROP INDEX "LiteLLM_DailyTeamSpend_team_id_date_api_key_model_custom_ll_key";

-- DropIndex
DROP INDEX "LiteLLM_DailyUserSpend_user_id_date_api_key_model_custom_ll_key";

-- AlterTable
ALTER TABLE "LiteLLM_DailyTagSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT,
ALTER COLUMN "model" DROP NOT NULL;

-- AlterTable
ALTER TABLE "LiteLLM_DailyTeamSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT,
ALTER COLUMN "model" DROP NOT NULL;

-- AlterTable
ALTER TABLE "LiteLLM_DailyUserSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT,
ALTER COLUMN "model" DROP NOT NULL;

-- AlterTable
ALTER TABLE "LiteLLM_SpendLogs" ADD COLUMN "mcp_namespaced_tool_name" TEXT;

-- CreateIndex
CREATE INDEX "LiteLLM_DailyTagSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyTagSpend"("mcp_namespaced_tool_name");

-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key" ON "LiteLLM_DailyTagSpend"("tag", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name");

-- CreateIndex
CREATE INDEX "LiteLLM_DailyTeamSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyTeamSpend"("mcp_namespaced_tool_name");

-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyTeamSpend_team_id_date_api_key_model_custom_ll_key" ON "LiteLLM_DailyTeamSpend"("team_id", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name");

-- CreateIndex
CREATE INDEX "LiteLLM_DailyUserSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyUserSpend"("mcp_namespaced_tool_name");

-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyUserSpend_user_id_date_api_key_model_custom_ll_key" ON "LiteLLM_DailyUserSpend"("user_id", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name");

25 changes: 16 additions & 9 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ model LiteLLM_SpendLogs {
response Json? @default("{}")
session_id String?
status String?
mcp_namespaced_tool_name String?
proxy_server_request Json? @default("{}")
@@index([startTime])
@@index([end_user])
Expand Down Expand Up @@ -360,9 +361,10 @@ model LiteLLM_DailyUserSpend {
user_id String?
date String
api_key String
model String
model String?
model_group String?
custom_llm_provider String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
Expand All @@ -374,11 +376,12 @@ model LiteLLM_DailyUserSpend {
created_at DateTime @default(now())
updated_at DateTime @updatedAt

@@unique([user_id, date, api_key, model, custom_llm_provider])
@@unique([user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([user_id])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}

// Track daily team spend metrics per model and key
Expand All @@ -387,9 +390,10 @@ model LiteLLM_DailyTeamSpend {
team_id String?
date String
api_key String
model String
model String?
model_group String?
custom_llm_provider String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
Expand All @@ -401,11 +405,12 @@ model LiteLLM_DailyTeamSpend {
created_at DateTime @default(now())
updated_at DateTime @updatedAt

@@unique([team_id, date, api_key, model, custom_llm_provider])
@@unique([team_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([team_id])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}

// Track daily team spend metrics per model and key
Expand All @@ -414,9 +419,10 @@ model LiteLLM_DailyTagSpend {
tag String?
date String
api_key String
model String
model String?
model_group String?
custom_llm_provider String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
Expand All @@ -428,11 +434,12 @@ model LiteLLM_DailyTagSpend {
created_at DateTime @default(now())
updated_at DateTime @updatedAt

@@unique([tag, date, api_key, model, custom_llm_provider])
@@unique([tag, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([tag])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}


Expand Down
122 changes: 84 additions & 38 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from starlette.types import Receive, Scope, Send

from litellm._logging import verbose_logger
from litellm.constants import MCP_TOOL_NAME_PREFIX
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
Expand Down Expand Up @@ -130,7 +129,9 @@ async def initialize_session_managers():
await _sse_session_manager_cm.__aenter__()

_SESSION_MANAGERS_INITIALIZED = True
verbose_logger.info("MCP Server started with StreamableHTTP and SSE session managers!")
verbose_logger.info(
"MCP Server started with StreamableHTTP and SSE session managers!"
)

async def shutdown_session_managers():
"""Shutdown the session managers."""
Expand Down Expand Up @@ -198,17 +199,48 @@ async def mcp_server_tool_call(
Raises:
HTTPException: If tool not found or arguments missing
"""
from fastapi import Request

from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import proxy_config

# Validate arguments
user_api_key_auth, mcp_auth_header, _ = get_auth_context()

verbose_logger.debug(
f"MCP mcp_server_tool_call - User API Key Auth from context: {user_api_key_auth}"
)
response = await call_mcp_tool(
name=name,
arguments=arguments,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
)
try:
# Create a body date for logging
body_data = {"name": name, "arguments": arguments}

request = Request(
scope={
"type": "http",
"method": "POST",
"path": "/mcp/tools/call",
"headers": [(b"content-type", b"application/json")],
}
)
if user_api_key_auth is not None:
data = await add_litellm_data_to_request(
data=body_data,
request=request,
user_api_key_dict=user_api_key_auth,
proxy_config=proxy_config,
)
else:
data = body_data

response = await call_mcp_tool(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
**data, # for logging
)
except Exception as e:
verbose_logger.exception(f"MCP mcp_server_tool_call - error: {e}")
raise e

return response

########################################################
Expand All @@ -222,7 +254,7 @@ async def mcp_server_tool_call(
async def _get_tools_from_mcp_servers(
user_api_key_auth: Optional[UserAPIKeyAuth],
mcp_auth_header: Optional[str],
mcp_servers: Optional[List[str]]
mcp_servers: Optional[List[str]],
) -> List[MCPTool]:
"""
Helper method to fetch tools from MCP servers based on server filtering criteria.
Expand All @@ -238,12 +270,19 @@ async def _get_tools_from_mcp_servers(
if mcp_servers:
# If mcp_servers header is present, only get tools from specified servers
tools = []
for server_id in await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth):
for server_id in await global_mcp_server_manager.get_allowed_mcp_servers(
user_api_key_auth
):
server = global_mcp_server_manager.get_mcp_server_by_id(server_id)
if server and any(normalize_server_name(server.name) == normalize_server_name(s) for s in mcp_servers):
server_tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
mcp_auth_header=mcp_auth_header,
if server and any(
normalize_server_name(server.name) == normalize_server_name(s)
for s in mcp_servers
):
server_tools = (
await global_mcp_server_manager._get_tools_from_server(
server=server,
mcp_auth_header=mcp_auth_header,
)
)
tools.extend(server_tools)
return tools
Expand Down Expand Up @@ -284,7 +323,7 @@ async def _list_mcp_tools(
tools_from_mcp_servers = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_servers=mcp_servers
mcp_servers=mcp_servers,
)

verbose_logger.debug("TOOLS FROM MCP SERVERS: %s", tools_from_mcp_servers)
Expand All @@ -294,28 +333,31 @@ async def _list_mcp_tools(

@client
async def call_mcp_tool(
name: str,
arguments: Optional[Dict[str, Any]] = None,
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
mcp_auth_header: Optional[str] = None,
**kwargs: Any
name: str,
arguments: Optional[Dict[str, Any]] = None,
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
mcp_auth_header: Optional[str] = None,
**kwargs: Any,
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Call a specific tool with the provided arguments (handles prefixed tool names)
"""

if arguments is None:
raise HTTPException(
status_code=400, detail="Request arguments are required"
)

# Remove prefix from tool name for logging and processing
original_tool_name, _ = get_server_name_prefix_tool_mcp(
name)
original_tool_name, server_name_from_prefix = get_server_name_prefix_tool_mcp(
name
)

standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = (
_get_standard_logging_mcp_tool_call(
name=original_tool_name, # Use original name for logging
arguments=arguments,
server_name=server_name_from_prefix,
)
)
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
Expand All @@ -325,20 +367,16 @@ async def call_mcp_tool(
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
standard_logging_mcp_tool_call
)
model_name = f"MCP: {MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}"
litellm_logging_obj.model = model_name
litellm_logging_obj.model_call_details["model"] = model_name
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
standard_logging_mcp_tool_call.get("mcp_server_name")
)
#########################################################
# Managed MCP Server Tool
# Try managed server tool first (pass the full prefixed name)
# Primary and recommended way to use MCP servers
#########################################################
mcp_server: Optional[MCPServer] = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
mcp_server: Optional[MCPServer] = (
global_mcp_server_manager._get_mcp_server_from_tool_name(name)
)
if mcp_server:
standard_logging_mcp_tool_call["mcp_server_cost_info"] = (mcp_server.mcp_info or {}).get("mcp_server_cost_info")
standard_logging_mcp_tool_call["mcp_server_cost_info"] = (
mcp_server.mcp_info or {}
).get("mcp_server_cost_info")
return await _handle_managed_mcp_tool(
name=name, # Pass the full name (potentially prefixed)
arguments=arguments,
Expand All @@ -355,6 +393,7 @@ async def call_mcp_tool(
def _get_standard_logging_mcp_tool_call(
name: str,
arguments: Dict[str, Any],
server_name: Optional[str],
) -> StandardLoggingMCPToolCall:
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
if mcp_server:
Expand All @@ -364,15 +403,17 @@ def _get_standard_logging_mcp_tool_call(
arguments=arguments,
mcp_server_name=mcp_info.get("server_name"),
mcp_server_logo_url=mcp_info.get("logo_url"),
namespaced_tool_name=f"{server_name}/{name}" if server_name else name,
)
else:
return StandardLoggingMCPToolCall(
name=name,
arguments=arguments,
namespaced_tool_name=f"{server_name}/{name}" if server_name else name,
)

async def _handle_managed_mcp_tool(
name: str,
name: str,
arguments: Dict[str, Any],
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
mcp_auth_header: Optional[str] = None,
Expand All @@ -388,7 +429,7 @@ async def _handle_managed_mcp_tool(
return call_tool_result.content # type: ignore[return-value]

async def _handle_local_mcp_tool(
name: str, arguments: Dict[str, Any]
name: str, arguments: Dict[str, Any]
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]:
"""
Handle tool execution for local registry tools
Expand All @@ -404,7 +445,6 @@ async def _handle_local_mcp_tool(
except Exception as e:
return [MCPTextContent(text=f"Error: {str(e)}", type="text")]


async def handle_streamable_http_mcp(
scope: Scope, receive: Receive, send: Send
) -> None:
Expand Down Expand Up @@ -486,7 +526,7 @@ def get_mcp_server_enabled() -> Dict[str, bool]:
########################################################

def set_auth_context(
user_api_key_auth: UserAPIKeyAuth,
user_api_key_auth: UserAPIKeyAuth,
mcp_auth_header: Optional[str] = None,
mcp_servers: Optional[List[str]] = None,
) -> None:
Expand All @@ -505,7 +545,9 @@ def set_auth_context(
)
auth_context_var.set(auth_user)

def get_auth_context() -> Tuple[Optional[UserAPIKeyAuth], Optional[str], Optional[List[str]]]:
def get_auth_context() -> (
Tuple[Optional[UserAPIKeyAuth], Optional[str], Optional[List[str]]]
):
"""
Get the UserAPIKeyAuth from the auth context variable.

Expand All @@ -514,7 +556,11 @@ def get_auth_context() -> Tuple[Optional[UserAPIKeyAuth], Optional[str], Optiona
"""
auth_user = auth_context_var.get()
if auth_user and isinstance(auth_user, MCPAuthenticatedUser):
return auth_user.user_api_key_auth, auth_user.mcp_auth_header, auth_user.mcp_servers
return (
auth_user.user_api_key_auth,
auth_user.mcp_auth_header,
auth_user.mcp_servers,
)
return None, None, None

########################################################
Expand Down
Loading
Loading