diff --git a/ci_cd/security_scans.sh b/ci_cd/security_scans.sh index 3a212a56f64..340f8e96063 100755 --- a/ci_cd/security_scans.sh +++ b/ci_cd/security_scans.sh @@ -154,6 +154,7 @@ run_grype_scans() { "CVE-2025-15367" # No fix available yet "CVE-2025-12781" # No fix available yet "CVE-2025-11468" # No fix available yet + "CVE-2026-1299" # Python 3.13 email module header injection - not applicable, LiteLLM doesn't use BytesGenerator for email serialization ) # Build JSON array of allowlisted CVE IDs for jq diff --git a/docs/my-website/blog/sub_millisecond_proxy_overhead/index.md b/docs/my-website/blog/sub_millisecond_proxy_overhead/index.md new file mode 100644 index 00000000000..1857383363c --- /dev/null +++ b/docs/my-website/blog/sub_millisecond_proxy_overhead/index.md @@ -0,0 +1,92 @@ +--- +slug: sub-millisecond-proxy-overhead +title: "Achieving Sub-Millisecond Proxy Overhead" +date: 2026-02-02T10:00:00 +authors: + - name: Alexsander Hamir + title: "Performance Engineer, LiteLLM" + url: https://www.linkedin.com/in/alexsander-baptista/ + image_url: https://github.com/AlexsanderHamir.png + - name: Krrish Dholakia + title: "CEO, LiteLLM" + url: https://www.linkedin.com/in/krish-d/ + image_url: https://pbs.twimg.com/profile_images/1298587542745358340/DZv3Oj-h_400x400.jpg + - name: Ishaan Jaff + title: "CTO, LiteLLM" + url: https://www.linkedin.com/in/reffajnaahsi/ + image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg +description: "Our Q1 performance target and architectural direction for achieving sub-millisecond proxy overhead on modest hardware." +tags: [performance, architecture] +hide_table_of_contents: false +--- + +![Sidecar architecture: Python control plane vs. sidecar hot path](https://raw.githubusercontent.com/AlexsanderHamir/assets/main/Screenshot%202026-02-02%20172554.png) + +# Achieving Sub-Millisecond Proxy Overhead + +## Introduction + +Our Q1 performance target is to aggressively move toward sub-millisecond proxy overhead on a single instance with 4 CPUs and 8 GB of RAM, and to continue pushing that boundary over time. Our broader goal is to make LiteLLM inexpensive to deploy, lightweight, and fast. This post outlines the architectural direction behind that effort. + +Proxy overhead refers to the latency introduced by LiteLLM itself, independent of the upstream provider. + +To measure it, we run the same workload directly against the provider and through LiteLLM at identical QPS (for example, 1,000 QPS) and compare the latency delta. To reduce noise, the load generator, LiteLLM, and a mock LLM endpoint all run on the same machine, ensuring the difference reflects proxy overhead rather than network latency. + +--- + +## Where We're Coming From + +Under the same benchmark originally conducted by [TensorZero](https://www.tensorzero.com/docs/gateway/benchmarks), LiteLLM previously failed at around 1,000 QPS. + +That is no longer the case. Today, LiteLLM can be stress-tested at 1,000 QPS with no failures and can scale up to 5,000 QPS without failures on a 4-CPU, 8-GB RAM single instance setup. + +This establishes a more up to date baseline and provides useful context as we continue working on proxy overhead and overall performance. + +--- + +## Design Choice + +Achieving sub-millisecond proxy overhead with a Python-based system requires being deliberate about where work happens. + +Python is a strong fit for flexibility and extensibility: provider abstraction, configuration-driven routing, and a rich callback ecosystem. These are areas where development velocity and correctness matter more than raw throughput. + +At higher request rates, however, certain classes of work become expensive when executed inside the Python process on every request. Rather than rewriting LiteLLM or introducing complex deployment requirements, we adopt an optional **sidecar architecture**. + +This architectural change is how we intend to make LiteLLM **permanently fast**. While it supports our near-term performance targets, it is a long-term investment. + +Python continues to own: + +- Request validation and normalization +- Model and provider selection +- Callbacks and integrations + +The sidecar owns **performance-critical execution**, such as: + +- Efficient request forwarding +- Connection reuse and pooling +- Enforcing timeouts and limits +- Aggregating high-frequency metrics + +This separation allows each component to focus on what it does best: Python acts as the control plane, while the sidecar handles the hot path. + +--- + +### Why the Sidecar Is Optional + +The sidecar is intentionally **optional**. + +This allows us to ship it incrementally, validate it under real-world workloads, and avoid making it a hard dependency before it is fully battle-tested across all LiteLLM features. + +Just as importantly, this ensures that self-hosting LiteLLM remains simple. The sidecar is bundled and started automatically, requires no additional infrastructure, and can be disabled entirely. From a user's perspective, LiteLLM continues to behave like a single service. + +As of today, the sidecar is an optimization, not a requirement. + +--- + +## Conclusion + +Sub-millisecond proxy overhead is not achieved through a single optimization, but through architectural changes. + +By keeping Python focused on orchestration and extensibility, and offloading performance-critical execution to a sidecar, we establish a foundation for making LiteLLM **permanently fast over time**—even on modest hardware such as a 1-CPU, 2-GB RAM instance, while keeping deployment and self-hosting simple. + +This work extends beyond Q1, and we will continue sharing benchmarks and updates as the architecture evolves. diff --git a/docs/my-website/docs/mcp_semantic_filter.md b/docs/my-website/docs/mcp_semantic_filter.md new file mode 100644 index 00000000000..c58be80a680 --- /dev/null +++ b/docs/my-website/docs/mcp_semantic_filter.md @@ -0,0 +1,158 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# MCP Semantic Tool Filter + +Automatically filter MCP tools by semantic relevance. When you have many MCP tools registered, LiteLLM semantically matches the user's query against tool descriptions and sends only the most relevant tools to the LLM. + +## How It Works + +Tool search shifts tool selection from a prompt-engineering problem to a retrieval problem. Instead of injecting a large static list of tools into every prompt, the semantic filter: + +1. Builds a semantic index of all available MCP tools on startup +2. On each request, semantically matches the user's query against tool descriptions +3. Returns only the top-K most relevant tools to the LLM + +This approach improves context efficiency, increases reliability by reducing tool confusion, and enables scalability to ecosystems with hundreds or thousands of MCP tools. + +```mermaid +sequenceDiagram + participant Client + participant LiteLLM as LiteLLM Proxy + participant SemanticFilter as Semantic Filter + participant MCP as MCP Registry + participant LLM as LLM Provider + + Note over LiteLLM,MCP: Startup: Build Semantic Index + LiteLLM->>MCP: Fetch all registered MCP tools + MCP->>LiteLLM: Return all tools (e.g., 50 tools) + LiteLLM->>SemanticFilter: Build semantic router with embeddings + SemanticFilter->>LLM: Generate embeddings for tool descriptions + LLM->>SemanticFilter: Return embeddings + Note over SemanticFilter: Index ready for fast lookup + + Note over Client,LLM: Request: Semantic Tool Filtering + Client->>LiteLLM: POST /v1/responses with MCP tools + LiteLLM->>SemanticFilter: Expand MCP references (50 tools available) + SemanticFilter->>SemanticFilter: Extract user query from request + SemanticFilter->>LLM: Generate query embedding + LLM->>SemanticFilter: Return query embedding + SemanticFilter->>SemanticFilter: Match query against tool embeddings + SemanticFilter->>LiteLLM: Return top-K tools (e.g., 3 most relevant) + LiteLLM->>LLM: Forward request with filtered tools (3 tools) + LLM->>LiteLLM: Return response + LiteLLM->>Client: Response with headers
x-litellm-semantic-filter: 50->3
x-litellm-semantic-filter-tools: tool1,tool2,tool3 +``` + +## Configuration + +Enable semantic filtering in your LiteLLM config: + +```yaml title="config.yaml" showLineNumbers +litellm_settings: + mcp_semantic_tool_filter: + enabled: true + embedding_model: "text-embedding-3-small" # Model for semantic matching + top_k: 5 # Max tools to return + similarity_threshold: 0.3 # Min similarity score +``` + +**Configuration Options:** +- `enabled` - Enable/disable semantic filtering (default: `false`) +- `embedding_model` - Model for generating embeddings (default: `"text-embedding-3-small"`) +- `top_k` - Maximum number of tools to return (default: `10`) +- `similarity_threshold` - Minimum similarity score for matches (default: `0.3`) + +## Usage + +Use MCP tools normally with the Responses API or Chat Completions. The semantic filter runs automatically: + + + + +```bash title="Responses API with Semantic Filtering" showLineNumbers +curl --location 'http://localhost:4000/v1/responses' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer sk-1234" \ +--data '{ + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": "give me TLDR of what BerriAI/litellm repo is about", + "type": "message" + } + ], + "tools": [ + { + "type": "mcp", + "server_url": "litellm_proxy", + "require_approval": "never" + } + ], + "tool_choice": "required" +}' +``` + + + + +```bash title="Chat Completions with Semantic Filtering" showLineNumbers +curl --location 'http://localhost:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header "Authorization: Bearer sk-1234" \ +--data '{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Search Wikipedia for LiteLLM"} + ], + "tools": [ + { + "type": "mcp", + "server_url": "litellm_proxy" + } + ] +}' +``` + + + + +## Response Headers + +The semantic filter adds diagnostic headers to every response: + +``` +x-litellm-semantic-filter: 10->3 +x-litellm-semantic-filter-tools: wikipedia-fetch,github-search,slack-post +``` + +- **`x-litellm-semantic-filter`** - Shows before→after tool count (e.g., `10->3` means 10 tools were filtered down to 3) +- **`x-litellm-semantic-filter-tools`** - CSV list of the filtered tool names (max 150 chars, clipped with `...` if longer) + +These headers help you understand which tools were selected for each request and verify the filter is working correctly. + +## Example + +If you have 50 MCP tools registered and make a request asking about Wikipedia, the semantic filter will: + +1. Semantically match your query `"Search Wikipedia for LiteLLM"` against all 50 tool descriptions +2. Select the top 5 most relevant tools (e.g., `wikipedia-fetch`, `wikipedia-search`, etc.) +3. Pass only those 5 tools to the LLM +4. Add headers showing `x-litellm-semantic-filter: 50->5` + +This dramatically reduces prompt size while ensuring the LLM has access to the right tools for the task. + +## Performance + +The semantic filter is optimized for production: +- Router builds once on startup (no per-request overhead) +- Semantic matching typically takes under 50ms +- Fails gracefully - returns all tools if filtering fails +- No impact on latency for requests without MCP tools + +## Related + +- [MCP Overview](./mcp.md) - Learn about MCP in LiteLLM +- [MCP Permission Management](./mcp_control.md) - Control tool access by key/team +- [Using MCP](./mcp_usage.md) - Complete MCP usage guide diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 487212ad655..e546ed97656 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -9,7 +9,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor | Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). | | Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#deepseek-not-r1), [`bedrock/deepseek_r1/`](#deepseek-r1), [`bedrock/qwen3/`](#qwen3-imported-models), [`bedrock/qwen2/`](./bedrock_imported.md#qwen2-imported-models), [`bedrock/openai/`](./bedrock_imported.md#openai-compatible-imported-models-qwen-25-vl-etc), [`bedrock/moonshot`](./bedrock_imported.md#moonshot-kimi-k2-thinking) | | Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) | -| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` | +| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations`, `/v1/realtime`| | Rerank Endpoint | `/rerank` | | Pass-through Endpoint | [Supported](../pass_through/bedrock.md) | diff --git a/docs/my-website/docs/tutorials/bedrock_realtime_with_audio.md b/docs/my-website/docs/providers/bedrock_realtime_with_audio.md similarity index 98% rename from docs/my-website/docs/tutorials/bedrock_realtime_with_audio.md rename to docs/my-website/docs/providers/bedrock_realtime_with_audio.md index 07e29af5320..a2d9813ffd9 100644 --- a/docs/my-website/docs/tutorials/bedrock_realtime_with_audio.md +++ b/docs/my-website/docs/providers/bedrock_realtime_with_audio.md @@ -1,8 +1,4 @@ -# Call Bedrock Nova Sonic Realtime API with Audio Input/Output - -:::info -Requires LiteLLM Proxy v1.70.1+ -::: +# Bedrock Realtime API ## Overview diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 264c7d765b3..385b4b0de32 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -545,6 +545,9 @@ router_settings: | DEFAULT_MAX_TOKENS | Default maximum tokens for LLM calls. Default is 4096 | DEFAULT_MAX_TOKENS_FOR_TRITON | Default maximum tokens for Triton models. Default is 2000 | DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE | Default maximum size for redis batch cache. Default is 1000 +| DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL | Default embedding model for MCP semantic tool filtering. Default is "text-embedding-3-small" +| DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD | Default similarity threshold for MCP semantic tool filtering. Default is 0.3 +| DEFAULT_MCP_SEMANTIC_FILTER_TOP_K | Default number of top results to return for MCP semantic tool filtering. Default is 10 | DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT | Default token count for mock response completions. Default is 20 | DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT | Default token count for mock response prompts. Default is 10 | DEFAULT_MODEL_CREATED_AT_TIME | Default creation timestamp for models. Default is 1677610602 @@ -802,6 +805,7 @@ router_settings: | MAXIMUM_TRACEBACK_LINES_TO_LOG | Maximum number of lines to log in traceback in LiteLLM Logs UI. Default is 100 | MAX_RETRY_DELAY | Maximum delay in seconds for retrying requests. Default is 8.0 | MAX_LANGFUSE_INITIALIZED_CLIENTS | Maximum number of Langfuse clients to initialize on proxy. Default is 50. This is set since langfuse initializes 1 thread everytime a client is initialized. We've had an incident in the past where we reached 100% cpu utilization because Langfuse was initialized several times. +| MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH | Maximum header length for MCP semantic filter tools. Default is 150 | MIN_NON_ZERO_TEMPERATURE | Minimum non-zero temperature value. Default is 0.0001 | MINIMUM_PROMPT_CACHE_TOKEN_COUNT | Minimum token count for caching a prompt. Default is 1024 | MISTRAL_API_BASE | Base URL for Mistral API. Default is https://api.mistral.ai diff --git a/docs/my-website/docs/proxy/ui_logs.md b/docs/my-website/docs/proxy/ui_logs.md index b6d3d2ae7ca..8cfe818ebfd 100644 --- a/docs/my-website/docs/proxy/ui_logs.md +++ b/docs/my-website/docs/proxy/ui_logs.md @@ -37,6 +37,40 @@ general_settings: +## Tracing Tools + +View which tools were provided and called in your completion requests. + + + +**Example:** Make a completion request with tools: + +```bash +curl -X POST 'http://localhost:4000/chat/completions' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "What is the weather?"}], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + } + ] + }' +``` + +Check the Logs page to see all tools provided and which ones were called. ## Stop storing Error Logs in DB diff --git a/docs/my-website/img/ui_tools.png b/docs/my-website/img/ui_tools.png new file mode 100644 index 00000000000..6f4d0f87410 Binary files /dev/null and b/docs/my-website/img/ui_tools.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index e533665032e..d932b6af250 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -538,6 +538,7 @@ const sidebars = { items: [ "mcp", "mcp_usage", + "mcp_semantic_filter", "mcp_control", "mcp_cost", "mcp_guardrail", @@ -716,6 +717,7 @@ const sidebars = { "providers/bedrock_agents", "providers/bedrock_writer", "providers/bedrock_batches", + "providers/bedrock_realtime_with_audio", "providers/aws_polly", "providers/bedrock_vector_store", ] diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql deleted file mode 100644 index 572eea9b529..00000000000 --- a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260129103648_add_verificationtoken_indexes/migration.sql +++ /dev/null @@ -1,8 +0,0 @@ --- CreateIndex -CREATE INDEX "LiteLLM_VerificationToken_user_id_team_id_idx" ON "LiteLLM_VerificationToken"("user_id", "team_id"); - --- CreateIndex -CREATE INDEX "LiteLLM_VerificationToken_team_id_idx" ON "LiteLLM_VerificationToken"("team_id"); - --- CreateIndex -CREATE INDEX "LiteLLM_VerificationToken_budget_reset_at_expires_idx" ON "LiteLLM_VerificationToken"("budget_reset_at", "expires"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 3b81da10923..b118400b620 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -305,16 +305,6 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) - - // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 - @@index([user_id, team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 - @@index([team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 - @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/litellm/constants.py b/litellm/constants.py index 3c84547d7ce..6427c367924 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -67,6 +67,20 @@ os.getenv("DEFAULT_REASONING_EFFORT_DISABLE_THINKING_BUDGET", 0) ) +# MCP Semantic Tool Filter Defaults +DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL = str( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL", "text-embedding-3-small") +) +DEFAULT_MCP_SEMANTIC_FILTER_TOP_K = int( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_TOP_K", 10) +) +DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD = float( + os.getenv("DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD", 0.3) +) +MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH = int( + os.getenv("MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH", 150) +) + # Gemini model-specific minimal thinking budget constants DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET_GEMINI_2_5_FLASH = int( os.getenv("DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET_GEMINI_2_5_FLASH", 1) diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index a9ac21bb56f..b5a6949f272 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -478,6 +478,13 @@ def _map_function( # noqa: PLR0915 if "type" in tool and tool["type"] == "computer_use": computer_use_config = {k: v for k, v in tool.items() if k != "type"} tool = {VertexToolName.COMPUTER_USE.value: computer_use_config} + # Handle OpenAI-style web_search and web_search_preview tools + # Transform them to Gemini's googleSearch tool + elif "type" in tool and tool["type"] in ("web_search", "web_search_preview"): + verbose_logger.info( + f"Gemini: Transforming OpenAI-style '{tool['type']}' tool to googleSearch" + ) + tool = {VertexToolName.GOOGLE_SEARCH.value: {}} # Handle tools with 'type' field (OpenAI spec compliance) Ignore this field -> https://github.com/BerriAI/litellm/issues/14644#issuecomment-3342061838 elif "type" in tool: tool = {k: tool[k] for k in tool if k != "type"} diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 6aeb51d5817..485bee4f191 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -21488,6 +21488,20 @@ "supports_tool_choice": true, "supports_web_search": true }, + "moonshot/kimi-k2.5": { + "cache_read_input_token_cost": 1e-07, + "input_cost_per_token": 6e-07, + "litellm_provider": "moonshot", + "max_input_tokens": 262144, + "max_output_tokens": 262144, + "max_tokens": 262144, + "mode": "chat", + "output_cost_per_token": 3e-06, + "source": "https://platform.moonshot.ai/docs/pricing/chat", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_vision": true + }, "moonshot/kimi-latest": { "cache_read_input_token_cost": 1.5e-07, "input_cost_per_token": 2e-06, diff --git a/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py new file mode 100644 index 00000000000..e5cb6a0098d --- /dev/null +++ b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py @@ -0,0 +1,250 @@ +""" +Semantic MCP Tool Filtering using semantic-router + +Filters MCP tools semantically for /chat/completions and /responses endpoints. +""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from litellm._logging import verbose_logger + +if TYPE_CHECKING: + from semantic_router.routers import SemanticRouter + + from litellm.router import Router + + +class SemanticMCPToolFilter: + """Filters MCP tools using semantic similarity to reduce context window size.""" + + def __init__( + self, + embedding_model: str, + litellm_router_instance: "Router", + top_k: int = 10, + similarity_threshold: float = 0.3, + enabled: bool = True, + ): + """ + Initialize the semantic tool filter. + + Args: + embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small") + litellm_router_instance: Router instance for embedding generation + top_k: Maximum number of tools to return + similarity_threshold: Minimum similarity score for filtering + enabled: Whether filtering is enabled + """ + self.enabled = enabled + self.top_k = top_k + self.similarity_threshold = similarity_threshold + self.embedding_model = embedding_model + self.router_instance = litellm_router_instance + self.tool_router: Optional["SemanticRouter"] = None + self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts + + async def build_router_from_mcp_registry(self) -> None: + """Build semantic router from all MCP tools in the registry (no auth checks).""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + try: + # Get all servers from registry without auth checks + registry = global_mcp_server_manager.get_registry() + if not registry: + verbose_logger.warning("MCP registry is empty") + self.tool_router = None + return + + # Fetch tools from all servers in parallel + all_tools = [] + for server_id, server in registry.items(): + try: + tools = await global_mcp_server_manager.get_tools_for_server(server_id) + all_tools.extend(tools) + except Exception as e: + verbose_logger.warning(f"Failed to fetch tools from server {server_id}: {e}") + continue + + if not all_tools: + verbose_logger.warning("No MCP tools found in registry") + self.tool_router = None + return + + verbose_logger.info(f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers") + self._build_router(all_tools) + + except Exception as e: + verbose_logger.error(f"Failed to build router from MCP registry: {e}") + self.tool_router = None + raise + + def _extract_tool_info(self, tool) -> tuple[str, str]: + """Extract name and description from MCP tool or OpenAI function dict.""" + name: str + description: str + + if isinstance(tool, dict): + # OpenAI function format + name = tool.get("name", "") + description = tool.get("description", name) + else: + # MCPTool object + name = str(tool.name) + description = str(tool.description) if tool.description else str(tool.name) + + return name, description + + def _build_router(self, tools: List) -> None: + """Build semantic router with tools (MCPTool objects or OpenAI function dicts).""" + from semantic_router.routers import SemanticRouter + from semantic_router.routers.base import Route + + from litellm.router_strategy.auto_router.litellm_encoder import ( + LiteLLMRouterEncoder, + ) + + if not tools: + self.tool_router = None + return + + try: + # Convert tools to routes + routes = [] + self._tool_map = {} + + for tool in tools: + name, description = self._extract_tool_info(tool) + self._tool_map[name] = tool + + routes.append( + Route( + name=name, + description=description, + utterances=[description], + score_threshold=self.similarity_threshold, + ) + ) + + self.tool_router = SemanticRouter( + routes=routes, + encoder=LiteLLMRouterEncoder( + litellm_router_instance=self.router_instance, + model_name=self.embedding_model, + score_threshold=self.similarity_threshold, + ), + auto_sync="local", + ) + + verbose_logger.info( + f"Built semantic router with {len(routes)} tools" + ) + + except Exception as e: + verbose_logger.error(f"Failed to build semantic router: {e}") + self.tool_router = None + raise + + async def filter_tools( + self, + query: str, + available_tools: List[Any], + top_k: Optional[int] = None, + ) -> List[Any]: + """ + Filter tools semantically based on query. + + Args: + query: User query to match against tools + available_tools: Full list of available MCP tools + top_k: Override default top_k (optional) + + Returns: + Filtered and ordered list of tools (up to top_k) + """ + # Early returns for cases where we can't/shouldn't filter + if not self.enabled: + return available_tools + + if not available_tools: + return available_tools + + if not query or not query.strip(): + return available_tools + + # Router should be built on startup - if not, something went wrong + if self.tool_router is None: + verbose_logger.warning("Router not initialized - was build_router_from_mcp_registry() called on startup?") + return available_tools + + # Run semantic filtering + try: + limit = top_k or self.top_k + matches = self.tool_router(text=query, limit=limit) + matched_tool_names = self._extract_tool_names_from_matches(matches) + + if not matched_tool_names: + return available_tools + + return self._get_tools_by_names(matched_tool_names, available_tools) + + except Exception as e: + verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True) + return available_tools + + def _extract_tool_names_from_matches(self, matches) -> List[str]: + """Extract tool names from semantic router match results.""" + if not matches: + return [] + + # Handle single match + if hasattr(matches, "name") and matches.name: + return [matches.name] + + # Handle list of matches + if isinstance(matches, list): + return [m.name for m in matches if hasattr(m, "name") and m.name] + + return [] + + def _get_tools_by_names( + self, tool_names: List[str], available_tools: List[Any] + ) -> List[Any]: + """Get tools from available_tools by their names, preserving order.""" + # Match tools from available_tools (preserves format - dict or MCPTool) + matched_tools = [] + for tool in available_tools: + tool_name, _ = self._extract_tool_info(tool) + if tool_name in tool_names: + matched_tools.append(tool) + + # Reorder to match semantic router's ordering + tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools} + return [tool_map[name] for name in tool_names if name in tool_map] + + def extract_user_query(self, messages: List[Dict[str, Any]]) -> str: + """ + Extract user query from messages for /chat/completions or /responses. + + Args: + messages: List of message dictionaries (from 'messages' or 'input' field) + + Returns: + Extracted query string + """ + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content", "") + + if isinstance(content, str): + return content + + if isinstance(content, list): + texts = [ + block.get("text", "") if isinstance(block, dict) else str(block) + for block in content + if isinstance(block, (dict, str)) + ] + return " ".join(texts) + + return "" diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 045d2fd5f14..9ae95085f55 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -228,6 +228,7 @@ class KeyManagementRoutes(str, enum.Enum): KEY_BLOCK = "/key/block" KEY_UNBLOCK = "/key/unblock" KEY_BULK_UPDATE = "/key/bulk_update" + KEY_RESET_SPEND = "/key/{key_id}/reset_spend" # info and health routes KEY_INFO = "/key/info" @@ -987,6 +988,10 @@ class RegenerateKeyRequest(GenerateKeyRequest): new_master_key: Optional[str] = None +class ResetSpendRequest(LiteLLMPydanticObjectBase): + reset_to: float + + class KeyRequest(LiteLLMPydanticObjectBase): keys: Optional[List[str]] = None key_aliases: Optional[List[str]] = None diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index af2574d88ee..71ae1348f39 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -64,27 +64,6 @@ def _get_models_from_access_groups( return all_models -def get_access_groups_from_models( - model_access_groups: Dict[str, List[str]], - models: List[str], -) -> List[str]: - """ - Extract access group names from a models list. - - Given a models list like ["gpt-4", "beta-models", "claude-v1"] - and access groups like {"beta-models": ["gpt-5", "gpt-6"]}, - returns ["beta-models"]. - - This is used to pass allowed access groups to the router for filtering - deployments during load balancing (GitHub issue #18333). - """ - access_groups = [] - for model in models: - if model in model_access_groups: - access_groups.append(model) - return access_groups - - async def get_mcp_server_ids( user_api_key_dict: UserAPIKeyAuth, ) -> List[str]: @@ -101,6 +80,7 @@ async def get_mcp_server_ids( # Make a direct SQL query to get just the mcp_servers try: + result = await prisma_client.db.litellm_objectpermissiontable.find_unique( where={"object_permission_id": user_api_key_dict.object_permission_id}, ) @@ -196,7 +176,6 @@ def get_complete_model_list( """ unique_models = [] - def append_unique(models): for model in models: if model not in unique_models: @@ -209,7 +188,7 @@ def append_unique(models): else: append_unique(proxy_model_list) if include_model_access_groups: - append_unique(list(model_access_groups.keys())) # TODO: keys order + append_unique(list(model_access_groups.keys())) # TODO: keys order if user_model: append_unique([user_model]) diff --git a/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md b/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md new file mode 100644 index 00000000000..f2f9a1d4856 --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/ARCHITECTURE.md @@ -0,0 +1,96 @@ +# MCP Semantic Tool Filter Architecture + +## Why Filter MCP Tools + +When multiple MCP servers are connected, the proxy may expose hundreds of tools. Sending all tools in every request wastes context window tokens and increases cost. The semantic filter keeps only the top-K most relevant tools based on embedding similarity. + +```mermaid +sequenceDiagram + participant Client + participant Hook as SemanticToolFilterHook + participant Filter as SemanticMCPToolFilter + participant Router as semantic-router + participant LLM + + Client->>Hook: POST /chat/completions + Note over Client,Hook: tools: [100+ MCP tools] + Note over Client,Hook: messages: [{"role": "user", "content": "Get my Jira issues"}] + + rect rgb(240, 240, 240) + Note over Hook: 1. Extract User Query + Hook->>Filter: filter_tools("Get my Jira issues", tools) + end + + rect rgb(240, 240, 240) + Note over Filter: 2. Convert Tools → Routes + Note over Filter: Tool name + description → Route + end + + rect rgb(240, 240, 240) + Note over Filter: 3. Semantic Matching + Filter->>Router: router(query) + Router->>Router: Embeddings + similarity + Router-->>Filter: [top 10 matches] + end + + rect rgb(240, 240, 240) + Note over Filter: 4. Return Filtered Tools + Filter-->>Hook: [10 relevant tools] + end + + Hook->>LLM: POST /chat/completions + Note over Hook,LLM: tools: [10 Jira-related tools] ← FILTERED + Note over Hook,LLM: messages: [...] ← UNCHANGED + + LLM-->>Client: Response (unchanged) +``` + +## Filter Operations + +The hook intercepts requests before they reach the LLM: + +| Operation | Description | +|-----------|-------------| +| **Extract query** | Get user message from `messages[-1]` | +| **Convert to Routes** | Transform MCP tools into semantic-router Routes | +| **Semantic match** | Use `semantic-router` to find top-K similar tools | +| **Filter tools** | Replace request `tools` with filtered subset | + +## Trigger Conditions + +The filter only runs when: +- Call type is `completion` or `acompletion` +- Request contains `tools` field +- Request contains `messages` field +- Filter is enabled in config + +## What Does NOT Change + +- Request messages +- Response body +- Non-tool parameters + +## Integration with semantic-router + +Reuses existing LiteLLM infrastructure: +- `semantic-router` - Already an optional dependency +- `LiteLLMRouterEncoder` - Wraps `Router.aembedding()` for embeddings +- `SemanticRouter` - Handles similarity calculation and top-K selection + +## Configuration + +```yaml +litellm_settings: + mcp_semantic_tool_filter: + enabled: true + embedding_model: "openai/text-embedding-3-small" + top_k: 10 + similarity_threshold: 0.3 +``` + +## Error Handling + +The filter fails gracefully: +- If filtering fails → Return all tools (no impact on functionality) +- If query extraction fails → Skip filtering +- If no matches found → Return all tools diff --git a/litellm/proxy/hooks/mcp_semantic_filter/__init__.py b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py new file mode 100644 index 00000000000..36d357d560f --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py @@ -0,0 +1,9 @@ +""" +MCP Semantic Tool Filter Hook + +Semantic filtering for MCP tools to reduce context window size +and improve tool selection accuracy. +""" +from litellm.proxy.hooks.mcp_semantic_filter.hook import SemanticToolFilterHook + +__all__ = ["SemanticToolFilterHook"] diff --git a/litellm/proxy/hooks/mcp_semantic_filter/hook.py b/litellm/proxy/hooks/mcp_semantic_filter/hook.py new file mode 100644 index 00000000000..fc9349c2a42 --- /dev/null +++ b/litellm/proxy/hooks/mcp_semantic_filter/hook.py @@ -0,0 +1,353 @@ +""" +Semantic Tool Filter Hook + +Pre-call hook that filters MCP tools semantically before LLM inference. +Reduces context window size and improves tool selection accuracy. +""" +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.constants import ( + DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL, + DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD, + DEFAULT_MCP_SEMANTIC_FILTER_TOP_K, +) +from litellm.integrations.custom_logger import CustomLogger + +if TYPE_CHECKING: + from litellm.caching.caching import DualCache + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy._types import UserAPIKeyAuth + from litellm.router import Router + + +class SemanticToolFilterHook(CustomLogger): + """ + Pre-call hook that filters MCP tools semantically. + + This hook: + 1. Extracts the user query from messages + 2. Filters tools based on semantic similarity to the query + 3. Returns only the top-k most relevant tools to the LLM + """ + + def __init__(self, semantic_filter: "SemanticMCPToolFilter"): + """ + Initialize the hook. + + Args: + semantic_filter: SemanticMCPToolFilter instance + """ + super().__init__() + self.filter = semantic_filter + + verbose_proxy_logger.debug( + f"Initialized SemanticToolFilterHook with filter: " + f"enabled={semantic_filter.enabled}, top_k={semantic_filter.top_k}" + ) + + def _should_expand_mcp_tools(self, tools: List[Any]) -> bool: + """ + Check if tools contain MCP references with server_url="litellm_proxy". + + Only expands MCP tools pointing to litellm proxy, not external MCP servers. + """ + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, + ) + + return LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools) + + async def _expand_mcp_tools( + self, + tools: List[Any], + user_api_key_dict: "UserAPIKeyAuth", + ) -> List[Dict[str, Any]]: + """ + Expand MCP references to actual tool definitions. + + Reuses LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format + which internally does: parse -> fetch -> filter -> deduplicate -> transform + """ + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, + ) + + # Parse to separate MCP tools from other tools + mcp_tools, _ = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools) + + if not mcp_tools: + return [] + + # Use single combined method instead of 3 separate calls + # This already handles: fetch -> filter by allowed_tools -> deduplicate -> transform + openai_tools, _ = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format( + user_api_key_auth=user_api_key_dict, + mcp_tools_with_litellm_proxy=mcp_tools + ) + + # Convert Pydantic models to dicts for compatibility + openai_tools_as_dicts = [] + for tool in openai_tools: + if hasattr(tool, "model_dump"): + tool_dict = tool.model_dump(exclude_none=True) + verbose_proxy_logger.debug(f"Converted Pydantic tool to dict: {type(tool).__name__} -> dict with keys: {list(tool_dict.keys())}") + openai_tools_as_dicts.append(tool_dict) + elif hasattr(tool, "dict"): + tool_dict = tool.dict(exclude_none=True) + verbose_proxy_logger.debug(f"Converted Pydantic tool (v1) to dict: {type(tool).__name__} -> dict") + openai_tools_as_dicts.append(tool_dict) + elif isinstance(tool, dict): + verbose_proxy_logger.debug(f"Tool is already a dict with keys: {list(tool.keys())}") + openai_tools_as_dicts.append(tool) + else: + verbose_proxy_logger.warning(f"Tool is unknown type: {type(tool)}, passing as-is") + openai_tools_as_dicts.append(tool) + + verbose_proxy_logger.debug( + f"Expanded {len(mcp_tools)} MCP reference(s) to {len(openai_tools_as_dicts)} tools (all as dicts)" + ) + + return openai_tools_as_dicts + + def _get_metadata_variable_name(self, data: dict) -> str: + if "litellm_metadata" in data: + return "litellm_metadata" + return "metadata" + + async def async_pre_call_hook( + self, + user_api_key_dict: "UserAPIKeyAuth", + cache: "DualCache", + data: dict, + call_type: str, + ) -> Optional[Union[Exception, str, dict]]: + """ + Filter tools before LLM call based on user query. + + This hook is called before the LLM request is made. It filters the + tools list to only include semantically relevant tools. + + Args: + user_api_key_dict: User authentication + cache: Cache instance + data: Request data containing messages and tools + call_type: Type of call (completion, acompletion, etc.) + + Returns: + Modified data dict with filtered tools, or None if no changes + """ + # Only filter endpoints that support tools + if call_type not in ("completion", "acompletion", "aresponses"): + verbose_proxy_logger.debug( + f"Skipping semantic filter for call_type={call_type}" + ) + return None + + # Check if tools are present + tools = data.get("tools") + if not tools: + verbose_proxy_logger.debug("No tools in request, skipping semantic filter") + return None + + original_tool_count = len(tools) + + # Check for MCP references (server_url="litellm_proxy") and expand them + if self._should_expand_mcp_tools(tools): + verbose_proxy_logger.debug( + "Detected litellm_proxy MCP references, expanding before semantic filtering" + ) + + try: + expanded_tools = await self._expand_mcp_tools( + tools, user_api_key_dict + ) + + if not expanded_tools: + verbose_proxy_logger.warning( + "No tools expanded from MCP references" + ) + return None + + verbose_proxy_logger.info( + f"Expanded {len(tools)} MCP reference(s) to {len(expanded_tools)} tools" + ) + + # Update tools for filtering + tools = expanded_tools + original_tool_count = len(tools) + + except Exception as e: + verbose_proxy_logger.error( + f"Failed to expand MCP references: {e}", exc_info=True + ) + return None + + # Check if messages are present (try both "messages" and "input" for responses API) + messages = data.get("messages", []) + if not messages: + messages = data.get("input", []) + if not messages: + verbose_proxy_logger.debug("No messages in request, skipping semantic filter") + return None + + # Check if filter is enabled + if not self.filter.enabled: + verbose_proxy_logger.debug("Semantic filter disabled, skipping") + return None + + try: + # Extract user query from messages + user_query = self.filter.extract_user_query(messages) + if not user_query: + verbose_proxy_logger.debug("No user query found, skipping semantic filter") + return None + + verbose_proxy_logger.debug( + f"Applying semantic filter to {len(tools)} tools " + f"with query: '{user_query[:50]}...'" + ) + + # Filter tools semantically + filtered_tools = await self.filter.filter_tools( + query=user_query, + available_tools=tools, # type: ignore + ) + + # Always update tools and emit header (even if count unchanged) + data["tools"] = filtered_tools + + # Store filter stats and tool names for response header + filter_stats = f"{original_tool_count}->{len(filtered_tools)}" + tool_names_csv = self._get_tool_names_csv(filtered_tools) + + _metadata_variable_name = self._get_metadata_variable_name(data) + data[_metadata_variable_name]["litellm_semantic_filter_stats"] = filter_stats + data[_metadata_variable_name]["litellm_semantic_filter_tools"] = tool_names_csv + + verbose_proxy_logger.info( + f"Semantic tool filter: {filter_stats} tools" + ) + + return data + + except Exception as e: + verbose_proxy_logger.warning( + f"Semantic tool filter hook failed: {e}. Proceeding with all tools." + ) + return None + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: "UserAPIKeyAuth", + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + """Add semantic filter stats and tool names to response headers.""" + from litellm.constants import MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH + + _metadata_variable_name = self._get_metadata_variable_name(data) + metadata = data[_metadata_variable_name] + + filter_stats = metadata.get("litellm_semantic_filter_stats") + if not filter_stats: + return None + + headers = {"x-litellm-semantic-filter": filter_stats} + + # Add CSV of filtered tool names (nginx-safe length) + tool_names_csv = metadata.get("litellm_semantic_filter_tools", "") + if tool_names_csv: + if len(tool_names_csv) > MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH: + tool_names_csv = tool_names_csv[:MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH - 3] + "..." + + headers["x-litellm-semantic-filter-tools"] = tool_names_csv + + return headers + + def _get_tool_names_csv(self, tools: List[Any]) -> str: + """Extract tool names and return as CSV string.""" + if not tools: + return "" + + tool_names = [] + for tool in tools: + name = tool.get("name", "") if isinstance(tool, dict) else getattr(tool, "name", "") + if name: + tool_names.append(name) + + return ",".join(tool_names) + + @staticmethod + async def initialize_from_config( + config: Optional[Dict[str, Any]], + llm_router: Optional["Router"], + ) -> Optional["SemanticToolFilterHook"]: + """ + Initialize semantic tool filter from proxy config. + + Args: + config: Proxy configuration dict (litellm_settings.mcp_semantic_tool_filter) + llm_router: LiteLLM router instance for embeddings + + Returns: + SemanticToolFilterHook instance if enabled, None otherwise + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + if not config or not config.get("enabled", False): + verbose_proxy_logger.debug("Semantic tool filter not enabled in config") + return None + + if llm_router is None: + verbose_proxy_logger.warning( + "Cannot initialize semantic filter: llm_router is None" + ) + return None + + try: + + embedding_model = config.get( + "embedding_model", DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL + ) + top_k = config.get("top_k", DEFAULT_MCP_SEMANTIC_FILTER_TOP_K) + similarity_threshold = config.get( + "similarity_threshold", DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD + ) + + semantic_filter = SemanticMCPToolFilter( + embedding_model=embedding_model, + litellm_router_instance=llm_router, + top_k=top_k, + similarity_threshold=similarity_threshold, + enabled=True, + ) + + # Build router from MCP registry on startup + await semantic_filter.build_router_from_mcp_registry() + + hook = SemanticToolFilterHook(semantic_filter) + + verbose_proxy_logger.info( + f"✅ MCP Semantic Tool Filter enabled: " + f"embedding_model={embedding_model}, top_k={top_k}, " + f"similarity_threshold={similarity_threshold}" + ) + + return hook + + except ImportError as e: + verbose_proxy_logger.warning( + f"semantic-router not installed. Install with: " + f"pip install 'litellm[semantic-router]'. Error: {e}" + ) + return None + except Exception as e: + verbose_proxy_logger.exception( + f"Failed to initialize MCP semantic tool filter: {e}" + ) + return None diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 72f23e609ab..9be78264e85 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1021,37 +1021,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915 "user_api_key_user_max_budget" ] = user_api_key_dict.user_max_budget - # Extract allowed access groups for router filtering (GitHub issue #18333) - # This allows the router to filter deployments based on key's and team's access groups - # NOTE: We keep key and team access groups SEPARATE because a key doesn't always - # inherit all team access groups (per maintainer feedback). - if llm_router is not None: - from litellm.proxy.auth.model_checks import get_access_groups_from_models - - model_access_groups = llm_router.get_model_access_groups() - - # Key-level access groups (from user_api_key_dict.models) - key_models = list(user_api_key_dict.models) if user_api_key_dict.models else [] - key_allowed_access_groups = get_access_groups_from_models( - model_access_groups=model_access_groups, models=key_models - ) - if key_allowed_access_groups: - data[_metadata_variable_name][ - "user_api_key_allowed_access_groups" - ] = key_allowed_access_groups - - # Team-level access groups (from user_api_key_dict.team_models) - team_models = ( - list(user_api_key_dict.team_models) if user_api_key_dict.team_models else [] - ) - team_allowed_access_groups = get_access_groups_from_models( - model_access_groups=model_access_groups, models=team_models - ) - if team_allowed_access_groups: - data[_metadata_variable_name][ - "user_api_key_team_allowed_access_groups" - ] = team_allowed_access_groups - data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) _headers.pop( diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 278971a91a5..d1840363009 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -3373,6 +3373,163 @@ async def regenerate_key_fn( raise handle_exception_on_proxy(e) +async def _check_proxy_or_team_admin_for_key( + key_in_db: LiteLLM_VerificationToken, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, +) -> None: + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return + + if key_in_db.team_id is not None: + team_table = await get_team_object( + team_id=key_in_db.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + check_db_only=True, + ) + if team_table is not None: + if _is_user_team_admin( + user_api_key_dict=user_api_key_dict, + team_obj=team_table, + ): + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "You must be a proxy admin or team admin to reset key spend"}, + ) + + +def _validate_reset_spend_value( + reset_to: Any, key_in_db: LiteLLM_VerificationToken +) -> float: + if not isinstance(reset_to, (int, float)): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "reset_to must be a float"}, + ) + + reset_to = float(reset_to) + + if reset_to < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "reset_to must be >= 0"}, + ) + + current_spend = key_in_db.spend or 0.0 + if reset_to > current_spend: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"reset_to ({reset_to}) must be <= current spend ({current_spend})"}, + ) + + max_budget = key_in_db.max_budget + if key_in_db.litellm_budget_table is not None: + budget_max_budget = getattr(key_in_db.litellm_budget_table, "max_budget", None) + if budget_max_budget is not None: + if max_budget is None or budget_max_budget < max_budget: + max_budget = budget_max_budget + + if max_budget is not None and reset_to > max_budget: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"reset_to ({reset_to}) must be <= budget ({max_budget})"}, + ) + + return reset_to + + +@router.post( + "/key/{key:path}/reset_spend", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def reset_key_spend_fn( + key: str, + data: ResetSpendRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +) -> Dict[str, Any]: + try: + from litellm.proxy.proxy_server import ( + hash_token, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "DB not connected. prisma_client is None"}, + ) + + if "sk" not in key: + hashed_api_key = key + else: + hashed_api_key = hash_token(key) + + _key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_api_key}, + include={"litellm_budget_table": True}, + ) + if _key_in_db is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Key {key} not found."}, + ) + + current_spend = _key_in_db.spend or 0.0 + reset_to = _validate_reset_spend_value(data.reset_to, _key_in_db) + + await _check_proxy_or_team_admin_for_key( + key_in_db=_key_in_db, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + + updated_key = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_api_key}, + data={"spend": reset_to}, + ) + + if updated_key is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "Failed to update key spend"}, + ) + + await _delete_cache_key_object( + hashed_token=hashed_api_key, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + max_budget = updated_key.max_budget + budget_reset_at = updated_key.budget_reset_at + + return { + "key_hash": hashed_api_key, + "spend": reset_to, + "previous_spend": current_spend, + "max_budget": max_budget, + "budget_reset_at": budget_reset_at, + } + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception("Error resetting key spend: %s", e) + raise handle_exception_on_proxy(e) + + async def validate_key_list_check( user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index e12e75b54ff..d87ae8b14ca 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,4 +1,14 @@ model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY + + - model_name: text-embedding-3-small + litellm_params: + model: openai/text-embedding-3-small + api_key: os.environ/OPENAI_API_KEY + - model_name: bedrock-claude-sonnet-3.5 litellm_params: model: "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0" @@ -22,4 +32,31 @@ model_list: - model_name: bedrock-nova-premier litellm_params: model: "bedrock/us.amazon.nova-premier-v1:0" - aws_region_name: "us-east-1" \ No newline at end of file + aws_region_name: "us-east-1" + +# MCP Server Configuration +mcp_servers: + # Wikipedia MCP - reliable and works without external deps + wikipedia: + transport: "stdio" + command: "uvx" + args: ["mcp-server-fetch"] + description: "Fetch web pages and Wikipedia content" + deepwiki: + transport: "http" + url: "https://mcp.deepwiki.com/mcp" + +# General Settings +general_settings: + master_key: sk-1234 + store_model_in_db: false + +# LiteLLM Settings +litellm_settings: + # Enable MCP Semantic Tool Filter + mcp_semantic_tool_filter: + enabled: true + embedding_model: "text-embedding-3-small" + top_k: 5 + similarity_threshold: 0.3 + diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 59e5fdc56af..8f433bfa486 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -793,6 +793,21 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 redis_usage_cache=redis_usage_cache, ) + ## SEMANTIC TOOL FILTER ## + # Read litellm_settings from config for semantic filter initialization + try: + verbose_proxy_logger.debug("About to initialize semantic tool filter") + _config = proxy_config.get_config_state() + _litellm_settings = _config.get("litellm_settings", {}) + verbose_proxy_logger.debug(f"litellm_settings keys = {list(_litellm_settings.keys())}") + await ProxyStartupEvent._initialize_semantic_tool_filter( + llm_router=llm_router, + litellm_settings=_litellm_settings, + ) + verbose_proxy_logger.debug("After semantic tool filter initialization") + except Exception as e: + verbose_proxy_logger.error(f"Semantic filter init failed: {e}", exc_info=True) + ## JWT AUTH ## ProxyStartupEvent._initialize_jwt_auth( general_settings=general_settings, @@ -4742,6 +4757,34 @@ def _initialize_startup_logging( llm_router=llm_router, redis_usage_cache=redis_usage_cache ) + @classmethod + async def _initialize_semantic_tool_filter( + cls, + llm_router: Optional[Router], + litellm_settings: Dict[str, Any], + ): + """Initialize MCP semantic tool filter if configured""" + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + + verbose_proxy_logger.info( + f"Initializing semantic tool filter: llm_router={llm_router is not None}, " + f"litellm_settings keys={list(litellm_settings.keys())}" + ) + + mcp_semantic_filter_config = litellm_settings.get("mcp_semantic_tool_filter", None) + verbose_proxy_logger.debug(f"Semantic filter config: {mcp_semantic_filter_config}") + + hook = await SemanticToolFilterHook.initialize_from_config( + config=mcp_semantic_filter_config, + llm_router=llm_router, + ) + + if hook: + verbose_proxy_logger.debug("✅ Semantic tool filter hook registered") + litellm.logging_callback_manager.add_litellm_callback(hook) + else: + verbose_proxy_logger.warning("❌ Semantic tool filter hook not initialized") + @classmethod def _initialize_jwt_auth( cls, diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 3b81da10923..b118400b620 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -305,16 +305,6 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) - - // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 - @@index([user_id, team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 - @@index([team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 - @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/litellm/router.py b/litellm/router.py index 65445e29c41..d01c8443dab 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -88,7 +88,6 @@ is_clientside_credential, ) from litellm.router_utils.common_utils import ( - filter_deployments_by_access_groups, filter_team_based_models, filter_web_search_deployments, ) @@ -8088,17 +8087,10 @@ async def async_get_healthy_deployments( request_kwargs=request_kwargs, ) - verbose_router_logger.debug(f"healthy_deployments after web search filter: {healthy_deployments}") - - # Filter by allowed access groups (GitHub issue #18333) - # This prevents cross-team load balancing when teams have models with same name in different access groups - healthy_deployments = filter_deployments_by_access_groups( - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, + verbose_router_logger.debug( + f"healthy_deployments after web search filter: {healthy_deployments}" ) - verbose_router_logger.debug(f"healthy_deployments after access group filter: {healthy_deployments}") - if isinstance(healthy_deployments, dict): return healthy_deployments diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 2c0ea5976d6..10acc343abd 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -75,7 +75,6 @@ def filter_team_based_models( if deployment.get("model_info", {}).get("id") not in ids_to_remove ] - def _deployment_supports_web_search(deployment: Dict) -> bool: """ Check if a deployment supports web search. @@ -113,7 +112,7 @@ def filter_web_search_deployments( is_web_search_request = False tools = request_kwargs.get("tools") or [] for tool in tools: - # These are the two websearch tools for OpenAI / Azure. + # These are the two websearch tools for OpenAI / Azure. if tool.get("type") == "web_search" or tool.get("type") == "web_search_preview": is_web_search_request = True break @@ -122,82 +121,8 @@ def filter_web_search_deployments( return healthy_deployments # Filter out deployments that don't support web search - final_deployments = [ - d for d in healthy_deployments if _deployment_supports_web_search(d) - ] + final_deployments = [d for d in healthy_deployments if _deployment_supports_web_search(d)] if len(healthy_deployments) > 0 and len(final_deployments) == 0: verbose_logger.warning("No deployments support web search for request") return final_deployments - -def filter_deployments_by_access_groups( - healthy_deployments: Union[List[Dict], Dict], - request_kwargs: Optional[Dict] = None, -) -> Union[List[Dict], Dict]: - """ - Filter deployments to only include those matching the user's allowed access groups. - - Reads from TWO separate metadata fields (per maintainer feedback): - - `user_api_key_allowed_access_groups`: Access groups from the API Key's models. - - `user_api_key_team_allowed_access_groups`: Access groups from the Team's models. - - A deployment is included if its access_groups overlap with EITHER the key's - or the team's allowed access groups. Deployments with no access_groups are - always included (not restricted). - - This prevents cross-team load balancing when multiple teams have models with - the same name but in different access groups (GitHub issue #18333). - """ - if request_kwargs is None: - return healthy_deployments - - if isinstance(healthy_deployments, dict): - return healthy_deployments - - metadata = request_kwargs.get("metadata") or {} - litellm_metadata = request_kwargs.get("litellm_metadata") or {} - - # Gather key-level allowed access groups - key_allowed_access_groups = ( - metadata.get("user_api_key_allowed_access_groups") - or litellm_metadata.get("user_api_key_allowed_access_groups") - or [] - ) - - # Gather team-level allowed access groups - team_allowed_access_groups = ( - metadata.get("user_api_key_team_allowed_access_groups") - or litellm_metadata.get("user_api_key_team_allowed_access_groups") - or [] - ) - - # Combine both for the final allowed set - combined_allowed_access_groups = list(key_allowed_access_groups) + list( - team_allowed_access_groups - ) - - # If no access groups specified from either source, return all deployments (backwards compatible) - if not combined_allowed_access_groups: - return healthy_deployments - - allowed_set = set(combined_allowed_access_groups) - filtered = [] - for deployment in healthy_deployments: - model_info = deployment.get("model_info") or {} - deployment_access_groups = model_info.get("access_groups") or [] - - # If deployment has no access groups, include it (not restricted) - if not deployment_access_groups: - filtered.append(deployment) - continue - - # Include if any of deployment's groups overlap with allowed groups - if set(deployment_access_groups) & allowed_set: - filtered.append(deployment) - - if len(healthy_deployments) > 0 and len(filtered) == 0: - verbose_logger.warning( - f"No deployments match allowed access groups {combined_allowed_access_groups}" - ) - - return filtered diff --git a/litellm/router_utils/fallback_event_handlers.py b/litellm/router_utils/fallback_event_handlers.py index 738b82d7023..62e706a0cf5 100644 --- a/litellm/router_utils/fallback_event_handlers.py +++ b/litellm/router_utils/fallback_event_handlers.py @@ -113,16 +113,8 @@ async def run_async_fallback( The most recent exception if all fallback model groups fail. """ - ### BASE CASE ### MAX FALLBACK DEPTH REACHED - if fallback_depth >= max_fallbacks: - raise original_exception - - ### CHECK IF MODEL GROUP LIST EXHAUSTED - if original_model_group in fallback_model_group: - fallback_group_length = len(fallback_model_group) - 1 - else: - fallback_group_length = len(fallback_model_group) - if fallback_depth >= fallback_group_length: + ### BASE CASE ### MAX FALLBACK DEPTH REACHED + if fallback_depth >= max_fallbacks: raise original_exception error_from_fallbacks = original_exception diff --git a/pyproject.toml b/pyproject.toml index 450dadac930..9832ca483dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.81.6" +version = "1.81.7" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -174,7 +174,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.81.6" +version = "1.81.7" version_files = [ "pyproject.toml:^version" ] diff --git a/schema.prisma b/schema.prisma index 3b81da10923..b118400b620 100644 --- a/schema.prisma +++ b/schema.prisma @@ -305,16 +305,6 @@ model LiteLLM_VerificationToken { litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) - - // SELECT COUNT(*) FROM (SELECT "public"."LiteLLM_VerificationToken"."token" FROM "public"."LiteLLM_VerificationToken" WHERE ("public"."LiteLLM_VerificationToken"."user_id" = $1 AND ("public"."LiteLLM_VerificationToken"."team_id" IS NULL OR "public"."LiteLLM_VerificationToken"."team_id" <> $2)) OFFSET $3 ) AS "sub" - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."user_id" = $1 OFFSET $2 - @@index([user_id, team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE "public"."LiteLLM_VerificationToken"."team_id" = $1 OFFSET $2 - @@index([team_id]) - - // SELECT ... FROM "public"."LiteLLM_VerificationToken" WHERE (("public"."LiteLLM_VerificationToken"."expires" IS NULL OR "public"."LiteLLM_VerificationToken"."expires" > $1) AND "public"."LiteLLM_VerificationToken"."budget_reset_at" < $2) OFFSET $3 - @@index([budget_reset_at, expires]) } // Audit table for deleted keys - preserves spend and key information for historical tracking diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py index e3e05786449..c1c52757cf0 100644 --- a/tests/llm_translation/test_gemini.py +++ b/tests/llm_translation/test_gemini.py @@ -1435,3 +1435,20 @@ def test_gemini_image_size_limit_exceeded(): error_message = str(excinfo.value) assert "Image size" in error_message assert "exceeds maximum allowed size" in error_message + +@pytest.mark.asyncio +async def test_gemini_openai_web_search_tool_to_google_search(): + """ + Test that OpenAI-style web_search tools are transformed to Gemini's googleSearch. + + When passing {"type": "web_search"} or {"type": "web_search_preview"} to Gemini, + these should be transformed to googleSearch, not silently ignored. + """ + response = await litellm.acompletion( + model="gemini/gemini-2.5-flash", + messages=[{"role": "user", "content": "What is the capital of France?"}], + tools=[{"type": "web_search"}], + ) + print("response: ", response.model_dump_json(indent=4)) + assert hasattr(response, "vertex_ai_grounding_metadata") + assert getattr(response, "vertex_ai_grounding_metadata") is not None diff --git a/tests/mcp_tests/test_semantic_tool_filter_e2e.py b/tests/mcp_tests/test_semantic_tool_filter_e2e.py new file mode 100644 index 00000000000..91c072ae8a3 --- /dev/null +++ b/tests/mcp_tests/test_semantic_tool_filter_e2e.py @@ -0,0 +1,89 @@ +""" +End-to-end test for MCP Semantic Tool Filtering +""" +import asyncio +import os +import sys +from unittest.mock import Mock + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from mcp.types import Tool as MCPTool + +# Check if semantic-router is available +try: + import semantic_router + SEMANTIC_ROUTER_AVAILABLE = True +except ImportError: + SEMANTIC_ROUTER_AVAILABLE = False + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not SEMANTIC_ROUTER_AVAILABLE, + reason="semantic-router not installed. Install with: pip install 'litellm[semantic-router]'" +) +async def test_e2e_semantic_filter(): + """E2E: Load router/filter and verify hook filters tools.""" + from litellm import Router + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create router and filter + router = Router( + model_list=[{ + "model_name": "text-embedding-3-small", + "litellm_params": {"model": "openai/text-embedding-3-small"}, + }] + ) + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=router, + top_k=3, + enabled=True, + ) + + # Create 10 tools + tools = [ + MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}), + MCPTool(name="calendar_create", description="Create a calendar event", inputSchema={"type": "object"}), + MCPTool(name="file_upload", description="Upload a file", inputSchema={"type": "object"}), + MCPTool(name="web_search", description="Search the web", inputSchema={"type": "object"}), + MCPTool(name="slack_send", description="Send Slack message", inputSchema={"type": "object"}), + MCPTool(name="doc_read", description="Read document", inputSchema={"type": "object"}), + MCPTool(name="db_query", description="Query database", inputSchema={"type": "object"}), + MCPTool(name="api_call", description="Make API call", inputSchema={"type": "object"}), + MCPTool(name="task_create", description="Create task", inputSchema={"type": "object"}), + MCPTool(name="note_add", description="Add note", inputSchema={"type": "object"}), + ] + + # Build router with test tools + filter_instance._build_router(tools) + + hook = SemanticToolFilterHook(filter_instance) + + data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Send an email and create a calendar event"}], + "tools": tools, + "metadata": {}, # Initialize metadata dict for hook to store filter stats + } + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=Mock(), + cache=Mock(), + data=data, + call_type="completion", + ) + + # Single assertion: hook filtered tools + assert result and len(result["tools"]) < len(tools), f"Expected filtered tools, got {len(result['tools'])} tools (original: {len(tools)})" + + print(f"✅ E2E test passed: Filtering reduced tools from {len(tools)} to {len(result['tools'])}") + print(f" Filtered tools: {[t.name for t in result['tools']]}") diff --git a/tests/test_fallbacks.py b/tests/test_fallbacks.py index c22cefa6be6..bc9aa4c64c8 100644 --- a/tests/test_fallbacks.py +++ b/tests/test_fallbacks.py @@ -336,45 +336,3 @@ async def test_chat_completion_bad_and_good_model(): f"Iteration {iteration + 1}: {'✓' if success else '✗'} ({time.time() - start_time:.2f}s)" ) assert success, "Not all good model requests succeeded" - - -@pytest.mark.asyncio -async def test_router_fallback_exhaustion(): - """ - Test for Bug 19985: - """ - from litellm import Router - import pytest - - # Setup: Only ONE fallback model available - model_list = [ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": {"model": "openai/fake", "api_key": "bad-key"}, - }, - { - "model_name": "bad-model-1", - "litellm_params": {"model": "azure/fake", "api_key": "bad-key"}, - } - ] - - # max_fallbacks=10 is much larger than the 1 fallback provided in the list - router = Router( - model_list=model_list, - fallbacks=[{"gpt-3.5-turbo": ["bad-model-1"]}], - max_fallbacks=10 - ) - - try: - # This will fail and attempt to fallback - await router.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "test"}] - ) - except Exception as e: - # The success criteria is that we DON'T get an IndexError - assert not isinstance(e, IndexError), f"Expected API error, but got IndexError: {e}" - # Also ensure we actually hit a fallback attempt - print(f"Caught expected exception: {type(e).__name__}") - - diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py index ac099a0168c..cb3b51acd69 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py @@ -2663,6 +2663,111 @@ def test_vertex_ai_single_tool_type_still_works(): assert tools[0]["code_execution"] == {} +def test_vertex_ai_openai_web_search_tool_transformation(): + """ + Test that OpenAI-style web_search and web_search_preview tools are transformed to googleSearch. + + This fixes the issue where passing OpenAI-style web search tools like: + {"type": "web_search"} or {"type": "web_search_preview"} + would be silently ignored (the request succeeds but grounding is not applied). + + The fix transforms these to Gemini's googleSearch tool. + + Input: + value=[{"type": "web_search"}] + + Expected Output: + tools=[{"googleSearch": {}}] + """ + v = VertexGeminiConfig() + optional_params = {} + + # Test web_search transformation + tools = v._map_function( + value=[{"type": "web_search"}], + optional_params=optional_params + ) + + assert len(tools) == 1, f"Expected 1 Tool object, got {len(tools)}" + assert "googleSearch" in tools[0], f"Expected googleSearch in tool, got {tools[0].keys()}" + assert tools[0]["googleSearch"] == {}, f"Expected empty googleSearch config, got {tools[0]['googleSearch']}" + + +def test_vertex_ai_openai_web_search_preview_tool_transformation(): + """ + Test that OpenAI-style web_search_preview tool is transformed to googleSearch. + + Input: + value=[{"type": "web_search_preview"}] + + Expected Output: + tools=[{"googleSearch": {}}] + """ + v = VertexGeminiConfig() + optional_params = {} + + # Test web_search_preview transformation + tools = v._map_function( + value=[{"type": "web_search_preview"}], + optional_params=optional_params + ) + + assert len(tools) == 1, f"Expected 1 Tool object, got {len(tools)}" + assert "googleSearch" in tools[0], f"Expected googleSearch in tool, got {tools[0].keys()}" + assert tools[0]["googleSearch"] == {}, f"Expected empty googleSearch config, got {tools[0]['googleSearch']}" + + +def test_vertex_ai_openai_web_search_with_function_tools(): + """ + Test that OpenAI-style web_search tool works alongside function tools. + + Input: + value=[ + {"type": "web_search"}, + {"type": "function", "function": {"name": "get_weather", "description": "Get weather"}}, + ] + + Expected Output: + tools=[ + {"googleSearch": {}}, + {"function_declarations": [{"name": "get_weather", "description": "Get weather"}]}, + ] + """ + v = VertexGeminiConfig() + optional_params = {} + + tools = v._map_function( + value=[ + {"type": "web_search"}, + {"type": "function", "function": {"name": "get_weather", "description": "Get weather"}}, + ], + optional_params=optional_params + ) + + # Should have 2 separate Tool objects + assert len(tools) == 2, f"Expected 2 Tool objects, got {len(tools)}" + + # Find each tool type + search_tool = None + func_tool = None + + for tool in tools: + if "googleSearch" in tool: + search_tool = tool + elif "function_declarations" in tool: + func_tool = tool + + # Verify both tools are present + assert search_tool is not None, "googleSearch Tool should be present" + assert func_tool is not None, "function_declarations Tool should be present" + + # Verify googleSearch is empty config + assert search_tool["googleSearch"] == {} + + # Verify function declaration content + assert func_tool["function_declarations"][0]["name"] == "get_weather" + + def test_vertex_ai_multiple_function_declarations_grouped(): """ Test that multiple function declarations are grouped in ONE Tool object. diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py new file mode 100644 index 00000000000..87c597c659b --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_semantic_tool_filter.py @@ -0,0 +1,394 @@ +""" +Unit tests for MCP Semantic Tool Filtering + +Tests the core filtering logic that takes a long list of tools and returns +an ordered set of top K tools based on semantic similarity. +""" +import asyncio +import os +import sys +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from mcp.types import Tool as MCPTool + + +@pytest.mark.asyncio +async def test_semantic_filter_basic_filtering(): + """ + Test that the semantic filter correctly filters tools based on query. + + Given: 10 email/calendar tools + When: Query is "send an email" + Then: Email tools should rank higher than calendar tools + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create mock tools - mix of email and calendar tools + tools = [ + MCPTool(name="gmail_send", description="Send an email via Gmail", inputSchema={"type": "object"}), + MCPTool(name="outlook_send", description="Send an email via Outlook", inputSchema={"type": "object"}), + MCPTool(name="calendar_create", description="Create a calendar event", inputSchema={"type": "object"}), + MCPTool(name="calendar_update", description="Update a calendar event", inputSchema={"type": "object"}), + MCPTool(name="email_read", description="Read emails from inbox", inputSchema={"type": "object"}), + MCPTool(name="email_delete", description="Delete an email", inputSchema={"type": "object"}), + MCPTool(name="calendar_delete", description="Delete a calendar event", inputSchema={"type": "object"}), + MCPTool(name="email_search", description="Search for emails", inputSchema={"type": "object"}), + MCPTool(name="calendar_list", description="List calendar events", inputSchema={"type": "object"}), + MCPTool(name="email_forward", description="Forward an email to someone", inputSchema={"type": "object"}), + ] + + # Mock router that returns mock embeddings + from litellm.types.utils import Embedding, EmbeddingResponse + + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + # Create filter + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Build router with the tools before filtering + filter_instance._build_router(tools) + + # Filter tools with email-related query + filtered = await filter_instance.filter_tools( + query="send an email to john@example.com", + available_tools=tools, + ) + + # Assertions - validate filtering mechanics work + assert len(filtered) <= 3, f"Should return at most 3 tools (top_k), got {len(filtered)}" + assert len(filtered) > 0, "Should return at least some tools" + assert len(filtered) < len(tools), f"Should filter down from {len(tools)} tools, got {len(filtered)}" + + # Validate tools are actual MCPTool objects + for tool in filtered: + assert hasattr(tool, 'name'), "Filtered result should be MCPTool with name" + assert hasattr(tool, 'description'), "Filtered result should be MCPTool with description" + + filtered_names = [t.name for t in filtered] + print(f"✅ Successfully filtered {len(tools)} tools down to top {len(filtered)}: {filtered_names}") + print(f" Filter respects top_k parameter correctly") + + +@pytest.mark.asyncio +async def test_semantic_filter_top_k_limiting(): + """ + Test that the filter respects top_k parameter. + + Given: 20 tools + When: top_k=5 + Then: Should return at most 5 tools + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + # Create 20 tools + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool number {i} for testing", inputSchema={"type": "object"}) + for i in range(20) + ] + + # Mock router + from litellm.types.utils import Embedding, EmbeddingResponse + + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + # Create filter with top_k=5 + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=5, + similarity_threshold=0.3, + enabled=True, + ) + + # Build router with the tools before filtering + filter_instance._build_router(tools) + + # Filter tools + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=tools, + ) + + # Should return at most 5 tools + assert len(filtered) <= 5, f"Expected at most 5 tools, got {len(filtered)}" + print(f"Returned {len(filtered)} tools out of {len(tools)} (top_k=5)") + + +@pytest.mark.asyncio +async def test_semantic_filter_disabled(): + """ + Test that when filter is disabled, all tools are returned. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"}) + for i in range(10) + ] + + mock_router = Mock() + + # Create disabled filter + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=False, # Disabled + ) + + # Filter tools + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=tools, + ) + + # Should return all tools when disabled + assert len(filtered) == len(tools), f"Expected all {len(tools)} tools, got {len(filtered)}" + + +@pytest.mark.asyncio +async def test_semantic_filter_empty_tools(): + """ + Test that filter handles empty tool list gracefully. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + mock_router = Mock() + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Filter empty list + filtered = await filter_instance.filter_tools( + query="test query", + available_tools=[], + ) + + assert len(filtered) == 0, "Should return empty list for empty input" + + +@pytest.mark.asyncio +async def test_semantic_filter_extract_user_query(): + """ + Test that user query extraction works correctly from messages. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + + mock_router = Mock() + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Test string content + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Send an email to john@example.com"}, + ] + + query = filter_instance.extract_user_query(messages) + assert query == "Send an email to john@example.com" + + # Test list content blocks + messages_with_blocks = [ + {"role": "user", "content": [ + {"type": "text", "text": "Hello, "}, + {"type": "text", "text": "send email please"}, + ]}, + ] + + query2 = filter_instance.extract_user_query(messages_with_blocks) + assert "Hello" in query2 and "send email" in query2 + + # Test no user messages + messages_no_user = [ + {"role": "system", "content": "System message only"}, + ] + + query3 = filter_instance.extract_user_query(messages_no_user) + assert query3 == "" + + +@pytest.mark.asyncio +async def test_semantic_filter_hook_triggers_on_completion(): + """ + Test that the hook triggers for completion requests with tools. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + from litellm.types.utils import Embedding, EmbeddingResponse + + # Create mock filter + mock_router = Mock() + + def mock_embedding_sync(*args, **kwargs): + return EmbeddingResponse( + data=[Embedding(embedding=[0.1] * 1536, index=0, object="embedding")], + model="text-embedding-3-small", + object="list", + usage={"prompt_tokens": 10, "total_tokens": 10} + ) + + async def mock_embedding_async(*args, **kwargs): + return mock_embedding_sync() + + mock_router.embedding = mock_embedding_sync + mock_router.aembedding = mock_embedding_async + + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Prepare data - completion request with tools + tools = [ + MCPTool(name=f"tool_{i}", description=f"Tool {i}", inputSchema={"type": "object"}) + for i in range(10) + ] + + # Build router with the tools before filtering + filter_instance._build_router(tools) + + # Create hook + hook = SemanticToolFilterHook(filter_instance) + + data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Send an email"} + ], + "tools": tools, + "metadata": {}, # Hook needs metadata field to store filter stats + } + + # Mock user API key dict and cache + mock_user_api_key_dict = Mock() + mock_cache = Mock() + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="completion", + ) + + # Assertions + assert result is not None, "Hook should return modified data" + assert "tools" in result, "Result should contain tools" + assert len(result["tools"]) < len(tools), f"Hook should filter tools, got {len(result['tools'])}/{len(tools)}" + + print(f"✅ Hook triggered correctly: {len(tools)} -> {len(result['tools'])} tools") + + + +@pytest.mark.asyncio +async def test_semantic_filter_hook_skips_no_tools(): + """ + Test that the hook does NOT trigger when there are no tools. + """ + from litellm.proxy._experimental.mcp_server.semantic_tool_filter import ( + SemanticMCPToolFilter, + ) + from litellm.proxy.hooks.mcp_semantic_filter import SemanticToolFilterHook + + # Create mock filter + mock_router = Mock() + filter_instance = SemanticMCPToolFilter( + embedding_model="text-embedding-3-small", + litellm_router_instance=mock_router, + top_k=3, + similarity_threshold=0.3, + enabled=True, + ) + + # Create hook + hook = SemanticToolFilterHook(filter_instance) + + # Prepare data - completion without tools + data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello"} + ], + } + + # Mock user API key dict and cache + mock_user_api_key_dict = Mock() + mock_cache = Mock() + + # Call hook + result = await hook.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="completion", + ) + + # Should return None (no modification) + assert result is None, "Hook should skip requests without tools" + print("✅ Hook correctly skips requests without tools") + diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 3638fd7e2c9..e90fb277eed 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -19,10 +19,12 @@ LiteLLM_BudgetTable, LiteLLM_OrganizationTable, LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, LiteLLM_VerificationToken, LitellmUserRoles, Member, ProxyException, + ResetSpendRequest, UpdateKeyRequest, ) from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth @@ -37,6 +39,7 @@ _save_deleted_verification_token_records, _transform_verification_tokens_to_deleted_records, _validate_max_budget, + _validate_reset_spend_value, can_modify_verification_token, check_org_key_model_specific_limits, check_team_key_model_specific_limits, @@ -44,6 +47,8 @@ generate_key_helper_fn, list_keys, prepare_key_update_data, + reset_key_spend_fn, + validate_key_list_check, validate_key_team_change, ) from litellm.proxy.proxy_server import app @@ -4690,3 +4695,699 @@ async def test_bulk_update_keys_partial_failures(monkeypatch): assert response.successful_updates[0].key == "test-key-1" assert response.failed_updates[0].key == "non-existent-key" assert "Key not found" in response.failed_updates[0].failed_reason + + +@pytest.mark.parametrize( + "reset_to,key_spend,key_max_budget,budget_max_budget,expected_error", + [ + ("not_a_number", 100.0, None, None, "reset_to must be a float"), + (None, 100.0, None, None, "reset_to must be a float"), + ([], 100.0, None, None, "reset_to must be a float"), + ({}, 100.0, None, None, "reset_to must be a float"), + (-1.0, 100.0, None, None, "reset_to must be >= 0"), + (-0.1, 100.0, None, None, "reset_to must be >= 0"), + (101.0, 100.0, None, None, "reset_to (101.0) must be <= current spend (100.0)"), + (150.0, 100.0, None, None, "reset_to (150.0) must be <= current spend (100.0)"), + (50.0, 100.0, 30.0, None, "reset_to (50.0) must be <= budget (30.0)"), + ], +) +def test_validate_reset_spend_value_invalid( + reset_to, key_spend, key_max_budget, budget_max_budget, expected_error +): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=key_spend, + max_budget=key_max_budget, + litellm_budget_table=LiteLLM_BudgetTable( + budget_id="test-budget", max_budget=budget_max_budget + ).dict() + if budget_max_budget is not None + else None, + ) + + with pytest.raises(HTTPException) as exc_info: + _validate_reset_spend_value(reset_to, key_in_db) + + assert exc_info.value.status_code == 400 + assert expected_error in str(exc_info.value.detail) + + +@pytest.mark.parametrize( + "reset_to,key_spend,key_max_budget,budget_max_budget", + [ + (0.0, 100.0, None, None), + (0, 100.0, None, None), + (50.0, 100.0, None, None), + (100.0, 100.0, None, None), + (25.0, 100.0, 50.0, None), + (0.0, 0.0, None, None), + (10.5, 50.0, 20.0, None), + ], +) +def test_validate_reset_spend_value_valid( + reset_to, key_spend, key_max_budget, budget_max_budget +): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=key_spend, + max_budget=key_max_budget, + litellm_budget_table=LiteLLM_BudgetTable( + budget_id="test-budget", max_budget=budget_max_budget + ).dict() + if budget_max_budget is not None + else None, + ) + + result = _validate_reset_spend_value(reset_to, key_in_db) + assert result == float(reset_to) + + +def test_validate_reset_spend_value_no_budget_table(): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=100.0, + max_budget=50.0, + litellm_budget_table=None, + ) + + result = _validate_reset_spend_value(25.0, key_in_db) + assert result == 25.0 + + +def test_validate_reset_spend_value_none_spend(): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=0.0, + max_budget=None, + litellm_budget_table=None, + ) + + result = _validate_reset_spend_value(0.0, key_in_db) + assert result == 0.0 + + with pytest.raises(HTTPException) as exc_info: + _validate_reset_spend_value(1.0, key_in_db) + assert exc_info.value.status_code == 400 + assert "must be <= current spend" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_success(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "hashed-test-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=100.0, + max_budget=200.0, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=50.0, + max_budget=200.0, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + + with patch( + "litellm.proxy.proxy_server.hash_token" + ) as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_hash_token.return_value = hashed_key + mock_check_admin.return_value = None + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + response = await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + assert response["previous_spend"] == 100.0 + assert response["key_hash"] == hashed_key + assert response["max_budget"] == 200.0 + mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once() + mock_delete_cache.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reset_key_spend_success_team_admin(monkeypatch): + """Test that team admin can reset key spend for keys in their team.""" + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "hashed-test-key" + team_id = "test-team-123" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id=team_id, + spend=100.0, + max_budget=200.0, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id=team_id, + spend=50.0, + max_budget=200.0, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + # Set up team table with user as admin + team_table = LiteLLM_TeamTableCachedObj( + team_id=team_id, + team_alias="test-team", + tpm_limit=None, + rpm_limit=None, + max_budget=None, + spend=0.0, + models=[], + blocked=False, + members_with_roles=[ + Member(user_id="team-admin-user", role="admin"), + Member(user_id="test-user", role="user"), + ], + ) + + async def mock_get_team_object(*args, **kwargs): + return team_table + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr( + "litellm.proxy.management_endpoints.key_management_endpoints.get_team_object", + mock_get_team_object, + ) + + with patch( + "litellm.proxy.proxy_server.hash_token" + ) as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_hash_token.return_value = hashed_key + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-team-admin", + user_id="team-admin-user", + ) + + response = await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + assert response["previous_spend"] == 100.0 + assert response["key_hash"] == hashed_key + assert response["max_budget"] == 200.0 + mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once() + mock_delete_cache.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reset_key_spend_key_not_found(monkeypatch): + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token: + mock_hash_token.return_value = "hashed-key" + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 404 + assert "Key not found" in str(exc_info.value.detail) or "Key sk-test-key not found" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_db_not_connected(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 500 + assert "DB not connected" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_validation_error(monkeypatch): + mock_prisma_client = MagicMock() + key_in_db = LiteLLM_VerificationToken( + token="hashed-key", + user_id="test-user", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token: + mock_hash_token.return_value = "hashed-key" + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=150.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 400 + assert "must be <= current spend" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_authorization_failure(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + + hashed_key = "hashed-test-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id="team-1", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin: + mock_hash_token.return_value = hashed_key + mock_check_admin.side_effect = HTTPException( + status_code=403, detail={"error": "Not authorized"} + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-user", + user_id="user-1", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_reset_key_spend_hashed_key(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "already-hashed-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=50.0, + max_budget=None, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_check_admin.return_value = None + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + response = await reset_key_spend_fn( + key=hashed_key, + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + mock_prisma_client.db.litellm_verificationtoken.find_unique.assert_called_once_with( + where={"token": hashed_key}, include={"litellm_budget_table": True} + ) + + +@pytest.mark.asyncio +async def test_validate_key_list_check_proxy_admin(): + mock_prisma_client = AsyncMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + user_id="admin-user", + ) + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert result is None + + +@pytest.mark.asyncio +async def test_validate_key_list_check_team_admin_success(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=["team-1"], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id="team-1", + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert result is not None + assert result.user_id == "test-user" + + +@pytest.mark.asyncio +async def test_validate_key_list_check_team_admin_fail(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=["team-1"], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with pytest.raises(ProxyException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id="team-2", + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.code == "403" or exc_info.value.code == 403 + assert "not authorized to check this team's keys" in exc_info.value.message + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_authorized(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + key_info = LiteLLM_VerificationToken( + token="hashed-key", + user_id="test-user", + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._can_user_query_key_info" + ) as mock_can_query: + mock_can_query.return_value = True + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="hashed-key", + prisma_client=mock_prisma_client, + ) + + assert result is not None + assert result.user_id == "test-user" + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_unauthorized(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + key_info = LiteLLM_VerificationToken( + token="hashed-key", + user_id="other-user", + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._can_user_query_key_info" + ) as mock_can_query: + mock_can_query.return_value = False + + with pytest.raises(HTTPException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="hashed-key", + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.status_code == 403 + assert "not allowed to access this key's info" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_not_found(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + side_effect=Exception("Key not found") + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with pytest.raises(ProxyException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="non-existent-key", + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.code == "403" or exc_info.value.code == 403 + assert "Key Hash not found" in exc_info.value.message diff --git a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py b/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py deleted file mode 100644 index 9ac5072c5d8..00000000000 --- a/tests/test_litellm/router_unit_tests/test_filter_deployments_by_access_groups.py +++ /dev/null @@ -1,227 +0,0 @@ -""" -Unit tests for filter_deployments_by_access_groups function. - -Tests the fix for GitHub issue #18333: Models loadbalanced outside of Model Access Group. -""" - -import pytest - -from litellm.router_utils.common_utils import filter_deployments_by_access_groups - - -class TestFilterDeploymentsByAccessGroups: - """Tests for the filter_deployments_by_access_groups function.""" - - def test_no_filter_when_no_access_groups_in_metadata(self): - """When no allowed_access_groups in metadata, return all deployments.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_team_id": "team-1"}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 2 # All deployments returned - - def test_filter_to_single_access_group(self): - """Filter to only deployments matching allowed access group.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "2" - - def test_filter_with_multiple_allowed_groups(self): - """Filter with multiple allowed access groups.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - {"model_info": {"id": "3", "access_groups": ["AG3"]}}, - ] - request_kwargs = { - "metadata": {"user_api_key_allowed_access_groups": ["AG1", "AG2"]} - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 2 - ids = [d["model_info"]["id"] for d in result] - assert "1" in ids - assert "2" in ids - assert "3" not in ids - - def test_deployment_with_multiple_access_groups(self): - """Deployment with multiple access groups should match if any overlap.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1", "AG2"]}}, - {"model_info": {"id": "2", "access_groups": ["AG3"]}}, - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "1" - - def test_deployment_without_access_groups_included(self): - """Deployments without access groups should be included (not restricted).""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2"}}, # No access_groups - {"model_info": {"id": "3", "access_groups": []}}, # Empty access_groups - ] - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - # Should include deployments 2 and 3 (no restrictions) - assert len(result) == 2 - ids = [d["model_info"]["id"] for d in result] - assert "2" in ids - assert "3" in ids - - def test_dict_deployment_passes_through(self): - """When deployment is a dict (specific deployment), pass through.""" - deployment = {"model_info": {"id": "1", "access_groups": ["AG1"]}} - request_kwargs = {"metadata": {"user_api_key_allowed_access_groups": ["AG2"]}} - - result = filter_deployments_by_access_groups( - healthy_deployments=deployment, - request_kwargs=request_kwargs, - ) - - assert result == deployment # Unchanged - - def test_none_request_kwargs_passes_through(self): - """When request_kwargs is None, return deployments unchanged.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - ] - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=None, - ) - - assert result == deployments - - def test_litellm_metadata_fallback(self): - """Should also check litellm_metadata for allowed access groups.""" - deployments = [ - {"model_info": {"id": "1", "access_groups": ["AG1"]}}, - {"model_info": {"id": "2", "access_groups": ["AG2"]}}, - ] - request_kwargs = { - "litellm_metadata": {"user_api_key_allowed_access_groups": ["AG1"]} - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "1" - - -def test_filter_deployments_by_access_groups_issue_18333(): - """ - Regression test for GitHub issue #18333. - - Scenario: Two models named 'gpt-5' in different access groups (AG1, AG2). - Team2 has access to AG2 only. When Team2 requests 'gpt-5', only the AG2 - deployment should be available for load balancing. - """ - deployments = [ - { - "model_name": "gpt-5", - "litellm_params": {"model": "gpt-4.1", "api_key": "key-1"}, - "model_info": {"id": "ag1-deployment", "access_groups": ["AG1"]}, - }, - { - "model_name": "gpt-5", - "litellm_params": {"model": "gpt-4o", "api_key": "key-2"}, - "model_info": {"id": "ag2-deployment", "access_groups": ["AG2"]}, - }, - ] - - # Team2's request with allowed access groups - request_kwargs = { - "metadata": { - "user_api_key_team_id": "team-2", - "user_api_key_allowed_access_groups": ["AG2"], - } - } - - result = filter_deployments_by_access_groups( - healthy_deployments=deployments, - request_kwargs=request_kwargs, - ) - - # Only AG2 deployment should be returned - assert len(result) == 1 - assert result[0]["model_info"]["id"] == "ag2-deployment" - assert result[0]["litellm_params"]["model"] == "gpt-4o" - - -def test_get_access_groups_from_models(): - """ - Test the helper function that extracts access group names from models list. - This is used by the proxy to populate user_api_key_allowed_access_groups. - """ - from litellm.proxy.auth.model_checks import get_access_groups_from_models - - # Setup: access groups definition - model_access_groups = { - "AG1": ["gpt-4", "gpt-5"], - "AG2": ["claude-v1", "claude-v2"], - "beta-models": ["gpt-5-turbo"], - } - - # Test 1: Extract access groups from models list - models = ["gpt-4", "AG1", "AG2", "some-other-model"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert set(result) == {"AG1", "AG2"} - - # Test 2: No access groups in models list - models = ["gpt-4", "claude-v1", "some-model"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert result == [] - - # Test 3: Empty models list - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=[] - ) - assert result == [] - - # Test 4: All access groups - models = ["AG1", "AG2", "beta-models"] - result = get_access_groups_from_models( - model_access_groups=model_access_groups, models=models - ) - assert set(result) == {"AG1", "AG2", "beta-models"} diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/sso/useSSOSettings.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/sso/useSSOSettings.ts index f03f3977115..0431a8d39f7 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/sso/useSSOSettings.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/sso/useSSOSettings.ts @@ -28,6 +28,7 @@ export interface SSOSettingsValues { user_email: string | null; ui_access_mode: string | null; role_mappings: RoleMappings; + team_mappings: TeamMappings; } export interface RoleMappings { @@ -39,6 +40,10 @@ export interface RoleMappings { }; } +export interface TeamMappings { + team_ids_jwt_field: string; +} + export interface SSOSettingsResponse { values: SSOSettingsValues; field_schema: SSOFieldSchema; diff --git a/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.test.tsx b/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.test.tsx index 3052f790098..6da2f82a2f1 100644 --- a/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.test.tsx +++ b/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.test.tsx @@ -37,12 +37,19 @@ vi.mock("antd", async (importOriginal) => { mode, ...props }: any) => { + // Simulate maxTagCount responsive behavior - if value length > 5, call maxTagPlaceholder + const shouldShowPlaceholder = maxTagCount === "responsive" && Array.isArray(value) && value.length > 5; + const visibleValues = shouldShowPlaceholder ? value.slice(0, 5) : value; + const omittedValues = shouldShowPlaceholder + ? value.slice(5).map((v: string) => ({ value: v, label: v })) + : []; + return (
+ {shouldShowPlaceholder && maxTagPlaceholder && ( +
{maxTagPlaceholder(omittedValues)}
+ )}
); }, @@ -82,6 +92,24 @@ const mockUseTeam = vi.mocked(useTeam); const mockUseOrganization = vi.mocked(useOrganization); const mockUseCurrentUser = vi.mocked(useCurrentUser); +const createMockOrganization = (models: string[]): Organization => ({ + organization_id: "org-1", + organization_alias: "Test Org", + budget_id: "budget-1", + metadata: {}, + models, + spend: 0, + model_spend: {}, + created_at: "2024-01-01", + created_by: "user-1", + updated_at: "2024-01-01", + updated_by: "user-1", + litellm_budget_table: null, + teams: null, + users: null, + members: null, +}); + describe("ModelSelect", () => { const mockProxyModels: ProxyModel[] = [ { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, @@ -112,125 +140,44 @@ describe("ModelSelect", () => { } as any); }); - it("should render", async () => { + it("should render with all option groups", async () => { renderWithProviders( , ); await waitFor(() => { expect(screen.getByTestId("model-select")).toBeInTheDocument(); - }); - }); - - it("should show skeleton loader when loading", () => { - mockUseAllProxyModels.mockReturnValue({ - data: undefined, - isLoading: true, - } as any); - - renderWithProviders(); - - expect(screen.getByTestId("skeleton-input")).toBeInTheDocument(); - expect(screen.queryByTestId("model-select")).not.toBeInTheDocument(); - }); - - it("should show skeleton loader when team is loading", () => { - mockUseTeam.mockReturnValue({ - data: undefined, - isLoading: true, - } as any); - - renderWithProviders(); - - expect(screen.getByTestId("skeleton-input")).toBeInTheDocument(); - }); - - it("should show skeleton loader when organization is loading", () => { - mockUseOrganization.mockReturnValue({ - data: undefined, - isLoading: true, - } as any); - - renderWithProviders(); - - expect(screen.getByTestId("skeleton-input")).toBeInTheDocument(); - }); - - it("should show skeleton loader when current user is loading", () => { - mockUseCurrentUser.mockReturnValue({ - data: undefined, - isLoading: true, - } as any); - - renderWithProviders(); - - expect(screen.getByTestId("skeleton-input")).toBeInTheDocument(); - }); - - it("should render special options group", async () => { - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["all-proxy-models"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; - - mockUseOrganization.mockReturnValue({ - data: mockOrganization, - isLoading: false, - } as any); - - renderWithProviders( - , - ); - - await waitFor(() => { - const select = screen.getByTestId("model-select"); - expect(select).toBeInTheDocument(); - expect(screen.getByText("All Proxy Models")).toBeInTheDocument(); - expect(screen.getByText("No Default Models")).toBeInTheDocument(); - }); - }); - - it("should render wildcard options group", async () => { - renderWithProviders( - , - ); - - await waitFor(() => { + expect(screen.getByText("gpt-4")).toBeInTheDocument(); + expect(screen.getByText("claude-3")).toBeInTheDocument(); expect(screen.getByText("All Openai models")).toBeInTheDocument(); expect(screen.getByText("All Anthropic models")).toBeInTheDocument(); }); }); - it("should render regular models group", async () => { - renderWithProviders( - , - ); + it("should show skeleton loader when any data is loading", () => { + const loadingScenarios = [ + { hook: mockUseAllProxyModels, context: "user" as const }, + { hook: mockUseTeam, context: "team" as const, props: { teamID: "team-1" } }, + { hook: mockUseOrganization, context: "organization" as const, props: { organizationID: "org-1" } }, + { hook: mockUseCurrentUser, context: "user" as const }, + ]; - await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.getByText("claude-3")).toBeInTheDocument(); + loadingScenarios.forEach(({ hook, context, props = {} }) => { + hook.mockReturnValue({ + data: undefined, + isLoading: true, + } as any); + + const { unmount } = renderWithProviders( + , + ); + + expect(screen.getByTestId("skeleton-input")).toBeInTheDocument(); + unmount(); }); }); - it("should call onChange when selecting a regular model", async () => { + it("should handle model selection and onChange", async () => { const user = userEvent.setup(); renderWithProviders( , @@ -242,32 +189,16 @@ describe("ModelSelect", () => { const select = screen.getByRole("listbox"); await user.selectOptions(select, "gpt-4"); - expect(mockOnChange).toHaveBeenCalledWith(["gpt-4"]); + + await user.selectOptions(select, ["gpt-4", "claude-3"]); + expect(mockOnChange).toHaveBeenCalled(); }); - it("should call onChange with only last special option when multiple special options are selected", async () => { + it("should handle special options correctly", async () => { const user = userEvent.setup(); - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["all-proxy-models"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; - mockUseOrganization.mockReturnValue({ - data: mockOrganization, + data: createMockOrganization(["all-proxy-models"]), isLoading: false, } as any); @@ -281,16 +212,16 @@ describe("ModelSelect", () => { ); await waitFor(() => { - expect(screen.getByTestId("model-select")).toBeInTheDocument(); + expect(screen.getByText("All Proxy Models")).toBeInTheDocument(); + expect(screen.getByText("No Default Models")).toBeInTheDocument(); }); const select = screen.getByRole("listbox"); await user.selectOptions(select, ["all-proxy-models", "no-default-models"]); - expect(mockOnChange).toHaveBeenCalledWith(["no-default-models"]); }); - it("should disable regular models when special option is selected", async () => { + it("should disable models when special option is selected", async () => { renderWithProviders( { ); await waitFor(() => { - const gpt4Option = screen.getByRole("option", { name: "gpt-4" }); - expect(gpt4Option).toBeDisabled(); + expect(screen.getByRole("option", { name: "gpt-4" })).toBeDisabled(); + expect(screen.getByRole("option", { name: "All Openai models" })).toBeDisabled(); }); }); - it("should disable wildcard models when special option is selected", async () => { - renderWithProviders( - , - ); - - await waitFor(() => { - const openaiWildcardOption = screen.getByRole("option", { name: "All Openai models" }); - expect(openaiWildcardOption).toBeDisabled(); - }); - }); - - it("should disable other special options when one special option is selected", async () => { - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["all-proxy-models"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; - - mockUseOrganization.mockReturnValue({ - data: mockOrganization, - isLoading: false, - } as any); - - renderWithProviders( - , - ); - - await waitFor(() => { - const noDefaultOption = screen.getByRole("option", { name: "No Default Models" }); - expect(noDefaultOption).toBeDisabled(); - }); - }); + it("should filter models based on context", async () => { + const testCases = [ + { + name: "user context with includeUserModels", + context: "user" as const, + options: { includeUserModels: true }, + setup: () => { + mockUseCurrentUser.mockReturnValue({ + data: { models: ["gpt-4"] }, + isLoading: false, + } as any); + }, + expectedVisible: ["gpt-4"], + expectedHidden: ["claude-3"], + }, + { + name: "user context without includeUserModels", + context: "user" as const, + options: {}, + setup: () => { + mockUseCurrentUser.mockReturnValue({ + data: { models: ["gpt-4"] }, + isLoading: false, + } as any); + }, + expectedVisible: [], + expectedHidden: ["gpt-4", "claude-3"], + }, + { + name: "team context without organization", + context: "team" as const, + options: {}, + props: { teamID: "team-1" }, + setup: () => { + mockUseTeam.mockReturnValue({ + data: { team_id: "team-1", team_alias: "Test Team", models: [] }, + isLoading: false, + } as any); + mockUseOrganization.mockReturnValue({ + data: undefined, + isLoading: false, + } as any); + }, + expectedVisible: ["gpt-4", "claude-3"], + expectedHidden: [], + }, + { + name: "team context with organization having all-proxy-models", + context: "team" as const, + options: {}, + props: { teamID: "team-1", organizationID: "org-1" }, + setup: () => { + mockUseTeam.mockReturnValue({ + data: { team_id: "team-1", team_alias: "Test Team", models: [] }, + isLoading: false, + } as any); + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["all-proxy-models"]), + isLoading: false, + } as any); + }, + expectedVisible: ["gpt-4", "claude-3"], + expectedHidden: [], + }, + { + name: "team context with organization filtering models", + context: "team" as const, + options: {}, + props: { teamID: "team-1", organizationID: "org-1" }, + setup: () => { + mockUseTeam.mockReturnValue({ + data: { team_id: "team-1", team_alias: "Test Team", models: [] }, + isLoading: false, + } as any); + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["gpt-4"]), + isLoading: false, + } as any); + }, + expectedVisible: ["gpt-4"], + expectedHidden: ["claude-3"], + }, + { + name: "organization context", + context: "organization" as const, + options: {}, + props: { organizationID: "org-1" }, + setup: () => { + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["gpt-4"]), + isLoading: false, + } as any); + }, + expectedVisible: ["gpt-4", "claude-3"], + expectedHidden: [], + }, + { + name: "global context", + context: "global" as const, + options: {}, + setup: () => { }, + expectedVisible: ["gpt-4", "claude-3"], + expectedHidden: [], + }, + ]; - it("should filter models when showAllProxyModelsOverride is true", async () => { - renderWithProviders( - , - ); + for (const testCase of testCases) { + testCase.setup(); + const { unmount } = renderWithProviders( + , + ); - await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.getByText("claude-3")).toBeInTheDocument(); - }); + await waitFor(() => { + testCase.expectedVisible.forEach((model) => { + expect(screen.getByText(model)).toBeInTheDocument(); + }); + testCase.expectedHidden.forEach((model) => { + expect(screen.queryByText(model)).not.toBeInTheDocument(); + }); + }); + + unmount(); + vi.clearAllMocks(); + mockUseAllProxyModels.mockReturnValue({ + data: { data: mockProxyModels }, + isLoading: false, + } as any); + } }); - it("should filter models when organization has all-proxy-models in models array", async () => { - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["all-proxy-models"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; - - mockUseOrganization.mockReturnValue({ - data: mockOrganization, - isLoading: false, - } as any); + it("should show All Proxy Models option based on conditions", async () => { + const testCases = [ + { + name: "when showAllProxyModelsOverride is true", + context: "user" as const, + options: { showAllProxyModelsOverride: true, includeSpecialOptions: true }, + setup: () => { }, + shouldShow: true, + }, + { + name: "when organization has all-proxy-models", + context: "organization" as const, + options: { includeSpecialOptions: true }, + props: { organizationID: "org-1" }, + setup: () => { + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["all-proxy-models"]), + isLoading: false, + } as any); + }, + shouldShow: true, + }, + { + name: "when organization has empty models array", + context: "organization" as const, + options: { includeSpecialOptions: true }, + props: { organizationID: "org-1" }, + setup: () => { + mockUseOrganization.mockReturnValue({ + data: createMockOrganization([]), + isLoading: false, + } as any); + }, + shouldShow: true, + }, + { + name: "when context is global", + context: "global" as const, + options: { includeSpecialOptions: true }, + setup: () => { }, + shouldShow: true, + }, + { + name: "when organization has specific models", + context: "organization" as const, + options: { includeSpecialOptions: true }, + props: { organizationID: "org-1" }, + setup: () => { + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["gpt-4"]), + isLoading: false, + } as any); + }, + shouldShow: false, + }, + ]; - renderWithProviders(); + for (const testCase of testCases) { + testCase.setup(); + const { unmount } = renderWithProviders( + , + ); - await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.getByText("claude-3")).toBeInTheDocument(); - }); + await waitFor(() => { + if (testCase.shouldShow) { + expect(screen.getByText("All Proxy Models")).toBeInTheDocument(); + } else { + expect(screen.queryByText("All Proxy Models")).not.toBeInTheDocument(); + expect(screen.getByText("No Default Models")).toBeInTheDocument(); + } + }); + + unmount(); + vi.clearAllMocks(); + mockUseAllProxyModels.mockReturnValue({ + data: { data: mockProxyModels }, + isLoading: false, + } as any); + } }); - it("should show all models when organization context is used", async () => { - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["gpt-4"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; + it("should deduplicate models with same id", async () => { + const duplicateModels: ProxyModel[] = [ + { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, + { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, + ]; - mockUseOrganization.mockReturnValue({ - data: mockOrganization, + mockUseAllProxyModels.mockReturnValue({ + data: { data: duplicateModels }, isLoading: false, } as any); - renderWithProviders(); + renderWithProviders( + , + ); await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.getByText("claude-3")).toBeInTheDocument(); + const gpt4Options = screen.getAllByText("gpt-4"); + expect(gpt4Options.length).toBeGreaterThan(0); }); }); @@ -452,113 +494,77 @@ describe("ModelSelect", () => { }); }); - it("should handle multiple model selections", async () => { - const user = userEvent.setup(); - renderWithProviders( - , - ); - - await waitFor(() => { - expect(screen.getByTestId("model-select")).toBeInTheDocument(); - }); - - const select = screen.getByRole("listbox"); - await user.selectOptions(select, "gpt-4"); - expect(mockOnChange).toHaveBeenCalledWith(["gpt-4"]); - - await user.selectOptions(select, "claude-3"); - expect(mockOnChange).toHaveBeenCalled(); - const allCalls = mockOnChange.mock.calls.map((call) => call[0]); - expect(allCalls.some((call) => Array.isArray(call) && call.includes("gpt-4"))).toBe(true); - expect(allCalls.some((call) => Array.isArray(call) && call.includes("claude-3"))).toBe(true); - }); - - it("should capitalize provider name in wildcard options", async () => { - renderWithProviders( - , - ); - - await waitFor(() => { - expect(screen.getByText("All Openai models")).toBeInTheDocument(); - expect(screen.getByText("All Anthropic models")).toBeInTheDocument(); - }); - }); - - it("should deduplicate models with same id", async () => { - const duplicateModels: ProxyModel[] = [ - { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, - { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, - ]; + it("should return all proxy models for team context when organization has empty models array", async () => { + mockUseTeam.mockReturnValue({ + data: { team_id: "team-1", team_alias: "Test Team", models: [] }, + isLoading: false, + } as any); - mockUseAllProxyModels.mockReturnValue({ - data: { data: duplicateModels }, + mockUseOrganization.mockReturnValue({ + data: createMockOrganization([]), isLoading: false, } as any); - renderWithProviders( - , - ); + renderWithProviders(); await waitFor(() => { - const gpt4Options = screen.getAllByText("gpt-4"); - expect(gpt4Options.length).toBeGreaterThan(0); + expect(screen.getByText("gpt-4")).toBeInTheDocument(); + expect(screen.getByText("claude-3")).toBeInTheDocument(); }); }); - it("should filter models based on user context with includeUserModels option", async () => { - mockUseCurrentUser.mockReturnValue({ - data: { models: ["gpt-4"] }, + it("should disable No Default Models when all-proxy-models is selected", async () => { + mockUseOrganization.mockReturnValue({ + data: createMockOrganization(["all-proxy-models"]), isLoading: false, } as any); - renderWithProviders(); + renderWithProviders( + , + ); await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.queryByText("claude-3")).not.toBeInTheDocument(); + const noDefaultOption = screen.getByRole("option", { name: "No Default Models" }); + expect(noDefaultOption).toBeDisabled(); }); }); - it("should filter models based on team context", async () => { - const mockTeam = { - team_id: "team-1", - team_alias: "Test Team", - models: ["gpt-4"], - }; - - const mockOrganization: Organization = { - organization_id: "org-1", - organization_alias: "Test Org", - budget_id: "budget-1", - metadata: {}, - models: ["gpt-4"], - spend: 0, - model_spend: {}, - created_at: "2024-01-01", - created_by: "user-1", - updated_at: "2024-01-01", - updated_by: "user-1", - litellm_budget_table: null, - teams: null, - users: null, - members: null, - }; + it("should render maxTagPlaceholder when many items are selected", async () => { + // Create many models to trigger maxTagCount responsive behavior + const manyModels: ProxyModel[] = Array.from({ length: 20 }, (_, i) => ({ + id: `model-${i}`, + object: "model", + created: 1234567890, + owned_by: "test", + })); - mockUseTeam.mockReturnValue({ - data: mockTeam, + mockUseAllProxyModels.mockReturnValue({ + data: { data: manyModels }, isLoading: false, } as any); - mockUseOrganization.mockReturnValue({ - data: mockOrganization, - isLoading: false, - } as any); + const selectedValues = manyModels.slice(0, 10).map((m) => m.id); - renderWithProviders(); + renderWithProviders( + , + ); await waitFor(() => { - expect(screen.getByText("gpt-4")).toBeInTheDocument(); - expect(screen.queryByText("claude-3")).not.toBeInTheDocument(); + expect(screen.getByTestId("model-select")).toBeInTheDocument(); + // Verify maxTagPlaceholder is rendered with omitted values + expect(screen.getByTestId("max-tag-placeholder")).toBeInTheDocument(); + expect(screen.getByText(/\+5 more/)).toBeInTheDocument(); }); }); }); diff --git a/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.tsx b/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.tsx index 78ccdddd81b..2b7399c4565 100644 --- a/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.tsx +++ b/ui/litellm-dashboard/src/components/ModelSelect/ModelSelect.tsx @@ -30,10 +30,11 @@ export interface ModelSelectProps { showAllProxyModelsOverride?: boolean; includeSpecialOptions?: boolean; }; - context: "team" | "organization" | "user"; + context: "team" | "organization" | "user" | "global"; dataTestId?: string; value?: string[]; onChange: (values: string[]) => void; + style?: React.CSSProperties; } type FilterContextArgs = { @@ -65,6 +66,10 @@ const contextFilters: Record { return allProxyModels; }, + + global: ({ allProxyModels }) => { + return allProxyModels; + }, }; const filterModels = ( @@ -84,7 +89,7 @@ const filterModels = ( }; export const ModelSelect = (props: ModelSelectProps) => { - const { teamID, organizationID, options, context, dataTestId, value = [], onChange } = props; + const { teamID, organizationID, options, context, dataTestId, value = [], onChange, style } = props; const { includeUserModels, showAllTeamModelsOption, showAllProxyModelsOverride, includeSpecialOptions } = options || {}; const { data: allProxyModels, isLoading: isLoadingAllProxyModels } = useAllProxyModels(); @@ -98,7 +103,7 @@ export const ModelSelect = (props: ModelSelectProps) => { const organizationHasAllProxyModels = organization?.models.includes(MODEL_SELECT_ALL_PROXY_MODELS_SPECIAL_VALUE.value) || organization?.models.length === 0; const shouldShowAllProxyModels = showAllProxyModelsOverride || - (organizationHasAllProxyModels && includeSpecialOptions); + (organizationHasAllProxyModels && includeSpecialOptions) || context === "global"; if (isLoading) { return ; @@ -134,6 +139,7 @@ export const ModelSelect = (props: ModelSelectProps) => { data-testid={dataTestId} value={value} onChange={handleChange} + style={style} options={[ includeSpecialOptions ? { diff --git a/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.test.tsx b/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.test.tsx new file mode 100644 index 00000000000..6994def858b --- /dev/null +++ b/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.test.tsx @@ -0,0 +1,50 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { renderWithProviders, screen } from "../../../../tests/test-utils"; +import { CommunityEngagementButtons } from "./CommunityEngagementButtons"; + +let mockUseDisableShowPromptsImpl = () => false; + +vi.mock("@/app/(dashboard)/hooks/useDisableShowPrompts", () => ({ + useDisableShowPrompts: () => mockUseDisableShowPromptsImpl(), +})); + +describe("CommunityEngagementButtons", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockUseDisableShowPromptsImpl = () => false; + }); + + it("should render", () => { + renderWithProviders(); + expect(screen.getByRole("link", { name: /join slack/i })).toBeInTheDocument(); + }); + + it("should render Join Slack button with correct link", () => { + renderWithProviders(); + + const joinSlackLink = screen.getByRole("link", { name: /join slack/i }); + expect(joinSlackLink).toBeInTheDocument(); + expect(joinSlackLink).toHaveAttribute("href", "https://www.litellm.ai/support"); + expect(joinSlackLink).toHaveAttribute("target", "_blank"); + expect(joinSlackLink).toHaveAttribute("rel", "noopener noreferrer"); + }); + + it("should render Star us on GitHub button with correct link", () => { + renderWithProviders(); + + const starOnGithubLink = screen.getByRole("link", { name: /star us on github/i }); + expect(starOnGithubLink).toBeInTheDocument(); + expect(starOnGithubLink).toHaveAttribute("href", "https://github.com/BerriAI/litellm"); + expect(starOnGithubLink).toHaveAttribute("target", "_blank"); + expect(starOnGithubLink).toHaveAttribute("rel", "noopener noreferrer"); + }); + + it("should not render buttons when prompts are disabled", () => { + mockUseDisableShowPromptsImpl = () => true; + + renderWithProviders(); + + expect(screen.queryByRole("link", { name: /join slack/i })).not.toBeInTheDocument(); + expect(screen.queryByRole("link", { name: /star us on github/i })).not.toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.tsx b/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.tsx new file mode 100644 index 00000000000..649bcc0b589 --- /dev/null +++ b/ui/litellm-dashboard/src/components/Navbar/CommunityEngagementButtons/CommunityEngagementButtons.tsx @@ -0,0 +1,36 @@ +import { useDisableShowPrompts } from "@/app/(dashboard)/hooks/useDisableShowPrompts"; +import { GithubOutlined, SlackOutlined } from "@ant-design/icons"; +import { Button } from "antd"; +import React from "react"; + +export const CommunityEngagementButtons: React.FC = () => { + const disableShowPrompts = useDisableShowPrompts(); + + // Hide buttons if prompts are disabled + if (disableShowPrompts) { + return null; + } + + return ( + <> + + + + ); +}; diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.test.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.test.tsx index a885bffa710..c68e2716f5b 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.test.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.test.tsx @@ -151,6 +151,117 @@ describe("BaseSSOSettingsForm", () => { expect(screen.getByText("Default Role")).toBeInTheDocument(); }); }); + + it("should show team mappings checkbox for okta provider", async () => { + const TestWrapper = () => { + const [form] = Form.useForm(); + const handleSubmit = vi.fn(); + + return ; + }; + + renderWithProviders(); + + const providerSelect = screen.getByLabelText("SSO Provider"); + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + await waitFor(() => { + const oktaOption = screen.getByText(/okta/i); + fireEvent.click(oktaOption); + }); + + await waitFor(() => { + expect(screen.getByText("Use Team Mappings")).toBeInTheDocument(); + }); + }); + + it("should show team mappings checkbox for generic provider", async () => { + const TestWrapper = () => { + const [form] = Form.useForm(); + const handleSubmit = vi.fn(); + + return ; + }; + + renderWithProviders(); + + const providerSelect = screen.getByLabelText("SSO Provider"); + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + await waitFor(() => { + const genericOption = screen.getByText(/generic sso/i); + fireEvent.click(genericOption); + }); + + await waitFor(() => { + expect(screen.getByText("Use Team Mappings")).toBeInTheDocument(); + }); + }); + + it("should show team IDs JWT field when use_team_mappings is checked for okta provider", async () => { + const TestWrapper = () => { + const [form] = Form.useForm(); + const handleSubmit = vi.fn(); + + return ; + }; + + renderWithProviders(); + + const providerSelect = screen.getByLabelText("SSO Provider"); + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + await waitFor(() => { + const oktaOption = screen.getByText(/okta/i); + fireEvent.click(oktaOption); + }); + + await waitFor(() => { + expect(screen.getByText("Use Team Mappings")).toBeInTheDocument(); + }); + + const checkbox = screen.getByLabelText("Use Team Mappings"); + await act(async () => { + fireEvent.click(checkbox); + }); + + await waitFor(() => { + expect(screen.getByText("Team IDs JWT Field")).toBeInTheDocument(); + }); + }); + + it("should not show team mappings checkbox for google provider", async () => { + const TestWrapper = () => { + const [form] = Form.useForm(); + const handleSubmit = vi.fn(); + + return ; + }; + + renderWithProviders(); + + const providerSelect = screen.getByLabelText("SSO Provider"); + await act(async () => { + fireEvent.mouseDown(providerSelect); + }); + + await waitFor(() => { + const googleOption = screen.getByText(/google sso/i); + fireEvent.click(googleOption); + }); + + await waitFor(() => { + expect(screen.getByText("Google Client ID")).toBeInTheDocument(); + }); + + expect(screen.queryByText("Use Team Mappings")).not.toBeInTheDocument(); + }); }); describe("renderProviderFields", () => { diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.tsx index 6431b2dd3ac..d16b04466e0 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/BaseSSOSettingsForm.tsx @@ -251,6 +251,43 @@ const BaseSSOSettingsForm: React.FC = ({ form, onFormS ) : null; }} + + prevValues.sso_provider !== currentValues.sso_provider} + > + {({ getFieldValue }) => { + const provider = getFieldValue("sso_provider"); + return provider === "okta" || provider === "generic" ? ( + + + + ) : null; + }} + + + + prevValues.use_team_mappings !== currentValues.use_team_mappings || + prevValues.sso_provider !== currentValues.sso_provider + } + > + {({ getFieldValue }) => { + const useTeamMappings = getFieldValue("use_team_mappings"); + const provider = getFieldValue("sso_provider"); + const supportsTeamMappings = provider === "okta" || provider === "generic"; + return useTeamMappings && supportsTeamMappings ? ( + + + + ) : null; + }} + ); diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/DeleteSSOSettingsModal.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/DeleteSSOSettingsModal.tsx index 44cbf0020eb..2656c861aa8 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/DeleteSSOSettingsModal.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/DeleteSSOSettingsModal.tsx @@ -33,6 +33,7 @@ const DeleteSSOSettingsModal: React.FC = ({ isVisib user_email: null, sso_provider: null, role_mappings: null, + team_mappings: null, }; await editSSOSettings(clearSettings, { diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.test.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.test.tsx index 559d837b409..d2d54033395 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.test.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.test.tsx @@ -105,6 +105,14 @@ const createRoleMappingsSSOData = (overrides: Record = {}) => ...overrides, }); +const createTeamMappingsSSOData = (overrides: Record = {}) => + createGenericSSOData({ + team_mappings: { + team_ids_jwt_field: overrides.team_ids_jwt_field || "teams", + }, + ...overrides, + }); + // Mock utilities const createMockHooks = (): { useSSOSettings: SSOSettingsHookReturn; @@ -577,6 +585,104 @@ describe("EditSSOSettingsModal", () => { }); }); }); + }); + + describe("Team Mappings", () => { + it("processes team mappings when team_mappings exists", async () => { + const ssoData = createTeamMappingsSSOData(); + + setupMocks({ + useSSOSettings: { data: ssoData, isLoading: false, error: null }, + }); + + renderComponent(); + + await waitFor(() => { + expect(mockForm.setFieldsValue).toHaveBeenCalledWith({ + sso_provider: SSO_PROVIDERS.GENERIC, + ...ssoData.values, + use_team_mappings: true, + team_ids_jwt_field: "teams", + }); + }); + }); + + it("handles team mappings with custom JWT field name", async () => { + const ssoData = createTeamMappingsSSOData({ + team_ids_jwt_field: "custom_teams_field", + }); + + setupMocks({ + useSSOSettings: { data: ssoData, isLoading: false, error: null }, + }); + + renderComponent(); + + await waitFor(() => { + expect(mockForm.setFieldsValue).toHaveBeenCalledWith({ + sso_provider: SSO_PROVIDERS.GENERIC, + ...ssoData.values, + use_team_mappings: true, + team_ids_jwt_field: "custom_teams_field", + }); + }); + }); + + it("handles team mappings and role mappings together", async () => { + const ssoData = createGenericSSOData({ + role_mappings: { + group_claim: "groups", + default_role: "internal_user", + roles: { + proxy_admin: ["admin-group"], + proxy_admin_viewer: [], + internal_user: [], + internal_user_viewer: [], + }, + }, + team_mappings: { + team_ids_jwt_field: "teams", + }, + }); + + setupMocks({ + useSSOSettings: { data: ssoData, isLoading: false, error: null }, + }); + + renderComponent(); + + await waitFor(() => { + expect(mockForm.setFieldsValue).toHaveBeenCalledWith({ + sso_provider: SSO_PROVIDERS.GENERIC, + ...ssoData.values, + use_role_mappings: true, + group_claim: "groups", + default_role: "internal_user", + proxy_admin_teams: "admin-group", + admin_viewer_teams: "", + internal_user_teams: "", + internal_viewer_teams: "", + use_team_mappings: true, + team_ids_jwt_field: "teams", + }); + }); + }); + + it("does not set team mapping fields when team_mappings is not present", async () => { + const ssoData = createGenericSSOData(); + + setupMocks({ + useSSOSettings: { data: ssoData, isLoading: false, error: null }, + }); + + renderComponent(); + + await waitFor(() => { + const callArgs = mockForm.setFieldsValue.mock.calls[0][0]; + expect(callArgs.use_team_mappings).toBeUndefined(); + expect(callArgs.team_ids_jwt_field).toBeUndefined(); + }); + }); it("handles provider detection with partial SSO data", async () => { const ssoData = createSSOData({ diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.tsx index a731af68ff1..bbae8f1451a 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/Modals/EditSSOSettingsModal.tsx @@ -68,11 +68,22 @@ const EditSSOSettingsModal: React.FC = ({ isVisible, }; } + // Extract team mappings if they exist + let teamMappingFields = {}; + if (ssoData.values.team_mappings) { + const teamMappings = ssoData.values.team_mappings; + teamMappingFields = { + use_team_mappings: true, + team_ids_jwt_field: teamMappings.team_ids_jwt_field, + }; + } + // Set form values with existing data (excluding UI access control fields) const formValues = { sso_provider: selectedProvider, ...ssoData.values, ...roleMappingFields, + ...teamMappingFields, }; console.log("Setting form values:", formValues); // Debug log diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/SSOSettings.tsx b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/SSOSettings.tsx index adc1251cde2..fdeda0ece3e 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/SSOSettings.tsx +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/SSOSettings.tsx @@ -1,7 +1,7 @@ "use client"; import { useSSOSettings, type SSOSettingsValues } from "@/app/(dashboard)/hooks/sso/useSSOSettings"; -import { Button, Card, Descriptions, Space, Typography } from "antd"; +import { Button, Card, Descriptions, Space, Tag, Typography } from "antd"; import { Edit, Shield, Trash2 } from "lucide-react"; import { useState } from "react"; import { ssoProviderDisplayNames, ssoProviderLogoMap } from "./constants"; @@ -28,6 +28,7 @@ export default function SSOSettings() { const selectedProvider = ssoSettings?.values ? detectSSOProvider(ssoSettings.values) : null; const isRoleMappingsEnabled = Boolean(ssoSettings?.values.role_mappings); + const isTeamMappingsEnabled = Boolean(ssoSettings?.values.team_mappings); const renderEndpointValue = (value?: string | null) => ( @@ -38,6 +39,15 @@ export default function SSOSettings() { const renderSimpleValue = (value?: string | null) => value ? value : Not configured; + const renderTeamMappingsField = (values: SSOSettingsValues) => { + if (!values.team_mappings?.team_ids_jwt_field) { + return Not configured; + } + return ( + {values.team_mappings.team_ids_jwt_field} + ); + }; + const descriptionsConfig = { column: { xxl: 1, @@ -103,6 +113,10 @@ export default function SSOSettings() { render: (values: SSOSettingsValues) => renderEndpointValue(values.generic_userinfo_endpoint), }, { label: "Proxy Base URL", render: (values: SSOSettingsValues) => renderSimpleValue(values.proxy_base_url) }, + isTeamMappingsEnabled ? { + label: "Team IDs JWT Field", + render: (values: SSOSettingsValues) => renderTeamMappingsField(values), + } : null, ], }, generic: { @@ -129,6 +143,10 @@ export default function SSOSettings() { render: (values: SSOSettingsValues) => renderEndpointValue(values.generic_userinfo_endpoint), }, { label: "Proxy Base URL", render: (values: SSOSettingsValues) => renderSimpleValue(values.proxy_base_url) }, + isTeamMappingsEnabled ? { + label: "Team IDs JWT Field", + render: (values: SSOSettingsValues) => renderTeamMappingsField(values), + } : null, ], }, }; @@ -155,7 +173,7 @@ export default function SSOSettings() { {config.providerText} - {config.fields.map((field, index) => ( + {config.fields.map((field, index) => field && ( {field.render(values)} diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.test.ts b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.test.ts index 718302d35fe..722d52d64f9 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.test.ts +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.test.ts @@ -12,6 +12,8 @@ describe("processSSOSettingsPayload", () => { default_role: "proxy_admin", group_claim: "groups", use_role_mappings: false, + use_team_mappings: false, + team_ids_jwt_field: "teams", other_field: "value", another_field: 123, }; @@ -23,6 +25,7 @@ describe("processSSOSettingsPayload", () => { another_field: 123, }); expect(result.role_mappings).toBeUndefined(); + expect(result.team_mappings).toBeUndefined(); }); it("should return all fields except role mapping fields when use_role_mappings is not present", () => { @@ -33,6 +36,8 @@ describe("processSSOSettingsPayload", () => { internal_viewer_teams: "viewer1", default_role: "proxy_admin", group_claim: "groups", + use_team_mappings: false, + team_ids_jwt_field: "teams", other_field: "value", }; @@ -42,6 +47,7 @@ describe("processSSOSettingsPayload", () => { other_field: "value", }); expect(result.role_mappings).toBeUndefined(); + expect(result.team_mappings).toBeUndefined(); }); }); @@ -253,6 +259,143 @@ describe("processSSOSettingsPayload", () => { }); }); + describe("without team mappings", () => { + it("should return all fields except team mapping fields when use_team_mappings is false", () => { + const formValues = { + use_team_mappings: false, + team_ids_jwt_field: "teams", + sso_provider: "okta", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result).toEqual({ + sso_provider: "okta", + other_field: "value", + }); + expect(result.team_mappings).toBeUndefined(); + }); + + it("should return all fields except team mapping fields when use_team_mappings is not present", () => { + const formValues = { + team_ids_jwt_field: "teams", + sso_provider: "generic", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result).toEqual({ + sso_provider: "generic", + other_field: "value", + }); + expect(result.team_mappings).toBeUndefined(); + }); + + it("should not include team mappings for unsupported providers even when use_team_mappings is true", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "teams", + sso_provider: "google", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result).toEqual({ + sso_provider: "google", + other_field: "value", + }); + expect(result.team_mappings).toBeUndefined(); + }); + + it("should not include team mappings for microsoft provider even when use_team_mappings is true", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "teams", + sso_provider: "microsoft", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result).toEqual({ + sso_provider: "microsoft", + other_field: "value", + }); + expect(result.team_mappings).toBeUndefined(); + }); + }); + + describe("with team mappings enabled", () => { + it("should create team mappings for okta provider when use_team_mappings is true", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "teams", + sso_provider: "okta", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result.other_field).toBe("value"); + expect(result.team_mappings).toEqual({ + team_ids_jwt_field: "teams", + }); + }); + + it("should create team mappings for generic provider when use_team_mappings is true", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "custom_teams", + sso_provider: "generic", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result.other_field).toBe("value"); + expect(result.team_mappings).toEqual({ + team_ids_jwt_field: "custom_teams", + }); + }); + + it("should exclude team mapping fields from payload when team mappings are included", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "teams", + sso_provider: "okta", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result.use_team_mappings).toBeUndefined(); + expect(result.team_ids_jwt_field).toBeUndefined(); + }); + + it("should handle team mappings and role mappings together", () => { + const formValues = { + use_team_mappings: true, + team_ids_jwt_field: "teams", + use_role_mappings: true, + group_claim: "groups", + default_role: "internal_user", + sso_provider: "okta", + other_field: "value", + }; + + const result = processSSOSettingsPayload(formValues); + + expect(result.team_mappings).toEqual({ + team_ids_jwt_field: "teams", + }); + expect(result.role_mappings).toBeDefined(); + expect(result.role_mappings.group_claim).toBe("groups"); + }); + }); + describe("edge cases", () => { it("should handle empty form values", () => { const result = processSSOSettingsPayload({}); @@ -263,6 +406,7 @@ describe("processSSOSettingsPayload", () => { it("should preserve other fields in the payload", () => { const formValues = { use_role_mappings: false, + use_team_mappings: false, sso_provider: "google", client_id: "123", client_secret: "secret", diff --git a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.ts b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.ts index c199048df3e..948ed4d2bfe 100644 --- a/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.ts +++ b/ui/litellm-dashboard/src/components/Settings/AdminSettings/SSOSettings/utils.ts @@ -13,6 +13,8 @@ export const processSSOSettingsPayload = (formValues: Record): Reco default_role, group_claim, use_role_mappings, + use_team_mappings, + team_ids_jwt_field, ...rest } = formValues; @@ -21,7 +23,9 @@ export const processSSOSettingsPayload = (formValues: Record): Reco }; // Add role mappings only if use_role_mappings is checked AND provider supports role mappings - if (use_role_mappings) { + const provider = rest.sso_provider; + const supportsRoleMappings = provider === "okta" || provider === "generic"; + if (use_role_mappings && supportsRoleMappings) { // Helper function to split comma-separated string into array const splitTeams = (teams: string | undefined): string[] => { if (!teams || teams.trim() === "") return []; @@ -52,6 +56,14 @@ export const processSSOSettingsPayload = (formValues: Record): Reco }; } + // Add team mappings only if use_team_mappings is checked AND provider supports team mappings + const supportsTeamMappings = provider === "okta" || provider === "generic"; + if (use_team_mappings && supportsTeamMappings) { + payload.team_mappings = { + team_ids_jwt_field: team_ids_jwt_field, + }; + } + return payload; }; diff --git a/ui/litellm-dashboard/src/components/TeamSSOSettings.test.tsx b/ui/litellm-dashboard/src/components/TeamSSOSettings.test.tsx index f5e43fc3d5f..34085df8f10 100644 --- a/ui/litellm-dashboard/src/components/TeamSSOSettings.test.tsx +++ b/ui/litellm-dashboard/src/components/TeamSSOSettings.test.tsx @@ -1,63 +1,653 @@ -import { screen } from "@testing-library/react"; +import React from "react"; +import { screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; import { beforeEach, describe, expect, it, vi } from "vitest"; import { renderWithProviders } from "../../tests/test-utils"; import TeamSSOSettings from "./TeamSSOSettings"; import * as networking from "./networking"; +import NotificationsManager from "./molecules/notifications_manager"; -// Mock the networking functions vi.mock("./networking"); -// Mock the budget duration dropdown +vi.mock("@tremor/react", async (importOriginal) => { + const actual = await importOriginal(); + const React = await import("react"); + return { + ...actual, + Card: ({ children }: { children: React.ReactNode }) => React.createElement("div", { "data-testid": "card" }, children), + Title: ({ children }: { children: React.ReactNode }) => React.createElement("h2", {}, children), + Text: ({ children }: { children: React.ReactNode }) => React.createElement("span", {}, children), + Divider: () => React.createElement("hr", {}), + TextInput: ({ value, onChange, placeholder, className }: any) => + React.createElement("input", { + type: "text", + value: value || "", + onChange, + placeholder, + className, + }), + }; +}); + vi.mock("./common_components/budget_duration_dropdown", () => ({ default: ({ value, onChange }: { value: string | null; onChange: (value: string) => void }) => ( - onChange(e.target.value)} + aria-label="Budget duration" + > ), - getBudgetDurationLabel: vi.fn((value: string) => value), + getBudgetDurationLabel: vi.fn((value: string) => `Budget: ${value}`), })); -// Mock the model display name helper vi.mock("./key_team_helpers/fetch_available_models_team_key", () => ({ getModelDisplayName: vi.fn((model: string) => model), })); +vi.mock("./ModelSelect/ModelSelect", () => ({ + ModelSelect: ({ value, onChange }: { value: string[]; onChange: (value: string[]) => void }) => ( + + ), +})); + +vi.mock("antd", async (importOriginal) => { + const actual = await importOriginal(); + const React = await import("react"); + const SelectComponent = ({ + value, + onChange, + mode, + children, + className, + }: { + value: any; + onChange: (value: any) => void; + mode?: string; + children: React.ReactNode; + className?: string; + }) => { + const isMultiple = mode === "multiple"; + const selectValue = isMultiple ? (Array.isArray(value) ? value : []) : value || ""; + return React.createElement( + "select", + { + multiple: isMultiple, + value: selectValue, + onChange: (e: React.ChangeEvent) => { + const selectedValues = Array.from(e.target.selectedOptions, (option) => option.value); + onChange(isMultiple ? selectedValues : selectedValues[0] || undefined); + }, + className, + "aria-label": "Select", + role: "listbox", + }, + children, + ); + }; + SelectComponent.Option = ({ value: optionValue, children: optionChildren }: { value: string; children: React.ReactNode }) => + React.createElement("option", { value: optionValue }, optionChildren); + return { + ...actual, + Spin: ({ size }: { size?: string }) => React.createElement("div", { "data-testid": "spinner", "data-size": size }), + Switch: ({ checked, onChange }: { checked: boolean; onChange: (checked: boolean) => void }) => + React.createElement("input", { + type: "checkbox", + role: "switch", + checked: checked, + onChange: (e) => onChange(e.target.checked), + "aria-label": "Toggle switch", + }), + Select: SelectComponent, + Typography: { + Paragraph: ({ children }: { children: React.ReactNode }) => React.createElement("p", {}, children), + }, + }; +}); + +const mockGetDefaultTeamSettings = vi.mocked(networking.getDefaultTeamSettings); +const mockUpdateDefaultTeamSettings = vi.mocked(networking.updateDefaultTeamSettings); +const mockModelAvailableCall = vi.mocked(networking.modelAvailableCall); +const mockNotificationsManager = vi.mocked(NotificationsManager); + describe("TeamSSOSettings", () => { + const defaultProps = { + accessToken: "test-token", + userID: "test-user", + userRole: "admin", + }; + + const mockSettings = { + values: { + budget_duration: "monthly", + max_budget: 1000, + enabled: true, + allowed_models: ["gpt-4", "claude-3"], + models: ["gpt-4"], + status: "active", + }, + field_schema: { + description: "Default team settings schema", + properties: { + budget_duration: { + type: "string", + description: "Budget duration setting", + }, + max_budget: { + type: "number", + description: "Maximum budget amount", + }, + enabled: { + type: "boolean", + description: "Enable feature", + }, + allowed_models: { + type: "array", + items: { + enum: ["gpt-4", "claude-3", "gpt-3.5-turbo"], + }, + description: "Allowed models", + }, + models: { + type: "array", + description: "Selected models", + }, + status: { + type: "string", + enum: ["active", "inactive", "pending"], + description: "Status", + }, + }, + }, + }; + beforeEach(() => { vi.clearAllMocks(); + mockModelAvailableCall.mockResolvedValue({ + data: [{ id: "gpt-4" }, { id: "claude-3" }], + }); + }); + + it("should render", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Default Team Settings")).toBeInTheDocument(); + }); + }); + + it("should show loading spinner while fetching settings", () => { + mockGetDefaultTeamSettings.mockImplementation(() => new Promise(() => { })); + + renderWithProviders(); + + expect(screen.getByTestId("spinner")).toBeInTheDocument(); + }); + + it("should display message when no settings are available", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(null as any); + + renderWithProviders(); + + await waitFor(() => { + expect( + screen.getByText("No team settings available or you do not have permission to view them."), + ).toBeInTheDocument(); + }); + }); + + it("should not fetch settings when access token is null", async () => { + renderWithProviders(); + + await waitFor(() => { + expect(mockGetDefaultTeamSettings).not.toHaveBeenCalled(); + }); }); - it("renders the component", async () => { - // Mock successful API responses - vi.mocked(networking.getDefaultTeamSettings).mockResolvedValue({ + it("should display settings fields with correct values", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Budget Duration")).toBeInTheDocument(); + expect(screen.getByText("Max Budget")).toBeInTheDocument(); + }); + + expect(screen.getByText("Budget: monthly")).toBeInTheDocument(); + expect(screen.getByText("1000")).toBeInTheDocument(); + const enabledTexts = screen.getAllByText("Enabled"); + expect(enabledTexts.length).toBeGreaterThan(0); + }); + + it("should display 'Not set' for null values", async () => { + const settingsWithNulls = { + ...mockSettings, values: { - budget_duration: "monthly", - max_budget: 1000, + ...mockSettings.values, + max_budget: null, }, + }; + mockGetDefaultTeamSettings.mockResolvedValue(settingsWithNulls); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Not set")).toBeInTheDocument(); + }); + }); + + it("should toggle edit mode when edit button is clicked", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + expect(screen.getByRole("button", { name: "Cancel" })).toBeInTheDocument(); + expect(screen.getByRole("button", { name: "Save Changes" })).toBeInTheDocument(); + expect(screen.queryByRole("button", { name: "Edit Settings" })).not.toBeInTheDocument(); + }); + + it("should cancel edit mode and reset values", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + const cancelButton = screen.getByRole("button", { name: "Cancel" }); + await userEvent.click(cancelButton); + + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + expect(screen.queryByRole("button", { name: "Cancel" })).not.toBeInTheDocument(); + }); + + it("should save settings when save button is clicked", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + mockUpdateDefaultTeamSettings.mockResolvedValue({ + settings: mockSettings.values, + }); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Save Changes" })).toBeInTheDocument(); + }); + + const saveButton = screen.getByRole("button", { name: "Save Changes" }); + await userEvent.click(saveButton); + + await waitFor(() => { + expect(mockUpdateDefaultTeamSettings).toHaveBeenCalledWith("test-token", mockSettings.values); + }); + + expect(mockNotificationsManager.success).toHaveBeenCalledWith("Default team settings updated successfully"); + }); + + it("should show error notification when save fails", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + mockUpdateDefaultTeamSettings.mockRejectedValue(new Error("Save failed")); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Save Changes" })).toBeInTheDocument(); + }); + + const saveButton = screen.getByRole("button", { name: "Save Changes" }); + await userEvent.click(saveButton); + + await waitFor(() => { + expect(mockNotificationsManager.fromBackend).toHaveBeenCalledWith("Failed to update team settings"); + }); + }); + + it("should render boolean field as switch in edit mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + const switchElement = screen.getByRole("switch"); + expect(switchElement).toBeInTheDocument(); + expect(switchElement).toBeChecked(); + }); + }); + + it("should update boolean value when switch is toggled", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByRole("switch")).toBeInTheDocument(); + }); + + const switchElement = screen.getByRole("switch"); + await userEvent.click(switchElement); + + expect(switchElement).not.toBeChecked(); + }); + + it("should render budget duration dropdown in edit mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByLabelText("Budget duration")).toBeInTheDocument(); + }); + }); + + it("should update budget duration when dropdown value changes", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByLabelText("Budget duration")).toBeInTheDocument(); + }); + + const dropdown = screen.getByLabelText("Budget duration"); + await userEvent.selectOptions(dropdown, "daily"); + + expect(dropdown).toHaveValue("daily"); + }); + + it("should render text input for string fields in edit mode", async () => { + const settingsWithString = { + ...mockSettings, field_schema: { - description: "Default team settings", + ...mockSettings.field_schema, properties: { - budget_duration: { + ...mockSettings.field_schema.properties, + team_name: { type: "string", - description: "Budget duration", + description: "Team name", }, - max_budget: { + }, + }, + values: { + ...mockSettings.values, + team_name: "Test Team", + }, + }; + mockGetDefaultTeamSettings.mockResolvedValue(settingsWithString); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + const textInput = screen.getByDisplayValue("Test Team"); + expect(textInput).toBeInTheDocument(); + }); + }); + + it("should render enum select for string enum fields in edit mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + const statusSelect = screen.getAllByRole("listbox")[0]; + expect(statusSelect).toBeInTheDocument(); + }); + }); + + it("should render multi-select for array enum fields in edit mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + const multiSelects = screen.getAllByRole("listbox"); + expect(multiSelects.length).toBeGreaterThan(0); + }); + }); + + it("should render ModelSelect for models field in edit mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByTestId("model-select")).toBeInTheDocument(); + }); + }); + + it("should display models as badges in view mode", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + const gpt4Elements = screen.getAllByText("gpt-4"); + expect(gpt4Elements.length).toBeGreaterThan(0); + }); + }); + + it("should display 'None' for empty arrays in view mode", async () => { + const settingsWithEmptyArray = { + ...mockSettings, + values: { + ...mockSettings.values, + models: [], + }, + }; + mockGetDefaultTeamSettings.mockResolvedValue(settingsWithEmptyArray); + + renderWithProviders(); + + await waitFor(() => { + const noneTexts = screen.getAllByText("None"); + expect(noneTexts.length).toBeGreaterThan(0); + }); + }); + + it("should display schema description when available", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Default team settings schema")).toBeInTheDocument(); + }); + }); + + it("should show error notification when fetching settings fails", async () => { + mockGetDefaultTeamSettings.mockRejectedValue(new Error("Fetch failed")); + + renderWithProviders(); + + await waitFor(() => { + expect(mockNotificationsManager.fromBackend).toHaveBeenCalledWith("Failed to fetch team settings"); + }); + }); + + it("should handle model fetch error gracefully", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + mockModelAvailableCall.mockRejectedValue(new Error("Model fetch failed")); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Default Team Settings")).toBeInTheDocument(); + }); + }); + + it("should disable cancel button while saving", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + mockUpdateDefaultTeamSettings.mockImplementation( + () => new Promise((resolve) => setTimeout(() => resolve({ settings: mockSettings.values }), 100)), + ); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await userEvent.click(editButton); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Save Changes" })).toBeInTheDocument(); + }); + + const saveButton = screen.getByRole("button", { name: "Save Changes" }); + await userEvent.click(saveButton); + + const cancelButton = screen.getByRole("button", { name: "Cancel" }); + expect(cancelButton).toBeDisabled(); + }); + + it("should display field descriptions", async () => { + mockGetDefaultTeamSettings.mockResolvedValue(mockSettings); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Budget duration setting")).toBeInTheDocument(); + expect(screen.getByText("Maximum budget amount")).toBeInTheDocument(); + }); + }); + + it("should format field names by replacing underscores and capitalizing", async () => { + const settingsWithUnderscores = { + ...mockSettings, + field_schema: { + ...mockSettings.field_schema, + properties: { + ...mockSettings.field_schema.properties, + max_budget_per_user: { type: "number", - description: "Maximum budget", + description: "Max budget per user", }, }, }, - }); + values: { + ...mockSettings.values, + max_budget_per_user: 500, + }, + }; + mockGetDefaultTeamSettings.mockResolvedValue(settingsWithUnderscores); - vi.mocked(networking.modelAvailableCall).mockResolvedValue({ - data: [{ id: "gpt-4" }, { id: "claude-3" }], + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Max Budget Per User")).toBeInTheDocument(); }); + }); + + it("should display 'No schema information available' when schema is missing", async () => { + const settingsWithoutSchema = { + values: {}, + field_schema: null, + }; + mockGetDefaultTeamSettings.mockResolvedValue(settingsWithoutSchema); - renderWithProviders(); + renderWithProviders(); - const container = await screen.findByText("Default Team Settings"); - expect(container).toBeInTheDocument(); + await waitFor(() => { + expect(screen.getByText("No schema information available")).toBeInTheDocument(); + }); }); }); diff --git a/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx b/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx index 8537b108cdc..33bfc783afd 100644 --- a/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx +++ b/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx @@ -5,6 +5,7 @@ import { getDefaultTeamSettings, updateDefaultTeamSettings, modelAvailableCall } import BudgetDurationDropdown, { getBudgetDurationLabel } from "./common_components/budget_duration_dropdown"; import { getModelDisplayName } from "./key_team_helpers/fetch_available_models_team_key"; import NotificationsManager from "./molecules/notifications_manager"; +import { ModelSelect } from "./ModelSelect/ModelSelect"; interface TeamSSOSettingsProps { accessToken: string | null; @@ -116,22 +117,15 @@ const TeamSSOSettings: React.FC = ({ accessToken, userID, ); } else if (key === "models") { return ( - + context="global" + style={{ width: "100%" }} + options={{ + includeSpecialOptions: true, + }} + /> ); } else if (type === "string" && property.enum) { return ( diff --git a/ui/litellm-dashboard/src/components/navbar.test.tsx b/ui/litellm-dashboard/src/components/navbar.test.tsx index a2996f70587..125187e2340 100644 --- a/ui/litellm-dashboard/src/components/navbar.test.tsx +++ b/ui/litellm-dashboard/src/components/navbar.test.tsx @@ -12,11 +12,24 @@ vi.mock("@/utils/proxyUtils", () => ({ fetchProxySettings: vi.fn(), })); +// Mock CommunityEngagementButtons component +vi.mock("./Navbar/CommunityEngagementButtons/CommunityEngagementButtons", () => ({ + CommunityEngagementButtons: () => ( + + ), +})); + // Create mock functions that can be controlled in tests let mockUseThemeImpl = () => ({ logoUrl: null as string | null }); let mockUseHealthReadinessImpl = () => ({ data: null as any }); let mockGetLocalStorageItemImpl = (key: string) => null as string | null; -let mockUseDisableShowPromptsImpl = () => false; let mockUseAuthorizedImpl = () => ({ userId: "test-user", userEmail: "test@example.com", @@ -32,10 +45,6 @@ vi.mock("@/app/(dashboard)/hooks/healthReadiness/useHealthReadiness", () => ({ useHealthReadiness: () => mockUseHealthReadinessImpl(), })); -vi.mock("@/app/(dashboard)/hooks/useDisableShowPrompts", () => ({ - useDisableShowPrompts: () => mockUseDisableShowPromptsImpl(), -})); - vi.mock("@/app/(dashboard)/hooks/useAuthorized", () => ({ default: () => mockUseAuthorizedImpl(), })); @@ -79,26 +88,6 @@ describe("Navbar", () => { expect(screen.getByText("User")).toBeInTheDocument(); }); - it("should render Join Slack button with correct link", () => { - renderWithProviders(); - - const joinSlackLink = screen.getByRole("link", { name: /join slack/i }); - expect(joinSlackLink).toBeInTheDocument(); - expect(joinSlackLink).toHaveAttribute("href", "https://www.litellm.ai/support"); - expect(joinSlackLink).toHaveAttribute("target", "_blank"); - expect(joinSlackLink).toHaveAttribute("rel", "noopener noreferrer"); - }); - - it("should render Star us on GitHub button with correct link", () => { - renderWithProviders(); - - const starOnGithubLink = screen.getByRole("link", { name: /star us on github/i }); - expect(starOnGithubLink).toBeInTheDocument(); - expect(starOnGithubLink).toHaveAttribute("href", "https://github.com/BerriAI/litellm"); - expect(starOnGithubLink).toHaveAttribute("target", "_blank"); - expect(starOnGithubLink).toHaveAttribute("rel", "noopener noreferrer"); - }); - it("should display user information in dropdown", async () => { const user = userEvent.setup(); renderWithProviders(); diff --git a/ui/litellm-dashboard/src/components/navbar.tsx b/ui/litellm-dashboard/src/components/navbar.tsx index 3649ca76238..2ffa0632f27 100644 --- a/ui/litellm-dashboard/src/components/navbar.tsx +++ b/ui/litellm-dashboard/src/components/navbar.tsx @@ -4,16 +4,15 @@ import { useTheme } from "@/contexts/ThemeContext"; import { clearTokenCookies } from "@/utils/cookieUtils"; import { fetchProxySettings } from "@/utils/proxyUtils"; import { - GithubOutlined, MenuFoldOutlined, MenuUnfoldOutlined, MoonOutlined, - SlackOutlined, SunOutlined, } from "@ant-design/icons"; -import { Button, Switch, Tag } from "antd"; +import { Switch, Tag } from "antd"; import Link from "next/link"; import React, { useEffect, useState } from "react"; +import { CommunityEngagementButtons } from "./Navbar/CommunityEngagementButtons/CommunityEngagementButtons"; import UserDropdown from "./Navbar/UserDropdown/UserDropdown"; interface NavbarProps { @@ -129,24 +128,7 @@ const Navbar: React.FC = ({ {/* Right side nav items */}
- - + {/* Dark mode is currently a work in progress. To test, you can change 'false' to 'true' below. Do not set this to true by default until all components are confirmed to support dark mode styles. */} {false && ({ teamInfoCall: vi.fn(), teamMemberDeleteCall: vi.fn(), @@ -12,12 +12,18 @@ vi.mock("@/components/networking", () => ({ teamMemberUpdateCall: vi.fn(), teamUpdateCall: vi.fn(), getGuardrailsList: vi.fn(), + getPoliciesList: vi.fn(), + getPolicyInfoWithGuardrails: vi.fn(), fetchMCPAccessGroups: vi.fn(), getTeamPermissionsCall: vi.fn(), organizationInfoCall: vi.fn(), })); -// Mock hooks used by ModelSelect +vi.mock("@/components/utils/dataUtils", () => ({ + copyToClipboard: vi.fn().mockResolvedValue(true), + formatNumberWithCommas: vi.fn((value: number) => value.toLocaleString()), +})); + vi.mock("@/app/(dashboard)/hooks/models/useModels", () => ({ useAllProxyModels: vi.fn(), })); @@ -34,6 +40,59 @@ vi.mock("@/app/(dashboard)/hooks/users/useCurrentUser", () => ({ useCurrentUser: vi.fn(), })); +vi.mock("@/components/team/team_member_view", () => ({ + default: vi.fn(({ setIsAddMemberModalVisible }) => ( +
+ +
+ )), +})); + +vi.mock("@/components/common_components/user_search_modal", () => ({ + default: vi.fn(({ isVisible, onCancel, onSubmit }) => + isVisible ? ( +
+ + +
+ ) : null + ), +})); + +vi.mock("@/components/team/EditMembership", () => ({ + default: vi.fn(({ visible, onCancel, onSubmit }) => + visible ? ( +
+ + +
+ ) : null + ), +})); + +vi.mock("@/components/common_components/DeleteResourceModal", () => ({ + default: vi.fn(({ isOpen, onCancel, onOk }) => + isOpen ? ( +
+ + +
+ ) : null + ), +})); + +vi.mock("@/components/team/member_permissions", () => ({ + default: vi.fn(() =>
Member Permissions
), +})); + +vi.mock("@/components/team/member_permissions", () => ({ + default: vi.fn(() =>
Member Permissions
), +})); + import { useAllProxyModels } from "@/app/(dashboard)/hooks/models/useModels"; import { useOrganization } from "@/app/(dashboard)/hooks/organizations/useOrganizations"; import { useTeam } from "@/app/(dashboard)/hooks/teams/useTeams"; @@ -44,9 +103,60 @@ const mockUseTeam = vi.mocked(useTeam); const mockUseOrganization = vi.mocked(useOrganization); const mockUseCurrentUser = vi.mocked(useCurrentUser); +const createMockTeamData = (overrides = {}) => ({ + team_id: "123", + team_info: { + team_alias: "Test Team", + team_id: "123", + organization_id: null, + admins: ["admin@test.com"], + members: ["user1@test.com"], + members_with_roles: [ + { + user_id: "user1@test.com", + user_email: "user1@test.com", + role: "member", + spend: 0, + budget_id: "budget1", + }, + ], + metadata: {}, + tpm_limit: null, + rpm_limit: null, + max_budget: null, + budget_duration: null, + models: [], + blocked: false, + spend: 0, + max_parallel_requests: null, + budget_reset_at: null, + model_id: null, + litellm_model_table: null, + created_at: "2024-01-01T00:00:00Z", + team_member_budget_table: null, + guardrails: [], + policies: [], + object_permission: null, + ...overrides, + }, + keys: [], + team_memberships: [], +}); + describe("TeamInfoView", () => { + const defaultProps = { + teamId: "123", + onUpdate: vi.fn(), + onClose: vi.fn(), + accessToken: "test-token", + is_team_admin: true, + is_proxy_admin: true, + userModels: ["gpt-4", "gpt-3.5-turbo"], + editTeam: false, + premiumUser: false, + }; + beforeEach(() => { - // Set up default mock implementations mockUseAllProxyModels.mockReturnValue({ data: { data: [] }, isLoading: false, @@ -63,6 +173,14 @@ describe("TeamInfoView", () => { data: { models: [] }, isLoading: false, } as any); + + vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); + vi.mocked(networking.getPoliciesList).mockResolvedValue({ policies: [] }); + vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); + vi.mocked(networking.getTeamPermissionsCall).mockResolvedValue({ + all_available_permissions: [], + team_member_permissions: [], + }); }); afterEach(() => { @@ -70,503 +188,389 @@ describe("TeamInfoView", () => { }); it("should render", async () => { - // Mock the team info response - vi.mocked(networking.teamInfoCall).mockResolvedValue({ - team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: null, - admins: ["admin@test.com"], - members: ["user1@test.com", "user2@test.com"], - members_with_roles: [ - { - user_id: "user1@test.com", - user_email: "user1@test.com", - role: "member", - spend: 0, - budget_id: "budget1", - }, - ], - metadata: {}, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: [], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - }, - keys: [], - team_memberships: [], + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); + + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); + }); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); + it("should display loading state while fetching team data", () => { + vi.mocked(networking.teamInfoCall).mockImplementation(() => new Promise(() => { })); - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={[]} - editTeam={false} - premiumUser={false} - />, - ); - await waitFor( - () => { - expect(screen.queryByText("User ID")).not.toBeNull(); - }, - // This is a workaround to fix the flaky test issue. TODO: Remove this once we have a better solution. - { timeout: 10000 }, - ); + renderWithProviders(); + + expect(screen.getByText("Loading...")).toBeInTheDocument(); }); - it("should not show all-proxy-models option when user has no access to it", async () => { + it("should display error message when team is not found", async () => { vi.mocked(networking.teamInfoCall).mockResolvedValue({ team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: null, - admins: ["admin@test.com"], - members: ["user1@test.com", "user2@test.com"], - members_with_roles: [ - { - user_id: "user1@test.com", - user_email: "user1@test.com", - role: "member", - spend: 0, - budget_id: "budget1", - }, - ], - metadata: {}, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: ["gpt-4"], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - }, + team_info: null as any, keys: [], team_memberships: [], }); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByText("Team not found")).toBeInTheDocument(); + }); + }); - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={["gpt-4", "gpt-3.5-turbo"]} - editTeam={false} - premiumUser={false} - />, + it("should display budget information in overview", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ + max_budget: 1000, + spend: 250.5, + budget_duration: "30d", + }) ); + renderWithProviders(); + await waitFor(() => { - expect(screen.getAllByText("Test Team")).not.toBeNull(); + expect(screen.getByText("Budget Status")).toBeInTheDocument(); }); + }); - const settingsTab = screen.getByRole("tab", { name: "Settings" }); - act(() => { - fireEvent.click(settingsTab); - }); + it("should display guardrails in overview when present", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ + guardrails: ["guardrail1", "guardrail2"], + }) + ); + + renderWithProviders(); await waitFor(() => { - expect(screen.getByText("Team Settings")).toBeInTheDocument(); + expect(screen.getByText("Guardrails")).toBeInTheDocument(); }); + }); - const editButton = screen.getByRole("button", { name: "Edit Settings" }); - act(() => { - fireEvent.click(editButton); + it("should display policies in overview when present", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ + policies: ["policy1"], + }) + ); + vi.mocked(networking.getPolicyInfoWithGuardrails).mockResolvedValue({ + resolved_guardrails: ["guardrail1"], }); + renderWithProviders(); + await waitFor(() => { - expect(screen.getByTestId("models-select")).toBeInTheDocument(); + expect(screen.getByText("Policies")).toBeInTheDocument(); }); + }); - const allProxyModelsOption = screen.queryByText("All Proxy Models"); - expect(allProxyModelsOption).not.toBeInTheDocument(); - }, 10000); // This is a workaround to fix the flaky test issue. TODO: Remove this once we have a better solution. + it("should show members tab when user can edit team", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); - it("should only show organization models in dropdown when team is in organization with limited models", async () => { - const organizationId = "org-123"; - const organizationModels = ["gpt-4", "claude-3-opus"]; - const userModels = ["gpt-4", "gpt-3.5-turbo", "claude-3-opus", "claude-2"]; + renderWithProviders(); - // Mock all proxy models - should include all user models - const allProxyModels = userModels.map((id) => ({ - id, - object: "model", - created: 1234567890, - owned_by: "openai", - })); + await waitFor(() => { + expect(screen.getByRole("tab", { name: "Members" })).toBeInTheDocument(); + }); + }); - mockUseAllProxyModels.mockReturnValue({ - data: { data: allProxyModels }, - isLoading: false, - } as any); + it("should not show members tab when user cannot edit team", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); - mockUseCurrentUser.mockReturnValue({ - data: { models: userModels }, - isLoading: false, - } as any); + renderWithProviders(); - const organizationData = { - organization_id: organizationId, - organization_name: "Test Organization", - spend: 0, - max_budget: null, - models: organizationModels, - tpm_limit: null, - rpm_limit: null, - members: null, - }; + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); + }); - mockUseOrganization.mockReturnValue({ - data: organizationData, - isLoading: false, - } as any); + expect(screen.queryByRole("tab", { name: "Members" })).not.toBeInTheDocument(); + }); - vi.mocked(networking.teamInfoCall).mockResolvedValue({ - team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: organizationId, - admins: ["admin@test.com"], - members: ["user1@test.com"], - members_with_roles: [ - { - user_id: "user1@test.com", - user_email: "user1@test.com", - role: "member", - spend: 0, - budget_id: "budget1", - }, - ], - metadata: {}, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: ["gpt-4"], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - }, - keys: [], - team_memberships: [], - }); + it("should show settings tab when user can edit team", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); - vi.mocked(networking.organizationInfoCall).mockResolvedValue(organizationData); + renderWithProviders(); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); + await waitFor(() => { + expect(screen.getByRole("tab", { name: "Settings" })).toBeInTheDocument(); + }); + }); - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={userModels} - editTeam={false} - premiumUser={false} - />, - ); + it("should navigate to settings tab when clicked", async () => { + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); await waitFor(() => { - expect(screen.getAllByText("Test Team")).not.toBeNull(); + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); const settingsTab = screen.getByRole("tab", { name: "Settings" }); - act(() => { - fireEvent.click(settingsTab); - }); + await user.click(settingsTab); await waitFor(() => { expect(screen.getByText("Team Settings")).toBeInTheDocument(); }); + }); + + it("should open edit mode when edit button is clicked", async () => { + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); + + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); + }); + + const settingsTab = screen.getByRole("tab", { name: "Settings" }); + await user.click(settingsTab); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); const editButton = screen.getByRole("button", { name: "Edit Settings" }); - act(() => { - fireEvent.click(editButton); + await user.click(editButton); + + await waitFor(() => { + expect(screen.getByLabelText("Team Name")).toBeInTheDocument(); }); + }); + + it("should close edit mode when cancel button is clicked", async () => { + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); await waitFor(() => { - expect(screen.getByTestId("models-select")).toBeInTheDocument(); + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); - // Find the Ant Design Select selector element to open the dropdown - // The data-testid is on the Select component, we need to find the selector inside it - const modelsSelectElement = screen.getByTestId("models-select"); - const selectSelector = modelsSelectElement.querySelector(".ant-select-selector"); - expect(selectSelector).toBeTruthy(); + const settingsTab = screen.getByRole("tab", { name: "Settings" }); + await user.click(settingsTab); - // Open the dropdown by clicking on the selector - act(() => { - fireEvent.mouseDown(selectSelector!); + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); }); - // Wait for dropdown to open - Ant Design renders options in a portal - await waitFor( - () => { - const dropdownOptions = document.querySelectorAll(".ant-select-item-option"); - expect(dropdownOptions.length).toBeGreaterThan(0); - }, - { timeout: 5000 }, - ); + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await user.click(editButton); - const dropdownOptions = document.querySelectorAll(".ant-select-item-option"); - const optionTexts = Array.from(dropdownOptions).map((option) => option.textContent?.trim() || ""); + await waitFor(() => { + expect(screen.getByLabelText("Team Name")).toBeInTheDocument(); + }); + + const cancelButton = screen.getByRole("button", { name: "Cancel" }); + await user.click(cancelButton); + + await waitFor(() => { + expect(screen.queryByLabelText("Team Name")).not.toBeInTheDocument(); + }); + }); - organizationModels.forEach((model) => { - expect(optionTexts).toContain(model); + it("should call onClose when back button is clicked", async () => { + const user = userEvent.setup(); + const onClose = vi.fn(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); + + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); - const modelsNotInOrganization = userModels.filter((m) => !organizationModels.includes(m)); - modelsNotInOrganization.forEach((model) => { - expect(optionTexts).not.toContain(model); + const backButton = screen.getByRole("button", { name: /back to teams/i }); + await user.click(backButton); + + expect(onClose).toHaveBeenCalled(); + }); + + it("should copy team ID to clipboard when copy button is clicked", async () => { + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(createMockTeamData()); + + renderWithProviders(); + + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); - }, 10000); + + const copyButtons = screen.getAllByRole("button"); + const copyButton = copyButtons.find((btn) => btn.querySelector("svg")); + expect(copyButton).toBeTruthy(); + + if (copyButton) { + await user.click(copyButton); + } + }); it("should disable secret manager settings for non-premium users", async () => { - const teamResponse = { - team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: null, - admins: ["admin@test.com"], - members: [], - members_with_roles: [], + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ metadata: { secret_manager_settings: { provider: "aws", secret_id: "abc" }, }, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: ["gpt-4"], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - }, - keys: [], - team_memberships: [], - }; + }) + ); - vi.mocked(networking.teamInfoCall).mockResolvedValue(teamResponse as any); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); + renderWithProviders(); - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={["gpt-4"]} - editTeam={false} - premiumUser={false} - />, - ); + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); + }); + + const settingsTab = screen.getByRole("tab", { name: "Settings" }); + await user.click(settingsTab); - const settingsTab = await screen.findByRole("tab", { name: "Settings" }); - act(() => fireEvent.click(settingsTab)); + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); - const editButton = await screen.findByRole("button", { name: "Edit Settings" }); - act(() => fireEvent.click(editButton)); + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await user.click(editButton); const secretField = await screen.findByPlaceholderText( - '{"namespace": "admin", "mount": "secret", "path_prefix": "litellm"}', + '{"namespace": "admin", "mount": "secret", "path_prefix": "litellm"}' ); expect(secretField).toBeDisabled(); - expect(secretField).toHaveValue(JSON.stringify(teamResponse.team_info.metadata.secret_manager_settings, null, 2)); - }, 10000); + }); - it("should allow premium users to update secret manager settings", async () => { - const teamResponse = { - team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: null, - admins: ["admin@test.com"], - members: [], - members_with_roles: [], + it("should allow premium users to edit secret manager settings", async () => { + const user = userEvent.setup(); + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ metadata: { secret_manager_settings: { provider: "aws", secret_id: "abc" }, }, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: ["gpt-4"], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - }, - keys: [], - team_memberships: [], - }; - - vi.mocked(networking.teamInfoCall).mockResolvedValue(teamResponse as any); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any); - - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={["gpt-4"]} - editTeam={false} - premiumUser={true} - />, + }) ); + vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: {}, team_id: "123" } as any); - const settingsTab = await screen.findByRole("tab", { name: "Settings" }); - act(() => fireEvent.click(settingsTab)); + renderWithProviders(); - const editButton = await screen.findByRole("button", { name: "Edit Settings" }); - act(() => fireEvent.click(editButton)); + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); + }); + + const settingsTab = screen.getByRole("tab", { name: "Settings" }); + await user.click(settingsTab); + + await waitFor(() => { + expect(screen.getByRole("button", { name: "Edit Settings" })).toBeInTheDocument(); + }); + + const editButton = screen.getByRole("button", { name: "Edit Settings" }); + await user.click(editButton); const secretField = await screen.findByPlaceholderText( - '{"namespace": "admin", "mount": "secret", "path_prefix": "litellm"}', + '{"namespace": "admin", "mount": "secret", "path_prefix": "litellm"}' ); expect(secretField).not.toBeDisabled(); + }); - act(() => { - fireEvent.change(secretField, { target: { value: '{"provider":"azure","secret_id":"xyz"}' } }); + it("should add team member when form is submitted", async () => { + const user = userEvent.setup(); + const onUpdate = vi.fn(); + const teamData = createMockTeamData(); + vi.mocked(networking.teamInfoCall).mockResolvedValue(teamData); + vi.mocked(networking.teamMemberAddCall).mockResolvedValue({} as any); + + renderWithProviders(); + + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); }); - const saveButton = await screen.findByRole("button", { name: "Save Changes" }); - act(() => fireEvent.click(saveButton)); + const membersTab = screen.getByRole("tab", { name: "Members" }); + await user.click(membersTab); await waitFor(() => { - expect(networking.teamUpdateCall).toHaveBeenCalled(); + expect(screen.getByRole("button", { name: "Add Member" })).toBeInTheDocument(); }); - const payload = vi.mocked(networking.teamUpdateCall).mock.calls[0][1]; - expect(payload.metadata.secret_manager_settings).toEqual({ provider: "azure", secret_id: "xyz" }); - }, 10000); + const addButton = screen.getByRole("button", { name: "Add Member" }); + await user.click(addButton); - it("should include vector stores in object_permission when updating team", async () => { - const teamResponse = { - team_id: "123", - team_info: { - team_alias: "Test Team", - team_id: "123", - organization_id: null, - admins: ["admin@test.com"], - members: [], - members_with_roles: [], - metadata: {}, - tpm_limit: null, - rpm_limit: null, - max_budget: null, - budget_duration: null, - models: ["gpt-4"], - blocked: false, - spend: 0, - max_parallel_requests: null, - budget_reset_at: null, - model_id: null, - litellm_model_table: null, - created_at: "2024-01-01T00:00:00Z", - team_member_budget_table: null, - object_permission: { - vector_stores: ["store1", "store2"], - }, - }, - keys: [], - team_memberships: [], - }; + await waitFor(() => { + expect(screen.getByRole("button", { name: "Submit" })).toBeInTheDocument(); + }); - vi.mocked(networking.teamInfoCall).mockResolvedValue(teamResponse as any); - vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); - vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any); - - renderWithProviders( - {}} - onClose={() => {}} - accessToken="123" - is_team_admin={true} - is_proxy_admin={true} - userModels={["gpt-4"]} - editTeam={false} - premiumUser={true} - />, + const submitButton = screen.getByRole("button", { name: "Submit" }); + await user.click(submitButton); + + await waitFor(() => { + expect(networking.teamMemberAddCall).toHaveBeenCalled(); + }); + }); + + it("should display team member budget information when present", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ + team_member_budget_table: { + max_budget: 500, + budget_duration: "30d", + tpm_limit: 5000, + rpm_limit: 50, + }, + }) ); - const settingsTab = await screen.findByRole("tab", { name: "Settings" }); - act(() => fireEvent.click(settingsTab)); + renderWithProviders(); - const editButton = await screen.findByRole("button", { name: "Edit Settings" }); - act(() => fireEvent.click(editButton)); + await waitFor(() => { + expect(screen.getByText("Budget Status")).toBeInTheDocument(); + }); + }); - // Verify that Vector Stores field is present - expect(screen.getByLabelText("Vector Stores")).toBeInTheDocument(); + it("should display virtual keys information", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue({ + ...createMockTeamData(), + keys: [ + { user_id: "user1", token: "key1" }, + { token: "key2" }, + ], + }); - const saveButton = await screen.findByRole("button", { name: "Save Changes" }); - act(() => fireEvent.click(saveButton)); + renderWithProviders(); await waitFor(() => { - expect(networking.teamUpdateCall).toHaveBeenCalled(); + expect(screen.getByText("Virtual Keys")).toBeInTheDocument(); }); + }); + + it("should display object permissions when present", async () => { + vi.mocked(networking.teamInfoCall).mockResolvedValue( + createMockTeamData({ + object_permission: { + object_permission_id: "perm-1", + mcp_servers: ["server1"], + vector_stores: ["store1"], + }, + }) + ); + + renderWithProviders(); - const payload = vi.mocked(networking.teamUpdateCall).mock.calls[0][1]; - expect(payload.object_permission.vector_stores).toEqual(["store1", "store2"]); - }, 10000); + await waitFor(() => { + const teamNameElements = screen.queryAllByText("Test Team"); + expect(teamNameElements.length).toBeGreaterThan(0); + }); + }); }); diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index 193d056fdd4..34ff903c864 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -462,6 +462,7 @@ const TeamInfoView: React.FC = ({ ...parsedMetadata, guardrails: values.guardrails || [], logging: values.logging_settings || [], + disable_global_guardrails: values.disable_global_guardrails || false, ...(secretManagerSettings !== undefined ? { secret_manager_settings: secretManagerSettings } : {}), }, policies: values.policies || [], @@ -572,11 +573,10 @@ const TeamInfoView: React.FC = ({ size="small" icon={copiedStates["team-id"] ? : } onClick={() => copyToClipboard(info.team_id, "team-id")} - className={`left-2 z-10 transition-all duration-200 ${ - copiedStates["team-id"] - ? "text-green-600 bg-green-50 border-green-200" - : "text-gray-500 hover:text-gray-700 hover:bg-gray-100" - }`} + className={`left-2 z-10 transition-all duration-200 ${copiedStates["team-id"] + ? "text-green-600 bg-green-50 border-green-200" + : "text-gray-500 hover:text-gray-700 hover:bg-gray-100" + }`} />
@@ -588,10 +588,10 @@ const TeamInfoView: React.FC = ({ Overview, ...(canEditTeam ? [ - Members, - Member Permissions, - Settings, - ] + Members, + Member Permissions, + Settings, + ] : []), ]} @@ -764,10 +764,10 @@ const TeamInfoView: React.FC = ({ disable_global_guardrails: info.metadata?.disable_global_guardrails || false, metadata: info.metadata ? JSON.stringify( - (({ logging, secret_manager_settings, ...rest }) => rest)(info.metadata), - null, - 2, - ) + (({ logging, secret_manager_settings, ...rest }) => rest)(info.metadata), + null, + 2, + ) : "", logging_settings: info.metadata?.logging || [], secret_manager_settings: info.metadata?.secret_manager_settings @@ -905,7 +905,7 @@ const TeamInfoView: React.FC = ({ - Disable Global Guardrails{" "} + Disable Global Guardrails