diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20250707230009_add_mcp_namespaced_tool_name/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20250707230009_add_mcp_namespaced_tool_name/migration.sql new file mode 100644 index 0000000000..3130619a77 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20250707230009_add_mcp_namespaced_tool_name/migration.sql @@ -0,0 +1,42 @@ +-- DropIndex +DROP INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key"; + +-- DropIndex +DROP INDEX "LiteLLM_DailyTeamSpend_team_id_date_api_key_model_custom_ll_key"; + +-- DropIndex +DROP INDEX "LiteLLM_DailyUserSpend_user_id_date_api_key_model_custom_ll_key"; + +-- AlterTable +ALTER TABLE "LiteLLM_DailyTagSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT, +ALTER COLUMN "model" DROP NOT NULL; + +-- AlterTable +ALTER TABLE "LiteLLM_DailyTeamSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT, +ALTER COLUMN "model" DROP NOT NULL; + +-- AlterTable +ALTER TABLE "LiteLLM_DailyUserSpend" ADD COLUMN "mcp_namespaced_tool_name" TEXT, +ALTER COLUMN "model" DROP NOT NULL; + +-- AlterTable +ALTER TABLE "LiteLLM_SpendLogs" ADD COLUMN "mcp_namespaced_tool_name" TEXT; + +-- CreateIndex +CREATE INDEX "LiteLLM_DailyTagSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyTagSpend"("mcp_namespaced_tool_name"); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_DailyTagSpend_tag_date_api_key_model_custom_llm_pro_key" ON "LiteLLM_DailyTagSpend"("tag", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name"); + +-- CreateIndex +CREATE INDEX "LiteLLM_DailyTeamSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyTeamSpend"("mcp_namespaced_tool_name"); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_DailyTeamSpend_team_id_date_api_key_model_custom_ll_key" ON "LiteLLM_DailyTeamSpend"("team_id", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name"); + +-- CreateIndex +CREATE INDEX "LiteLLM_DailyUserSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyUserSpend"("mcp_namespaced_tool_name"); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_DailyUserSpend_user_id_date_api_key_model_custom_ll_key" ON "LiteLLM_DailyUserSpend"("user_id", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name"); + diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 32c50e7ffe..74fd7157fd 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -262,6 +262,7 @@ model LiteLLM_SpendLogs { response Json? @default("{}") session_id String? status String? + mcp_namespaced_tool_name String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) @@ -360,9 +361,10 @@ model LiteLLM_DailyUserSpend { user_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -374,11 +376,12 @@ model LiteLLM_DailyUserSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([user_id, date, api_key, model, custom_llm_provider]) + @@unique([user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([user_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -387,9 +390,10 @@ model LiteLLM_DailyTeamSpend { team_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -401,11 +405,12 @@ model LiteLLM_DailyTeamSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([team_id, date, api_key, model, custom_llm_provider]) + @@unique([team_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([team_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -414,9 +419,10 @@ model LiteLLM_DailyTagSpend { tag String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -428,11 +434,12 @@ model LiteLLM_DailyTagSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([tag, date, api_key, model, custom_llm_provider]) + @@unique([tag, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([tag]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 3ee9817524..6ec68aba1f 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -11,7 +11,6 @@ from starlette.types import Receive, Scope, Send from litellm._logging import verbose_logger -from litellm.constants import MCP_TOOL_NAME_PREFIX from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, @@ -130,7 +129,9 @@ async def initialize_session_managers(): await _sse_session_manager_cm.__aenter__() _SESSION_MANAGERS_INITIALIZED = True - verbose_logger.info("MCP Server started with StreamableHTTP and SSE session managers!") + verbose_logger.info( + "MCP Server started with StreamableHTTP and SSE session managers!" + ) async def shutdown_session_managers(): """Shutdown the session managers.""" @@ -198,17 +199,48 @@ async def mcp_server_tool_call( Raises: HTTPException: If tool not found or arguments missing """ + from fastapi import Request + + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + from litellm.proxy.proxy_server import proxy_config + # Validate arguments user_api_key_auth, mcp_auth_header, _ = get_auth_context() + verbose_logger.debug( f"MCP mcp_server_tool_call - User API Key Auth from context: {user_api_key_auth}" ) - response = await call_mcp_tool( - name=name, - arguments=arguments, - user_api_key_auth=user_api_key_auth, - mcp_auth_header=mcp_auth_header, - ) + try: + # Create a body date for logging + body_data = {"name": name, "arguments": arguments} + + request = Request( + scope={ + "type": "http", + "method": "POST", + "path": "/mcp/tools/call", + "headers": [(b"content-type", b"application/json")], + } + ) + if user_api_key_auth is not None: + data = await add_litellm_data_to_request( + data=body_data, + request=request, + user_api_key_dict=user_api_key_auth, + proxy_config=proxy_config, + ) + else: + data = body_data + + response = await call_mcp_tool( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + **data, # for logging + ) + except Exception as e: + verbose_logger.exception(f"MCP mcp_server_tool_call - error: {e}") + raise e + return response ######################################################## @@ -222,7 +254,7 @@ async def mcp_server_tool_call( async def _get_tools_from_mcp_servers( user_api_key_auth: Optional[UserAPIKeyAuth], mcp_auth_header: Optional[str], - mcp_servers: Optional[List[str]] + mcp_servers: Optional[List[str]], ) -> List[MCPTool]: """ Helper method to fetch tools from MCP servers based on server filtering criteria. @@ -238,12 +270,19 @@ async def _get_tools_from_mcp_servers( if mcp_servers: # If mcp_servers header is present, only get tools from specified servers tools = [] - for server_id in await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth): + for server_id in await global_mcp_server_manager.get_allowed_mcp_servers( + user_api_key_auth + ): server = global_mcp_server_manager.get_mcp_server_by_id(server_id) - if server and any(normalize_server_name(server.name) == normalize_server_name(s) for s in mcp_servers): - server_tools = await global_mcp_server_manager._get_tools_from_server( - server=server, - mcp_auth_header=mcp_auth_header, + if server and any( + normalize_server_name(server.name) == normalize_server_name(s) + for s in mcp_servers + ): + server_tools = ( + await global_mcp_server_manager._get_tools_from_server( + server=server, + mcp_auth_header=mcp_auth_header, + ) ) tools.extend(server_tools) return tools @@ -284,7 +323,7 @@ async def _list_mcp_tools( tools_from_mcp_servers = await _get_tools_from_mcp_servers( user_api_key_auth=user_api_key_auth, mcp_auth_header=mcp_auth_header, - mcp_servers=mcp_servers + mcp_servers=mcp_servers, ) verbose_logger.debug("TOOLS FROM MCP SERVERS: %s", tools_from_mcp_servers) @@ -294,28 +333,31 @@ async def _list_mcp_tools( @client async def call_mcp_tool( - name: str, - arguments: Optional[Dict[str, Any]] = None, - user_api_key_auth: Optional[UserAPIKeyAuth] = None, - mcp_auth_header: Optional[str] = None, - **kwargs: Any + name: str, + arguments: Optional[Dict[str, Any]] = None, + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + mcp_auth_header: Optional[str] = None, + **kwargs: Any, ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: """ Call a specific tool with the provided arguments (handles prefixed tool names) """ + if arguments is None: raise HTTPException( status_code=400, detail="Request arguments are required" ) # Remove prefix from tool name for logging and processing - original_tool_name, _ = get_server_name_prefix_tool_mcp( - name) + original_tool_name, server_name_from_prefix = get_server_name_prefix_tool_mcp( + name + ) standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = ( _get_standard_logging_mcp_tool_call( name=original_tool_name, # Use original name for logging arguments=arguments, + server_name=server_name_from_prefix, ) ) litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( @@ -325,20 +367,16 @@ async def call_mcp_tool( litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( standard_logging_mcp_tool_call ) - model_name = f"MCP: {MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}" - litellm_logging_obj.model = model_name - litellm_logging_obj.model_call_details["model"] = model_name - litellm_logging_obj.model_call_details["custom_llm_provider"] = ( - standard_logging_mcp_tool_call.get("mcp_server_name") - ) - ######################################################### - # Managed MCP Server Tool # Try managed server tool first (pass the full prefixed name) # Primary and recommended way to use MCP servers ######################################################### - mcp_server: Optional[MCPServer] = global_mcp_server_manager._get_mcp_server_from_tool_name(name) + mcp_server: Optional[MCPServer] = ( + global_mcp_server_manager._get_mcp_server_from_tool_name(name) + ) if mcp_server: - standard_logging_mcp_tool_call["mcp_server_cost_info"] = (mcp_server.mcp_info or {}).get("mcp_server_cost_info") + standard_logging_mcp_tool_call["mcp_server_cost_info"] = ( + mcp_server.mcp_info or {} + ).get("mcp_server_cost_info") return await _handle_managed_mcp_tool( name=name, # Pass the full name (potentially prefixed) arguments=arguments, @@ -355,6 +393,7 @@ async def call_mcp_tool( def _get_standard_logging_mcp_tool_call( name: str, arguments: Dict[str, Any], + server_name: Optional[str], ) -> StandardLoggingMCPToolCall: mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name) if mcp_server: @@ -364,15 +403,17 @@ def _get_standard_logging_mcp_tool_call( arguments=arguments, mcp_server_name=mcp_info.get("server_name"), mcp_server_logo_url=mcp_info.get("logo_url"), + namespaced_tool_name=f"{server_name}/{name}" if server_name else name, ) else: return StandardLoggingMCPToolCall( name=name, arguments=arguments, + namespaced_tool_name=f"{server_name}/{name}" if server_name else name, ) async def _handle_managed_mcp_tool( - name: str, + name: str, arguments: Dict[str, Any], user_api_key_auth: Optional[UserAPIKeyAuth] = None, mcp_auth_header: Optional[str] = None, @@ -388,7 +429,7 @@ async def _handle_managed_mcp_tool( return call_tool_result.content # type: ignore[return-value] async def _handle_local_mcp_tool( - name: str, arguments: Dict[str, Any] + name: str, arguments: Dict[str, Any] ) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: """ Handle tool execution for local registry tools @@ -404,7 +445,6 @@ async def _handle_local_mcp_tool( except Exception as e: return [MCPTextContent(text=f"Error: {str(e)}", type="text")] - async def handle_streamable_http_mcp( scope: Scope, receive: Receive, send: Send ) -> None: @@ -486,7 +526,7 @@ def get_mcp_server_enabled() -> Dict[str, bool]: ######################################################## def set_auth_context( - user_api_key_auth: UserAPIKeyAuth, + user_api_key_auth: UserAPIKeyAuth, mcp_auth_header: Optional[str] = None, mcp_servers: Optional[List[str]] = None, ) -> None: @@ -505,7 +545,9 @@ def set_auth_context( ) auth_context_var.set(auth_user) - def get_auth_context() -> Tuple[Optional[UserAPIKeyAuth], Optional[str], Optional[List[str]]]: + def get_auth_context() -> ( + Tuple[Optional[UserAPIKeyAuth], Optional[str], Optional[List[str]]] + ): """ Get the UserAPIKeyAuth from the auth context variable. @@ -514,7 +556,11 @@ def get_auth_context() -> Tuple[Optional[UserAPIKeyAuth], Optional[str], Optiona """ auth_user = auth_context_var.get() if auth_user and isinstance(auth_user, MCPAuthenticatedUser): - return auth_user.user_api_key_auth, auth_user.mcp_auth_header, auth_user.mcp_servers + return ( + auth_user.user_api_key_auth, + auth_user.mcp_auth_header, + auth_user.mcp_servers, + ) return None, None, None ######################################################## diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 0cee7907c8..1d25faf359 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2176,6 +2176,7 @@ class SpendLogsPayload(TypedDict): model: str model_id: Optional[str] model_group: Optional[str] + mcp_namespaced_tool_name: Optional[str] api_base: str user: str metadata: str # json str @@ -3056,8 +3057,9 @@ class DefaultInternalUserParams(LiteLLMPydanticObjectBase): class BaseDailySpendTransaction(TypedDict): date: str api_key: str - model: str + model: Optional[str] model_group: Optional[str] + mcp_namespaced_tool_name: Optional[str] custom_llm_provider: Optional[str] # token count metrics diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 9a7b14272d..a5398fa7f2 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -775,6 +775,8 @@ async def _commit_spend_updates_to_db( # noqa: PLR0915 e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj ) + # fmt: off + @overload @staticmethod async def _update_daily_spend( @@ -786,7 +788,7 @@ async def _update_daily_spend( entity_id_field: str, table_name: str, unique_constraint_name: str, - ) -> None: + ) -> None: ... @overload @@ -814,8 +816,9 @@ async def _update_daily_spend( entity_id_field: str, table_name: str, unique_constraint_name: str, - ) -> None: + ) -> None: ... + # fmt: on @staticmethod async def _update_daily_spend( @@ -870,6 +873,9 @@ async def _update_daily_spend( "custom_llm_provider": transaction.get( "custom_llm_provider" ), + "mcp_namespaced_tool_name": transaction.get( + "mcp_namespaced_tool_name" + ), } } @@ -881,8 +887,11 @@ async def _update_daily_spend( entity_id_field: entity_id, "date": transaction["date"], "api_key": transaction["api_key"], - "model": transaction["model"], + "model": transaction.get("model"), "model_group": transaction.get("model_group"), + "mcp_namespaced_tool_name": transaction.get( + "mcp_namespaced_tool_name" + ), "custom_llm_provider": transaction.get( "custom_llm_provider" ), @@ -898,13 +907,13 @@ async def _update_daily_spend( # Add cache-related fields if they exist if "cache_read_input_tokens" in transaction: - common_data[ - "cache_read_input_tokens" - ] = transaction.get("cache_read_input_tokens", 0) + common_data["cache_read_input_tokens"] = ( + transaction.get("cache_read_input_tokens", 0) + ) if "cache_creation_input_tokens" in transaction: - common_data[ - "cache_creation_input_tokens" - ] = transaction.get("cache_creation_input_tokens", 0) + common_data["cache_creation_input_tokens"] = ( + transaction.get("cache_creation_input_tokens", 0) + ) # Create update data structure update_data = { @@ -993,7 +1002,7 @@ async def update_daily_user_spend( entity_type="user", entity_id_field="user_id", table_name="litellm_dailyuserspend", - unique_constraint_name="user_id_date_api_key_model_custom_llm_provider", + unique_constraint_name="user_id_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name", ) @staticmethod @@ -1014,7 +1023,7 @@ async def update_daily_team_spend( entity_type="team", entity_id_field="team_id", table_name="litellm_dailyteamspend", - unique_constraint_name="team_id_date_api_key_model_custom_llm_provider", + unique_constraint_name="team_id_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name", ) @staticmethod @@ -1035,7 +1044,7 @@ async def update_daily_tag_spend( entity_type="tag", entity_id_field="tag", table_name="litellm_dailytagspend", - unique_constraint_name="tag_date_api_key_model_custom_llm_provider", + unique_constraint_name="tag_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name", ) async def _common_add_spend_log_transaction_to_daily_transaction( @@ -1044,7 +1053,7 @@ async def _common_add_spend_log_transaction_to_daily_transaction( prisma_client: PrismaClient, type: Literal["user", "team", "request_tags"] = "user", ) -> Optional[BaseDailySpendTransaction]: - common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"] + common_expected_keys = ["startTime", "api_key"] if type == "user": expected_keys = ["user", *common_expected_keys] elif type == "team": @@ -1053,13 +1062,28 @@ async def _common_add_spend_log_transaction_to_daily_transaction( expected_keys = ["request_tags", *common_expected_keys] else: raise ValueError(f"Invalid type: {type}") - if not all(key in payload for key in expected_keys): verbose_proxy_logger.debug( f"Missing expected keys: {expected_keys}, in payload, skipping from daily_user_spend_transactions" ) return None + any_expected_keys = ["model", "mcp_namespaced_tool_name"] + if not any(key in payload for key in any_expected_keys): + verbose_proxy_logger.debug( + f"Missing any expected keys: {any_expected_keys}, in payload, skipping from daily_user_spend_transactions" + ) + return None + elif "mcp_namespaced_tool_name" in payload: + pass + elif "model" in payload and ( + "custom_llm_provider" not in payload or "model_group" not in payload + ): + verbose_proxy_logger.debug( + "Missing custom_llm_provider or model_group in payload, skipping from daily_user_spend_transactions" + ) + return None + request_status = prisma_client.get_request_status(payload) verbose_proxy_logger.info(f"Logged request status: {request_status}") _metadata: SpendLogsMetadata = json.loads(payload["metadata"]) @@ -1078,9 +1102,10 @@ async def _common_add_spend_log_transaction_to_daily_transaction( daily_transaction = BaseDailySpendTransaction( date=date, api_key=payload["api_key"], - model=payload["model"], - model_group=payload["model_group"], - custom_llm_provider=payload["custom_llm_provider"], + model=payload.get("model", None), + model_group=payload.get("model_group", None), + mcp_namespaced_tool_name=payload.get("mcp_namespaced_tool_name", None), + custom_llm_provider=payload.get("custom_llm_provider", None), prompt_tokens=payload["prompt_tokens"], completion_tokens=payload["completion_tokens"], spend=payload["spend"], diff --git a/litellm/proxy/management_endpoints/common_daily_activity.py b/litellm/proxy/management_endpoints/common_daily_activity.py index 5efd31b232..754169a819 100644 --- a/litellm/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/proxy/management_endpoints/common_daily_activity.py @@ -44,16 +44,81 @@ def update_breakdown_metrics( """Updates breakdown metrics for a single record using the existing update_metrics function""" # Update model breakdown - if record.model not in breakdown.models: + if record.model and record.model not in breakdown.models: breakdown.models[record.model] = MetricWithMetadata( metrics=SpendMetrics(), metadata=model_metadata.get( record.model, {} ), # Add any model-specific metadata here ) - breakdown.models[record.model].metrics = update_metrics( - breakdown.models[record.model].metrics, record - ) + if record.model: + breakdown.models[record.model].metrics = update_metrics( + breakdown.models[record.model].metrics, record + ) + + # Update API key breakdown for this model + if record.api_key not in breakdown.models[record.model].api_key_breakdown: + breakdown.models[record.model].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), + ), + ) + ) + breakdown.models[record.model].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.models[record.model] + .api_key_breakdown[record.api_key] + .metrics, + record, + ) + ) + + if record.mcp_namespaced_tool_name: + if record.mcp_namespaced_tool_name not in breakdown.mcp_servers: + breakdown.mcp_servers[record.mcp_namespaced_tool_name] = MetricWithMetadata( + metrics=SpendMetrics(), + metadata={}, + ) + breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics = update_metrics( + breakdown.mcp_servers[record.mcp_namespaced_tool_name].metrics, record + ) + + # Update API key breakdown for this MCP server + if ( + record.api_key + not in breakdown.mcp_servers[ + record.mcp_namespaced_tool_name + ].api_key_breakdown + ): + breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[ + record.api_key + ] = KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), + ), + ) + + breakdown.mcp_servers[record.mcp_namespaced_tool_name].api_key_breakdown[ + record.api_key + ].metrics = update_metrics( + breakdown.mcp_servers[record.mcp_namespaced_tool_name] + .api_key_breakdown[record.api_key] + .metrics, + record, + ) # Update provider breakdown provider = record.custom_llm_provider or "unknown" @@ -68,6 +133,28 @@ def update_breakdown_metrics( breakdown.providers[provider].metrics, record ) + # Update API key breakdown for this provider + if record.api_key not in breakdown.providers[provider].api_key_breakdown: + breakdown.providers[provider].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), + ), + ) + ) + breakdown.providers[provider].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.providers[provider].api_key_breakdown[record.api_key].metrics, + record, + ) + ) + # Update api key breakdown if record.api_key not in breakdown.api_keys: breakdown.api_keys[record.api_key] = KeyMetricWithMetadata( @@ -92,14 +179,40 @@ def update_breakdown_metrics( if entity_value not in breakdown.entities: breakdown.entities[entity_value] = MetricWithMetadata( metrics=SpendMetrics(), - metadata=entity_metadata_field.get(entity_value, {}) - if entity_metadata_field - else {}, + metadata=( + entity_metadata_field.get(entity_value, {}) + if entity_metadata_field + else {} + ), ) breakdown.entities[entity_value].metrics = update_metrics( breakdown.entities[entity_value].metrics, record ) + # Update API key breakdown for this entity + if record.api_key not in breakdown.entities[entity_value].api_key_breakdown: + breakdown.entities[entity_value].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), + ), + ) + ) + breakdown.entities[entity_value].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.entities[entity_value] + .api_key_breakdown[record.api_key] + .metrics, + record, + ) + ) + return breakdown @@ -131,6 +244,7 @@ async def get_daily_activity( exclude_entity_ids: Optional[List[str]] = None, ) -> SpendAnalyticsPaginatedResponse: """Common function to get daily activity for any entity type.""" + if prisma_client is None: raise HTTPException( status_code=500, @@ -181,6 +295,22 @@ async def get_daily_activity( take=page_size, ) + # # for 50% of the records, set the mcp_server_id to a random value + # mcp_server_dict = {"Zapier_Gmail_MCP", "Stripe_MCP"} + # import random + + # for idx, record in enumerate(daily_spend_data): + # record = LiteLLM_DailyUserSpend(**record.model_dump()) + # if random.random() < 0.5: + # record.mcp_server_id = random.choice(list(mcp_server_dict)) + # record.model = None + # record.model_group = None + # record.prompt_tokens = 0 + # record.completion_tokens = 0 + # record.cache_read_input_tokens = 0 + # record.cache_creation_input_tokens = 0 + # daily_spend_data[idx] = record + # Get all unique API keys from the spend data api_keys = set() for record in daily_spend_data: diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 15cb3d2ad1..1549bd9e20 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -1451,16 +1451,17 @@ def update_breakdown_metrics( """Updates breakdown metrics for a single record using the existing update_metrics function""" # Update model breakdown - if record.model not in breakdown.models: - breakdown.models[record.model] = MetricWithMetadata( - metrics=SpendMetrics(), - metadata=model_metadata.get( - record.model, {} - ), # Add any model-specific metadata here + if record.model: + if record.model not in breakdown.models: + breakdown.models[record.model] = MetricWithMetadata( + metrics=SpendMetrics(), + metadata=model_metadata.get( + record.model, {} + ), # Add any model-specific metadata here + ) + breakdown.models[record.model].metrics = update_metrics( + breakdown.models[record.model].metrics, record ) - breakdown.models[record.model].metrics = update_metrics( - breakdown.models[record.model].metrics, record - ) # Update provider breakdown provider = record.custom_llm_provider or "unknown" diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 32c50e7ffe..74fd7157fd 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -262,6 +262,7 @@ model LiteLLM_SpendLogs { response Json? @default("{}") session_id String? status String? + mcp_namespaced_tool_name String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) @@ -360,9 +361,10 @@ model LiteLLM_DailyUserSpend { user_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -374,11 +376,12 @@ model LiteLLM_DailyUserSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([user_id, date, api_key, model, custom_llm_provider]) + @@unique([user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([user_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -387,9 +390,10 @@ model LiteLLM_DailyTeamSpend { team_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -401,11 +405,12 @@ model LiteLLM_DailyTeamSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([team_id, date, api_key, model, custom_llm_provider]) + @@unique([team_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([team_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -414,9 +419,10 @@ model LiteLLM_DailyTagSpend { tag String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -428,11 +434,12 @@ model LiteLLM_DailyTagSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([tag, date, api_key, model, custom_llm_provider]) + @@unique([tag, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([tag]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 98ec274380..9f2c7772e8 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -286,6 +286,13 @@ def get_logging_payload( # noqa: PLR0915 id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id + mcp_namespaced_tool_name = None + mcp_tool_call_metadata = clean_metadata.get("mcp_tool_call_metadata", {}) + if mcp_tool_call_metadata is not None: + mcp_namespaced_tool_name = mcp_tool_call_metadata.get( + "namespaced_tool_name", None + ) + try: payload: SpendLogsPayload = SpendLogsPayload( request_id=str(id), @@ -311,6 +318,7 @@ def get_logging_payload( # noqa: PLR0915 api_base=litellm_params.get("api_base", ""), model_group=_model_group, model_id=_model_id, + mcp_namespaced_tool_name=mcp_namespaced_tool_name, requester_ip_address=clean_metadata.get("requester_ip_address", None), custom_llm_provider=kwargs.get("custom_llm_provider", ""), messages=_get_messages_for_spend_logs_payload( diff --git a/litellm/types/proxy/management_endpoints/common_daily_activity.py b/litellm/types/proxy/management_endpoints/common_daily_activity.py index 6213087f64..fc87f63fdf 100644 --- a/litellm/types/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/types/proxy/management_endpoints/common_daily_activity.py @@ -31,10 +31,6 @@ class MetricBase(BaseModel): metrics: SpendMetrics -class MetricWithMetadata(MetricBase): - metadata: Dict[str, Any] = Field(default_factory=dict) - - class KeyMetadata(BaseModel): """Metadata for a key""" @@ -48,9 +44,20 @@ class KeyMetricWithMetadata(MetricBase): metadata: KeyMetadata = Field(default_factory=KeyMetadata) +class MetricWithMetadata(MetricBase): + metadata: Dict[str, Any] = Field(default_factory=dict) + # API key breakdown for this metric (e.g., which API keys are using this MCP server) + api_key_breakdown: Dict[str, KeyMetricWithMetadata] = Field( + default_factory=dict + ) # api_key -> {metrics, metadata} + + class BreakdownMetrics(BaseModel): """Breakdown of spend by different dimensions""" + mcp_servers: Dict[str, MetricWithMetadata] = Field( + default_factory=dict + ) # mcp_server -> {metrics, metadata} models: Dict[str, MetricWithMetadata] = Field( default_factory=dict ) # model -> {metrics, metadata} @@ -96,7 +103,8 @@ class LiteLLM_DailyUserSpend(BaseModel): user_id: str date: str api_key: str - model: str + mcp_server_id: Optional[str] = None + model: Optional[str] = None model_group: Optional[str] = None custom_llm_provider: Optional[str] = None prompt_tokens: int = 0 diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 515eca1e06..70b6c97a05 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1821,7 +1821,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): user_api_key_request_route: Optional[str] -class StandardLoggingMCPToolCall(TypedDict, total=False): +class StandardLoggingMCPToolCall(TypedDict, total=False): name: str """ Name of the tool to call @@ -1847,6 +1847,13 @@ class StandardLoggingMCPToolCall(TypedDict, total=False): (this is to render the logo on the logs page on litellm ui) """ + namespaced_tool_name: Optional[str] + """ + Namespaced tool name of the MCP tool that the tool call was made to + + Includes the server name prefix if it exists - eg. `deepwiki-mcp/get_page_content` + """ + mcp_server_cost_info: Optional[MCPServerCostInfo] """ Cost per query for the MCP server tool call diff --git a/schema.prisma b/schema.prisma index 32c50e7ffe..74fd7157fd 100644 --- a/schema.prisma +++ b/schema.prisma @@ -262,6 +262,7 @@ model LiteLLM_SpendLogs { response Json? @default("{}") session_id String? status String? + mcp_namespaced_tool_name String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) @@ -360,9 +361,10 @@ model LiteLLM_DailyUserSpend { user_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -374,11 +376,12 @@ model LiteLLM_DailyUserSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([user_id, date, api_key, model, custom_llm_provider]) + @@unique([user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([user_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -387,9 +390,10 @@ model LiteLLM_DailyTeamSpend { team_id String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -401,11 +405,12 @@ model LiteLLM_DailyTeamSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([team_id, date, api_key, model, custom_llm_provider]) + @@unique([team_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([team_id]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } // Track daily team spend metrics per model and key @@ -414,9 +419,10 @@ model LiteLLM_DailyTagSpend { tag String? date String api_key String - model String + model String? model_group String? - custom_llm_provider String? + custom_llm_provider String? + mcp_namespaced_tool_name String? prompt_tokens BigInt @default(0) completion_tokens BigInt @default(0) cache_read_input_tokens BigInt @default(0) @@ -428,11 +434,12 @@ model LiteLLM_DailyTagSpend { created_at DateTime @default(now()) updated_at DateTime @updatedAt - @@unique([tag, date, api_key, model, custom_llm_provider]) + @@unique([tag, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name]) @@index([date]) @@index([tag]) @@index([api_key]) @@index([model]) + @@index([mcp_namespaced_tool_name]) } diff --git a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json index 0ee27de808..b25080df0d 100644 --- a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json +++ b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json @@ -26,5 +26,6 @@ "messages": "{}", "response": "{}", "proxy_server_request": "{}", - "status": "success" + "status": "success", + "mcp_namespaced_tool_name": null } \ No newline at end of file diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py new file mode 100644 index 0000000000..c1114d1443 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -0,0 +1,131 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.proxy._types import UserAPIKeyAuth + + +@pytest.mark.asyncio +async def test_mcp_server_tool_call_body_contains_request_data(): + """Test that proxy_server_request body contains name and arguments""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + mcp_server_tool_call, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Setup test data + tool_name = "test_tool" + tool_arguments = {"param1": "value1", "param2": 123} + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # Mock the add_litellm_data_to_request function to capture the data + captured_data = {} + + async def mock_add_litellm_data_to_request( + data, request, user_api_key_dict, proxy_config + ): + captured_data.update(data) + # Simulate the proxy_server_request creation + captured_data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": {}, + "body": data.copy(), # This is what we want to test + } + return captured_data + + # Mock the call_mcp_tool function to avoid actual tool execution + async def mock_call_mcp_tool(*args, **kwargs): + return [{"type": "text", "text": "mocked response"}] + + with patch( + "litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", + mock_add_litellm_data_to_request, + ): + with patch( + "litellm.proxy._experimental.mcp_server.server.call_mcp_tool", + mock_call_mcp_tool, + ): + with patch( + "litellm.proxy.proxy_server.proxy_config", + MagicMock(), + ): + # Call the function + await mcp_server_tool_call(tool_name, tool_arguments) + + # Verify the body contains the expected data + assert "proxy_server_request" in captured_data + assert "body" in captured_data["proxy_server_request"] + + body = captured_data["proxy_server_request"]["body"] + assert body["name"] == tool_name + assert body["arguments"] == tool_arguments + + +@pytest.mark.asyncio +async def test_mcp_server_tool_call_body_with_none_arguments(): + """Test that proxy_server_request body handles None arguments correctly""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + mcp_server_tool_call, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Setup test data + tool_name = "test_tool_no_args" + tool_arguments = None + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # Mock the add_litellm_data_to_request function to capture the data + captured_data = {} + + async def mock_add_litellm_data_to_request( + data, request, user_api_key_dict, proxy_config + ): + captured_data.update(data) + captured_data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": {}, + "body": data.copy(), + } + return captured_data + + # Mock the call_mcp_tool function + async def mock_call_mcp_tool(*args, **kwargs): + return [{"type": "text", "text": "mocked response"}] + + with patch( + "litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", + mock_add_litellm_data_to_request, + ): + with patch( + "litellm.proxy._experimental.mcp_server.server.call_mcp_tool", + mock_call_mcp_tool, + ): + with patch( + "litellm.proxy.proxy_server.proxy_config", + MagicMock(), + ): + # Call the function + await mcp_server_tool_call(tool_name, tool_arguments) + + # Verify the body contains the expected data + assert "proxy_server_request" in captured_data + assert "body" in captured_data["proxy_server_request"] + + body = captured_data["proxy_server_request"]["body"] + assert body["name"] == tool_name + assert body["arguments"] == tool_arguments # Should be None diff --git a/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py b/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py index 2af9588337..93be9717e9 100644 --- a/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py +++ b/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py @@ -763,6 +763,7 @@ async def test_spend_logs_payload_e2e(self): "response": "{}", "proxy_server_request": "{}", "status": "success", + "mcp_namespaced_tool_name": None, } ) @@ -855,6 +856,7 @@ async def test_spend_logs_payload_success_log_with_api_base(self, monkeypatch): "response": "{}", "proxy_server_request": "{}", "status": "success", + "mcp_namespaced_tool_name": None, } ) @@ -945,9 +947,13 @@ async def test_spend_logs_payload_success_log_with_router(self): "response": "{}", "proxy_server_request": "{}", "status": "success", + "mcp_namespaced_tool_name": None, } ) + print(f"payload: {payload}") + print(f"expected_payload: {expected_payload}") + differences = _compare_nested_dicts( payload, expected_payload, ignore_keys=ignored_keys ) @@ -1085,8 +1091,8 @@ async def test_global_spend_keys_endpoint_limit_validation(client, monkeypatch): async def test_view_spend_logs_summarize_parameter(client, monkeypatch): """Test the new summarize parameter in the /spend/logs endpoint""" import datetime - from datetime import timezone, timedelta - + from datetime import timedelta, timezone + # Mock spend logs data mock_spend_logs = [ { @@ -1096,7 +1102,9 @@ async def test_view_spend_logs_summarize_parameter(client, monkeypatch): "user": "test_user_1", "team_id": "team1", "spend": 0.05, - "startTime": (datetime.datetime.now(timezone.utc) - timedelta(days=1)).isoformat(), + "startTime": ( + datetime.datetime.now(timezone.utc) - timedelta(days=1) + ).isoformat(), "model": "gpt-3.5-turbo", "prompt_tokens": 100, "completion_tokens": 50, @@ -1109,23 +1117,25 @@ async def test_view_spend_logs_summarize_parameter(client, monkeypatch): "user": "test_user_1", "team_id": "team1", "spend": 0.10, - "startTime": (datetime.datetime.now(timezone.utc) - timedelta(days=1)).isoformat(), + "startTime": ( + datetime.datetime.now(timezone.utc) - timedelta(days=1) + ).isoformat(), "model": "gpt-4", "prompt_tokens": 200, "completion_tokens": 100, "total_tokens": 300, }, ] - + # Mock for unsummarized data (summarize=false) class MockDB: def __init__(self): self.litellm_spendlogs = self - + async def find_many(self, *args, **kwargs): # Return individual log entries when summarize=false return mock_spend_logs - + async def group_by(self, *args, **kwargs): # Return grouped data when summarize=true # Simplified mock response for grouped data @@ -1134,17 +1144,17 @@ async def group_by(self, *args, **kwargs): { "api_key": "sk-test-key", "user": "test_user_1", - "model": "gpt-3.5-turbo", + "model": "gpt-3.5-turbo", "startTime": yesterday.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), - "_sum": {"spend": 0.05} + "_sum": {"spend": 0.05}, }, { "api_key": "sk-test-key", - "user": "test_user_1", + "user": "test_user_1", "model": "gpt-4", "startTime": yesterday.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), - "_sum": {"spend": 0.10} - } + "_sum": {"spend": 0.10}, + }, ] class MockPrismaClient: @@ -1156,7 +1166,9 @@ def __init__(self): monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) # Set up test dates - start_date = (datetime.datetime.now(timezone.utc) - timedelta(days=2)).strftime("%Y-%m-%d") + start_date = (datetime.datetime.now(timezone.utc) - timedelta(days=2)).strftime( + "%Y-%m-%d" + ) end_date = datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d") # Test 1: summarize=false should return individual log entries @@ -1169,10 +1181,10 @@ def __init__(self): }, headers={"Authorization": "Bearer sk-test"}, ) - + assert response.status_code == 200 data = response.json() - + # Should return the raw log entries assert isinstance(data, list) assert len(data) == 2 @@ -1180,7 +1192,7 @@ def __init__(self): assert data[1]["id"] == "log2" assert data[0]["request_id"] == "req1" assert data[1]["request_id"] == "req2" - + # Test 2: summarize=true should return grouped data response = client.get( "/spend/logs", @@ -1191,10 +1203,10 @@ def __init__(self): }, headers={"Authorization": "Bearer sk-test"}, ) - + assert response.status_code == 200 data = response.json() - + # Should return grouped/summarized data assert isinstance(data, list) # The structure should be different - grouped by date with aggregated spend @@ -1202,7 +1214,7 @@ def __init__(self): assert "spend" in data[0] assert "users" in data[0] assert "models" in data[0] - + # Test 3: default behavior (no summarize parameter) should maintain backward compatibility response = client.get( "/spend/logs", @@ -1212,10 +1224,10 @@ def __init__(self): }, headers={"Authorization": "Bearer sk-test"}, ) - + assert response.status_code == 200 data = response.json() - + # Should return grouped/summarized data (same as summarize=true) assert isinstance(data, list) assert "startTime" in data[0] diff --git a/ui/litellm-dashboard/src/components/activity_metrics.tsx b/ui/litellm-dashboard/src/components/activity_metrics.tsx index 1c4fbfa480..680c600419 100644 --- a/ui/litellm-dashboard/src/components/activity_metrics.tsx +++ b/ui/litellm-dashboard/src/components/activity_metrics.tsx @@ -1,20 +1,10 @@ -import React from "react"; -import { - Card, - Grid, - Text, - Title, -} from "@tremor/react"; -import { AreaChart, BarChart } from "@tremor/react"; +import React from 'react'; +import { Card, Grid, Text, Title, Accordion, AccordionHeader, AccordionBody } from '@tremor/react'; +import { AreaChart, BarChart } from '@tremor/react'; +import { SpendMetrics, DailyData, ModelActivityData, MetricWithMetadata, KeyMetricWithMetadata, TopApiKeyData } from './usage/types'; +import { Collapse } from 'antd'; +import { formatNumberWithCommas } from '@/utils/dataUtils'; import type { CustomTooltipProps } from "@tremor/react"; -import { - SpendMetrics, - DailyData, - ModelActivityData, - KeyMetricWithMetadata, -} from "./usage/types"; -import { Collapse } from "antd"; -import { formatNumberWithCommas } from "@/utils/dataUtils"; import { valueFormatter, valueFormatterSpend, @@ -191,6 +181,35 @@ const ModelSection = ({ + {/* Top API Keys Section */} + {metrics.top_api_keys && metrics.top_api_keys.length > 0 && ( + + Top API Keys by Spend +
+
+ {metrics.top_api_keys.map((keyData, index) => ( +
+
+ + {keyData.key_alias || `${keyData.api_key.substring(0, 10)}...`} + + {keyData.team_id && ( + Team: {keyData.team_id} + )} +
+
+ ${formatNumberWithCommas(keyData.spend, 2)} + + {keyData.requests.toLocaleString()} requests | {keyData.tokens.toLocaleString()} tokens + +
+
+ ))} +
+
+
+ )} + {/* Charts */} @@ -542,10 +561,7 @@ const formatKeyLabel = ( }; // Process data function -export const processActivityData = ( - dailyActivity: { results: DailyData[] }, - key: "models" | "api_keys" -): Record => { +export const processActivityData = (dailyActivity: { results: DailyData[] }, key: "models" | "api_keys" | "mcp_servers"): Record => { const modelMetrics: Record = {}; dailyActivity.results.forEach((day) => { @@ -565,7 +581,8 @@ export const processActivityData = ( total_spend: 0, total_cache_read_input_tokens: 0, total_cache_creation_input_tokens: 0, - daily_data: [], + top_api_keys: [], + daily_data: [] }; } @@ -605,6 +622,41 @@ export const processActivityData = ( }); }); + // Process API key breakdowns for each metric (skip if key is 'api_keys' to avoid duplication) + if (key !== 'api_keys') { + Object.entries(modelMetrics).forEach(([model, _]) => { + const apiKeyBreakdown: Record = {}; + + // Aggregate API key data across all days + dailyActivity.results.forEach((day) => { + const modelData = day.breakdown[key]?.[model]; + if (modelData && 'api_key_breakdown' in modelData) { + Object.entries(modelData.api_key_breakdown || {}).forEach(([apiKey, keyData]) => { + if (!apiKeyBreakdown[apiKey]) { + apiKeyBreakdown[apiKey] = { + api_key: apiKey, + key_alias: keyData.metadata.key_alias, + team_id: keyData.metadata.team_id, + spend: 0, + requests: 0, + tokens: 0, + }; + } + + apiKeyBreakdown[apiKey].spend += keyData.metrics.spend; + apiKeyBreakdown[apiKey].requests += keyData.metrics.api_requests; + apiKeyBreakdown[apiKey].tokens += keyData.metrics.total_tokens; + }); + } + }); + + // Sort by spend and take top 5 + modelMetrics[model].top_api_keys = Object.values(apiKeyBreakdown) + .sort((a, b) => b.spend - a.spend) + .slice(0, 5); + }); + } + // Sort daily data Object.values(modelMetrics).forEach((metrics) => { metrics.daily_data.sort( diff --git a/ui/litellm-dashboard/src/components/entity_usage.tsx b/ui/litellm-dashboard/src/components/entity_usage.tsx index 4b43b3293a..9f5ec20282 100644 --- a/ui/litellm-dashboard/src/components/entity_usage.tsx +++ b/ui/litellm-dashboard/src/components/entity_usage.tsx @@ -22,14 +22,10 @@ import { Subtitle, } from "@tremor/react"; import UsageDatePicker from "./shared/usage_date_picker"; -import { Select } from "antd"; -import { ActivityMetrics, processActivityData } from "./activity_metrics"; -import { - DailyData, - KeyMetricWithMetadata, - EntityMetricWithMetadata, -} from "./usage/types"; -import { tagDailyActivityCall, teamDailyActivityCall } from "./networking"; +import { Select } from 'antd'; +import { ActivityMetrics, processActivityData } from './activity_metrics'; +import { DailyData, BreakdownMetrics, KeyMetricWithMetadata, EntityMetricWithMetadata } from './usage/types'; +import { tagDailyActivityCall, teamDailyActivityCall } from './networking'; import TopKeyView from "./top_key_view"; import { formatNumberWithCommas } from "@/utils/dataUtils"; import { valueFormatterSpend } from "./usage/utils/value_formatters"; @@ -49,13 +45,6 @@ interface EntityMetrics { metadata: Record; } -interface BreakdownMetrics { - models: Record; - providers: Record; - api_keys: Record; - entities: Record; -} - interface ExtendedDailyData extends DailyData { breakdown: BreakdownMetrics; } @@ -203,7 +192,8 @@ const EntityUsage: React.FC = ({ }, metadata: { key_alias: metrics.metadata.key_alias, - }, + team_id: metrics.metadata.team_id || null + } }; } keySpend[key].metrics.spend += metrics.metrics.spend; @@ -297,9 +287,9 @@ const EntityUsage: React.FC = ({ cache_creation_input_tokens: 0, }, metadata: { - alias: data.metadata.team_alias || entity, - id: entity, - }, + alias: (data.metadata as any).team_alias || entity, + id: entity + } }; } entitySpend[entity].metrics.spend += data.metrics.spend; diff --git a/ui/litellm-dashboard/src/components/new_usage.tsx b/ui/litellm-dashboard/src/components/new_usage.tsx index 180a99bef3..7a7d437ab1 100644 --- a/ui/litellm-dashboard/src/components/new_usage.tsx +++ b/ui/litellm-dashboard/src/components/new_usage.tsx @@ -126,6 +126,7 @@ const NewUsagePage: React.FC = ({ cache_creation_input_tokens: 0, }, metadata: {}, + api_key_breakdown: {} }; } modelSpend[model].metrics.spend += metrics.metrics.spend; @@ -162,24 +163,24 @@ const NewUsagePage: React.FC = ({ // Calculate provider spend from the breakdown data const getProviderSpend = () => { const providerSpend: { [key: string]: MetricWithMetadata } = {}; - userSpendData.results.forEach((day) => { - Object.entries(day.breakdown.providers || {}).forEach( - ([provider, metrics]) => { - if (!providerSpend[provider]) { - providerSpend[provider] = { - metrics: { - spend: 0, - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - api_requests: 0, - successful_requests: 0, - failed_requests: 0, - cache_read_input_tokens: 0, - cache_creation_input_tokens: 0, - }, - metadata: {}, - }; + userSpendData.results.forEach(day => { + Object.entries(day.breakdown.providers || {}).forEach(([provider, metrics]) => { + if (!providerSpend[provider]) { + providerSpend[provider] = { + metrics: { + spend: 0, + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + api_requests: 0, + successful_requests: 0, + failed_requests: 0, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0 + }, + metadata: {}, + api_key_breakdown: {} + }; } providerSpend[provider].metrics.spend += metrics.metrics.spend; providerSpend[provider].metrics.prompt_tokens += @@ -198,8 +199,7 @@ const NewUsagePage: React.FC = ({ metrics.metrics.cache_read_input_tokens || 0; providerSpend[provider].metrics.cache_creation_input_tokens += metrics.metrics.cache_creation_input_tokens || 0; - } - ); + }); }); return Object.entries(providerSpend).map(([provider, metrics]) => ({ @@ -232,7 +232,8 @@ const NewUsagePage: React.FC = ({ }, metadata: { key_alias: metrics.metadata.key_alias, - }, + team_id: null + } }; } keySpend[key].metrics.spend += metrics.metrics.spend; @@ -331,6 +332,7 @@ const NewUsagePage: React.FC = ({ const modelMetrics = processActivityData(userSpendData, "models"); const keyMetrics = processActivityData(userSpendData, "api_keys"); + const mcpServerMetrics = processActivityData(userSpendData, "mcp_servers"); return (
@@ -380,6 +382,7 @@ const NewUsagePage: React.FC = ({ Cost Model Activity Key Activity + MCP Server Activity {/* Cost Panel */} @@ -644,6 +647,9 @@ const NewUsagePage: React.FC = ({ + + + diff --git a/ui/litellm-dashboard/src/components/usage/types.ts b/ui/litellm-dashboard/src/components/usage/types.ts index 46d96b4aa1..3d76517d9c 100644 --- a/ui/litellm-dashboard/src/components/usage/types.ts +++ b/ui/litellm-dashboard/src/components/usage/types.ts @@ -18,21 +18,35 @@ export interface DailyData { export interface BreakdownMetrics { models: { [key: string]: MetricWithMetadata }; + mcp_servers: { [key: string]: MetricWithMetadata }; providers: { [key: string]: MetricWithMetadata }; api_keys: { [key: string]: KeyMetricWithMetadata }; + entities: { [key: string]: MetricWithMetadata }; } export interface MetricWithMetadata { metrics: SpendMetrics; metadata: object; + api_key_breakdown: { [key: string]: KeyMetricWithMetadata }; } export interface KeyMetricWithMetadata { metrics: SpendMetrics; - metadata: { - key_alias: string | null; - team_id?: string | null; - }; + metadata: KeyMetadata; +} + +export interface KeyMetadata { + key_alias: string | null; + team_id: string | null; +} + +export interface TopApiKeyData { + api_key: string; + key_alias: string | null; + team_id: string | null; + spend: number; + requests: number; + tokens: number; } export interface ModelActivityData { @@ -46,6 +60,7 @@ export interface ModelActivityData { prompt_tokens: number; completion_tokens: number; total_spend: number; + top_api_keys: TopApiKeyData[]; daily_data: { date: string; metrics: { @@ -62,11 +77,6 @@ export interface ModelActivityData { }[]; } -export interface KeyMetadata { - key_alias: string | null; - team_id: string | null; -} - export interface EntityMetadata { alias: string; id: string; @@ -76,14 +86,3 @@ export interface EntityMetricWithMetadata { metrics: SpendMetrics; metadata: EntityMetadata; } - -export interface MetricWithMetadata { - metrics: SpendMetrics; - metadata: object; -} - -export interface BreakdownMetrics { - models: { [key: string]: MetricWithMetadata }; - providers: { [key: string]: MetricWithMetadata }; - api_keys: { [key: string]: KeyMetricWithMetadata }; -}