diff --git a/docs/my-website/docs/pass_through/assembly_ai.md b/docs/my-website/docs/pass_through/assembly_ai.md index 4606640c5c4..c7c70639e7e 100644 --- a/docs/my-website/docs/pass_through/assembly_ai.md +++ b/docs/my-website/docs/pass_through/assembly_ai.md @@ -1,31 +1,36 @@ -# Assembly AI +# AssemblyAI -Pass-through endpoints for Assembly AI - call Assembly AI endpoints, in native format (no translation). +Pass-through endpoints for AssemblyAI - call AssemblyAI endpoints, in native format (no translation). -| Feature | Supported | Notes | +| Feature | Supported | Notes | |-------|-------|-------| | Cost Tracking | ✅ | works across all integrations | | Logging | ✅ | works across all integrations | -Supports **ALL** Assembly AI Endpoints +Supports **ALL** AssemblyAI Endpoints -[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference) +[**See All AssemblyAI Endpoints**](https://www.assemblyai.com/docs/api-reference) - +## Supported Routes + +| AssemblyAI Service | LiteLLM Route | AssemblyAI Base URL | +|-------------------|---------------|---------------------| +| Speech-to-Text (US) | `/assemblyai/*` | `api.assemblyai.com` | +| Speech-to-Text (EU) | `/eu.assemblyai/*` | `eu.api.assemblyai.com` | ## Quick Start -Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts) +Let's call the AssemblyAI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts) -1. Add Assembly AI API Key to your environment +1. Add AssemblyAI API Key to your environment ```bash export ASSEMBLYAI_API_KEY="" ``` -2. Start LiteLLM Proxy +2. Start LiteLLM Proxy ```bash litellm @@ -33,53 +38,157 @@ litellm # RUNNING on http://0.0.0.0:4000 ``` -3. Test it! +3. Test it! -Let's call the Assembly AI `/v2/transcripts` endpoint +Let's call the AssemblyAI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts). Includes commented-out [Speech Understanding](https://www.assemblyai.com/docs/speech-understanding) features you can toggle on. ```python import assemblyai as aai -LITELLM_VIRTUAL_KEY = "sk-1234" # -LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # /assemblyai +aai.settings.base_url = "http://0.0.0.0:4000/assemblyai" # /assemblyai +aai.settings.api_key = "Bearer sk-1234" # Bearer -aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}" -aai.settings.base_url = LITELLM_PROXY_BASE_URL +# Use a publicly-accessible URL +audio_file = "https://assembly.ai/wildfires.mp3" -# URL of the file to transcribe -FILE_URL = "https://assembly.ai/wildfires.mp3" +# Or use a local file: +# audio_file = "./example.mp3" -# You can also transcribe a local file by passing in a file path -# FILE_URL = './path/to/file.mp3' +config = aai.TranscriptionConfig( + speech_models=["universal-3-pro", "universal-2"], + language_detection=True, + speaker_labels=True, + # Speech understanding features + # sentiment_analysis=True, + # entity_detection=True, + # auto_chapters=True, + # summarization=True, + # summary_type=aai.SummarizationType.bullets, + # redact_pii=True, + # content_safety=True, +) -transcriber = aai.Transcriber() -transcript = transcriber.transcribe(FILE_URL) -print(transcript) -print(transcript.id) -``` +transcript = aai.Transcriber().transcribe(audio_file, config=config) -## Calling Assembly AI EU endpoints +if transcript.status == aai.TranscriptStatus.error: + raise RuntimeError(f"Transcription failed: {transcript.error}") -If you want to send your request to the Assembly AI EU endpoint, you can do so by setting the `LITELLM_PROXY_BASE_URL` to `/eu.assemblyai` +print(f"\nFull Transcript:\n\n{transcript.text}") +# Optionally print speaker diarization results +# for utterance in transcript.utterances: +# print(f"Speaker {utterance.speaker}: {utterance.text}") +``` + +4. [Prompting with Universal-3 Pro](https://www.assemblyai.com/docs/speech-to-text/prompting) (optional) ```python import assemblyai as aai -LITELLM_VIRTUAL_KEY = "sk-1234" # -LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/eu.assemblyai" # /eu.assemblyai +aai.settings.base_url = "http://0.0.0.0:4000/assemblyai" # /assemblyai +aai.settings.api_key = "Bearer sk-1234" # Bearer + +audio_file = "https://assemblyaiassets.com/audios/verbatim.mp3" + +config = aai.TranscriptionConfig( + speech_models=["universal-3-pro", "universal-2"], + language_detection=True, + prompt="Produce a transcript suitable for conversational analysis. Every disfluency is meaningful data. Include: fillers (um, uh, er, ah, hmm, mhm, like, you know, I mean), repetitions (I I, the the), restarts (I was- I went), stutters (th-that, b-but, no-not), and informal speech (gonna, wanna, gotta)", +) + +transcript = aai.Transcriber().transcribe(audio_file, config) + +print(transcript.text) +``` + +## Calling AssemblyAI EU endpoints + +If you want to send your request to the AssemblyAI EU endpoint, you can do so by setting the `LITELLM_PROXY_BASE_URL` to `/eu.assemblyai` -aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}" -aai.settings.base_url = LITELLM_PROXY_BASE_URL -# URL of the file to transcribe -FILE_URL = "https://assembly.ai/wildfires.mp3" +```python +import assemblyai as aai + +aai.settings.base_url = "http://0.0.0.0:4000/eu.assemblyai" # /eu.assemblyai +aai.settings.api_key = "Bearer sk-1234" # Bearer -# You can also transcribe a local file by passing in a file path -# FILE_URL = './path/to/file.mp3' +# Use a publicly-accessible URL +audio_file = "https://assembly.ai/wildfires.mp3" + +# Or use a local file: +# audio_file = "./path/to/file.mp3" transcriber = aai.Transcriber() -transcript = transcriber.transcribe(FILE_URL) +transcript = transcriber.transcribe(audio_file) print(transcript) print(transcript.id) ``` + +## LLM Gateway + +Use AssemblyAI's [LLM Gateway](https://www.assemblyai.com/docs/llm-gateway) as an OpenAI-compatible provider — a unified API for Claude, GPT, and Gemini models with full LiteLLM logging, guardrails, and cost tracking support. + +[**See Available Models**](https://www.assemblyai.com/docs/llm-gateway#available-models) + +### Usage + +#### LiteLLM Python SDK + +```python +import litellm +import os + +os.environ["ASSEMBLYAI_API_KEY"] = "your-assemblyai-api-key" + +response = litellm.completion( + model="assemblyai/claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "What is the capital of France?"}] +) + +print(response.choices[0].message.content) +``` + +#### LiteLLM Proxy + +1. Config + +```yaml +model_list: + - model_name: assemblyai/* + litellm_params: + model: assemblyai/* + api_key: os.environ/ASSEMBLYAI_API_KEY +``` + +2. Start proxy + +```bash +litellm --config config.yaml + +# RUNNING on http://0.0.0.0:4000 +``` + +3. Test it! + +```python +import requests + +headers = { + "authorization": "Bearer sk-1234" # Bearer +} + +response = requests.post( + "http://0.0.0.0:4000/v1/chat/completions", + headers=headers, + json={ + "model": "assemblyai/claude-sonnet-4-5-20250929", + "messages": [ + {"role": "user", "content": "What is the capital of France?"} + ], + "max_tokens": 1000 + } +) + +result = response.json() +print(result["choices"][0]["message"]["content"]) +``` diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 18a139d1d29..d8f0d83b59d 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -113,6 +113,31 @@ litellm_settings: ``` +## Pod Health Metrics + +Use these to measure per-pod queue depth and diagnose latency that occurs **before** LiteLLM starts processing a request. + +| Metric Name | Type | Description | +|---|---|---| +| `litellm_in_flight_requests` | Gauge | Number of HTTP requests currently in-flight on this uvicorn worker. Tracks the pod's queue depth in real time. With multiple workers, values are summed across all live workers (`livesum`). | + +### When to use this + +LiteLLM measures latency from when its handler starts. If a request waits in uvicorn's event loop before the handler runs, that wait is invisible to LiteLLM's own logs. `litellm_in_flight_requests` shows how loaded the pod was at any point in time. + +``` +high in_flight_requests + high ALB TargetResponseTime → pod overloaded, scale out +low in_flight_requests + high ALB TargetResponseTime → delay is pre-ASGI (event loop blocking) +``` + +You can also check the current value directly without Prometheus: + +```bash +curl http://localhost:4000/health/backlog \ + -H "Authorization: Bearer sk-..." +# {"in_flight_requests": 47} +``` + ## Proxy Level Tracking Metrics Use this to track overall LiteLLM Proxy usage. diff --git a/docs/my-website/docs/troubleshoot/latency_overhead.md b/docs/my-website/docs/troubleshoot/latency_overhead.md index cfb2cb43a7e..dd7f012dcde 100644 --- a/docs/my-website/docs/troubleshoot/latency_overhead.md +++ b/docs/my-website/docs/troubleshoot/latency_overhead.md @@ -2,9 +2,41 @@ Use this guide when you see unexpected latency overhead between LiteLLM proxy and the LLM provider. +## The Invisible Latency Gap + +LiteLLM measures latency from when its handler starts. If a request waits in uvicorn's event loop **before** the handler runs, that wait is invisible to LiteLLM's own logs. + +``` +T=0 Request arrives at load balancer + [queue wait — LiteLLM never logs this] +T=10 LiteLLM handler starts → timer begins +T=20 Response sent + +LiteLLM logs: 10s User experiences: 20s +``` + +To measure the pre-handler wait, poll `/health/backlog` on each pod: + +```bash +curl http://localhost:4000/health/backlog \ + -H "Authorization: Bearer sk-..." +# {"in_flight_requests": 47} +``` + +Or scrape the `litellm_in_flight_requests` Prometheus gauge at `/metrics`. + +| `in_flight_requests` | ALB `TargetResponseTime` | Diagnosis | +|---|---|---| +| High | High | Pod overloaded → scale out | +| Low | High | Delay is pre-ASGI — check for sync blocking code or event loop saturation | +| High | Normal | Pod is busy but healthy, no queue buildup | + +If you're on **AWS ALB**, correlate `litellm_in_flight_requests` spikes with ALB's `TargetResponseTime` CloudWatch metric. The gap between what ALB reports and what LiteLLM logs is the invisible wait. + ## Quick Checklist -1. **Collect the `x-litellm-overhead-duration-ms` response header** — this tells you LiteLLM's total overhead on every request. Start here. +1. **Check `in_flight_requests` on each pod** via `/health/backlog` or the `litellm_in_flight_requests` Prometheus gauge — this tells you if requests are queuing before LiteLLM starts processing. Start here for unexplained latency. +2. **Collect the `x-litellm-overhead-duration-ms` response header** — this tells you LiteLLM's total overhead on every request. 2. **Is DEBUG logging enabled?** This is the #1 cause of latency with large payloads. 3. **Are you sending large base64 payloads?** (images, PDFs) — see [Large Payload Overhead](#large-payload-overhead). 4. **Enable detailed timing headers** to pinpoint where time is spent. diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260228000000_add_claude_code_plugin_table/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260228000000_add_claude_code_plugin_table/migration.sql new file mode 100644 index 00000000000..e2a3694e8ef --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260228000000_add_claude_code_plugin_table/migration.sql @@ -0,0 +1,18 @@ +-- CreateTable +CREATE TABLE "LiteLLM_ClaudeCodePluginTable" ( + "id" TEXT NOT NULL, + "name" TEXT NOT NULL, + "version" TEXT, + "description" TEXT, + "manifest_json" TEXT, + "files_json" TEXT DEFAULT '{}', + "enabled" BOOLEAN NOT NULL DEFAULT true, + "created_at" TIMESTAMP(3) DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP(3) DEFAULT CURRENT_TIMESTAMP, + "created_by" TEXT, + + CONSTRAINT "LiteLLM_ClaudeCodePluginTable_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_ClaudeCodePluginTable_name_key" ON "LiteLLM_ClaudeCodePluginTable"("name"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 13461be3e7c..2717480c7ef 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -300,7 +300,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) - available_on_public_internet Boolean @default(false) + available_on_public_internet Boolean @default(true) } // Generate Tokens for Proxy diff --git a/litellm/caching/llm_caching_handler.py b/litellm/caching/llm_caching_handler.py index 16eb824f4c9..331aa8f51cd 100644 --- a/litellm/caching/llm_caching_handler.py +++ b/litellm/caching/llm_caching_handler.py @@ -3,11 +3,37 @@ """ import asyncio +from typing import Set from .in_memory_cache import InMemoryCache class LLMClientCache(InMemoryCache): + # Background tasks must be stored to prevent garbage collection, which would + # trigger "coroutine was never awaited" warnings. See: + # https://docs.python.org/3/library/asyncio-task.html#creating-tasks + # Intentionally shared across all instances as a global task registry. + _background_tasks: Set[asyncio.Task] = set() + + def _remove_key(self, key: str) -> None: + """Close async clients before evicting them to prevent connection pool leaks.""" + value = self.cache_dict.get(key) + super()._remove_key(key) + if value is not None: + close_fn = getattr(value, "aclose", None) or getattr(value, "close", None) + if close_fn and asyncio.iscoroutinefunction(close_fn): + try: + task = asyncio.get_running_loop().create_task(close_fn()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + except RuntimeError: + pass + elif close_fn and callable(close_fn): + try: + close_fn() + except Exception: + pass + def update_cache_key_with_event_loop(self, key): """ Add the event loop to the cache key, to prevent event loop closed errors. diff --git a/litellm/constants.py b/litellm/constants.py index 3d2cebf2224..4c38ecd74b5 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -193,9 +193,9 @@ # Aiohttp connection pooling - prevents memory leaks from unbounded connection growth # Set to 0 for unlimited (not recommended for production) -AIOHTTP_CONNECTOR_LIMIT = int(os.getenv("AIOHTTP_CONNECTOR_LIMIT", 300)) +AIOHTTP_CONNECTOR_LIMIT = int(os.getenv("AIOHTTP_CONNECTOR_LIMIT", 1000)) AIOHTTP_CONNECTOR_LIMIT_PER_HOST = int( - os.getenv("AIOHTTP_CONNECTOR_LIMIT_PER_HOST", 50) + os.getenv("AIOHTTP_CONNECTOR_LIMIT_PER_HOST", 500) ) AIOHTTP_KEEPALIVE_TIMEOUT = int(os.getenv("AIOHTTP_KEEPALIVE_TIMEOUT", 120)) AIOHTTP_TTL_DNS_CACHE = int(os.getenv("AIOHTTP_TTL_DNS_CACHE", 300)) diff --git a/litellm/images/main.py b/litellm/images/main.py index 6c4c502a7b0..236266af6ad 100644 --- a/litellm/images/main.py +++ b/litellm/images/main.py @@ -483,6 +483,7 @@ def image_generation( # noqa: PLR0915 organization=organization, aimg_generation=aimg_generation, client=client, + headers=headers, ) elif custom_llm_provider == "bedrock": if model is None: diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index baf274f2c62..3b75a56fcc9 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1968,22 +1968,24 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) - # Store a shallow copy so usage stripping below - # does not mutate the stored chunk. - self.chunks.append(processed_chunk.model_copy()) - # Add mcp_list_tools to first chunk if present if not self.sent_first_chunk: processed_chunk = self._add_mcp_list_tools_to_first_chunk(processed_chunk) self.sent_first_chunk = True - if ( + + _has_usage = ( hasattr(processed_chunk, "usage") and getattr(processed_chunk, "usage", None) is not None - ): + ) + + if _has_usage: + # Store a copy ONLY when usage stripping below will mutate + # the chunk. For non-usage chunks (vast majority), store + # directly to avoid expensive model_copy() per chunk. + self.chunks.append(processed_chunk.model_copy()) + # Strip usage from the outgoing chunk so it's not sent twice # (once in the chunk, once in _hidden_params). - # Create a new object without usage, matching sync behavior. - # The copy in self.chunks retains usage for calculate_total_usage(). obj_dict = processed_chunk.model_dump() if "usage" in obj_dict: del obj_dict["usage"] @@ -1995,6 +1997,9 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 ) if is_empty: continue + else: + # No usage data — safe to store directly without copying + self.chunks.append(processed_chunk) # add usage as hidden param if self.sent_last_chunk is True and self.stream_options is None: diff --git a/litellm/llms/azure/realtime/handler.py b/litellm/llms/azure/realtime/handler.py index 8f4291ec271..0ad6fb57354 100644 --- a/litellm/llms/azure/realtime/handler.py +++ b/litellm/llms/azure/realtime/handler.py @@ -33,7 +33,7 @@ def _construct_url( self, api_base: str, model: str, - api_version: str, + api_version: Optional[str], realtime_protocol: Optional[str] = None, ) -> str: """ @@ -56,8 +56,9 @@ def _construct_url( """ api_base = api_base.replace("https://", "wss://") - # Determine path based on realtime_protocol - if realtime_protocol in ("GA", "v1"): + # Determine path based on realtime_protocol (case-insensitive) + _is_ga = realtime_protocol is not None and realtime_protocol.upper() in ("GA", "V1") + if _is_ga: path = "/openai/v1/realtime" return f"{api_base}{path}?model={model}" else: @@ -85,7 +86,7 @@ async def async_realtime( if api_base is None: raise ValueError("api_base is required for Azure OpenAI calls") - if api_version is None: + if api_version is None and (realtime_protocol is None or realtime_protocol.upper() not in ("GA", "V1")): raise ValueError("api_version is required for Azure OpenAI calls") url = self._construct_url( diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index 60a93b169c8..ec5b942ec1b 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -68,7 +68,7 @@ def make_sync_call( model_response=model_response, json_mode=json_mode ) else: - decoder = AWSEventStreamDecoder(model=model) + decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=stream_chunk_size)) # LOGGING diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 306d63b77d0..d210f294c64 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -511,6 +511,7 @@ def get_supported_openai_params(self, model: str) -> List[str]: "response_format", "requestMetadata", "service_tier", + "parallel_tool_calls", ] if ( @@ -913,6 +914,13 @@ def map_openai_params( ) if _tool_choice_value is not None: optional_params["tool_choice"] = _tool_choice_value + if param == "parallel_tool_calls": + disable_parallel = not value + optional_params["_parallel_tool_use_config"] = { + "tool_choice": { + "disable_parallel_tool_use": disable_parallel + } + } if param == "thinking": optional_params["thinking"] = value elif param == "reasoning_effort" and isinstance(value, str): @@ -1771,6 +1779,92 @@ def _translate_message_content(self, content_blocks: List[ContentBlock]) -> Tupl return content_str, tools, reasoningContentBlocks, citationsContentBlocks + @staticmethod + def _unwrap_bedrock_properties(json_str: str) -> str: + """ + Unwrap Bedrock's response_format JSON structure. + + If the JSON has a single "properties" key, extract its value. + Otherwise, return the original string. + + Args: + json_str: JSON string to unwrap + + Returns: + Unwrapped JSON string or original if unwrapping not needed + """ + try: + response_data = json.loads(json_str) + if ( + isinstance(response_data, dict) + and "properties" in response_data + and len(response_data) == 1 + ): + response_data = response_data["properties"] + return json.dumps(response_data) + except json.JSONDecodeError: + pass + return json_str + + @staticmethod + def _filter_json_mode_tools( + json_mode: Optional[bool], + tools: List[ChatCompletionToolCallChunk], + chat_completion_message: ChatCompletionResponseMessage, + ) -> Optional[List[ChatCompletionToolCallChunk]]: + """ + When json_mode is True, Bedrock may return the internal `json_tool_call` + tool alongside real user-defined tools. This method handles 3 scenarios: + + 1. Only json_tool_call present -> convert to text content, return None + 2. Mixed json_tool_call + real -> filter out json_tool_call, return real tools + 3. No json_tool_call / no json_mode -> return tools as-is + """ + if not json_mode or not tools: + return tools if tools else None + + json_tool_indices = [ + i + for i, t in enumerate(tools) + if t["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME + ] + + if not json_tool_indices: + # No json_tool_call found, return tools unchanged + return tools + + if len(json_tool_indices) == len(tools): + # All tools are json_tool_call — convert first one to content + verbose_logger.debug( + "Processing JSON tool call response for response_format" + ) + json_mode_content_str: Optional[str] = tools[0]["function"].get( + "arguments" + ) + if json_mode_content_str is not None: + json_mode_content_str = AmazonConverseConfig._unwrap_bedrock_properties( + json_mode_content_str + ) + chat_completion_message["content"] = json_mode_content_str + return None + + # Mixed: filter out json_tool_call, keep real tools. + # Preserve the json_tool_call content as message text so the structured + # output from response_format is not silently lost. + first_idx = json_tool_indices[0] + json_mode_args = tools[first_idx]["function"].get("arguments") + if json_mode_args is not None: + json_mode_args = AmazonConverseConfig._unwrap_bedrock_properties( + json_mode_args + ) + existing = chat_completion_message.get("content") or "" + chat_completion_message["content"] = ( + existing + json_mode_args if existing else json_mode_args + ) + + real_tools = [t for i, t in enumerate(tools) if i not in json_tool_indices] + return real_tools if real_tools else None + def _transform_response( # noqa: PLR0915 self, model: str, @@ -1793,7 +1887,7 @@ def _transform_response( # noqa: PLR0915 additional_args={"complete_input_dict": data}, ) - json_mode: Optional[bool] = optional_params.pop("json_mode", None) + json_mode: Optional[bool] = optional_params.get("json_mode", None) ## RESPONSE OBJECT try: completion_response = ConverseResponseBlock(**response.json()) # type: ignore @@ -1877,37 +1971,13 @@ def _transform_response( # noqa: PLR0915 self._transform_thinking_blocks(reasoningContentBlocks) ) chat_completion_message["content"] = content_str - if ( - json_mode is True - and tools is not None - and len(tools) == 1 - and tools[0]["function"].get("name") == RESPONSE_FORMAT_TOOL_NAME - ): - verbose_logger.debug( - "Processing JSON tool call response for response_format" - ) - json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments") - if json_mode_content_str is not None: - # Bedrock returns the response wrapped in a "properties" object - # We need to extract the actual content from this wrapper - try: - response_data = json.loads(json_mode_content_str) - - # If Bedrock wrapped the response in "properties", extract the content - if ( - isinstance(response_data, dict) - and "properties" in response_data - and len(response_data) == 1 - ): - response_data = response_data["properties"] - json_mode_content_str = json.dumps(response_data) - except json.JSONDecodeError: - # If parsing fails, use the original response - pass - - chat_completion_message["content"] = json_mode_content_str - elif tools: - chat_completion_message["tool_calls"] = tools + filtered_tools = self._filter_json_mode_tools( + json_mode=json_mode, + tools=tools, + chat_completion_message=chat_completion_message, + ) + if filtered_tools: + chat_completion_message["tool_calls"] = filtered_tools ## CALCULATING USAGE - bedrock returns usage in the headers usage = self._transform_usage(completion_response["usage"]) diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 1c58a11eebe..88f7341ed08 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -22,6 +22,7 @@ from litellm import verbose_logger from litellm._uuid import uuid from litellm.caching.caching import InMemoryCache +from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.logging_utils import track_llm_api_timing @@ -252,7 +253,7 @@ async def make_call( response.aiter_bytes(chunk_size=stream_chunk_size) ) else: - decoder = AWSEventStreamDecoder(model=model) + decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode) completion_stream = decoder.aiter_bytes( response.aiter_bytes(chunk_size=stream_chunk_size) ) @@ -346,7 +347,7 @@ def make_sync_call( response.iter_bytes(chunk_size=stream_chunk_size) ) else: - decoder = AWSEventStreamDecoder(model=model) + decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode) completion_stream = decoder.iter_bytes( response.iter_bytes(chunk_size=stream_chunk_size) ) @@ -1282,7 +1283,7 @@ def get_response_stream_shape(): class AWSEventStreamDecoder: - def __init__(self, model: str) -> None: + def __init__(self, model: str, json_mode: Optional[bool] = False) -> None: from botocore.parsers import EventStreamJSONParser self.model = model @@ -1290,6 +1291,8 @@ def __init__(self, model: str) -> None: self.content_blocks: List[ContentBlockDeltaEvent] = [] self.tool_calls_index: Optional[int] = None self.response_id: Optional[str] = None + self.json_mode = json_mode + self._current_tool_name: Optional[str] = None def check_empty_tool_call_args(self) -> bool: """ @@ -1391,6 +1394,16 @@ def _handle_converse_start_event( response_tool_name = get_bedrock_tool_name( response_tool_name=_response_tool_name ) + self._current_tool_name = response_tool_name + + # When json_mode is True, suppress the internal json_tool_call + # and convert its content to text in delta events instead + if ( + self.json_mode is True + and response_tool_name == RESPONSE_FORMAT_TOOL_NAME + ): + return tool_use, provider_specific_fields, thinking_blocks + self.tool_calls_index = ( 0 if self.tool_calls_index is None else self.tool_calls_index + 1 ) @@ -1445,19 +1458,27 @@ def _handle_converse_delta_event( if "text" in delta_obj: text = delta_obj["text"] elif "toolUse" in delta_obj: - tool_use = { - "id": None, - "type": "function", - "function": { - "name": None, - "arguments": delta_obj["toolUse"]["input"], - }, - "index": ( - self.tool_calls_index - if self.tool_calls_index is not None - else index - ), - } + # When json_mode is True and this is the internal json_tool_call, + # convert tool input to text content instead of tool call arguments + if ( + self.json_mode is True + and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME + ): + text = delta_obj["toolUse"]["input"] + else: + tool_use = { + "id": None, + "type": "function", + "function": { + "name": None, + "arguments": delta_obj["toolUse"]["input"], + }, + "index": ( + self.tool_calls_index + if self.tool_calls_index is not None + else index + ), + } elif "reasoningContent" in delta_obj: provider_specific_fields = { "reasoningContent": delta_obj["reasoningContent"], @@ -1494,6 +1515,17 @@ def _handle_converse_stop_event( ) -> Optional[ChatCompletionToolCallChunk]: """Handle stop/contentBlockIndex event in converse chunk parsing.""" tool_use: Optional[ChatCompletionToolCallChunk] = None + + # If the ending block was the internal json_tool_call, skip emitting + # the empty-args tool chunk and reset tracking state + if ( + self.json_mode is True + and self._current_tool_name == RESPONSE_FORMAT_TOOL_NAME + ): + self._current_tool_name = None + return tool_use + + self._current_tool_name = None is_empty = self.check_empty_tool_call_args() if is_empty: tool_use = { diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index c7524925bd0..7020f796bb7 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1401,6 +1401,7 @@ async def aimage_generation( client=None, max_retries=None, organization: Optional[str] = None, + headers: Optional[dict] = None, ): response = None try: @@ -1414,6 +1415,8 @@ async def aimage_generation( client=client, ) + if headers: + data["extra_headers"] = headers response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore stringified_response = response.model_dump() ## LOGGING @@ -1446,6 +1449,7 @@ def image_generation( client=None, aimg_generation=None, organization: Optional[str] = None, + headers: Optional[dict] = None, ) -> ImageResponse: data = {} try: @@ -1455,7 +1459,7 @@ def image_generation( raise OpenAIError(status_code=422, message="max retries must be an int") if aimg_generation is True: - return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries, organization=organization) # type: ignore + return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries, organization=organization, headers=headers) # type: ignore openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, @@ -1480,6 +1484,8 @@ def image_generation( ) ## COMPLETION CALL + if headers: + data["extra_headers"] = headers _response = openai_client.images.generate(**data, timeout=timeout) # type: ignore response = _response.model_dump() diff --git a/litellm/llms/openai_like/providers.json b/litellm/llms/openai_like/providers.json index 1b1b1c2f8cc..b3125d4ad38 100644 --- a/litellm/llms/openai_like/providers.json +++ b/litellm/llms/openai_like/providers.json @@ -90,5 +90,9 @@ "headers": { "api-subscription-key": "{api_key}" } + }, + "assemblyai": { + "base_url": "https://llm-gateway.assemblyai.com/v1", + "api_key_env": "ASSEMBLYAI_API_KEY" } } diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index b21f23ac022..f52288ea72a 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -19210,6 +19210,39 @@ "supports_tool_choice": true, "supports_vision": false }, + "gpt-audio-1.5": { + "input_cost_per_audio_token": 3.2e-05, + "input_cost_per_token": 2.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_audio_token": 6.4e-05, + "output_cost_per_token": 1e-05, + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supported_modalities": [ + "text", + "audio" + ], + "supported_output_modalities": [ + "text", + "audio" + ], + "supports_audio_input": true, + "supports_audio_output": true, + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_prompt_caching": false, + "supports_reasoning": false, + "supports_response_schema": false, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": false + }, "gpt-audio-2025-08-28": { "input_cost_per_audio_token": 3.2e-05, "input_cost_per_token": 2.5e-06, @@ -20927,6 +20960,38 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-realtime-1.5": { + "cache_creation_input_audio_token_cost": 4e-07, + "cache_read_input_token_cost": 4e-07, + "input_cost_per_audio_token": 3.2e-05, + "input_cost_per_image": 5e-06, + "input_cost_per_token": 4e-06, + "litellm_provider": "openai", + "max_input_tokens": 32000, + "max_output_tokens": 4096, + "max_tokens": 4096, + "mode": "chat", + "output_cost_per_audio_token": 6.4e-05, + "output_cost_per_token": 1.6e-05, + "supported_endpoints": [ + "/v1/realtime" + ], + "supported_modalities": [ + "text", + "image", + "audio" + ], + "supported_output_modalities": [ + "text", + "audio" + ], + "supports_audio_input": true, + "supports_audio_output": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_system_messages": true, + "supports_tool_choice": true + }, "gpt-realtime-mini": { "cache_creation_input_audio_token_cost": 3e-07, "cache_read_input_audio_token_cost": 3e-07, @@ -25092,6 +25157,25 @@ "supports_vision": true, "tool_use_system_prompt_tokens": 159 }, + "openrouter/anthropic/claude-opus-4.6": { + "cache_creation_input_token_cost": 6.25e-06, + "cache_read_input_token_cost": 5e-07, + "input_cost_per_token": 5e-06, + "litellm_provider": "openrouter", + "max_input_tokens": 1000000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 2.5e-05, + "supports_assistant_prefill": true, + "supports_computer_use": true, + "supports_function_calling": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "tool_use_system_prompt_tokens": 346 + }, "openrouter/anthropic/claude-sonnet-4.5": { "input_cost_per_image": 0.0048, "cache_creation_input_token_cost": 3.75e-06, @@ -26104,6 +26188,42 @@ "supports_prompt_caching": true, "supports_computer_use": false }, + "openrouter/openrouter/auto": { + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "litellm_provider": "openrouter", + "max_input_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_vision": true, + "supports_audio_input": true, + "supports_video_input": true + }, + "openrouter/openrouter/free": { + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "litellm_provider": "openrouter", + "max_input_tokens": 200000, + "max_tokens": 200000, + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_vision": true + }, + "openrouter/openrouter/bodybuilder": { + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "litellm_provider": "openrouter", + "max_input_tokens": 128000, + "max_tokens": 128000, + "mode": "chat" + }, "ovhcloud/DeepSeek-R1-Distill-Llama-70B": { "input_cost_per_token": 6.7e-07, "litellm_provider": "ovhcloud", @@ -26650,8 +26770,8 @@ "mode": "chat", "output_cost_per_token": 0.0, "source": "https://platform.publicai.co/docs", - "supports_function_calling": true, - "supports_tool_choice": true + "supports_function_calling": false, + "supports_tool_choice": false }, "publicai/swiss-ai/apertus-70b-instruct": { "input_cost_per_token": 0.0, @@ -26662,8 +26782,8 @@ "mode": "chat", "output_cost_per_token": 0.0, "source": "https://platform.publicai.co/docs", - "supports_function_calling": true, - "supports_tool_choice": true + "supports_function_calling": false, + "supports_tool_choice": false }, "publicai/aisingapore/Gemma-SEA-LION-v4-27B-IT": { "input_cost_per_token": 0.0, @@ -32991,6 +33111,7 @@ "supports_web_search": true }, "xai/grok-2-vision-1212": { + "deprecation_date": "2026-02-28", "input_cost_per_image": 2e-06, "input_cost_per_token": 2e-06, "litellm_provider": "xai", @@ -33095,6 +33216,7 @@ }, "xai/grok-3-mini": { "cache_read_input_token_cost": 7.5e-08, + "deprecation_date": "2026-02-28", "input_cost_per_token": 3e-07, "litellm_provider": "xai", "max_input_tokens": 131072, @@ -33111,6 +33233,7 @@ }, "xai/grok-3-mini-beta": { "cache_read_input_token_cost": 7.5e-08, + "deprecation_date": "2026-02-28", "input_cost_per_token": 3e-07, "litellm_provider": "xai", "max_input_tokens": 131072, diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 7484de33ce4..08213f40b43 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -331,7 +331,7 @@ async def load_servers_from_config( static_headers=server_config.get("static_headers", None), allow_all_keys=bool(server_config.get("allow_all_keys", False)), available_on_public_internet=bool( - server_config.get("available_on_public_internet", False) + server_config.get("available_on_public_internet", True) ), ) self.config_mcp_servers[server_id] = new_server @@ -634,7 +634,7 @@ async def build_mcp_server_from_table( disallowed_tools=getattr(mcp_server, "disallowed_tools", None), allow_all_keys=mcp_server.allow_all_keys, available_on_public_internet=bool( - getattr(mcp_server, "available_on_public_internet", False) + getattr(mcp_server, "available_on_public_internet", True) ), updated_at=getattr(mcp_server, "updated_at", None), ) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index dfc2ba59d96..afedb6c8e72 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -431,6 +431,7 @@ class LiteLLMRoutes(enum.Enum): agent_routes = [ "/v1/agents", + "/v1/agents/{agent_id}", "/agents", "/a2a/{agent_id}", "/a2a/{agent_id}/message/send", @@ -1092,7 +1093,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False - available_on_public_internet: bool = False + available_on_public_internet: bool = True @model_validator(mode="before") @classmethod @@ -1146,7 +1147,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False - available_on_public_internet: bool = False + available_on_public_internet: bool = True @model_validator(mode="before") @classmethod @@ -1203,7 +1204,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): token_url: Optional[str] = None registration_url: Optional[str] = None allow_all_keys: bool = False - available_on_public_internet: bool = False + available_on_public_internet: bool = True class MakeMCPServersPublicRequest(LiteLLMPydanticObjectBase): @@ -2416,6 +2417,7 @@ def get_litellm_internal_jobs_user_api_key_auth(cls) -> "UserAPIKeyAuth": key_alias=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME, team_alias="system", user_id="system", + user_role=LitellmUserRoles.PROXY_ADMIN, ) diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py index b411b81b434..65674d01be7 100644 --- a/litellm/proxy/agent_endpoints/endpoints.py +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -31,6 +31,23 @@ router = APIRouter() +def _check_agent_management_permission(user_api_key_dict: UserAPIKeyAuth) -> None: + """ + Raises HTTP 403 if the caller does not have permission to create, update, + or delete agents. Only PROXY_ADMIN users are allowed to perform these + write operations. + """ + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=403, + detail={ + "error": "Only proxy admins can create, update, or delete agents. Your role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + @router.get( "/v1/agents", tags=["[beta] A2A Agents"], @@ -164,6 +181,8 @@ async def create_agent( """ from litellm.proxy.proxy_server import prisma_client + _check_agent_management_permission(user_api_key_dict) + if prisma_client is None: raise HTTPException(status_code=500, detail="Prisma client not initialized") @@ -302,6 +321,8 @@ async def update_agent( """ from litellm.proxy.proxy_server import prisma_client + _check_agent_management_permission(user_api_key_dict) + if prisma_client is None: raise HTTPException( status_code=500, detail=CommonProxyErrors.db_not_connected_error.value @@ -391,6 +412,8 @@ async def patch_agent( """ from litellm.proxy.proxy_server import prisma_client + _check_agent_management_permission(user_api_key_dict) + if prisma_client is None: raise HTTPException( status_code=500, detail=CommonProxyErrors.db_not_connected_error.value @@ -441,7 +464,10 @@ async def patch_agent( tags=["Agents"], dependencies=[Depends(user_api_key_auth)], ) -async def delete_agent(agent_id: str): +async def delete_agent( + agent_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ Delete an agent @@ -460,6 +486,8 @@ async def delete_agent(agent_id: str): """ from litellm.proxy.proxy_server import prisma_client + _check_agent_management_permission(user_api_key_dict) + if prisma_client is None: raise HTTPException(status_code=500, detail="Prisma client not initialized") diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 9921b74b561..553ba4d6c49 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -8,6 +8,7 @@ import fnmatch import os +import re from typing import Any, List, Literal, Optional, Set, Tuple, cast from cryptography import x509 @@ -235,7 +236,17 @@ def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str return self.litellm_jwtauth.team_id_default else: return default_value - # At this point, team_id is not the sentinel, so it should be a string + # AAD and other IdPs often send roles/groups as a list of strings. + # team_id_jwt_field is singular, so take the first element when a list + # is returned. This avoids "unhashable type: 'list'" errors downstream. + if isinstance(team_id, list): + if not team_id: + return default_value + verbose_proxy_logger.debug( + f"JWT Auth: team_id_jwt_field '{self.litellm_jwtauth.team_id_jwt_field}' " + f"returned a list {team_id}; using first element '{team_id[0]}' automatically." + ) + team_id = team_id[0] return team_id # type: ignore[return-value] elif self.litellm_jwtauth.team_id_default is not None: team_id = self.litellm_jwtauth.team_id_default @@ -453,6 +464,52 @@ def get_scopes(self, token: dict) -> List[str]: scopes = [] return scopes + async def _resolve_jwks_url(self, url: str) -> str: + """ + If url points to an OIDC discovery document (*.well-known/openid-configuration), + fetch it and return the jwks_uri contained within. Otherwise return url unchanged. + This lets JWT_PUBLIC_KEY_URL be set to a well-known discovery endpoint instead of + requiring operators to manually find the JWKS URL. + """ + if ".well-known/openid-configuration" not in url: + return url + + cache_key = f"litellm_oidc_discovery_{url}" + cached_jwks_uri = await self.user_api_key_cache.async_get_cache(cache_key) + if cached_jwks_uri is not None: + return cached_jwks_uri + + verbose_proxy_logger.debug( + f"JWT Auth: Fetching OIDC discovery document from {url}" + ) + response = await self.http_handler.get(url) + if response.status_code != 200: + raise Exception( + f"JWT Auth: OIDC discovery endpoint {url} returned status {response.status_code}: {response.text}" + ) + try: + discovery = response.json() + except Exception as e: + raise Exception( + f"JWT Auth: Failed to parse OIDC discovery document at {url}: {e}" + ) + + jwks_uri = discovery.get("jwks_uri") + if not jwks_uri: + raise Exception( + f"JWT Auth: OIDC discovery document at {url} does not contain a 'jwks_uri' field." + ) + + verbose_proxy_logger.debug( + f"JWT Auth: Resolved OIDC discovery {url} -> jwks_uri={jwks_uri}" + ) + await self.user_api_key_cache.async_set_cache( + key=cache_key, + value=jwks_uri, + ttl=self.litellm_jwtauth.public_key_ttl, + ) + return jwks_uri + async def get_public_key(self, kid: Optional[str]) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") @@ -462,6 +519,7 @@ async def get_public_key(self, kid: Optional[str]) -> dict: keys_url_list = [url.strip() for url in keys_url.split(",")] for key_url in keys_url_list: + key_url = await self._resolve_jwks_url(key_url) cache_key = f"litellm_jwt_auth_keys_{key_url}" cached_keys = await self.user_api_key_cache.async_get_cache(cache_key) @@ -913,8 +971,30 @@ async def find_and_validate_specific_team_id( if jwt_handler.is_required_team_id() is True: team_id_field = jwt_handler.litellm_jwtauth.team_id_jwt_field team_alias_field = jwt_handler.litellm_jwtauth.team_alias_jwt_field + hint = "" + if team_id_field: + # "roles.0" — dot-notation numeric indexing is not supported + if "." in team_id_field: + parts = team_id_field.rsplit(".", 1) + if parts[-1].isdigit(): + base_field = parts[0] + hint = ( + f" Hint: dot-notation array indexing (e.g. '{team_id_field}') is not " + f"supported. Use '{base_field}' instead — LiteLLM automatically " + f"uses the first element when the field value is a list." + ) + # "roles[0]" — bracket-notation indexing is also not supported in get_nested_value + elif "[" in team_id_field and team_id_field.endswith("]"): + m = re.match(r"^(\w+)\[(\d+)\]$", team_id_field) + if m: + base_field = m.group(1) + hint = ( + f" Hint: array indexing (e.g. '{team_id_field}') is not supported " + f"in team_id_jwt_field. Use '{base_field}' instead — LiteLLM " + f"automatically uses the first element when the field value is a list." + ) raise Exception( - f"No team found in token. Checked team_id field '{team_id_field}' and team_alias field '{team_alias_field}'" + f"No team found in token. Checked team_id field '{team_id_field}' and team_alias field '{team_alias_field}'.{hint}" ) return individual_team_id, team_object diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/block_insults_-_contentfilter_(denied_insults.yaml).json b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/insults_cf.json similarity index 100% rename from litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/block_insults_-_contentfilter_(denied_insults.yaml).json rename to litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/insults_cf.json diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/block_investment_-_contentfilter_(denied_financial_advice.yaml).json b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/investment_cf.json similarity index 100% rename from litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/block_investment_-_contentfilter_(denied_financial_advice.yaml).json rename to litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/results/investment_cf.json diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/test_eval.py b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/test_eval.py index 01e820163fd..ca66b4da652 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/test_eval.py +++ b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/guardrail_benchmarks/test_eval.py @@ -18,6 +18,7 @@ import json import os +import re import time from datetime import datetime, timezone from typing import Any, Dict, List @@ -105,7 +106,25 @@ def _print_confusion_report(label: str, metrics: dict, wrong: list) -> None: def _save_confusion_results(label: str, metrics: dict, wrong: list, rows: list) -> dict: """Save confusion matrix results to a JSON file and return the result dict.""" os.makedirs(RESULTS_DIR, exist_ok=True) - safe_label = label.lower().replace(" ", "_").replace("—", "-") + # Build a short, filesystem-safe filename from the label. + # Full label is preserved inside the JSON; filename just needs to be + # unique and recognisable. Format: {topic}_{method_abbrev}.json + parts = label.split("\u2014") + topic = parts[0].strip().lower().replace("block ", "").replace(" ", "_") + method_full = parts[1].strip() if len(parts) > 1 else "" + method_name = re.sub(r"\s*\(.*?\)", "", method_full).strip().lower() + qualifier_match = re.search(r"\(([^)]+)\)", method_full) + qualifier = qualifier_match.group(1) if qualifier_match else "" + qualifier = re.sub(r"\.[a-z]+$", "", qualifier) # drop .yaml etc. + if method_name == "contentfilter": + safe_label = f"{topic}_cf" + elif qualifier: + safe_label = f"{topic}_{method_name}_{qualifier}" + else: + safe_label = f"{topic}_{method_name}" + safe_label = safe_label.replace(" ", "_") + safe_label = re.sub(r"[^a-z0-9_.\-]", "", safe_label) + safe_label = re.sub(r"_+", "_", safe_label).strip("_") result = { "label": label, "timestamp": datetime.now(timezone.utc).isoformat(), diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index 4496ad92631..95b1836d8a9 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -33,6 +33,9 @@ perform_health_check, run_with_timeout, ) +from litellm.proxy.middleware.in_flight_requests_middleware import ( + get_in_flight_requests, +) from litellm.secret_managers.main import get_secret #### Health ENDPOINTS #### @@ -1297,6 +1300,23 @@ async def health_readiness(): raise HTTPException(status_code=503, detail=f"Service Unhealthy ({str(e)})") +@router.get( + "/health/backlog", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], +) +async def health_backlog(): + """ + Returns the number of HTTP requests currently in-flight on this uvicorn worker. + + Use this to measure per-pod queue depth. A high value means the worker is + processing many concurrent requests — requests arriving now will have to wait + for the event loop to get to them, adding latency before LiteLLM even starts + its own timer. + """ + return {"in_flight_requests": get_in_flight_requests()} + + @router.get( "/health/liveliness", # Historical LiteLLM name; doesn't match k8s terminology but kept for backwards compatibility tags=["health"], diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index e535ccaaa46..5a0a05114a3 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -614,7 +614,7 @@ async def user_info( user_id is None and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN ): - return await _get_user_info_for_proxy_admin() + return await _get_user_info_for_proxy_admin(user_api_key_dict=user_api_key_dict) elif user_id is None: user_id = user_api_key_dict.user_id ## GET USER ROW ## @@ -714,7 +714,7 @@ async def user_info( raise handle_exception_on_proxy(e) -async def _get_user_info_for_proxy_admin(): +async def _get_user_info_for_proxy_admin(user_api_key_dict: UserAPIKeyAuth): """ Admin UI Endpoint - Returns All Teams and Keys when Proxy Admin is querying @@ -754,9 +754,23 @@ async def _get_user_info_for_proxy_admin(): _teams_in_db = [LiteLLM_TeamTable(**team) for team in _teams_in_db] _teams_in_db.sort(key=lambda x: (getattr(x, "team_alias", "") or "")) returned_keys = _process_keys_for_user_info(keys=keys_in_db, all_teams=_teams_in_db) + + # Get admin's own user_id and user_info + admin_user_id = user_api_key_dict.user_id + admin_user_info = None + + if admin_user_id is not None: + admin_user_info = await prisma_client.get_data(user_id=admin_user_id) + if admin_user_info is not None: + admin_user_info = ( + admin_user_info.model_dump() + if isinstance(admin_user_info, BaseModel) + else admin_user_info + ) + return UserInfoResponse( - user_id=None, - user_info=None, + user_id=admin_user_id, + user_info=admin_user_info, keys=returned_keys, teams=_teams_in_db, ) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 5b56133f1ce..9414ce6f686 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -3983,6 +3983,10 @@ async def list_keys( status: Optional[str] = Query( None, description="Filter by status (e.g. 'deleted')" ), + project_id: Optional[str] = Query(None, description="Filter keys by project ID"), + access_group_id: Optional[str] = Query( + None, description="Filter keys by access group ID" + ), ) -> KeyListResponseObject: """ List all keys for a given user / team / organization. @@ -4076,6 +4080,8 @@ async def list_keys( sort_order=sort_order, expand=expand, status=status, + project_id=project_id, + access_group_id=access_group_id, ) verbose_proxy_logger.debug("Successfully prepared response") @@ -4252,6 +4258,8 @@ def _build_key_filter_conditions( admin_team_ids: Optional[List[str]], member_team_ids: Optional[List[str]] = None, include_created_by_keys: bool = False, + project_id: Optional[str] = None, + access_group_id: Optional[str] = None, ) -> Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]]: """Build filter conditions for key listing. @@ -4343,6 +4351,13 @@ def _build_key_filter_conditions( elif len(or_conditions) == 1: where.update(or_conditions[0]) + # Apply project_id and access_group_id as global AND filters so they + # narrow results across all visibility conditions (own keys, team keys, etc.) + if project_id: + where = {"AND": [where, {"project_id": project_id}]} + if access_group_id: + where = {"AND": [where, {"access_group_ids": {"hasSome": [access_group_id]}}]} + verbose_proxy_logger.debug(f"Filter conditions: {where}") return where @@ -4369,6 +4384,8 @@ async def _list_key_helper( sort_order: str = "desc", expand: Optional[List[str]] = None, status: Optional[str] = None, + project_id: Optional[str] = None, + access_group_id: Optional[str] = None, ) -> KeyListResponseObject: """ Helper function to list keys @@ -4402,6 +4419,8 @@ async def _list_key_helper( admin_team_ids=admin_team_ids, member_team_ids=member_team_ids, include_created_by_keys=include_created_by_keys, + project_id=project_id, + access_group_id=access_group_id, ) # Calculate skip for pagination diff --git a/litellm/proxy/middleware/in_flight_requests_middleware.py b/litellm/proxy/middleware/in_flight_requests_middleware.py new file mode 100644 index 00000000000..d615640d870 --- /dev/null +++ b/litellm/proxy/middleware/in_flight_requests_middleware.py @@ -0,0 +1,81 @@ +""" +Tracks the number of HTTP requests currently in-flight on this uvicorn worker. + +Used by /health/backlog to expose per-pod queue depth, and emitted as the +Prometheus gauge `litellm_in_flight_requests`. +""" + +import os +from typing import Optional + +from starlette.types import ASGIApp, Receive, Scope, Send + + +class InFlightRequestsMiddleware: + """ + ASGI middleware that increments a counter when a request arrives and + decrements it when the response is sent (or an error occurs). + + The counter is class-level and therefore scoped to a single uvicorn worker + process — exactly the per-pod granularity we want. + + Also updates the `litellm_in_flight_requests` Prometheus gauge if + prometheus_client is installed. The gauge is lazily initialised on the + first request so that PROMETHEUS_MULTIPROC_DIR is already set by the time + we register the metric. Initialisation is attempted only once — if + prometheus_client is absent the class remembers and never retries. + """ + + _in_flight: int = 0 + _gauge: Optional[object] = None + _gauge_init_attempted: bool = False + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + InFlightRequestsMiddleware._in_flight += 1 + gauge = InFlightRequestsMiddleware._get_gauge() + if gauge is not None: + gauge.inc() # type: ignore[union-attr] + try: + await self.app(scope, receive, send) + finally: + InFlightRequestsMiddleware._in_flight -= 1 + if gauge is not None: + gauge.dec() # type: ignore[union-attr] + + @staticmethod + def get_count() -> int: + """Return the number of HTTP requests currently in-flight.""" + return InFlightRequestsMiddleware._in_flight + + @staticmethod + def _get_gauge() -> Optional[object]: + if InFlightRequestsMiddleware._gauge_init_attempted: + return InFlightRequestsMiddleware._gauge + InFlightRequestsMiddleware._gauge_init_attempted = True + try: + from prometheus_client import Gauge + + kwargs = {} + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + # livesum aggregates across all worker processes in the scrape response + kwargs["multiprocess_mode"] = "livesum" + InFlightRequestsMiddleware._gauge = Gauge( + "litellm_in_flight_requests", + "Number of HTTP requests currently in-flight on this uvicorn worker", + **kwargs, + ) + except Exception: + InFlightRequestsMiddleware._gauge = None + return InFlightRequestsMiddleware._gauge + + +def get_in_flight_requests() -> int: + """Module-level convenience wrapper used by the /health/backlog endpoint.""" + return InFlightRequestsMiddleware.get_count() diff --git a/litellm/proxy/prometheus_cleanup.py b/litellm/proxy/prometheus_cleanup.py index 6d935a8dd90..6353588532a 100644 --- a/litellm/proxy/prometheus_cleanup.py +++ b/litellm/proxy/prometheus_cleanup.py @@ -28,3 +28,20 @@ def wipe_directory(directory: str) -> None: verbose_proxy_logger.info( f"Prometheus cleanup: wiped {deleted} stale .db files from {directory}" ) + + +def mark_worker_exit(worker_pid: int) -> None: + """Remove prometheus .db files for a dead worker. Called by gunicorn child_exit hook.""" + if not os.environ.get("PROMETHEUS_MULTIPROC_DIR"): + return + try: + from prometheus_client import multiprocess + + multiprocess.mark_process_dead(worker_pid) + verbose_proxy_logger.info( + f"Prometheus cleanup: marked worker {worker_pid} as dead" + ) + except Exception as e: + verbose_proxy_logger.warning( + f"Failed to mark prometheus worker {worker_pid} as dead: {e}" + ) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f5163114983..921d86c35c1 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -277,6 +277,15 @@ def load(self): if max_requests_before_restart is not None: gunicorn_options["max_requests"] = max_requests_before_restart + # Clean up prometheus .db files when a worker exits (prevents ghost gauge values) + if os.environ.get("PROMETHEUS_MULTIPROC_DIR"): + from litellm.proxy.prometheus_cleanup import mark_worker_exit + + def child_exit(server, worker): + mark_worker_exit(worker.pid) + + gunicorn_options["child_exit"] = child_exit + if ssl_certfile_path is not None and ssl_keyfile_path is not None: print( # noqa f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" # noqa diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index be76c2ac5fb..bd5b5309e0f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -424,6 +424,9 @@ def generate_feedback_box(): router as user_agent_analytics_router, ) from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update +from litellm.proxy.middleware.in_flight_requests_middleware import ( + InFlightRequestsMiddleware, +) from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMiddleware from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router from litellm.proxy.openai_evals_endpoints.endpoints import router as evals_router @@ -1404,6 +1407,7 @@ def _restructure_ui_html_files(ui_root: str) -> None: ) app.add_middleware(PrometheusAuthMiddleware) +app.add_middleware(InFlightRequestsMiddleware) def mount_swagger_ui(): @@ -5298,13 +5302,15 @@ async def async_data_generator( ): verbose_proxy_logger.debug("inside generator") try: - # Use a list to accumulate response segments to avoid O(n^2) string concatenation - str_so_far_parts: list[str] = [] error_message: Optional[str] = None requested_model_from_client = _get_client_requested_model_for_streaming( request_data=request_data ) model_mismatch_logged = False + # Use a running string instead of list + join to avoid O(n^2) overhead. + # Previously "".join(str_so_far_parts) was called every chunk, re-joining + # the entire accumulated response. String += is O(n) amortized total. + _str_so_far: str = "" async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook( user_api_key_dict=user_api_key_dict, response=response, @@ -5315,12 +5321,12 @@ async def async_data_generator( user_api_key_dict=user_api_key_dict, response=chunk, data=request_data, - str_so_far="".join(str_so_far_parts), + str_so_far=_str_so_far if _str_so_far else None, ) if isinstance(chunk, (ModelResponse, ModelResponseStream)): response_str = litellm.get_response_string(response_obj=chunk) - str_so_far_parts.append(response_str) + _str_so_far += response_str chunk, model_mismatch_logged = _restamp_streaming_chunk_model( chunk=chunk, diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index a5b0d930f58..f18556ac329 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -300,7 +300,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) - available_on_public_internet Boolean @default(false) + available_on_public_internet Boolean @default(true) } // Generate Tokens for Proxy diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index f6613b5548f..5e0d5336aa9 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -23,23 +23,31 @@ ) from litellm import _custom_logger_compatible_callbacks_literal -from litellm.constants import (DEFAULT_MODEL_CREATED_AT_TIME, - MAX_TEAM_LIST_LIMIT) -from litellm.proxy._types import (DB_CONNECTION_ERROR_TYPES, CommonProxyErrors, - ProxyErrorTypes, ProxyException, - SpendLogsMetadata, SpendLogsPayload) +from litellm.constants import DEFAULT_MODEL_CREATED_AT_TIME, MAX_TEAM_LIST_LIMIT +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + CommonProxyErrors, + ProxyErrorTypes, + ProxyException, + SpendLogsMetadata, + SpendLogsPayload, +) from litellm.types.guardrails import GuardrailEventHooks from litellm.types.utils import CallTypes, CallTypesLiteral try: - from litellm_enterprise.enterprise_callbacks.send_emails.base_email import \ - BaseEmailLogger - from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import \ - ResendEmailLogger - from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import \ - SendGridEmailLogger - from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import \ - SMTPEmailLogger + from litellm_enterprise.enterprise_callbacks.send_emails.base_email import ( + BaseEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( + ResendEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import ( + SendGridEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import ( + SMTPEmailLogger, + ) except ImportError: BaseEmailLogger = None # type: ignore SendGridEmailLogger = None # type: ignore @@ -58,56 +66,70 @@ import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging -from litellm import (EmbeddingResponse, ImageResponse, ModelResponse, - ModelResponseStream, Router) +from litellm import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, + Router, +) from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.exceptions import RejectedRequestError -from litellm.integrations.custom_guardrail import (CustomGuardrail, - ModifyResponseException) +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting -from litellm.integrations.SlackAlerting.utils import \ - _add_langfuse_trace_id_to_alert +from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.litellm_core_utils.safe_json_loads import safe_json_loads from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler -from litellm.proxy._types import (AlertType, CallInfo, - LiteLLM_VerificationTokenView, Member, - UserAPIKeyAuth) +from litellm.proxy._types import ( + AlertType, + CallInfo, + LiteLLM_VerificationTokenView, + Member, + UserAPIKeyAuth, +) from litellm.proxy.auth.route_checks import RouteChecks -from litellm.proxy.db.create_views import (create_missing_views, - should_create_missing_views) +from litellm.proxy.db.create_views import ( + create_missing_views, + should_create_missing_views, +) from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler from litellm.proxy.db.log_db_metrics import log_db_metrics from litellm.proxy.db.prisma_client import PrismaWrapper -from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import \ - UnifiedLLMGuardrails +from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, +) from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter -from litellm.proxy.hooks.parallel_request_limiter import \ - _PROXY_MaxParallelRequestsHandler +from litellm.proxy.hooks.parallel_request_limiter import ( + _PROXY_MaxParallelRequestsHandler, +) from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES -from litellm.types.mcp import (MCPDuringCallResponseObject, - MCPPreCallRequestObject, - MCPPreCallResponseObject) -from litellm.types.proxy.policy_engine.pipeline_types import \ - PipelineExecutionResult +from litellm.types.mcp import ( + MCPDuringCallResponseObject, + MCPPreCallRequestObject, + MCPPreCallResponseObject, +) +from litellm.types.proxy.policy_engine.pipeline_types import PipelineExecutionResult from litellm.types.utils import LLMResponseTypes, LoggedLiteLLMParams if TYPE_CHECKING: from opentelemetry.trace import Span as _Span - from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj Span = Union[_Span, Any] else: @@ -1050,9 +1072,10 @@ async def _process_prompt_template( """Process prompt template if applicable.""" from litellm.proxy.prompts.prompt_endpoints import ( - construct_versioned_prompt_id, get_latest_version_prompt_id) - from litellm.proxy.prompts.prompt_registry import \ - IN_MEMORY_PROMPT_REGISTRY + construct_versioned_prompt_id, + get_latest_version_prompt_id, + ) + from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY from litellm.utils import get_non_default_completion_params if prompt_version is None: @@ -1102,8 +1125,9 @@ async def _process_prompt_template( def _process_guardrail_metadata(self, data: dict) -> None: """Process guardrails from metadata and add to applied_guardrails.""" - from litellm.proxy.common_utils.callback_utils import \ - add_guardrail_to_applied_guardrails_header + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) metadata_standard = data.get("metadata") or {} metadata_litellm = data.get("litellm_metadata") or {} @@ -2000,27 +2024,32 @@ async def async_post_call_streaming_hook( if isinstance(response, (ModelResponse, ModelResponseStream)): response_str = litellm.get_response_string(response_obj=response) elif isinstance(response, dict) and self.is_a2a_streaming_response(response): - from litellm.llms.a2a.common_utils import \ - extract_text_from_a2a_response + from litellm.llms.a2a.common_utils import extract_text_from_a2a_response response_str = extract_text_from_a2a_response(response) if response_str is not None: + # Cache model-level guardrails check per-request to avoid repeated + # dict lookups + llm_router.get_deployment() per callback per chunk. + _cached_guardrail_data: Optional[dict] = None + _guardrail_data_computed = False + for callback in litellm.callbacks: try: _callback: Optional[CustomLogger] = None if isinstance(callback, CustomGuardrail): # Main - V2 Guardrails implementation - from litellm.types.guardrails import \ - GuardrailEventHooks + from litellm.types.guardrails import GuardrailEventHooks - ## CHECK FOR MODEL-LEVEL GUARDRAILS - modified_data = _check_and_merge_model_level_guardrails( - data=data, llm_router=llm_router - ) + ## CHECK FOR MODEL-LEVEL GUARDRAILS (cached per-request) + if not _guardrail_data_computed: + _cached_guardrail_data = _check_and_merge_model_level_guardrails( + data=data, llm_router=llm_router + ) + _guardrail_data_computed = True if ( callback.should_run_guardrail( - data=modified_data, + data=_cached_guardrail_data, event_type=GuardrailEventHooks.post_call, ) is not True @@ -4626,8 +4655,9 @@ async def update_spend_logs_job( # Guardrail/policy usage tracking (same batch, outside spend-logs update) try: - from litellm.proxy.guardrails.usage_tracking import \ - process_spend_logs_guardrail_usage + from litellm.proxy.guardrails.usage_tracking import ( + process_spend_logs_guardrail_usage, + ) await process_spend_logs_guardrail_usage( prisma_client=prisma_client, logs_to_process=logs_to_process, @@ -4653,8 +4683,10 @@ async def _monitor_spend_logs_queue( db_writer_client: Optional HTTP handler for external spend logs endpoint proxy_logging_obj: Proxy logging object """ - from litellm.constants import (SPEND_LOG_QUEUE_POLL_INTERVAL, - SPEND_LOG_QUEUE_SIZE_THRESHOLD) + from litellm.constants import ( + SPEND_LOG_QUEUE_POLL_INTERVAL, + SPEND_LOG_QUEUE_SIZE_THRESHOLD, + ) threshold = SPEND_LOG_QUEUE_SIZE_THRESHOLD base_interval = SPEND_LOG_QUEUE_POLL_INTERVAL @@ -5175,11 +5207,12 @@ async def get_available_models_for_user( List of model names available to the user """ from litellm.proxy.auth.auth_checks import get_team_object - from litellm.proxy.auth.model_checks import (get_complete_model_list, - get_key_models, - get_team_models) - from litellm.proxy.management_endpoints.team_endpoints import \ - validate_membership + from litellm.proxy.auth.model_checks import ( + get_complete_model_list, + get_key_models, + get_team_models, + ) + from litellm.proxy.management_endpoints.team_endpoints import validate_membership # Get proxy model list and access groups if llm_router is None: diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index 3e64f61abdb..83ab63ef146 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -1,5 +1,6 @@ """Abstraction function for OpenAI's realtime API""" +import os from typing import Any, Optional, cast import litellm @@ -132,6 +133,8 @@ async def _arealtime( # noqa: PLR0915 realtime_protocol = ( kwargs.get("realtime_protocol") + or litellm_params.get("realtime_protocol") + or os.environ.get("LITELLM_AZURE_REALTIME_PROTOCOL") or "beta" ) await azure_realtime.async_realtime( diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 276e0cb9a04..a627531e994 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -177,6 +177,19 @@ async def aresponses_api_with_mcp( "litellm_metadata", {} ).get("user_api_key_auth") + # Extract MCP auth headers from request (for dynamic auth when fetching tools) + mcp_auth_header: Optional[str] = None + mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None + secret_fields = kwargs.get("secret_fields") + if secret_fields and isinstance(secret_fields, dict): + from litellm.responses.utils import ResponsesAPIRequestUtils + + mcp_auth_header, mcp_server_auth_headers, _, _ = ( + ResponsesAPIRequestUtils.extract_mcp_headers_from_request( + secret_fields=secret_fields, tools=tools + ) + ) + # Get original MCP tools (for events) and OpenAI tools (for LLM) by reusing existing methods ( original_mcp_tools, @@ -185,6 +198,8 @@ async def aresponses_api_with_mcp( user_api_key_auth=user_api_key_auth, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, litellm_trace_id=kwargs.get("litellm_trace_id"), + mcp_auth_header=mcp_auth_header, + mcp_server_auth_headers=mcp_server_auth_headers, ) openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai( original_mcp_tools @@ -370,6 +385,8 @@ async def aresponses_api_with_mcp( ) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform( user_api_key_auth=user_api_key_auth, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, + mcp_auth_header=mcp_auth_header, + mcp_server_auth_headers=mcp_server_auth_headers, ) final_response = ( LiteLLM_Proxy_MCP_Handler._add_mcp_output_elements_to_response( diff --git a/litellm/responses/mcp/chat_completions_handler.py b/litellm/responses/mcp/chat_completions_handler.py index 377ce396457..bacc627cc84 100644 --- a/litellm/responses/mcp/chat_completions_handler.py +++ b/litellm/responses/mcp/chat_completions_handler.py @@ -120,7 +120,18 @@ async def acompletion_with_mcp( # noqa: PLR0915 (kwargs.get("metadata", {}) or {}).get("user_api_key_auth") ) - # Process MCP tools + # Extract MCP auth headers before fetching tools (needed for dynamic auth) + ( + mcp_auth_header, + mcp_server_auth_headers, + oauth2_headers, + raw_headers, + ) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request( + secret_fields=kwargs.get("secret_fields"), + tools=tools, + ) + + # Process MCP tools (pass auth headers for dynamic auth) ( deduplicated_mcp_tools, tool_server_map, @@ -128,6 +139,8 @@ async def acompletion_with_mcp( # noqa: PLR0915 user_api_key_auth=user_api_key_auth, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, litellm_trace_id=kwargs.get("litellm_trace_id"), + mcp_auth_header=mcp_auth_header, + mcp_server_auth_headers=mcp_server_auth_headers, ) openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai( @@ -143,17 +156,6 @@ async def acompletion_with_mcp( # noqa: PLR0915 mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy ) - # Extract MCP auth headers - ( - mcp_auth_header, - mcp_server_auth_headers, - oauth2_headers, - raw_headers, - ) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request( - secret_fields=kwargs.get("secret_fields"), - tools=tools, - ) - # Prepare call parameters # Remove keys that shouldn't be passed to acompletion clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]} diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 805a1958552..5776ef95acb 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -99,6 +99,8 @@ async def _get_mcp_tools_from_manager( user_api_key_auth: Any, mcp_tools_with_litellm_proxy: Optional[Iterable[ToolParam]], litellm_trace_id: Optional[str] = None, + mcp_auth_header: Optional[str] = None, + mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, ) -> tuple[List[MCPTool], List[str]]: """ Get available tools from the MCP server manager. @@ -106,6 +108,8 @@ async def _get_mcp_tools_from_manager( Args: user_api_key_auth: User authentication info for access control mcp_tools_with_litellm_proxy: ToolParam objects with server_url starting with "litellm_proxy" + mcp_auth_header: Optional deprecated auth header for MCP servers + mcp_server_auth_headers: Optional server-specific auth headers (e.g. from x-mcp-{alias}-*) Returns: List of MCP tools @@ -133,13 +137,14 @@ async def _get_mcp_tools_from_manager( tools = await _get_tools_from_mcp_servers( user_api_key_auth=user_api_key_auth, - mcp_auth_header=None, + mcp_auth_header=mcp_auth_header, mcp_servers=mcp_servers, - mcp_server_auth_headers=None, + mcp_server_auth_headers=mcp_server_auth_headers, log_list_tools_to_spendlogs=True, list_tools_log_source="responses", litellm_trace_id=litellm_trace_id, ) + allowed_mcp_server_ids = ( await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth) ) @@ -278,6 +283,8 @@ async def _process_mcp_tools_without_openai_transform( user_api_key_auth: Any, mcp_tools_with_litellm_proxy: List[ToolParam], litellm_trace_id: Optional[str] = None, + mcp_auth_header: Optional[str] = None, + mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, ) -> tuple[List[Any], dict[str, str]]: """ Process MCP tools through filtering and deduplication pipeline without OpenAI transformation. @@ -286,6 +293,8 @@ async def _process_mcp_tools_without_openai_transform( Args: user_api_key_auth: User authentication info for access control mcp_tools_with_litellm_proxy: ToolParam objects with server_url starting with "litellm_proxy" + mcp_auth_header: Optional deprecated auth header for MCP servers + mcp_server_auth_headers: Optional server-specific auth headers (e.g. from x-mcp-{alias}-*) Returns: List of filtered and deduplicated MCP tools in their original format @@ -301,6 +310,8 @@ async def _process_mcp_tools_without_openai_transform( user_api_key_auth=user_api_key_auth, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, litellm_trace_id=litellm_trace_id, + mcp_auth_header=mcp_auth_header, + mcp_server_auth_headers=mcp_server_auth_headers, ) # Step 2: Filter tools based on allowed_tools parameter diff --git a/litellm/types/integrations/prometheus.py b/litellm/types/integrations/prometheus.py index 2c75276d9ca..0856d8a6f9b 100644 --- a/litellm/types/integrations/prometheus.py +++ b/litellm/types/integrations/prometheus.py @@ -237,6 +237,7 @@ class UserAPIKeyLabelNames(Enum): "litellm_remaining_api_key_tokens_for_model", "litellm_llm_api_failed_requests_metric", "litellm_callback_logging_failures_metric", + "litellm_in_flight_requests", ] diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 7f99fd526c8..69b34a25a21 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -52,7 +52,7 @@ class MCPServer(BaseModel): env: Optional[Dict[str, str]] = None access_groups: Optional[List[str]] = None allow_all_keys: bool = False - available_on_public_internet: bool = False + available_on_public_internet: bool = True updated_at: Optional[datetime] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/schema.prisma b/schema.prisma index bc32a8cce32..a8cb297a3ed 100644 --- a/schema.prisma +++ b/schema.prisma @@ -300,7 +300,7 @@ model LiteLLM_MCPServerTable { token_url String? registration_url String? allow_all_keys Boolean @default(false) - available_on_public_internet Boolean @default(false) + available_on_public_internet Boolean @default(true) } // Generate Tokens for Proxy diff --git a/tests/litellm/llms/openai_like/test_assemblyai_provider.py b/tests/litellm/llms/openai_like/test_assemblyai_provider.py new file mode 100644 index 00000000000..7eee810b271 --- /dev/null +++ b/tests/litellm/llms/openai_like/test_assemblyai_provider.py @@ -0,0 +1,77 @@ +""" +Unit tests for the AssemblyAI LLM Gateway OpenAI-like provider. +""" + +import os +import sys + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) +) + +from litellm.llms.openai_like.dynamic_config import create_config_class +from litellm.llms.openai_like.json_loader import JSONProviderRegistry + +ASSEMBLYAI_BASE_URL = "https://llm-gateway.assemblyai.com/v1" + + +def _get_config(): + provider = JSONProviderRegistry.get("assemblyai") + assert provider is not None + config_class = create_config_class(provider) + return config_class() + + +def test_assemblyai_provider_registered(): + provider = JSONProviderRegistry.get("assemblyai") + assert provider is not None + assert provider.base_url == ASSEMBLYAI_BASE_URL + assert provider.api_key_env == "ASSEMBLYAI_API_KEY" + + +def test_assemblyai_resolves_env_api_key(monkeypatch): + config = _get_config() + monkeypatch.setenv("ASSEMBLYAI_API_KEY", "test-key") + api_base, api_key = config._get_openai_compatible_provider_info(None, None) + assert api_base == ASSEMBLYAI_BASE_URL + assert api_key == "test-key" + + +def test_assemblyai_complete_url_appends_endpoint(): + config = _get_config() + url = config.get_complete_url( + api_base=ASSEMBLYAI_BASE_URL, + api_key="test-key", + model="assemblyai/claude-sonnet-4-5-20250929", + optional_params={}, + litellm_params={}, + stream=False, + ) + assert url == f"{ASSEMBLYAI_BASE_URL}/chat/completions" + + +def test_assemblyai_provider_resolution(): + from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider + + model, provider, api_key, api_base = get_llm_provider( + model="assemblyai/claude-sonnet-4-5-20250929", + custom_llm_provider=None, + api_base=None, + api_key=None, + ) + + assert model == "claude-sonnet-4-5-20250929" + assert provider == "assemblyai" + assert api_base == ASSEMBLYAI_BASE_URL + + +def test_assemblyai_provider_config_manager(): + from litellm import LlmProviders + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_chat_config( + model="claude-sonnet-4-5-20250929", provider=LlmProviders.ASSEMBLYAI + ) + + assert config is not None + assert config.custom_llm_provider == "assemblyai" diff --git a/tests/llm_translation/test_skills_data/slack-gif-creator.zip b/tests/llm_translation/test_skills_data/slack-gif-creator.zip index 15c60e3667d..9827db9ac2d 100644 Binary files a/tests/llm_translation/test_skills_data/slack-gif-creator.zip and b/tests/llm_translation/test_skills_data/slack-gif-creator.zip differ diff --git a/tests/llm_translation/test_skills_data/slack-gif-creator/LICENSE.txt b/tests/llm_translation/test_skills_data/slack-gif-creator/LICENSE.txt deleted file mode 100644 index 7a4a3ea2424..00000000000 --- a/tests/llm_translation/test_skills_data/slack-gif-creator/LICENSE.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/tests/llm_translation/test_skills_data/slack-gif-creator/SKILL.md b/tests/llm_translation/test_skills_data/slack-gif-creator/SKILL.md index 16660d8ceb7..3cae971b731 100644 --- a/tests/llm_translation/test_skills_data/slack-gif-creator/SKILL.md +++ b/tests/llm_translation/test_skills_data/slack-gif-creator/SKILL.md @@ -1,7 +1,6 @@ --- name: slack-gif-creator description: Knowledge and utilities for creating animated GIFs optimized for Slack. Provides constraints, validation tools, and animation concepts. Use when users request animated GIFs for Slack like "make me a GIF of X doing Y for Slack." -license: Complete terms in LICENSE.txt --- # Slack GIF Creator diff --git a/tests/mcp_tests/test_aresponses_api_with_mcp.py b/tests/mcp_tests/test_aresponses_api_with_mcp.py index bae0b15dfec..c22c3537af8 100644 --- a/tests/mcp_tests/test_aresponses_api_with_mcp.py +++ b/tests/mcp_tests/test_aresponses_api_with_mcp.py @@ -3,6 +3,7 @@ import sys import pytest from typing import List, Any, cast +from unittest.mock import AsyncMock, patch sys.path.insert(0, os.path.abspath("../../..")) @@ -254,6 +255,76 @@ async def test_aresponses_api_with_mcp_mock_integration(): print(f"Other tools parsed: {len(other_parsed)}") +@pytest.mark.asyncio +async def test_aresponses_api_with_mcp_passes_mcp_server_auth_headers_to_process_tools(): + """ + Test that MCP auth headers from secret_fields (e.g. x-mcp-linear_config-authorization) + are passed to _process_mcp_tools_without_openai_transform when using the responses API. + """ + from litellm.responses.main import aresponses_api_with_mcp + + captured_process_kwargs = {} + + async def mock_process(**kwargs): + captured_process_kwargs.update(kwargs) + return ([], {}) + + mock_response = ResponsesAPIResponse( + **{ + "id": "resp_test", + "object": "response", + "created_at": 1234567890, + "status": "completed", + "error": None, + "incomplete_details": None, + "instructions": None, + "max_output_tokens": None, + "model": "gpt-4o", + "output": [{"type": "message", "id": "msg_1", "status": "completed", "role": "assistant", "content": []}], + "parallel_tool_calls": True, + "previous_response_id": None, + "reasoning": {"effort": None, "summary": None}, + "store": True, + "temperature": 1.0, + "text": {"format": {"type": "text"}}, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "truncation": "disabled", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + "user": None, + "metadata": {}, + } + ) + + mcp_tools = [{"type": "mcp", "server_url": "litellm_proxy"}] + secret_fields = { + "raw_headers": {"x-mcp-linear_config-authorization": "Bearer linear-token"}, + } + + with patch.object( + LiteLLM_Proxy_MCP_Handler, + "_process_mcp_tools_without_openai_transform", + mock_process, + ), patch( + "litellm.responses.main.aresponses", + new_callable=AsyncMock, + return_value=mock_response, + ): + await aresponses_api_with_mcp( + input=[{"role": "user", "type": "message", "content": "hi"}], + model="gpt-4o", + tools=mcp_tools, + secret_fields=secret_fields, + ) + + assert "mcp_server_auth_headers" in captured_process_kwargs + mcp_server_auth_headers = captured_process_kwargs["mcp_server_auth_headers"] + assert mcp_server_auth_headers is not None + assert "linear_config" in mcp_server_auth_headers + assert mcp_server_auth_headers["linear_config"]["Authorization"] == "Bearer linear-token" + + @pytest.mark.asyncio async def test_mcp_allowed_tools_filtering(): """ diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index dc1e2068365..a81702d0db3 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -659,9 +659,9 @@ async def test_list_tools_rest_api_server_not_found(): mock_manager.get_allowed_mcp_servers = AsyncMock( return_value=["non_existent_server_id"] ) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) # Return None when trying to get the server (server doesn't exist) mock_manager.get_mcp_server_by_id = MagicMock(return_value=None) @@ -732,9 +732,9 @@ async def test_list_tools_rest_api_success(): return_value=["test-server-123"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) # Mock the _get_tools_for_single_server function @@ -814,9 +814,9 @@ def mock_get_server_by_id(server_id): ) mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else mock_server_2 mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) with patch( @@ -853,9 +853,9 @@ async def mock_get_tools_side_effect( mock_manager_2._get_tools_from_server = AsyncMock( side_effect=mock_get_tools_side_effect ) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager_2.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager_2.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) with patch( @@ -881,9 +881,9 @@ async def mock_get_tools_side_effect( ) mock_manager.get_mcp_server_by_id = lambda server_id: mock_server_1 if server_id == "server1_id" else (mock_server_2 if server_id == "server2_id" else mock_server_3) mock_manager._get_tools_from_server = AsyncMock(return_value=[mock_tool_1]) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) with patch( @@ -1817,9 +1817,9 @@ async def test_list_tool_rest_api_with_server_specific_auth(): mock_server.mcp_info = {"server_name": "zapier"} mock_manager.get_mcp_server_by_id.return_value = mock_server - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) mock_user_api_key_dict = UserAPIKeyAuth( @@ -1911,9 +1911,9 @@ async def test_list_tool_rest_api_with_default_auth(): mock_server.mcp_info = {"server_name": "unknown_server"} mock_manager.get_mcp_server_by_id.return_value = mock_server - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) mock_user_api_key_dict = UserAPIKeyAuth( @@ -2021,9 +2021,9 @@ async def test_list_tool_rest_api_all_servers_with_auth(): server_id ) ) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) mock_user_api_key_dict = UserAPIKeyAuth( @@ -2154,9 +2154,9 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-123"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) # Mock the _get_tools_from_server method to return all tools @@ -2268,9 +2268,9 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-456"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) # Mock the _get_tools_from_server method to return all tools mock_manager._get_tools_from_server = AsyncMock(return_value=mock_tools) @@ -2368,9 +2368,9 @@ def mock_client_constructor(*args, **kwargs): return_value=["test-server-000"] ) mock_manager.get_mcp_server_by_id = MagicMock(return_value=mock_server) - # Mock filter_server_ids_by_ip to return input unchanged (no IP filtering in test) - mock_manager.filter_server_ids_by_ip = MagicMock( - side_effect=lambda server_ids, client_ip: server_ids + # Mock filter_server_ids_by_ip_with_info to return input unchanged (no IP filtering in test) + mock_manager.filter_server_ids_by_ip_with_info = MagicMock( + side_effect=lambda server_ids, client_ip: (server_ids, 0) ) # Mock the _get_tools_from_server method to return all tools diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index cc1cc278392..fd97a38b41e 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2,7 +2,8 @@ import json import os import sys -from typing import Any, Dict, List, Optional +from datetime import datetime +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock import pytest @@ -1486,12 +1487,46 @@ def __init__( mock_key_data, ): self.db = MockDb(mock_team_data, mock_key_data) + + async def get_data( + self, + token: Optional[Union[str, list]] = None, + user_id: Optional[str] = None, + user_id_list: Optional[list] = None, + team_id: Optional[str] = None, + team_id_list: Optional[list] = None, + key_val: Optional[dict] = None, + table_name: Optional[str] = None, + query_type: str = "find_unique", + expires: Optional[datetime] = None, + reset_at: Optional[datetime] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + ): + """Mock get_data method to return user info for admin""" + from litellm.proxy._types import LiteLLM_UserTable + + # Return a proper LiteLLM_UserTable object when querying by user_id + if user_id: + return LiteLLM_UserTable( + user_id=user_id, + user_role="proxy_admin", + spend=0.0, + max_budget=None, + ) + return None @pytest.mark.asyncio async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): # Patch the prisma_client import - from litellm.proxy._types import UserInfoResponse + from litellm.proxy._types import UserAPIKeyAuth, UserInfoResponse + + # Create a mock user_api_key_dict for admin user + mock_user_api_key_dict = UserAPIKeyAuth( + user_id="admin_user_123", + user_role="proxy_admin", + ) with patch( "litellm.proxy.proxy_server.prisma_client", @@ -1502,11 +1537,18 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): ) # Execute the function - result = await _get_user_info_for_proxy_admin() + result = await _get_user_info_for_proxy_admin( + user_api_key_dict=mock_user_api_key_dict + ) # Verify the result structure assert isinstance(result, UserInfoResponse) assert len(result.keys) == 2 + # Verify admin's user_id is populated + assert result.user_id == "admin_user_123" + # Verify admin's user_info is populated + assert result.user_info is not None + assert result.user_info["user_id"] == "admin_user_123" def test_custom_openid_response(): diff --git a/tests/test_litellm/caching/test_llm_caching_handler.py b/tests/test_litellm/caching/test_llm_caching_handler.py new file mode 100644 index 00000000000..0ac4ac5de79 --- /dev/null +++ b/tests/test_litellm/caching/test_llm_caching_handler.py @@ -0,0 +1,158 @@ +import asyncio +import os +import sys +import warnings + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.caching.llm_caching_handler import LLMClientCache + + +class MockAsyncClient: + """Mock async HTTP client with an async close method.""" + + def __init__(self): + self.closed = False + + async def close(self): + self.closed = True + + +class MockSyncClient: + """Mock sync HTTP client with a sync close method.""" + + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + +@pytest.mark.asyncio +async def test_remove_key_no_unawaited_coroutine_warning(): + """ + Test that evicting an async client from LLMClientCache does not produce + 'coroutine was never awaited' warnings. + + Regression test for https://github.com/BerriAI/litellm/issues/22128 + """ + cache = LLMClientCache(max_size_in_memory=2) + + mock_client = MockAsyncClient() + cache.cache_dict["test-key"] = mock_client + cache.ttl_dict["test-key"] = 0 # expired + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + cache._remove_key("test-key") + # Let the event loop process the close task + await asyncio.sleep(0.1) + + coroutine_warnings = [ + w for w in caught_warnings if "coroutine" in str(w.message).lower() + ] + assert ( + len(coroutine_warnings) == 0 + ), f"Got unawaited coroutine warnings: {coroutine_warnings}" + + +@pytest.mark.asyncio +async def test_remove_key_closes_async_client(): + """ + Test that evicting an async client from the cache properly closes it. + """ + cache = LLMClientCache(max_size_in_memory=2) + + mock_client = MockAsyncClient() + cache.cache_dict["test-key"] = mock_client + cache.ttl_dict["test-key"] = 0 + + cache._remove_key("test-key") + # Let the event loop process the close task + await asyncio.sleep(0.1) + + assert mock_client.closed is True + assert "test-key" not in cache.cache_dict + assert "test-key" not in cache.ttl_dict + + +def test_remove_key_closes_sync_client(): + """ + Test that evicting a sync client from the cache properly closes it. + """ + cache = LLMClientCache(max_size_in_memory=2) + + mock_client = MockSyncClient() + cache.cache_dict["test-key"] = mock_client + cache.ttl_dict["test-key"] = 0 + + cache._remove_key("test-key") + + assert mock_client.closed is True + assert "test-key" not in cache.cache_dict + + +@pytest.mark.asyncio +async def test_eviction_closes_async_clients(): + """ + Test that cache eviction (when cache is full) properly closes async clients + without producing warnings. + """ + cache = LLMClientCache(max_size_in_memory=2, default_ttl=1) + + clients = [] + for i in range(2): + client = MockAsyncClient() + clients.append(client) + cache.set_cache(f"key-{i}", client) + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + # This should trigger eviction of one of the existing entries + cache.set_cache("key-new", "new-value") + await asyncio.sleep(0.1) + + coroutine_warnings = [ + w for w in caught_warnings if "coroutine" in str(w.message).lower() + ] + assert ( + len(coroutine_warnings) == 0 + ), f"Got unawaited coroutine warnings: {coroutine_warnings}" + + +def test_remove_key_no_event_loop(): + """ + Test that _remove_key doesn't raise when there's no running event loop + (falls through to the RuntimeError except branch). + """ + cache = LLMClientCache(max_size_in_memory=2) + + mock_client = MockAsyncClient() + cache.cache_dict["test-key"] = mock_client + cache.ttl_dict["test-key"] = 0 + + # Should not raise even though there's no running event loop + cache._remove_key("test-key") + assert "test-key" not in cache.cache_dict + + +@pytest.mark.asyncio +async def test_background_tasks_cleaned_up_after_completion(): + """ + Test that completed close tasks are removed from the _background_tasks set. + """ + cache = LLMClientCache(max_size_in_memory=2) + + mock_client = MockAsyncClient() + cache.cache_dict["test-key"] = mock_client + cache.ttl_dict["test-key"] = 0 + + cache._remove_key("test-key") + # Let the task complete + await asyncio.sleep(0.1) + + assert len(cache._background_tasks) == 0 diff --git a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py index 11d6bb028d8..7d38a5cc80a 100644 --- a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py +++ b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py @@ -379,7 +379,8 @@ async def test_realtime_guardrail_blocks_prompt_injection(): """ Test that when a transcription event containing prompt injection arrives from the backend, a registered guardrail blocks it — sending a warning to the client - and NOT sending response.create to the backend. + and voicing the guardrail violation message via response.cancel + + conversation.item.create + response.create. """ import litellm from litellm.integrations.custom_guardrail import CustomGuardrail @@ -430,19 +431,36 @@ async def apply_guardrail(self, inputs, request_data, input_type, logging_obj=No streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj) await streaming.backend_to_client_send_messages() - # ASSERT 1: no response.create was sent to backend (injection blocked). + # ASSERT 1: the guardrail blocked the normal auto-response and instead + # injected a conversation.item.create + response.create to voice the + # violation message. There should be exactly ONE response.create (the + # guardrail-triggered one), preceded by a response.cancel and a + # conversation.item.create carrying the violation text. sent_to_backend = [ json.loads(c.args[0]) for c in backend_ws.send.call_args_list if c.args ] - response_creates = [ + response_cancels = [ + e for e in sent_to_backend if e.get("type") == "response.cancel" + ] + assert len(response_cancels) == 1, ( + f"Guardrail should send response.cancel, got: {response_cancels}" + ) + guardrail_items = [ e for e in sent_to_backend - if e.get("type") == "response.create" + if e.get("type") == "conversation.item.create" ] - assert len(response_creates) == 0, ( - f"Guardrail should prevent response.create for injected content, " - f"but got: {response_creates}" + assert len(guardrail_items) == 1, ( + f"Guardrail should inject a conversation.item.create with violation message, " + f"got: {guardrail_items}" + ) + response_creates = [ + e for e in sent_to_backend if e.get("type") == "response.create" + ] + assert len(response_creates) == 1, ( + f"Guardrail should send exactly one response.create to voice the violation, " + f"got: {response_creates}" ) # ASSERT 2: error event was sent directly to the client WebSocket @@ -595,14 +613,26 @@ async def apply_guardrail(self, inputs, request_data, input_type, logging_obj=No assert len(error_events) == 1, f"Expected one error event, got: {sent_texts}" assert error_events[0]["error"]["type"] == "guardrail_violation" - # ASSERT: blocked item was NOT forwarded to the backend + # ASSERT: the original blocked item was NOT forwarded to the backend. + # The guardrail handler injects its own conversation.item.create with + # the violation message — only that one should be present, not the + # original user message. sent_to_backend = [c.args[0] for c in backend_ws.send.call_args_list if c.args] forwarded_items = [ json.loads(m) for m in sent_to_backend if isinstance(m, str) and json.loads(m).get("type") == "conversation.item.create" ] - assert len(forwarded_items) == 0, ( - f"Blocked item should not be forwarded to backend, got: {forwarded_items}" + # Filter out guardrail-injected items (contain "Say exactly the following message") + original_items = [ + item for item in forwarded_items + if not any( + "Say exactly the following message" in c.get("text", "") + for c in item.get("item", {}).get("content", []) + if isinstance(c, dict) + ) + ] + assert len(original_items) == 0, ( + f"Blocked item should not be forwarded to backend, got: {original_items}" ) litellm.callbacks = [] # cleanup diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index c671d9b37b8..636e84fe796 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -39,14 +39,14 @@ def test_anthropic_experimental_pass_through_messages_handler(): def test_anthropic_experimental_pass_through_messages_handler_dynamic_api_key_and_api_base_and_custom_values(): """ - Test that api key, api base, and extra kwargs are forwarded to litellm.responses for Azure models. - Azure models are routed directly to the Responses API. + Test that api key, api base, and extra kwargs are forwarded to litellm.completion for Azure models. + Azure models are routed through chat/completions (not the Responses API). """ from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( anthropic_messages_handler, ) - with patch("litellm.responses", return_value="test-response") as mock_responses: + with patch("litellm.completion", return_value=MagicMock()) as mock_completion: try: anthropic_messages_handler( max_tokens=100, @@ -58,10 +58,10 @@ def test_anthropic_experimental_pass_through_messages_handler_dynamic_api_key_an ) except Exception as e: print(f"Error: {e}") - mock_responses.assert_called_once() - assert mock_responses.call_args.kwargs["api_key"] == "test-api-key" - assert mock_responses.call_args.kwargs["api_base"] == "test-api-base" - assert mock_responses.call_args.kwargs["custom_key"] == "custom_value" + mock_completion.assert_called_once() + assert mock_completion.call_args.kwargs["api_key"] == "test-api-key" + assert mock_completion.call_args.kwargs["api_base"] == "test-api-base" + assert mock_completion.call_args.kwargs["custom_key"] == "custom_value" def test_anthropic_experimental_pass_through_messages_handler_custom_llm_provider(): diff --git a/tests/test_litellm/llms/azure/realtime/test_azure_realtime_handler.py b/tests/test_litellm/llms/azure/realtime/test_azure_realtime_handler.py index 2a110c8f9a7..e9c5c9cfc1b 100644 --- a/tests/test_litellm/llms/azure/realtime/test_azure_realtime_handler.py +++ b/tests/test_litellm/llms/azure/realtime/test_azure_realtime_handler.py @@ -158,6 +158,27 @@ async def test_construct_url_v1_protocol(): assert url.count("/realtime") == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize("protocol", ["ga", "Ga", "gA", "V1", "v1", "GA"]) +async def test_construct_url_case_insensitive_protocol(protocol): + """ + Test that realtime_protocol matching is case-insensitive. + """ + from litellm.llms.azure.realtime.handler import AzureOpenAIRealtime + + handler = AzureOpenAIRealtime() + url = handler._construct_url( + api_base="https://my-endpoint.openai.azure.com", + model="gpt-realtime-deployment", + api_version=None, + realtime_protocol=protocol, + ) + + assert "/openai/v1/realtime?" in url + assert "model=gpt-realtime-deployment" in url + assert "api-version" not in url + + @pytest.mark.asyncio async def test_async_realtime_uses_ga_protocol_end_to_end(): """ @@ -212,6 +233,113 @@ async def __aexit__(self, exc_type, exc, tb): assert "deployment" not in called_url +@pytest.mark.asyncio +async def test_async_realtime_ga_without_api_version(): + """ + Test that GA/v1 protocol works without api_version (which is not needed for the GA path). + Fixes #22127: api_version check was unconditional, blocking GA path. + """ + from litellm.llms.azure.realtime.handler import AzureOpenAIRealtime + + handler = AzureOpenAIRealtime() + api_base = "https://my-endpoint.openai.azure.com" + api_key = "test-key" + model = "gpt-realtime-deployment" + + dummy_websocket = AsyncMock() + dummy_logging_obj = MagicMock() + mock_backend_ws = AsyncMock() + + class DummyAsyncContextManager: + def __init__(self, value): + self.value = value + async def __aenter__(self): + return self.value + async def __aexit__(self, exc_type, exc, tb): + return None + + with patch("websockets.connect", return_value=DummyAsyncContextManager(mock_backend_ws)) as mock_ws_connect, \ + patch("litellm.llms.azure.realtime.handler.RealTimeStreaming") as mock_realtime_streaming: + + mock_streaming_instance = MagicMock() + mock_realtime_streaming.return_value = mock_streaming_instance + mock_streaming_instance.bidirectional_forward = AsyncMock() + + # GA protocol with api_version=None should NOT raise ValueError + await handler.async_realtime( + model=model, + websocket=dummy_websocket, + logging_obj=dummy_logging_obj, + api_base=api_base, + api_key=api_key, + api_version=None, + realtime_protocol="GA", + ) + + called_url = mock_ws_connect.call_args[0][0] + assert "/openai/v1/realtime?" in called_url + assert "model=gpt-realtime-deployment" in called_url + assert "api-version" not in called_url + + +@pytest.mark.asyncio +async def test_async_realtime_beta_without_api_version_raises(): + """ + Test that beta protocol still requires api_version. + """ + from litellm.llms.azure.realtime.handler import AzureOpenAIRealtime + + handler = AzureOpenAIRealtime() + dummy_websocket = AsyncMock() + dummy_logging_obj = MagicMock() + + with pytest.raises(ValueError, match="api_version is required"): + await handler.async_realtime( + model="gpt-4o-realtime-preview", + websocket=dummy_websocket, + logging_obj=dummy_logging_obj, + api_base="https://my-endpoint.openai.azure.com", + api_key="test-key", + api_version=None, + realtime_protocol="beta", + ) + + +@pytest.mark.asyncio +async def test_realtime_protocol_env_var_fallback(): + """ + Test that LITELLM_AZURE_REALTIME_PROTOCOL env var is used as fallback. + Fixes #22127: no way to set realtime_protocol from config. + """ + from litellm.realtime_api.main import _arealtime + from litellm.types.router import GenericLiteLLMParams + + with patch.dict(os.environ, {"LITELLM_AZURE_REALTIME_PROTOCOL": "v1"}): + # Create a GenericLiteLLMParams without realtime_protocol + litellm_params = GenericLiteLLMParams() + # The env var should be picked up as fallback + realtime_protocol = ( + {}.get("realtime_protocol") + or litellm_params.get("realtime_protocol") + or os.environ.get("LITELLM_AZURE_REALTIME_PROTOCOL") + or "beta" + ) + assert realtime_protocol == "v1" + + +@pytest.mark.asyncio +async def test_realtime_protocol_from_litellm_params(): + """ + Test that realtime_protocol is read from litellm_params (config.yaml extra field). + Fixes #22127: realtime_protocol in litellm_params was not used. + """ + from litellm.types.router import GenericLiteLLMParams + + # Simulate config.yaml with realtime_protocol as an extra field + litellm_params = GenericLiteLLMParams(realtime_protocol="GA") + assert litellm_params.get("realtime_protocol") == "GA" + + @pytest.mark.asyncio async def test_async_realtime_default_maintains_backwards_compatibility(): """ diff --git a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py index f6d3d3c12f7..345f3ae7c5d 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py @@ -3493,3 +3493,262 @@ def test_no_thinking_param_does_not_error(self): drop_params=False, ) assert "thinking" not in result or result.get("thinking") is None + +def test_transform_response_with_both_json_tool_call_and_real_tool(): + """ + When Bedrock returns BOTH json_tool_call AND a real tool (get_weather), + only the real tool should remain in tool_calls. The json_tool_call should be filtered out. + Fixes https://github.com/BerriAI/litellm/issues/18381 + """ + from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig + from litellm.types.utils import ModelResponse + + response_json = { + "metrics": {"latencyMs": 200}, + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_json_001", + "name": "json_tool_call", + "input": { + "Current_Temperature": 62, + "Weather_Explanation": "Mild and cool.", + }, + } + }, + { + "toolUse": { + "toolUseId": "tooluse_weather_001", + "name": "get_weather", + "input": { + "location": "San Francisco, CA", + "unit": "fahrenheit", + }, + } + }, + ], + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + "cacheReadInputTokenCount": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokenCount": 0, + "cacheWriteInputTokens": 0, + }, + } + + class MockResponse: + def json(self): + return response_json + + @property + def text(self): + return json.dumps(response_json) + + config = AmazonConverseConfig() + model_response = ModelResponse() + optional_params = {"json_mode": True} + + result = config._transform_response( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + response=MockResponse(), + model_response=model_response, + stream=False, + logging_obj=None, + optional_params=optional_params, + api_key=None, + data=None, + messages=[], + encoding=None, + ) + + # Only real tool should remain + assert result.choices[0].message.tool_calls is not None + assert len(result.choices[0].message.tool_calls) == 1 + assert result.choices[0].message.tool_calls[0].function.name == "get_weather" + assert ( + result.choices[0].message.tool_calls[0].function.arguments + == '{"location": "San Francisco, CA", "unit": "fahrenheit"}' + ) + + # json_tool_call content should be preserved as message text + content = result.choices[0].message.content + assert content is not None + parsed = json.loads(content) + assert parsed["Current_Temperature"] == 62 + assert parsed["Weather_Explanation"] == "Mild and cool." + + +def test_transform_response_does_not_mutate_optional_params(): + """ + Verify that optional_params still contains json_mode after _transform_response. + Previously, .pop() was used which mutated the caller's dict. + """ + from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig + from litellm.types.utils import ModelResponse + + response_json = { + "metrics": {"latencyMs": 50}, + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_001", + "name": "json_tool_call", + "input": {"result": "ok"}, + } + } + ], + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15, + "cacheReadInputTokenCount": 0, + "cacheReadInputTokens": 0, + "cacheWriteInputTokenCount": 0, + "cacheWriteInputTokens": 0, + }, + } + + class MockResponse: + def json(self): + return response_json + + @property + def text(self): + return json.dumps(response_json) + + config = AmazonConverseConfig() + model_response = ModelResponse() + optional_params = {"json_mode": True, "other_key": "value"} + + config._transform_response( + model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + response=MockResponse(), + model_response=model_response, + stream=False, + logging_obj=None, + optional_params=optional_params, + api_key=None, + data=None, + messages=[], + encoding=None, + ) + + # json_mode should still be in optional_params (not popped) + assert "json_mode" in optional_params + assert optional_params["json_mode"] is True + assert optional_params["other_key"] == "value" + + +def test_streaming_filters_json_tool_call_with_real_tools(): + """ + Simulate streaming chunks where both json_tool_call and a real tool arrive. + Verify json_tool_call chunks are converted to text content while real tool + chunks pass through normally. + """ + from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder + from litellm.types.llms.bedrock import ( + ContentBlockDeltaEvent, + ContentBlockStartEvent, + ) + + decoder = AWSEventStreamDecoder(model="test-model", json_mode=True) + + # Chunk 1: json_tool_call start + json_start = ContentBlockStartEvent( + toolUse={ + "toolUseId": "tooluse_json_001", + "name": "json_tool_call", + } + ) + tool_use_1, _, _ = decoder._handle_converse_start_event(json_start) + # json_tool_call start should be suppressed (return None tool_use) + assert tool_use_1 is None + # tool_calls_index should NOT have been incremented + assert decoder.tool_calls_index is None + + # Chunk 2: json_tool_call delta — should become text, not tool_use + json_delta = ContentBlockDeltaEvent(toolUse={"input": '{"temp": 62}'}) + text_2, tool_use_2, _, _, _ = decoder._handle_converse_delta_event( + json_delta, index=0 + ) + assert text_2 == '{"temp": 62}' + assert tool_use_2 is None + + # Chunk 3: json_tool_call stop + stop_tool = decoder._handle_converse_stop_event(index=0) + assert stop_tool is None + # _current_tool_name should be reset + assert decoder._current_tool_name is None + + # Chunk 4: real tool start + real_start = ContentBlockStartEvent( + toolUse={ + "toolUseId": "tooluse_weather_001", + "name": "get_weather", + } + ) + tool_use_4, _, _ = decoder._handle_converse_start_event(real_start) + assert tool_use_4 is not None + assert tool_use_4["function"]["name"] == "get_weather" + assert decoder.tool_calls_index == 0 + + # Chunk 5: real tool delta + real_delta = ContentBlockDeltaEvent( + toolUse={"input": '{"location": "SF"}'} + ) + text_5, tool_use_5, _, _, _ = decoder._handle_converse_delta_event( + real_delta, index=1 + ) + assert text_5 == "" + assert tool_use_5 is not None + assert tool_use_5["function"]["arguments"] == '{"location": "SF"}' + + +def test_streaming_without_json_mode_passes_all_tools(): + """ + Verify backward compatibility: when json_mode=False, all tools + (including json_tool_call if present) pass through unchanged. + """ + from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder + from litellm.types.llms.bedrock import ( + ContentBlockDeltaEvent, + ContentBlockStartEvent, + ) + + decoder = AWSEventStreamDecoder(model="test-model", json_mode=False) + + # json_tool_call start — should pass through when json_mode=False + json_start = ContentBlockStartEvent( + toolUse={ + "toolUseId": "tooluse_json_001", + "name": "json_tool_call", + } + ) + tool_use, _, _ = decoder._handle_converse_start_event(json_start) + assert tool_use is not None + assert tool_use["function"]["name"] == "json_tool_call" + assert decoder.tool_calls_index == 0 + + # json_tool_call delta — should be a tool_use, not text + json_delta = ContentBlockDeltaEvent(toolUse={"input": '{"data": 1}'}) + text, tool_use_delta, _, _, _ = decoder._handle_converse_delta_event( + json_delta, index=0 + ) + assert text == "" + assert tool_use_delta is not None + assert tool_use_delta["function"]["arguments"] == '{"data": 1}' + diff --git a/tests/test_litellm/llms/openai/image_generation/test_openai_image_generation_extra_headers.py b/tests/test_litellm/llms/openai/image_generation/test_openai_image_generation_extra_headers.py new file mode 100644 index 00000000000..33db9d33c1c --- /dev/null +++ b/tests/test_litellm/llms/openai/image_generation/test_openai_image_generation_extra_headers.py @@ -0,0 +1,212 @@ +""" +Unit tests for extra_headers propagation in OpenAI image generation. + +Verifies that extra_headers passed to litellm.image_generation() / +litellm.aimage_generation() are forwarded to the OpenAI API client as +extra_headers in the images.generate() call. +""" + +import os +import sys +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +sys.path.insert(0, os.path.abspath("../../../../..")) + +from litellm.llms.openai.openai import OpenAIChatCompletion + + +@pytest.fixture +def openai_chat_completions(): + return OpenAIChatCompletion() + + +@pytest.fixture +def mock_logging_obj(): + logging_obj = MagicMock() + logging_obj.pre_call = MagicMock() + logging_obj.post_call = MagicMock() + return logging_obj + + +class TestImageGenerationExtraHeaders: + """Test that extra_headers are properly injected into OpenAI image generation calls.""" + + def test_sync_image_generation_with_headers( + self, openai_chat_completions, mock_logging_obj + ): + """Sync image_generation should pass headers as extra_headers to images.generate().""" + mock_image_data = MagicMock() + mock_image_data.model_dump.return_value = { + "created": 1700000000, + "data": [{"url": "https://example.com/image.png"}], + } + + mock_openai_client = MagicMock() + mock_openai_client.images.generate.return_value = mock_image_data + mock_openai_client.api_key = "test-key" + mock_openai_client._base_url._uri_reference = "https://api.openai.com" + + test_headers = {"cf-aig-authorization": "Bearer custom-token"} + + openai_chat_completions.image_generation( + model="dall-e-3", + prompt="A white cat", + timeout=60.0, + optional_params={}, + logging_obj=mock_logging_obj, + api_key="test-key", + headers=test_headers, + client=mock_openai_client, + ) + + _, kwargs = mock_openai_client.images.generate.call_args + assert kwargs.get("extra_headers") == test_headers + + def test_sync_image_generation_without_headers( + self, openai_chat_completions, mock_logging_obj + ): + """Sync image_generation without headers should not inject extra_headers.""" + mock_image_data = MagicMock() + mock_image_data.model_dump.return_value = { + "created": 1700000000, + "data": [{"url": "https://example.com/image.png"}], + } + + mock_openai_client = MagicMock() + mock_openai_client.images.generate.return_value = mock_image_data + mock_openai_client.api_key = "test-key" + mock_openai_client._base_url._uri_reference = "https://api.openai.com" + + openai_chat_completions.image_generation( + model="dall-e-3", + prompt="A white cat", + timeout=60.0, + optional_params={}, + logging_obj=mock_logging_obj, + api_key="test-key", + client=mock_openai_client, + ) + + _, kwargs = mock_openai_client.images.generate.call_args + assert "extra_headers" not in kwargs + + @pytest.mark.asyncio + async def test_async_image_generation_with_headers( + self, openai_chat_completions, mock_logging_obj + ): + """Async aimage_generation should pass headers as extra_headers to images.generate().""" + mock_image_data = MagicMock() + mock_image_data.model_dump.return_value = { + "created": 1700000000, + "data": [{"url": "https://example.com/image.png"}], + } + + mock_openai_client = MagicMock() + mock_openai_client.images.generate = AsyncMock(return_value=mock_image_data) + mock_openai_client.api_key = "test-key" + + test_headers = {"cf-aig-authorization": "Bearer custom-token"} + + await openai_chat_completions.aimage_generation( + prompt="A white cat", + data={"model": "dall-e-3", "prompt": "A white cat"}, + model_response=MagicMock(), + timeout=60.0, + logging_obj=mock_logging_obj, + api_key="test-key", + headers=test_headers, + client=mock_openai_client, + ) + + _, kwargs = mock_openai_client.images.generate.call_args + assert kwargs.get("extra_headers") == test_headers + + @pytest.mark.asyncio + async def test_async_image_generation_without_headers( + self, openai_chat_completions, mock_logging_obj + ): + """Async aimage_generation without headers should not inject extra_headers.""" + mock_image_data = MagicMock() + mock_image_data.model_dump.return_value = { + "created": 1700000000, + "data": [{"url": "https://example.com/image.png"}], + } + + mock_openai_client = MagicMock() + mock_openai_client.images.generate = AsyncMock(return_value=mock_image_data) + mock_openai_client.api_key = "test-key" + + await openai_chat_completions.aimage_generation( + prompt="A white cat", + data={"model": "dall-e-3", "prompt": "A white cat"}, + model_response=MagicMock(), + timeout=60.0, + logging_obj=mock_logging_obj, + api_key="test-key", + client=mock_openai_client, + ) + + _, kwargs = mock_openai_client.images.generate.call_args + assert "extra_headers" not in kwargs + + def test_sync_image_generation_forwards_headers_to_async( + self, openai_chat_completions, mock_logging_obj + ): + """When aimg_generation=True, image_generation should forward headers to aimage_generation.""" + with patch.object( + openai_chat_completions, "aimage_generation" + ) as mock_aimage_gen: + mock_aimage_gen.return_value = MagicMock() + + test_headers = {"x-custom-header": "value"} + + openai_chat_completions.image_generation( + model="dall-e-3", + prompt="A white cat", + timeout=60.0, + optional_params={}, + logging_obj=mock_logging_obj, + api_key="test-key", + aimg_generation=True, + headers=test_headers, + ) + + mock_aimage_gen.assert_called_once() + call_kwargs = mock_aimage_gen.call_args[1] + assert call_kwargs["headers"] == test_headers + + +class TestImageGenerationEntryPointHeaders: + """Test that litellm.image_generation() passes headers through to the OpenAI provider.""" + + @pytest.mark.asyncio + async def test_extra_headers_reach_openai_provider(self): + """End-to-end: extra_headers from litellm.aimage_generation() reach OpenAI images.generate().""" + import litellm + + mock_image_data = MagicMock() + mock_image_data.model_dump.return_value = { + "created": 1700000000, + "data": [{"url": "https://example.com/image.png"}], + } + + mock_openai_client = MagicMock() + mock_openai_client.images.generate = AsyncMock(return_value=mock_image_data) + mock_openai_client.api_key = "test-key" + mock_openai_client._base_url._uri_reference = "https://api.openai.com" + + test_headers = {"cf-aig-authorization": "Bearer my-secret"} + + await litellm.aimage_generation( + model="dall-e-3", + prompt="A white cat", + extra_headers=test_headers, + client=mock_openai_client, + api_key="test-key", + ) + + mock_openai_client.images.generate.assert_called_once() + _, kwargs = mock_openai_client.images.generate.call_args + assert kwargs.get("extra_headers") == test_headers diff --git a/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py index 8cec3538077..fcf8f048190 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py @@ -7,6 +7,7 @@ from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.agent_endpoints import endpoints as agent_endpoints from litellm.proxy.agent_endpoints.endpoints import ( + _check_agent_management_permission, get_agent_daily_activity, router, user_api_key_auth, @@ -47,6 +48,16 @@ def _sample_agent_response( ) +def _make_app_with_role(role: LitellmUserRoles) -> TestClient: + """Create a TestClient where the auth dependency returns the given role.""" + test_app = FastAPI() + test_app.include_router(router) + test_app.dependency_overrides[user_api_key_auth] = lambda: UserAPIKeyAuth( + user_id="test-user", user_role=role + ) + return TestClient(test_app) + + app = FastAPI() app.include_router(router) app.dependency_overrides[user_api_key_auth] = lambda: UserAPIKeyAuth( @@ -258,3 +269,173 @@ async def test_get_agent_daily_activity_with_agent_names(monkeypatch): "agent-1": {"agent_name": "First Agent"}, "agent-2": {"agent_name": "Second Agent"}, } + + +# ---------- RBAC enforcement tests ---------- + + +class TestAgentRBACInternalUser: + """Internal users should be able to read agents but not create/update/delete.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + self.internal_client = _make_app_with_role(LitellmUserRoles.INTERNAL_USER) + self.mock_registry = MagicMock() + monkeypatch.setattr(agent_endpoints, "AGENT_REGISTRY", self.mock_registry) + + def test_should_allow_internal_user_to_list_agents(self, monkeypatch): + self.mock_registry.get_agent_list = MagicMock(return_value=[]) + resp = self.internal_client.get( + "/v1/agents", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 200 + + def test_should_allow_internal_user_to_get_agent_by_id(self, monkeypatch): + self.mock_registry.get_agent_by_id = MagicMock( + return_value=_sample_agent_response() + ) + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: + resp = self.internal_client.get( + "/v1/agents/agent-123", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 200 + + def test_should_block_internal_user_from_creating_agent(self): + resp = self.internal_client.post( + "/v1/agents", + json=_sample_agent_config(), + headers={"Authorization": "Bearer k"}, + ) + assert resp.status_code == 403 + assert "Only proxy admins" in resp.json()["detail"]["error"] + + def test_should_block_internal_user_from_updating_agent(self): + resp = self.internal_client.put( + "/v1/agents/agent-123", + json=_sample_agent_config(), + headers={"Authorization": "Bearer k"}, + ) + assert resp.status_code == 403 + + def test_should_block_internal_user_from_patching_agent(self): + resp = self.internal_client.patch( + "/v1/agents/agent-123", + json={"agent_name": "new-name"}, + headers={"Authorization": "Bearer k"}, + ) + assert resp.status_code == 403 + + def test_should_block_internal_user_from_deleting_agent(self): + resp = self.internal_client.delete( + "/v1/agents/agent-123", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 403 + + +class TestAgentRBACInternalUserViewOnly: + """View-only internal users should only be able to read agents.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + self.viewer_client = _make_app_with_role( + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + ) + self.mock_registry = MagicMock() + monkeypatch.setattr(agent_endpoints, "AGENT_REGISTRY", self.mock_registry) + + def test_should_allow_view_only_user_to_list_agents(self): + self.mock_registry.get_agent_list = MagicMock(return_value=[]) + resp = self.viewer_client.get( + "/v1/agents", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 200 + + def test_should_block_view_only_user_from_creating_agent(self): + resp = self.viewer_client.post( + "/v1/agents", + json=_sample_agent_config(), + headers={"Authorization": "Bearer k"}, + ) + assert resp.status_code == 403 + + def test_should_block_view_only_user_from_deleting_agent(self): + resp = self.viewer_client.delete( + "/v1/agents/agent-123", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 403 + + +class TestAgentRBACProxyAdmin: + """Proxy admins should have full CRUD access to agents.""" + + @pytest.fixture(autouse=True) + def _setup(self, monkeypatch): + self.admin_client = _make_app_with_role(LitellmUserRoles.PROXY_ADMIN) + self.mock_registry = MagicMock() + monkeypatch.setattr(agent_endpoints, "AGENT_REGISTRY", self.mock_registry) + + def test_should_allow_admin_to_create_agent(self, monkeypatch): + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: + self.mock_registry.get_agent_by_name = MagicMock(return_value=None) + self.mock_registry.add_agent_to_db = AsyncMock( + return_value=_sample_agent_response() + ) + self.mock_registry.register_agent = MagicMock() + resp = self.admin_client.post( + "/v1/agents", + json=_sample_agent_config(), + headers={"Authorization": "Bearer k"}, + ) + assert resp.status_code == 200 + + def test_should_allow_admin_to_delete_agent(self): + existing = { + "agent_id": "agent-123", + "agent_name": "Existing Agent", + "agent_card_params": _sample_agent_card_params(), + } + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: + mock_prisma.db.litellm_agentstable.find_unique = AsyncMock( + return_value=existing + ) + self.mock_registry.delete_agent_from_db = AsyncMock() + self.mock_registry.deregister_agent = MagicMock() + resp = self.admin_client.delete( + "/v1/agents/agent-123", headers={"Authorization": "Bearer k"} + ) + assert resp.status_code == 200 + + +class TestCheckAgentManagementPermission: + """Unit tests for the _check_agent_management_permission helper.""" + + def test_should_allow_proxy_admin(self): + auth = UserAPIKeyAuth( + user_id="admin", user_role=LitellmUserRoles.PROXY_ADMIN + ) + _check_agent_management_permission(auth) + + @pytest.mark.parametrize( + "role", + [ + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + ], + ) + def test_should_block_non_admin_roles(self, role): + from fastapi import HTTPException + + auth = UserAPIKeyAuth(user_id="user", user_role=role) + with pytest.raises(HTTPException) as exc_info: + _check_agent_management_permission(auth) + assert exc_info.value.status_code == 403 + + +class TestAgentRoutesIncludesAgentIdPattern: + """Verify that agent_routes includes the {agent_id} pattern for route access.""" + + def test_should_include_agent_id_pattern(self): + from litellm.proxy._types import LiteLLMRoutes + + assert "/v1/agents/{agent_id}" in LiteLLMRoutes.agent_routes.value diff --git a/tests/test_litellm/proxy/auth/test_handle_jwt.py b/tests/test_litellm/proxy/auth/test_handle_jwt.py index b56d13bb932..8418dde5e9c 100644 --- a/tests/test_litellm/proxy/auth/test_handle_jwt.py +++ b/tests/test_litellm/proxy/auth/test_handle_jwt.py @@ -1485,4 +1485,273 @@ async def test_get_objects_resolves_org_by_name(): ) +# --------------------------------------------------------------------------- +# Fix 1: OIDC discovery URL resolution +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_resolve_jwks_url_passthrough_for_direct_jwks_url(): + """Non-discovery URLs are returned unchanged.""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.caching.dual_cache import DualCache + + handler = JWTHandler() + handler.update_environment( + prisma_client=None, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth(), + ) + url = "https://login.microsoftonline.com/common/discovery/keys" + result = await handler._resolve_jwks_url(url) + assert result == url + + +@pytest.mark.asyncio +async def test_resolve_jwks_url_resolves_oidc_discovery_document(): + """ + A .well-known/openid-configuration URL should be fetched and its + jwks_uri returned. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.caching.dual_cache import DualCache + + handler = JWTHandler() + cache = DualCache() + handler.update_environment( + prisma_client=None, + user_api_key_cache=cache, + litellm_jwtauth=LiteLLM_JWTAuth(), + ) + + discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"jwks_uri": jwks_url, "issuer": "https://..."} + + with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response) as mock_get: + result = await handler._resolve_jwks_url(discovery_url) + + assert result == jwks_url + mock_get.assert_called_once_with(discovery_url) + + +@pytest.mark.asyncio +async def test_resolve_jwks_url_caches_resolved_jwks_uri(): + """Resolved jwks_uri is cached — second call does not hit the network.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.caching.dual_cache import DualCache + + handler = JWTHandler() + cache = DualCache() + handler.update_environment( + prisma_client=None, + user_api_key_cache=cache, + litellm_jwtauth=LiteLLM_JWTAuth(), + ) + + discovery_url = "https://login.microsoftonline.com/tenant/.well-known/openid-configuration" + jwks_url = "https://login.microsoftonline.com/tenant/discovery/keys" + + mock_response = MagicMock() + mock_response.json.return_value = {"jwks_uri": jwks_url} + + with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response) as mock_get: + first = await handler._resolve_jwks_url(discovery_url) + second = await handler._resolve_jwks_url(discovery_url) + + assert first == jwks_url + assert second == jwks_url + # Network should only be hit once + assert mock_get.call_count == 1 + + +@pytest.mark.asyncio +async def test_resolve_jwks_url_raises_if_no_jwks_uri_in_discovery_doc(): + """Raise a helpful error if the discovery document has no jwks_uri.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.caching.dual_cache import DualCache + + handler = JWTHandler() + handler.update_environment( + prisma_client=None, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth(), + ) + + discovery_url = "https://example.com/.well-known/openid-configuration" + mock_response = MagicMock() + mock_response.json.return_value = {"issuer": "https://example.com"} # no jwks_uri + + with patch.object(handler.http_handler, "get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(Exception, match="jwks_uri"): + await handler._resolve_jwks_url(discovery_url) + + +# --------------------------------------------------------------------------- +# Fix 2: handle array values in team_id_jwt_field (e.g. AAD "roles" claim) +# --------------------------------------------------------------------------- + + +def _make_jwt_handler(team_id_jwt_field: str) -> JWTHandler: + from litellm.caching.dual_cache import DualCache + + handler = JWTHandler() + handler.update_environment( + prisma_client=None, + user_api_key_cache=DualCache(), + litellm_jwtauth=LiteLLM_JWTAuth(team_id_jwt_field=team_id_jwt_field), + ) + return handler + + +def test_get_team_id_returns_first_element_when_roles_is_list(): + """ + AAD sends roles as a list. get_team_id() must return the first string + element rather than the raw list (which would later crash with + 'unhashable type: list'). + """ + handler = _make_jwt_handler("roles") + token = {"oid": "user-oid", "roles": ["team1"]} + result = handler.get_team_id(token=token, default_value=None) + assert result == "team1" + + +def test_get_team_id_returns_first_element_from_multi_value_roles_list(): + """When roles has multiple entries, the first one is used.""" + handler = _make_jwt_handler("roles") + token = {"roles": ["team2", "team1"]} + result = handler.get_team_id(token=token, default_value=None) + assert result == "team2" + + +def test_get_team_id_returns_default_when_roles_list_is_empty(): + """Empty list should fall back to default_value.""" + handler = _make_jwt_handler("roles") + token = {"roles": []} + result = handler.get_team_id(token=token, default_value="fallback") + assert result == "fallback" + + +def test_get_team_id_still_works_with_string_value(): + """String values (non-array) continue to work as before.""" + handler = _make_jwt_handler("appid") + token = {"appid": "my-team-id"} + result = handler.get_team_id(token=token, default_value=None) + assert result == "my-team-id" + + +def test_get_team_id_list_result_is_hashable(): + """ + The value returned by get_team_id() must be hashable so it can be + added to a set (the operation that previously crashed). + """ + handler = _make_jwt_handler("roles") + token = {"roles": ["team1"]} + result = handler.get_team_id(token=token, default_value=None) + # This must not raise TypeError + s: set = set() + s.add(result) + assert "team1" in s + + +# --------------------------------------------------------------------------- +# Fix 3: helpful error message for dot-notation array indexing (roles.0) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_find_and_validate_specific_team_id_hints_bracket_notation(): + """ + When team_id_jwt_field is set to 'roles.0' (unsupported dot-notation for + array indexing) and no team is found, the exception message should suggest + using 'roles' instead (and explain LiteLLM auto-unwraps list values). + """ + from unittest.mock import MagicMock + + from litellm.caching.dual_cache import DualCache + + handler = _make_jwt_handler("roles.0") + # token has roles as a list — dot-notation won't find anything + token = {"roles": ["team1"]} + + with pytest.raises(Exception) as exc_info: + await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler=handler, + jwt_valid_token=token, + prisma_client=None, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=MagicMock(), + ) + + error_msg = str(exc_info.value) + # Should mention the bad field name and suggest the fix + assert "roles.0" in error_msg, f"Expected field name in: {error_msg}" + assert "roles" in error_msg and "list" in error_msg, ( + f"Expected hint about using 'roles' instead: {error_msg}" + ) + + +@pytest.mark.asyncio +async def test_find_and_validate_specific_team_id_hints_bracket_index_notation(): + """ + When team_id_jwt_field is set to 'roles[0]' (bracket indexing, also unsupported + in get_nested_value) the error message should suggest using 'roles' instead. + """ + from unittest.mock import MagicMock + + from litellm.caching.dual_cache import DualCache + + handler = _make_jwt_handler("roles[0]") + token = {"roles": ["team1"]} + + with pytest.raises(Exception) as exc_info: + await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler=handler, + jwt_valid_token=token, + prisma_client=None, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=MagicMock(), + ) + + error_msg = str(exc_info.value) + assert "roles[0]" in error_msg, f"Expected field name in: {error_msg}" + assert "roles" in error_msg and "list" in error_msg, ( + f"Expected hint about using 'roles' instead: {error_msg}" + ) + + +@pytest.mark.asyncio +async def test_find_and_validate_specific_team_id_no_hint_for_valid_field(): + """ + When team_id_jwt_field is a normal field name (no dot-notation) the + error message should not contain a spurious bracket-notation hint. + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.caching.dual_cache import DualCache + + handler = _make_jwt_handler("appid") + token = {} # no appid — triggers the "no team found" path + + with pytest.raises(Exception) as exc_info: + await JWTAuthManager.find_and_validate_specific_team_id( + jwt_handler=handler, + jwt_valid_token=token, + prisma_client=None, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=MagicMock(), + ) + + error_msg = str(exc_info.value) + assert "Hint" not in error_msg diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 7565e901ecd..3a2e2f93949 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -6070,6 +6070,121 @@ async def test_build_key_filter_admin_all_member_overlap(): ) +@pytest.mark.asyncio +async def test_build_key_filter_project_id(): + """ + Test that project_id is applied as a global AND condition, narrowing all results + to keys that belong to the specified project. + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _build_key_filter_conditions, + ) + + user_id = "user-123" + project_id = "proj-abc" + + where = _build_key_filter_conditions( + user_id=user_id, + team_id=None, + organization_id=None, + key_alias=None, + key_hash=None, + exclude_team_id=None, + admin_team_ids=None, + member_team_ids=None, + include_created_by_keys=False, + project_id=project_id, + ) + + # Should be wrapped in a top-level AND for the project_id filter + assert "AND" in where + and_parts = where["AND"] + assert len(and_parts) == 2 + + # Second part of AND should be the project_id filter + assert {"project_id": project_id} in and_parts + + +@pytest.mark.asyncio +async def test_build_key_filter_access_group_id(): + """ + Test that access_group_id is applied as a global AND condition using hasSome, + narrowing results to keys whose access_group_ids array contains the given ID. + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _build_key_filter_conditions, + ) + + user_id = "user-123" + access_group_id = "ag-xyz" + + where = _build_key_filter_conditions( + user_id=user_id, + team_id=None, + organization_id=None, + key_alias=None, + key_hash=None, + exclude_team_id=None, + admin_team_ids=None, + member_team_ids=None, + include_created_by_keys=False, + access_group_id=access_group_id, + ) + + # Should be wrapped in a top-level AND for the access_group_id filter + assert "AND" in where + and_parts = where["AND"] + assert len(and_parts) == 2 + + # Second part of AND should use hasSome for the array field + assert {"access_group_ids": {"hasSome": [access_group_id]}} in and_parts + + +@pytest.mark.asyncio +async def test_build_key_filter_project_id_and_access_group_id(): + """ + Test that project_id and access_group_id stack correctly when both are provided. + Both should be applied as AND conditions, narrowing results to keys that match both. + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _build_key_filter_conditions, + ) + + user_id = "user-123" + project_id = "proj-abc" + access_group_id = "ag-xyz" + + where = _build_key_filter_conditions( + user_id=user_id, + team_id=None, + organization_id=None, + key_alias=None, + key_hash=None, + exclude_team_id=None, + admin_team_ids=None, + member_team_ids=None, + include_created_by_keys=False, + project_id=project_id, + access_group_id=access_group_id, + ) + + # After project_id: {"AND": [visibility_where, {"project_id": ...}]} + # After access_group_id: {"AND": [above, {"access_group_ids": ...}]} + assert "AND" in where + outer_and = where["AND"] + assert len(outer_and) == 2 + + # The access_group_ids filter is the outermost AND + access_group_filter = outer_and[1] + assert access_group_filter == {"access_group_ids": {"hasSome": [access_group_id]}} + + # The project_id filter is nested one level in + inner = outer_and[0] + assert "AND" in inner + inner_and = inner["AND"] + assert {"project_id": project_id} in inner_and + + @pytest.mark.asyncio async def test_get_member_team_ids(): """ diff --git a/tests/test_litellm/proxy/middleware/test_in_flight_requests_middleware.py b/tests/test_litellm/proxy/middleware/test_in_flight_requests_middleware.py new file mode 100644 index 00000000000..830bca49936 --- /dev/null +++ b/tests/test_litellm/proxy/middleware/test_in_flight_requests_middleware.py @@ -0,0 +1,98 @@ +""" +Tests for InFlightRequestsMiddleware. + +Verifies that in_flight_requests is incremented during a request and +decremented after it completes, including on errors. +""" +import asyncio + +import pytest +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route +from starlette.testclient import TestClient + +from litellm.proxy.middleware.in_flight_requests_middleware import ( + InFlightRequestsMiddleware, + get_in_flight_requests, +) + + +@pytest.fixture(autouse=True) +def reset_state(): + """Reset class-level state between tests.""" + InFlightRequestsMiddleware._in_flight = 0 + yield + InFlightRequestsMiddleware._in_flight = 0 + + +def _make_app(handler): + from starlette.applications import Starlette + + app = Starlette(routes=[Route("/", handler)]) + app.add_middleware(InFlightRequestsMiddleware) + return app + + +# ── Structure ───────────────────────────────────────────────────────────────── + + +def test_is_not_base_http_middleware(): + """Must be pure ASGI — BaseHTTPMiddleware causes streaming degradation.""" + assert not issubclass(InFlightRequestsMiddleware, BaseHTTPMiddleware) + + +def test_has_asgi_call_protocol(): + assert "__call__" in InFlightRequestsMiddleware.__dict__ + + +# ── Counter behaviour ───────────────────────────────────────────────────────── + + +def test_counter_zero_at_start(): + assert get_in_flight_requests() == 0 + + +def test_counter_increments_inside_handler(): + captured = [] + + async def handler(request: Request) -> Response: + captured.append(InFlightRequestsMiddleware.get_count()) + return JSONResponse({}) + + TestClient(_make_app(handler)).get("/") + assert captured == [1] + + +def test_counter_returns_to_zero_after_request(): + async def handler(request: Request) -> Response: + return JSONResponse({}) + + TestClient(_make_app(handler)).get("/") + assert get_in_flight_requests() == 0 + + +def test_counter_decrements_after_error(): + """Counter must reach 0 even when the handler raises.""" + + async def handler(request: Request) -> Response: + return Response("boom", status_code=500) + + TestClient(_make_app(handler)).get("/") + assert get_in_flight_requests() == 0 + + +def test_non_http_scopes_not_counted(): + """Lifespan / websocket scopes must not touch the counter.""" + + class _InnerApp: + async def __call__(self, scope, receive, send): + pass + + mw = InFlightRequestsMiddleware(_InnerApp()) + + asyncio.get_event_loop().run_until_complete( + mw({"type": "lifespan"}, None, None) # type: ignore[arg-type] + ) + assert get_in_flight_requests() == 0 diff --git a/tests/test_litellm/proxy/test_prometheus_cleanup.py b/tests/test_litellm/proxy/test_prometheus_cleanup.py index 276f2b592db..b3d785f1133 100644 --- a/tests/test_litellm/proxy/test_prometheus_cleanup.py +++ b/tests/test_litellm/proxy/test_prometheus_cleanup.py @@ -10,7 +10,7 @@ import pytest -from litellm.proxy.prometheus_cleanup import wipe_directory +from litellm.proxy.prometheus_cleanup import mark_worker_exit, wipe_directory from litellm.proxy.proxy_cli import ProxyInitializationHelpers @@ -23,6 +23,35 @@ def test_deletes_all_db_files(self, tmp_path): assert not list(tmp_path.glob("*.db")) +class TestMarkWorkerExit: + def test_calls_mark_process_dead_when_env_set(self, tmp_path): + with patch.dict(os.environ, {"PROMETHEUS_MULTIPROC_DIR": str(tmp_path)}): + with patch( + "prometheus_client.multiprocess.mark_process_dead" + ) as mock_mark: + mark_worker_exit(12345) + mock_mark.assert_called_once_with(12345) + + def test_noop_when_env_not_set(self): + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None) + with patch( + "prometheus_client.multiprocess.mark_process_dead" + ) as mock_mark: + mark_worker_exit(12345) + mock_mark.assert_not_called() + + def test_exception_is_caught_and_logged(self, tmp_path): + with patch.dict(os.environ, {"PROMETHEUS_MULTIPROC_DIR": str(tmp_path)}): + with patch( + "prometheus_client.multiprocess.mark_process_dead", + side_effect=FileNotFoundError("gone"), + ) as mock_mark: + # Should not raise + mark_worker_exit(99) + mock_mark.assert_called_once_with(99) + + class TestMaybeSetupPrometheusMultiprocDir: def test_respects_existing_env_var(self, tmp_path): """When PROMETHEUS_MULTIPROC_DIR is already set, don't override it.""" diff --git a/tests/test_litellm/proxy/test_proxy_types.py b/tests/test_litellm/proxy/test_proxy_types.py index 0e47134478b..ae2b7bbf24c 100644 --- a/tests/test_litellm/proxy/test_proxy_types.py +++ b/tests/test_litellm/proxy/test_proxy_types.py @@ -45,3 +45,27 @@ def test_audit_log_masking(): json_before_value = json.loads(audit_log.before_value) assert json_before_value["token"] == "1q2132r222" assert json_before_value["key"] == "sk-1*****7890" + + +def test_internal_jobs_user_has_proxy_admin_role(): + """ + Test that the internal jobs system user has PROXY_ADMIN role. + + This is critical for key rotation to work properly. The system user needs + PROXY_ADMIN role to bypass team permission checks in + TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint() + + Regression test for: https://github.com/BerriAI/litellm/pull/21896 + """ + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + + # Get the system user used for internal jobs like key rotation + system_user = UserAPIKeyAuth.get_litellm_internal_jobs_user_api_key_auth() + + # Verify the system user has PROXY_ADMIN role + assert system_user.user_role == LitellmUserRoles.PROXY_ADMIN + + # Verify other expected properties + assert system_user.user_id == "system" + assert system_user.team_id == "system" + assert system_user.team_alias == "system" diff --git a/tests/test_litellm/responses/mcp/test_chat_completions_handler.py b/tests/test_litellm/responses/mcp/test_chat_completions_handler.py index a238531d2e0..3ba41705733 100644 --- a/tests/test_litellm/responses/mcp/test_chat_completions_handler.py +++ b/tests/test_litellm/responses/mcp/test_chat_completions_handler.py @@ -90,6 +90,72 @@ def mock_extract(**kwargs): assert captured_secret_fields["value"] == {"api_key": "value"} +@pytest.mark.asyncio +async def test_acompletion_with_mcp_passes_mcp_server_auth_headers_to_process_tools( + monkeypatch, +): + """ + Test that MCP auth headers extracted from secret_fields (e.g. x-mcp-linear_config-authorization) + are passed to _process_mcp_tools_without_openai_transform for dynamic auth when fetching tools. + """ + tools = [{"type": "mcp", "server_url": "litellm_proxy"}] + mock_acompletion = AsyncMock(return_value="ok") + + captured_process_kwargs = {} + + async def mock_process(**kwargs): + captured_process_kwargs.update(kwargs) + return ([], {}) + + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_should_use_litellm_mcp_gateway", + staticmethod(lambda t: True), + ) + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_parse_mcp_tools", + staticmethod(lambda t: (t, [])), + ) + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_process_mcp_tools_without_openai_transform", + mock_process, + ) + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_transform_mcp_tools_to_openai", + staticmethod(lambda *_, **__: ["openai-tool"]), + ) + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_should_auto_execute_tools", + staticmethod(lambda **_: False), + ) + + # secret_fields with raw_headers containing MCP auth - extract_mcp_headers_from_request + # will parse these and pass to _process_mcp_tools_without_openai_transform + secret_fields = { + "raw_headers": { + "x-mcp-linear_config-authorization": "Bearer linear-token", + }, + } + + with patch("litellm.acompletion", mock_acompletion): + await acompletion_with_mcp( + model="test-model", + messages=[], + tools=tools, + secret_fields=secret_fields, + ) + + assert "mcp_server_auth_headers" in captured_process_kwargs + mcp_server_auth_headers = captured_process_kwargs["mcp_server_auth_headers"] + assert mcp_server_auth_headers is not None + assert "linear_config" in mcp_server_auth_headers + assert mcp_server_auth_headers["linear_config"]["Authorization"] == "Bearer linear-token" + + @pytest.mark.asyncio async def test_acompletion_with_mcp_auto_exec_performs_follow_up(monkeypatch): from litellm.utils import CustomStreamWrapper diff --git a/ui/litellm-dashboard/package-lock.json b/ui/litellm-dashboard/package-lock.json index fc2aa1599d3..503ed4a62a8 100644 --- a/ui/litellm-dashboard/package-lock.json +++ b/ui/litellm-dashboard/package-lock.json @@ -90,7 +90,6 @@ "version": "5.2.0", "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", "integrity": "sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==", - "dev": true, "license": "MIT", "engines": { "node": ">=10" @@ -1772,7 +1771,6 @@ "version": "0.3.13", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", - "dev": true, "license": "MIT", "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.0", @@ -1783,7 +1781,6 @@ "version": "3.1.2", "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", - "dev": true, "license": "MIT", "engines": { "node": ">=6.0.0" @@ -1793,14 +1790,12 @@ "version": "1.5.5", "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", - "dev": true, "license": "MIT" }, "node_modules/@jridgewell/trace-mapping": { "version": "0.3.31", "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", - "dev": true, "license": "MIT", "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -1978,7 +1973,6 @@ "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", - "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.stat": "2.0.5", @@ -1992,7 +1986,6 @@ "version": "2.0.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", - "dev": true, "license": "MIT", "engines": { "node": ">= 8" @@ -2002,7 +1995,6 @@ "version": "1.2.8", "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", - "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.scandir": "2.1.5", @@ -2326,7 +2318,7 @@ "version": "1.58.1", "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.58.1.tgz", "integrity": "sha512-6LdVIUERWxQMmUSSQi0I53GgCBYgM2RpGngCPY7hSeju+VrKjq3lvs7HpJoPbDiY5QM5EYRtRX5fvrinnMAz3w==", - "dev": true, + "devOptional": true, "license": "Apache-2.0", "dependencies": { "playwright": "1.58.1" @@ -3431,14 +3423,12 @@ "version": "15.7.15", "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz", "integrity": "sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==", - "dev": true, "license": "MIT" }, "node_modules/@types/react": { "version": "18.2.48", "resolved": "https://registry.npmjs.org/@types/react/-/react-18.2.48.tgz", "integrity": "sha512-qboRCl6Ie70DQQG9hhNREz81jqC1cs9EVNcjQ1AU+jH6NFfSAhVVbrrY/+nSF+Bsk4AOwm9Qa61InvMCyV+H3w==", - "dev": true, "license": "MIT", "dependencies": { "@types/prop-types": "*", @@ -3480,7 +3470,6 @@ "version": "0.26.0", "resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.26.0.tgz", "integrity": "sha512-WFHp9YUJQ6CKshqoC37iOlHnQSmxNc795UhB26CyBBttrN9svdIrUjl/NjnNmfcwtncN0h/0PPAFWv9ovP8mLA==", - "dev": true, "license": "MIT" }, "node_modules/@types/unist": { @@ -4341,14 +4330,12 @@ "version": "1.3.0", "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", - "dev": true, "license": "MIT" }, "node_modules/anymatch": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", - "dev": true, "license": "ISC", "dependencies": { "normalize-path": "^3.0.0", @@ -4362,7 +4349,6 @@ "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "dev": true, "license": "MIT", "engines": { "node": ">=8.6" @@ -4375,7 +4361,6 @@ "version": "5.0.2", "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", - "dev": true, "license": "MIT" }, "node_modules/argparse": { @@ -4747,7 +4732,6 @@ "version": "2.3.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -4773,7 +4757,6 @@ "version": "3.0.3", "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", - "dev": true, "license": "MIT", "dependencies": { "fill-range": "^7.1.1" @@ -4889,7 +4872,6 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", - "dev": true, "license": "MIT", "engines": { "node": ">= 6" @@ -5013,7 +4995,6 @@ "version": "3.6.0", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", - "dev": true, "license": "MIT", "dependencies": { "anymatch": "~3.1.2", @@ -5038,7 +5019,6 @@ "version": "5.1.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", - "dev": true, "license": "ISC", "dependencies": { "is-glob": "^4.0.1" @@ -5114,7 +5094,6 @@ "version": "4.1.1", "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", - "dev": true, "license": "MIT", "engines": { "node": ">= 6" @@ -5175,7 +5154,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", - "dev": true, "license": "MIT", "bin": { "cssesc": "bin/cssesc" @@ -5589,14 +5567,12 @@ "version": "1.2.2", "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", - "dev": true, "license": "Apache-2.0" }, "node_modules/dlv": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==", - "dev": true, "license": "MIT" }, "node_modules/doctrine": { @@ -6510,7 +6486,6 @@ "version": "1.20.1", "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.20.1.tgz", "integrity": "sha512-GGToxJ/w1x32s/D2EKND7kTil4n8OVk/9mycTc4VDza13lOvpUZTGX3mFSCtV9ksdGBVzvsyAVLM6mHFThxXxw==", - "dev": true, "license": "ISC", "dependencies": { "reusify": "^1.0.4" @@ -6543,7 +6518,6 @@ "version": "6.5.0", "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", "integrity": "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==", - "dev": true, "license": "MIT", "engines": { "node": ">=12.0.0" @@ -6581,7 +6555,6 @@ "version": "7.1.1", "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", - "dev": true, "license": "MIT", "dependencies": { "to-regex-range": "^5.0.1" @@ -6742,7 +6715,6 @@ "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", - "dev": true, "hasInstallScript": true, "license": "MIT", "optional": true, @@ -6893,7 +6865,6 @@ "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", - "dev": true, "license": "ISC", "dependencies": { "is-glob": "^4.0.3" @@ -7391,7 +7362,6 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", - "dev": true, "license": "MIT", "dependencies": { "binary-extensions": "^2.0.0" @@ -7444,7 +7414,6 @@ "version": "2.16.1", "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", - "dev": true, "license": "MIT", "dependencies": { "hasown": "^2.0.2" @@ -7505,7 +7474,6 @@ "version": "2.1.1", "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.10.0" @@ -7551,7 +7519,6 @@ "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", - "dev": true, "license": "MIT", "dependencies": { "is-extglob": "^2.1.1" @@ -7600,7 +7567,6 @@ "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.12.0" @@ -7877,7 +7843,6 @@ "version": "1.21.7", "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", - "dev": true, "license": "MIT", "bin": { "jiti": "bin/jiti.js" @@ -8163,7 +8128,6 @@ "version": "3.1.3", "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", "integrity": "sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==", - "dev": true, "license": "MIT", "engines": { "node": ">=14" @@ -8176,7 +8140,6 @@ "version": "1.2.4", "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", - "dev": true, "license": "MIT" }, "node_modules/locate-path": { @@ -8491,7 +8454,6 @@ "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", - "dev": true, "license": "MIT", "engines": { "node": ">= 8" @@ -8943,7 +8905,6 @@ "version": "4.0.8", "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", - "dev": true, "license": "MIT", "dependencies": { "braces": "^3.0.3", @@ -8957,7 +8918,6 @@ "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "dev": true, "license": "MIT", "engines": { "node": ">=8.6" @@ -9072,7 +9032,6 @@ "version": "2.7.0", "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", "integrity": "sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==", - "dev": true, "license": "MIT", "dependencies": { "any-promise": "^1.0.0", @@ -9284,7 +9243,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.10.0" @@ -9303,7 +9261,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", - "dev": true, "license": "MIT", "engines": { "node": ">= 6" @@ -9648,7 +9605,6 @@ "version": "1.0.7", "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", - "dev": true, "license": "MIT" }, "node_modules/path-scurry": { @@ -9695,7 +9651,6 @@ "version": "4.0.3", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", - "dev": true, "license": "MIT", "engines": { "node": ">=12" @@ -9708,7 +9663,6 @@ "version": "2.3.0", "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.10.0" @@ -9718,7 +9672,6 @@ "version": "4.0.7", "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.7.tgz", "integrity": "sha512-TfySrs/5nm8fQJDcBDuUng3VOUKsd7S+zqvbOTiGXHfxX4wK31ard+hoNuvkicM/2YFzlpDgABOevKSsB4G/FA==", - "dev": true, "license": "MIT", "engines": { "node": ">= 6" @@ -9728,7 +9681,7 @@ "version": "1.58.1", "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.58.1.tgz", "integrity": "sha512-+2uTZHxSCcxjvGc5C891LrS1/NlxglGxzrC4seZiVjcYVQfUa87wBL6rTDqzGjuoWNjnBzRqKmF6zRYGMvQUaQ==", - "dev": true, + "devOptional": true, "license": "Apache-2.0", "dependencies": { "playwright-core": "1.58.1" @@ -9747,7 +9700,7 @@ "version": "1.58.1", "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.1.tgz", "integrity": "sha512-bcWzOaTxcW+VOOGBCQgnaKToLJ65d6AqfLVKEWvexyS3AS6rbXl+xdpYRMGSRBClPvyj44njOWoxjNdL/H9UNg==", - "dev": true, + "devOptional": true, "license": "Apache-2.0", "bin": { "playwright-core": "cli.js" @@ -9770,7 +9723,6 @@ "version": "8.5.6", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz", "integrity": "sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==", - "dev": true, "funding": [ { "type": "opencollective", @@ -9799,7 +9751,6 @@ "version": "15.1.0", "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-15.1.0.tgz", "integrity": "sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==", - "dev": true, "license": "MIT", "dependencies": { "postcss-value-parser": "^4.0.0", @@ -9817,7 +9768,6 @@ "version": "4.1.0", "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.1.0.tgz", "integrity": "sha512-oIAOTqgIo7q2EOwbhb8UalYePMvYoIeRY2YKntdpFQXNosSu3vLrniGgmH9OKs/qAkfoj5oB3le/7mINW1LCfw==", - "dev": true, "funding": [ { "type": "opencollective", @@ -9843,7 +9793,6 @@ "version": "6.0.1", "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-6.0.1.tgz", "integrity": "sha512-oPtTM4oerL+UXmx+93ytZVN82RrlY/wPUV8IeDxFrzIjXOLF1pN+EmKPLbubvKHT2HC20xXsCAH2Z+CKV6Oz/g==", - "dev": true, "funding": [ { "type": "opencollective", @@ -9886,7 +9835,6 @@ "version": "6.2.0", "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.2.0.tgz", "integrity": "sha512-HQbt28KulC5AJzG+cZtj9kvKB93CFCdLvog1WFLf1D+xmMvPGlBstkpTEZfK5+AN9hfJocyBFCNiqyS48bpgzQ==", - "dev": true, "funding": [ { "type": "opencollective", @@ -9912,7 +9860,6 @@ "version": "6.1.2", "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz", "integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==", - "dev": true, "license": "MIT", "dependencies": { "cssesc": "^3.0.0", @@ -9926,7 +9873,6 @@ "version": "4.2.0", "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", - "dev": true, "license": "MIT" }, "node_modules/prelude-ls": { @@ -10040,7 +9986,6 @@ "version": "1.2.3", "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", - "dev": true, "funding": [ { "type": "github", @@ -10829,7 +10774,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", - "dev": true, "license": "MIT", "dependencies": { "pify": "^2.3.0" @@ -10839,7 +10783,6 @@ "version": "3.6.0", "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", - "dev": true, "license": "MIT", "dependencies": { "picomatch": "^2.2.1" @@ -10852,7 +10795,6 @@ "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "dev": true, "license": "MIT", "engines": { "node": ">=8.6" @@ -11117,7 +11059,6 @@ "version": "1.22.11", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", "integrity": "sha512-RfqAvLnMl313r7c9oclB1HhUEAezcpLjz95wFH4LVuhk9JF/r22qmVP9AMmOU4vMX7Q8pN8jwNg/CSpdFnMjTQ==", - "dev": true, "license": "MIT", "dependencies": { "is-core-module": "^2.16.1", @@ -11158,7 +11099,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", - "dev": true, "license": "MIT", "engines": { "iojs": ">=1.0.0", @@ -11214,7 +11154,6 @@ "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", - "dev": true, "funding": [ { "type": "github", @@ -11855,7 +11794,6 @@ "version": "3.35.1", "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.1.tgz", "integrity": "sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==", - "dev": true, "license": "MIT", "dependencies": { "@jridgewell/gen-mapping": "^0.3.2", @@ -11891,7 +11829,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -11927,7 +11864,6 @@ "version": "3.4.19", "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.19.tgz", "integrity": "sha512-3ofp+LL8E+pK/JuPLPggVAIaEuhvIz4qNcf3nA1Xn2o/7fb7s/TYpHhwGDv1ZU3PkBluUVaF8PyCHcm48cKLWQ==", - "dev": true, "license": "MIT", "dependencies": { "@alloc/quick-lru": "^5.2.0", @@ -11965,7 +11901,6 @@ "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", - "dev": true, "license": "MIT", "dependencies": { "@nodelib/fs.stat": "^2.0.2", @@ -11982,7 +11917,6 @@ "version": "5.1.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", - "dev": true, "license": "ISC", "dependencies": { "is-glob": "^4.0.1" @@ -12010,7 +11944,6 @@ "version": "3.3.1", "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", "integrity": "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==", - "dev": true, "license": "MIT", "dependencies": { "any-promise": "^1.0.0" @@ -12020,7 +11953,6 @@ "version": "1.6.0", "resolved": "https://registry.npmjs.org/thenify-all/-/thenify-all-1.6.0.tgz", "integrity": "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==", - "dev": true, "license": "MIT", "dependencies": { "thenify": ">= 3.1.0 < 4" @@ -12062,7 +11994,6 @@ "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", "integrity": "sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==", - "dev": true, "license": "MIT", "dependencies": { "fdir": "^6.5.0", @@ -12129,7 +12060,6 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", - "dev": true, "license": "MIT", "dependencies": { "is-number": "^7.0.0" @@ -12217,7 +12147,6 @@ "version": "0.1.13", "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", - "dev": true, "license": "Apache-2.0" }, "node_modules/tsconfig-paths": { @@ -12334,7 +12263,7 @@ "version": "5.3.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz", "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==", - "dev": true, + "devOptional": true, "license": "Apache-2.0", "bin": { "tsc": "bin/tsc", @@ -12536,7 +12465,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", - "dev": true, "license": "MIT" }, "node_modules/uuid": { @@ -12990,7 +12918,7 @@ "version": "8.19.0", "resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz", "integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==", - "dev": true, + "devOptional": true, "license": "MIT", "engines": { "node": ">=10.0.0" diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useProjectDetails.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useProjectDetails.ts new file mode 100644 index 00000000000..1d35ac1bf70 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useProjectDetails.ts @@ -0,0 +1,63 @@ +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import { all_admin_roles } from "@/utils/roles"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { ProjectResponse, projectKeys } from "./useProjects"; + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const fetchProjectDetails = async ( + accessToken: string, + projectId: string, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/project/info?project_id=${encodeURIComponent(projectId)}`; + + const response = await fetch(url, { + method: "GET", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useProjectDetails = (projectId?: string) => { + const { accessToken, userRole } = useAuthorized(); + const queryClient = useQueryClient(); + + return useQuery({ + queryKey: projectKeys.detail(projectId!), + queryFn: async () => fetchProjectDetails(accessToken!, projectId!), + enabled: + Boolean(accessToken && projectId) && + all_admin_roles.includes(userRole || ""), + + // Seed from the list cache when available + initialData: () => { + if (!projectId) return undefined; + + const projects = queryClient.getQueryData( + projectKeys.list({}), + ); + + return projects?.find((p) => p.project_id === projectId); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useUpdateProject.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useUpdateProject.ts new file mode 100644 index 00000000000..e6cd3071f5f --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/projects/useUpdateProject.ts @@ -0,0 +1,75 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { ProjectResponse, projectKeys } from "./useProjects"; + +// ── Types ──────────────────────────────────────────────────────────────────── + +export interface ProjectUpdateParams { + project_alias?: string; + description?: string; + team_id?: string; + models?: string[]; + max_budget?: number; + blocked?: boolean; + metadata?: Record; + model_rpm_limit?: Record; + model_tpm_limit?: Record; +} + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const updateProject = async ( + accessToken: string, + projectId: string, + params: ProjectUpdateParams, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/project/update`; + + const response = await fetch(url, { + method: "POST", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ project_id: projectId, ...params }), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useUpdateProject = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation< + ProjectResponse, + Error, + { projectId: string; params: ProjectUpdateParams } + >({ + mutationFn: async ({ projectId, params }) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return updateProject(accessToken, projectId, params); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: projectKeys.all }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/components/CostTrackingSettings/add_margin_form.test.tsx b/ui/litellm-dashboard/src/components/CostTrackingSettings/add_margin_form.test.tsx new file mode 100644 index 00000000000..d61f0987ac2 --- /dev/null +++ b/ui/litellm-dashboard/src/components/CostTrackingSettings/add_margin_form.test.tsx @@ -0,0 +1,148 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import AddMarginForm from "./add_margin_form"; +import { MarginConfig } from "./types"; + +vi.mock("../provider_info_helpers", () => ({ + Providers: { + OpenAI: "OpenAI", + Anthropic: "Anthropic", + }, + provider_map: { + OpenAI: "openai", + Anthropic: "anthropic", + }, + providerLogoMap: { + OpenAI: "https://example.com/openai.png", + Anthropic: "https://example.com/anthropic.png", + }, +})); + +vi.mock("./provider_display_helpers", () => ({ + handleImageError: vi.fn(), +})); + +const DEFAULT_PROPS = { + marginConfig: {} as MarginConfig, + selectedProvider: undefined, + marginType: "percentage" as const, + percentageValue: "", + fixedAmountValue: "", + onProviderChange: vi.fn(), + onMarginTypeChange: vi.fn(), + onPercentageChange: vi.fn(), + onFixedAmountChange: vi.fn(), + onAddProvider: vi.fn(), +}; + +describe("AddMarginForm", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render", () => { + renderWithProviders(); + expect(screen.getByRole("button", { name: /add provider margin/i })).toBeInTheDocument(); + }); + + it("should show the percentage input when marginType is percentage", () => { + renderWithProviders(); + expect(screen.getByPlaceholderText("10")).toBeInTheDocument(); + }); + + it("should show the fixed amount input when marginType is fixed", () => { + renderWithProviders(); + expect(screen.getByPlaceholderText("0.001")).toBeInTheDocument(); + }); + + it("should not show the fixed amount input when marginType is percentage", () => { + renderWithProviders(); + expect(screen.queryByPlaceholderText("0.001")).not.toBeInTheDocument(); + }); + + it("should not show the percentage input when marginType is fixed", () => { + renderWithProviders(); + expect(screen.queryByPlaceholderText("10")).not.toBeInTheDocument(); + }); + + it("should show the Percentage-based and Fixed Amount radio options", () => { + renderWithProviders(); + expect(screen.getByText("Percentage-based")).toBeInTheDocument(); + expect(screen.getByText("Fixed Amount")).toBeInTheDocument(); + }); + + it("should disable the submit button when no provider is selected (percentage mode)", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider margin/i })).toBeDisabled(); + }); + + it("should disable the submit button when provider is selected but no percentage value (percentage mode)", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider margin/i })).toBeDisabled(); + }); + + it("should enable the submit button when provider and percentage value are both provided", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider margin/i })).not.toBeDisabled(); + }); + + it("should disable the submit button in fixed mode when no fixed amount is provided", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider margin/i })).toBeDisabled(); + }); + + it("should enable the submit button in fixed mode when provider and fixed amount are provided", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider margin/i })).not.toBeDisabled(); + }); + + it("should call onAddProvider when the enabled submit button is clicked", async () => { + const onAddProvider = vi.fn(); + const user = userEvent.setup(); + renderWithProviders( + + ); + + await user.click(screen.getByRole("button", { name: /add provider margin/i })); + expect(onAddProvider).toHaveBeenCalledTimes(1); + }); + + it("should call onMarginTypeChange when the Fixed Amount radio is clicked", async () => { + const onMarginTypeChange = vi.fn(); + const user = userEvent.setup(); + renderWithProviders( + + ); + + await user.click(screen.getByText("Fixed Amount")); + expect(onMarginTypeChange).toHaveBeenCalledWith("fixed"); + }); +}); diff --git a/ui/litellm-dashboard/src/components/CostTrackingSettings/add_provider_form.test.tsx b/ui/litellm-dashboard/src/components/CostTrackingSettings/add_provider_form.test.tsx new file mode 100644 index 00000000000..611c8609c36 --- /dev/null +++ b/ui/litellm-dashboard/src/components/CostTrackingSettings/add_provider_form.test.tsx @@ -0,0 +1,98 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import AddProviderForm from "./add_provider_form"; +import { DiscountConfig } from "./types"; + +vi.mock("../provider_info_helpers", () => ({ + Providers: { + OpenAI: "OpenAI", + Anthropic: "Anthropic", + }, + provider_map: { + OpenAI: "openai", + Anthropic: "anthropic", + }, + providerLogoMap: { + OpenAI: "https://example.com/openai.png", + Anthropic: "https://example.com/anthropic.png", + }, +})); + +vi.mock("./provider_display_helpers", () => ({ + handleImageError: vi.fn(), +})); + +const DEFAULT_PROPS = { + discountConfig: {} as DiscountConfig, + selectedProvider: undefined, + newDiscount: "", + onProviderChange: vi.fn(), + onDiscountChange: vi.fn(), + onAddProvider: vi.fn(), +}; + +describe("AddProviderForm", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render", () => { + renderWithProviders(); + expect(screen.getByRole("button", { name: /add provider discount/i })).toBeInTheDocument(); + }); + + it("should render the discount percentage input field", () => { + renderWithProviders(); + expect(screen.getByPlaceholderText("5")).toBeInTheDocument(); + }); + + it("should disable the submit button when no provider is selected and no discount is entered", () => { + renderWithProviders(); + expect(screen.getByRole("button", { name: /add provider discount/i })).toBeDisabled(); + }); + + it("should disable the submit button when a provider is selected but no discount is entered", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider discount/i })).toBeDisabled(); + }); + + it("should disable the submit button when a discount is entered but no provider is selected", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider discount/i })).toBeDisabled(); + }); + + it("should enable the submit button when both a provider and a discount value are provided", () => { + renderWithProviders( + + ); + expect(screen.getByRole("button", { name: /add provider discount/i })).not.toBeDisabled(); + }); + + it("should call onAddProvider when the enabled submit button is clicked", async () => { + const onAddProvider = vi.fn(); + const user = userEvent.setup(); + renderWithProviders( + + ); + + await user.click(screen.getByRole("button", { name: /add provider discount/i })); + expect(onAddProvider).toHaveBeenCalledTimes(1); + }); + + it("should show the percent sign next to the discount input", () => { + renderWithProviders(); + expect(screen.getByText("%")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/CostTrackingSettings/cost_tracking_settings.test.tsx b/ui/litellm-dashboard/src/components/CostTrackingSettings/cost_tracking_settings.test.tsx new file mode 100644 index 00000000000..db6899ba17f --- /dev/null +++ b/ui/litellm-dashboard/src/components/CostTrackingSettings/cost_tracking_settings.test.tsx @@ -0,0 +1,201 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import CostTrackingSettings from "./cost_tracking_settings"; + +// Mock sub-hooks so we can control their state without network calls +const mockDiscountConfig = vi.fn(() => ({})); +const mockMarginConfig = vi.fn(() => ({})); + +vi.mock("./use_discount_config", () => ({ + useDiscountConfig: () => ({ + discountConfig: mockDiscountConfig(), + fetchDiscountConfig: vi.fn().mockResolvedValue(undefined), + handleAddProvider: vi.fn().mockResolvedValue(true), + handleRemoveProvider: vi.fn().mockResolvedValue(undefined), + handleDiscountChange: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock("./use_margin_config", () => ({ + useMarginConfig: () => ({ + marginConfig: mockMarginConfig(), + fetchMarginConfig: vi.fn().mockResolvedValue(undefined), + handleAddMargin: vi.fn().mockResolvedValue(true), + handleRemoveMargin: vi.fn().mockResolvedValue(undefined), + handleMarginChange: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock("./pricing_calculator/index", () => ({ + default: () =>
Pricing Calculator
, +})); + +vi.mock("../playground/llm_calls/fetch_models", () => ({ + fetchAvailableModels: vi.fn().mockResolvedValue([]), +})); + +vi.mock("../HelpLink", () => ({ + DocsMenu: () => null, +})); + +vi.mock("./how_it_works", () => ({ + default: () =>
How It Works
, +})); + +vi.mock("../provider_info_helpers", () => ({ + Providers: { OpenAI: "OpenAI" }, + provider_map: { OpenAI: "openai" }, + providerLogoMap: {}, +})); + +vi.mock("./provider_display_helpers", () => ({ + getProviderDisplayInfo: vi.fn(() => ({ displayName: "OpenAI", logo: "", enumKey: "OpenAI" })), + handleImageError: vi.fn(), +})); + +const ADMIN_PROPS = { + userID: "user-1", + userRole: "proxy_admin", + accessToken: "test-token", +}; + +describe("CostTrackingSettings", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockDiscountConfig.mockReturnValue({}); + mockMarginConfig.mockReturnValue({}); + }); + + it("should return nothing when accessToken is null", () => { + const { container } = renderWithProviders( + + ); + expect(container.firstChild).toBeNull(); + }); + + it("should render the page title", () => { + renderWithProviders(); + expect(screen.getByText("Cost Tracking Settings")).toBeInTheDocument(); + }); + + it("should show the Provider Discounts accordion header for proxy_admin", () => { + renderWithProviders(); + expect(screen.getByText("Provider Discounts")).toBeInTheDocument(); + }); + + it("should show the Fee/Price Margin accordion header for proxy_admin", () => { + renderWithProviders(); + expect(screen.getByText("Fee/Price Margin")).toBeInTheDocument(); + }); + + it("should always show the Pricing Calculator section", () => { + renderWithProviders(); + // The accordion header text appears in the DOM; getAllByText tolerates duplicates + expect(screen.getAllByText("Pricing Calculator").length).toBeGreaterThan(0); + }); + + it("should show the pricing calculator component", async () => { + renderWithProviders(); + expect(await screen.findByTestId("pricing-calculator")).toBeInTheDocument(); + }); + + it("should not show Provider Discounts section for a non-admin role", () => { + renderWithProviders( + + ); + expect(screen.queryByText("Provider Discounts")).not.toBeInTheDocument(); + }); + + it("should not show Fee/Price Margin section for a non-admin role", () => { + renderWithProviders( + + ); + expect(screen.queryByText("Fee/Price Margin")).not.toBeInTheDocument(); + }); + + it("should show Provider Discounts for the 'Admin' role as well", () => { + renderWithProviders( + + ); + expect(screen.getByText("Provider Discounts")).toBeInTheDocument(); + }); + + it("should show the subtitle describing discount/margin configuration", () => { + renderWithProviders(); + expect( + screen.getByText(/configure cost discounts and margins/i) + ).toBeInTheDocument(); + }); + + describe("Add Provider Discount modal", () => { + it("should open the Add Provider Discount modal when the button is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + // The button lives inside the Provider Discounts accordion — click the header to expand first + const accordionHeader = screen.getByText("Provider Discounts").closest("button"); + if (accordionHeader) { + await user.click(accordionHeader); + } + + const addButton = await screen.findByRole("button", { name: /add provider discount/i }); + await user.click(addButton); + + expect( + await screen.findByText("Add Provider Discount", { selector: "h2" }) + ).toBeInTheDocument(); + }); + }); + + describe("Add Provider Margin modal", () => { + it("should open the Add Provider Margin modal when the button is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + const accordionHeader = screen.getByText("Fee/Price Margin").closest("button"); + if (accordionHeader) { + await user.click(accordionHeader); + } + + const addButton = await screen.findByRole("button", { name: /add provider margin/i }); + await user.click(addButton); + + expect( + await screen.findByText("Add Provider Margin", { selector: "h2" }) + ).toBeInTheDocument(); + }); + }); + + describe("empty state messages", () => { + it("should show the empty state message when no discount config is loaded", async () => { + mockDiscountConfig.mockReturnValue({}); + renderWithProviders(); + + const accordionHeader = screen.getByText("Provider Discounts").closest("button"); + if (accordionHeader) { + await userEvent.setup().click(accordionHeader); + } + + expect( + await screen.findByText(/no provider discounts configured/i) + ).toBeInTheDocument(); + }); + + it("should show the empty state message when no margin config is loaded", async () => { + mockMarginConfig.mockReturnValue({}); + renderWithProviders(); + + const accordionHeader = screen.getByText("Fee/Price Margin").closest("button"); + if (accordionHeader) { + await userEvent.setup().click(accordionHeader); + } + + expect( + await screen.findByText(/no provider margins configured/i) + ).toBeInTheDocument(); + }); + }); +}); diff --git a/ui/litellm-dashboard/src/components/CostTrackingSettings/how_it_works.test.tsx b/ui/litellm-dashboard/src/components/CostTrackingSettings/how_it_works.test.tsx new file mode 100644 index 00000000000..fa608f555ce --- /dev/null +++ b/ui/litellm-dashboard/src/components/CostTrackingSettings/how_it_works.test.tsx @@ -0,0 +1,95 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import HowItWorks from "./how_it_works"; + +vi.mock("@/app/(dashboard)/api-reference/components/CodeBlock", () => ({ + default: ({ code }: { code: string }) =>
{code}
, +})); + +describe("HowItWorks", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render", () => { + renderWithProviders(); + expect(screen.getByText("Cost Calculation")).toBeInTheDocument(); + }); + + it("should display the cost calculation formula", () => { + renderWithProviders(); + expect(screen.getByText(/final_cost = base_cost/i)).toBeInTheDocument(); + }); + + it("should display the valid range information", () => { + renderWithProviders(); + expect(screen.getByText(/0% and 100%/i)).toBeInTheDocument(); + }); + + it("should render the code block with a curl example", () => { + renderWithProviders(); + expect(screen.getByTestId("code-block")).toBeInTheDocument(); + expect(screen.getByTestId("code-block").textContent).toContain("curl"); + }); + + it("should show the response header names for discount verification", () => { + renderWithProviders(); + expect(screen.getByText("x-litellm-response-cost")).toBeInTheDocument(); + expect(screen.getByText("x-litellm-response-cost-original")).toBeInTheDocument(); + expect(screen.getByText("x-litellm-response-cost-discount-amount")).toBeInTheDocument(); + }); + + it("should not show calculated results initially when no input is provided", () => { + renderWithProviders(); + expect(screen.queryByText("Calculated Results")).not.toBeInTheDocument(); + }); + + it("should not show calculated results when only response cost is entered", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + const responseCostInput = screen.getByPlaceholderText("0.0171938125"); + await user.type(responseCostInput, "0.01"); + + expect(screen.queryByText("Calculated Results")).not.toBeInTheDocument(); + }); + + it("should not show calculated results when only discount amount is entered", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + const discountAmountInput = screen.getByPlaceholderText("0.0009049375"); + await user.type(discountAmountInput, "0.001"); + + expect(screen.queryByText("Calculated Results")).not.toBeInTheDocument(); + }); + + it("should show calculated results when both fields are filled", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + const responseCostInput = screen.getByPlaceholderText("0.0171938125"); + const discountAmountInput = screen.getByPlaceholderText("0.0009049375"); + + await user.type(responseCostInput, "0.0171938125"); + await user.type(discountAmountInput, "0.0009049375"); + + expect(await screen.findByText("Calculated Results")).toBeInTheDocument(); + }); + + it("should show original cost, final cost, and discount amount in results", async () => { + const user = userEvent.setup(); + renderWithProviders(); + + await user.type(screen.getByPlaceholderText("0.0171938125"), "0.0171938125"); + await user.type(screen.getByPlaceholderText("0.0009049375"), "0.0009049375"); + + expect(await screen.findByText("Original Cost:")).toBeInTheDocument(); + expect(screen.getByText("Final Cost:")).toBeInTheDocument(); + expect(screen.getByText("Discount Amount:")).toBeInTheDocument(); + expect(screen.getByText("Discount Applied:")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/CostTrackingSettings/provider_discount_table.test.tsx b/ui/litellm-dashboard/src/components/CostTrackingSettings/provider_discount_table.test.tsx new file mode 100644 index 00000000000..7697c6e7686 --- /dev/null +++ b/ui/litellm-dashboard/src/components/CostTrackingSettings/provider_discount_table.test.tsx @@ -0,0 +1,241 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { screen, within } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import ProviderDiscountTable from "./provider_discount_table"; + +vi.mock("@heroicons/react/outline", () => ({ + TrashIcon: function TrashIcon() { return null; }, + PencilAltIcon: function PencilAltIcon() { return null; }, + CheckIcon: function CheckIcon() { return null; }, + XIcon: function XIcon() { return null; }, +})); + +vi.mock("@tremor/react", () => ({ + Table: ({ children }: any) => {children}
, + TableHead: ({ children }: any) => {children}, + TableRow: ({ children }: any) => {children}, + TableHeaderCell: ({ children }: any) => {children}, + TableBody: ({ children }: any) => {children}, + TableCell: ({ children }: any) => {children}, + Text: ({ children }: any) => {children}, + TextInput: ({ value, onValueChange, onKeyDown, placeholder, ...rest }: any) => ( + onValueChange?.(e.target.value)} + onKeyDown={onKeyDown} + placeholder={placeholder} + {...rest} + /> + ), + Icon: ({ icon: IconComponent, onClick }: any) => { + const name = IconComponent?.displayName ?? IconComponent?.name ?? "icon"; + return + + + {/* Project Details */} + + + + {project.description || "\u2014"} + + {new Date(project.created_at).toLocaleString()} + {project.created_by && ( + +  {"by"}  + + + )} + + + {new Date(project.updated_at).toLocaleString()} + {project.updated_by && ( + +  {"by"}  + + + )} + + + + + + {/* Spend / Budget */} + + + + + Budget + + } + style={{ height: "100%" }} + > + +
+ + ${spend.toFixed(2)} + +
+ {hasLimit ? `of $${maxBudget.toFixed(2)} budget` : "No budget limit"} +
+ {hasLimit && ( +
+ + + {(Math.round(spendPercent * 10) / 10).toFixed(1)}% utilized + +
+ )} +
+
+ + + + {modelSpendData.length > 0 ? ( + `$${value.toFixed(4)}`} + yAxisWidth={140} + showLegend={false} + style={{ height: Math.max(modelSpendData.length * 40, 120) }} + /> + ) : ( + + )} + + +
+ + {/* Keys & Team */} + + + + + Keys + + } + style={{ height: "100%" }} + > + + + + + + + Team + + } + style={{ height: "100%" }} + > + {teamInfo ? ( + (() => { + const teamBudget = teamInfo.max_budget ?? null; + const teamSpend = teamInfo.spend ?? 0; + const teamHasLimit = teamBudget != null && teamBudget > 0; + const teamPercent = teamHasLimit ? Math.min((teamSpend / teamBudget) * 100, 100) : 0; + const teamColor = teamPercent >= 90 ? "#f5222d" : teamPercent >= 70 ? "#faad14" : "#52c41a"; + + return ( + + {/* Team name + ID */} +
+ + {teamInfo.team_alias || teamInfo.team_id} + +
+ + ID:{" "} + + {teamInfo.team_id} + + +
+ + {/* Models */} +
+ + Models + + {(teamInfo.models?.length ?? 0) > 0 ? ( + + {teamInfo.models?.map((m: string) => ( + + {m} + + ))} + + ) : ( + All models + )} +
+ + {/* Budget + Spend compact */} +
+ + + Spend + + + ${teamSpend.toFixed(2)} + {teamHasLimit ? ( + + {" "} + / ${teamBudget.toFixed(2)} + + ) : ( + + {" "} + (Unlimited) + + )} + + + {teamHasLimit && ( + + )} +
+ + {/* Members */} + + + Members + + {teamInfo.members_with_roles?.length ?? 0} + +
+ ); + })() + ) : project.team_id ? ( + + } size="small" /> + + ) : ( + + )} +
+ +
+ + {/* Edit Modal */} + setIsEditModalVisible(false)} /> + + ); +} diff --git a/ui/litellm-dashboard/src/components/Projects/ProjectModals/CreateProjectModal.tsx b/ui/litellm-dashboard/src/components/Projects/ProjectModals/CreateProjectModal.tsx index 14b4d70b743..e490f89303f 100644 --- a/ui/litellm-dashboard/src/components/Projects/ProjectModals/CreateProjectModal.tsx +++ b/ui/litellm-dashboard/src/components/Projects/ProjectModals/CreateProjectModal.tsx @@ -1,95 +1,39 @@ -import { useEffect, useState } from "react"; +import { Modal, Form, Button, Typography, message } from "antd"; +import { FolderAddOutlined } from "@ant-design/icons"; import { - Alert, - Modal, - Form, - Input, - Select, - Switch, - InputNumber, - Collapse, - Button, - Col, - Flex, - Row, - Space, - Divider, - Typography, - message, -} from "antd"; -import { FolderAddOutlined, PlusOutlined, MinusCircleOutlined } from "@ant-design/icons"; -import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; -import { useTeams } from "@/app/(dashboard)/hooks/teams/useTeams"; -import { useCreateProject, ProjectCreateParams } from "@/app/(dashboard)/hooks/projects/useCreateProject"; -import { Team } from "../../key_team_helpers/key_list"; -import { fetchTeamModels } from "../../organisms/create_key_button"; -import { getModelDisplayName } from "../../key_team_helpers/fetch_available_models_team_key"; + useCreateProject, + ProjectCreateParams, +} from "@/app/(dashboard)/hooks/projects/useCreateProject"; +import { + ProjectBaseForm, + ProjectFormValues, +} from "./ProjectBaseForm"; +import { buildProjectApiParams } from "./projectFormUtils"; interface CreateProjectModalProps { isOpen: boolean; onClose: () => void; } -export function CreateProjectModal({ isOpen, onClose }: CreateProjectModalProps) { - const [form] = Form.useForm(); - const { accessToken, userId, userRole } = useAuthorized(); - const { data: teams } = useTeams(); +export function CreateProjectModal({ + isOpen, + onClose, +}: CreateProjectModalProps) { + const [form] = Form.useForm(); const createMutation = useCreateProject(); - const [selectedTeam, setSelectedTeam] = useState(null); - const [modelsToPick, setModelsToPick] = useState([]); - - // Fetch team-scoped models when team selection changes - useEffect(() => { - if (userId && userRole && accessToken && selectedTeam) { - fetchTeamModels(userId, userRole, accessToken, selectedTeam.team_id).then((models) => { - const allModels = Array.from(new Set([...(selectedTeam.models ?? []), ...models])); - setModelsToPick(allModels); - }); - } else { - setModelsToPick([]); - } - form.setFieldValue("models", []); - }, [selectedTeam, accessToken, userId, userRole, form]); - const handleSubmit = async () => { try { const values = await form.validateFields(); - - // Build model-specific limits from the dynamic form list - const modelRpmLimit: Record = {}; - const modelTpmLimit: Record = {}; - for (const entry of values.modelLimits ?? []) { - if (entry.model) { - if (entry.rpm != null) modelRpmLimit[entry.model] = entry.rpm; - if (entry.tpm != null) modelTpmLimit[entry.model] = entry.tpm; - } - } - - // Build metadata from the dynamic form list - const metadata: Record = {}; - for (const entry of values.metadata ?? []) { - if (entry.key) metadata[entry.key] = entry.value; - } - const params: ProjectCreateParams = { - project_alias: values.project_alias, - description: values.description, + ...buildProjectApiParams(values), team_id: values.team_id, - models: values.models ?? [], - max_budget: values.max_budget, - blocked: values.isBlocked ?? false, - ...(Object.keys(modelRpmLimit).length > 0 && { model_rpm_limit: modelRpmLimit }), - ...(Object.keys(modelTpmLimit).length > 0 && { model_tpm_limit: modelTpmLimit }), - ...(Object.keys(metadata).length > 0 && { metadata }), }; createMutation.mutate(params, { onSuccess: () => { message.success("Project created successfully"); form.resetFields(); - setSelectedTeam(null); - setModelsToPick([]); onClose(); }, onError: (error) => { @@ -103,16 +47,9 @@ export function CreateProjectModal({ isOpen, onClose }: CreateProjectModalProps) const handleCancel = () => { form.resetFields(); - setSelectedTeam(null); - setModelsToPick([]); onClose(); }; - const handleTeamChange = (teamId: string) => { - const team = teams?.find((t) => t.team_id === teamId) ?? null; - setSelectedTeam(team); - }; - return ( Cancel , - , ]} > -
- {/* Basic Info */} - - Basic Information - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - {/* Advanced Settings */} - - - - - Advanced Settings - - } - key="1" - > - - Block Project - - - - - prev.isBlocked !== cur.isBlocked}> - {({ getFieldValue }) => - getFieldValue("isBlocked") ? ( - - ) : null - } - - - - - - Model-Specific Limits - - - {(fields, { add, remove }) => ( - <> - {fields.map(({ key, name, ...restField }) => ( - - - - - - - - - - - remove(name)} style={{ color: "#ef4444" }} /> - - ))} - - - - - )} - - - - - - Metadata - - - {(fields, { add, remove }) => ( - <> - {fields.map(({ key, name, ...restField }) => ( - - - - - - - - remove(name)} style={{ color: "#ef4444" }} /> - - ))} - - - - - )} - - - - - - +
); } diff --git a/ui/litellm-dashboard/src/components/Projects/ProjectModals/EditProjectModal.tsx b/ui/litellm-dashboard/src/components/Projects/ProjectModals/EditProjectModal.tsx new file mode 100644 index 00000000000..75f56b1373f --- /dev/null +++ b/ui/litellm-dashboard/src/components/Projects/ProjectModals/EditProjectModal.tsx @@ -0,0 +1,126 @@ +import { useEffect } from "react"; +import { Modal, Form, Button, Typography, message } from "antd"; +import { SaveOutlined } from "@ant-design/icons"; +import { ProjectResponse } from "@/app/(dashboard)/hooks/projects/useProjects"; +import { + useUpdateProject, + ProjectUpdateParams, +} from "@/app/(dashboard)/hooks/projects/useUpdateProject"; +import { ProjectBaseForm, ProjectFormValues } from "./ProjectBaseForm"; +import { buildProjectApiParams } from "./projectFormUtils"; + +interface EditProjectModalProps { + isOpen: boolean; + project: ProjectResponse; + onClose: () => void; + onSuccess?: () => void; +} + +export function EditProjectModal({ + isOpen, + project, + onClose, + onSuccess, +}: EditProjectModalProps) { + const [form] = Form.useForm(); + const updateMutation = useUpdateProject(); + + // Populate form with existing project data when modal opens + useEffect(() => { + if (isOpen && project) { + // Model limits are stored inside metadata by the backend + const metadataObj = (project.metadata ?? {}) as Record; + const rpmLimits = (metadataObj.model_rpm_limit ?? {}) as Record; + const tpmLimits = (metadataObj.model_tpm_limit ?? {}) as Record; + + const modelLimits: ProjectFormValues["modelLimits"] = []; + const allLimitModels = new Set([ + ...Object.keys(rpmLimits), + ...Object.keys(tpmLimits), + ]); + for (const model of allLimitModels) { + modelLimits.push({ + model, + rpm: rpmLimits[model], + tpm: tpmLimits[model], + }); + } + + // Filter out internal keys from user-facing metadata + const internalKeys = new Set(["model_rpm_limit", "model_tpm_limit"]); + const metadata: ProjectFormValues["metadata"] = []; + for (const [key, value] of Object.entries(metadataObj)) { + if (!internalKeys.has(key)) { + metadata.push({ key, value: String(value) }); + } + } + + form.setFieldsValue({ + project_alias: project.project_alias ?? "", + team_id: project.team_id ?? "", + description: project.description ?? "", + models: project.models ?? [], + max_budget: project.litellm_budget_table?.max_budget ?? undefined, + isBlocked: project.blocked, + modelLimits: modelLimits.length > 0 ? modelLimits : undefined, + metadata: metadata.length > 0 ? metadata : undefined, + }); + } + }, [isOpen, project, form]); + + const handleSubmit = async () => { + try { + const values = await form.validateFields(); + const params: ProjectUpdateParams = { + ...buildProjectApiParams(values), + team_id: values.team_id, + }; + + updateMutation.mutate( + { projectId: project.project_id, params }, + { + onSuccess: () => { + message.success("Project updated successfully"); + onSuccess?.(); + onClose(); + }, + onError: (error) => { + message.error(error.message || "Failed to update project"); + }, + }, + ); + } catch (error) { + console.error("Validation failed:", error); + } + }; + + return ( + + Edit Project + + } + open={isOpen} + onCancel={onClose} + width={720} + destroyOnHidden + footer={[ + , + , + ]} + > + + + ); +} diff --git a/ui/litellm-dashboard/src/components/Projects/ProjectModals/ProjectBaseForm.tsx b/ui/litellm-dashboard/src/components/Projects/ProjectModals/ProjectBaseForm.tsx new file mode 100644 index 00000000000..bf1eca882c3 --- /dev/null +++ b/ui/litellm-dashboard/src/components/Projects/ProjectModals/ProjectBaseForm.tsx @@ -0,0 +1,401 @@ +import { useEffect, useState } from "react"; +import { + Alert, + Col, + Collapse, + Divider, + Flex, + Form, + Input, + InputNumber, + Row, + Select, + Space, + Switch, + Typography, + Button, +} from "antd"; +import type { FormInstance } from "antd"; +import { PlusOutlined, MinusCircleOutlined } from "@ant-design/icons"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { useTeams } from "@/app/(dashboard)/hooks/teams/useTeams"; +import { Team } from "../../key_team_helpers/key_list"; +import { fetchTeamModels } from "../../organisms/create_key_button"; +import { getModelDisplayName } from "../../key_team_helpers/fetch_available_models_team_key"; + +export interface ProjectFormValues { + project_alias: string; + team_id: string; + description?: string; + models: string[]; + max_budget?: number; + isBlocked: boolean; + modelLimits?: { model: string; tpm?: number; rpm?: number }[]; + metadata?: { key: string; value: string }[]; +} + +interface ProjectBaseFormProps { + form: FormInstance; +} + +export function ProjectBaseForm({ + form, +}: ProjectBaseFormProps) { + const { accessToken, userId, userRole } = useAuthorized(); + const { data: teams } = useTeams(); + + const [selectedTeam, setSelectedTeam] = useState(null); + const [modelsToPick, setModelsToPick] = useState([]); + + // Sync selectedTeam from form value (needed for edit mode pre-fill) + const teamIdValue = Form.useWatch("team_id", form); + useEffect(() => { + if (teamIdValue && teams) { + const team = teams.find((t) => t.team_id === teamIdValue) ?? null; + if (team && team.team_id !== selectedTeam?.team_id) { + setSelectedTeam(team); + } + } + }, [teamIdValue, teams, selectedTeam?.team_id]); + + // Fetch team-scoped models when team selection changes + useEffect(() => { + if (userId && userRole && accessToken && selectedTeam) { + fetchTeamModels(userId, userRole, accessToken, selectedTeam.team_id).then( + (models) => { + const allModels = Array.from( + new Set([...(selectedTeam.models ?? []), ...models]), + ); + setModelsToPick(allModels); + }, + ); + } else { + setModelsToPick([]); + } + }, [selectedTeam, accessToken, userId, userRole]); + + const handleTeamChange = (teamId: string) => { + const team = teams?.find((t) => t.team_id === teamId) ?? null; + setSelectedTeam(team); + form.setFieldValue("models", []); + }; + + return ( +
+ {/* Basic Info */} + + Basic Information + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {/* Advanced Settings */} + + + + Advanced Settings + + ), + children: ( + <> + + Block Project + + + + + prev.isBlocked !== cur.isBlocked} + > + {({ getFieldValue }) => + getFieldValue("isBlocked") ? ( + + ) : null + } + + + + + + Model-Specific Limits + + + {(fields, { add, remove }) => ( + <> + {fields.map(({ key, name, ...restField }) => ( + + { + if (!value) return Promise.resolve(); + const all = form.getFieldValue("modelLimits") ?? []; + const dupes = all.filter( + (entry: { model?: string }) => entry?.model === value, + ); + if (dupes.length > 1) { + return Promise.reject(new Error("Duplicate model")); + } + return Promise.resolve(); + }, + }, + ]} + > + + + + + + + + + remove(name)} + style={{ color: "#ef4444" }} + /> + + ))} + + + + + )} + + + + + + Metadata + + + {(fields, { add, remove }) => ( + <> + {fields.map(({ key, name, ...restField }) => ( + + { + if (!value) return Promise.resolve(); + const all = form.getFieldValue("metadata") ?? []; + const dupes = all.filter( + (entry: { key?: string }) => entry?.key === value, + ); + if (dupes.length > 1) { + return Promise.reject(new Error("Duplicate key")); + } + return Promise.resolve(); + }, + }, + ]} + > + + + + + + remove(name)} + style={{ color: "#ef4444" }} + /> + + ))} + + + + + )} + + + ), + }, + ]} + /> + + + + ); +} diff --git a/ui/litellm-dashboard/src/components/Projects/ProjectModals/projectFormUtils.ts b/ui/litellm-dashboard/src/components/Projects/ProjectModals/projectFormUtils.ts new file mode 100644 index 00000000000..97c093b57d9 --- /dev/null +++ b/ui/litellm-dashboard/src/components/Projects/ProjectModals/projectFormUtils.ts @@ -0,0 +1,36 @@ +import { ProjectFormValues } from "./ProjectBaseForm"; + +/** + * Transforms ProjectFormValues into the flat API param shape + * shared by both create and update endpoints. + */ +export function buildProjectApiParams(values: ProjectFormValues) { + const modelRpmLimit: Record = {}; + const modelTpmLimit: Record = {}; + for (const entry of values.modelLimits ?? []) { + if (entry.model) { + if (entry.rpm != null) modelRpmLimit[entry.model] = entry.rpm; + if (entry.tpm != null) modelTpmLimit[entry.model] = entry.tpm; + } + } + + const metadata: Record = {}; + for (const entry of values.metadata ?? []) { + if (entry.key) metadata[entry.key] = entry.value; + } + + return { + project_alias: values.project_alias, + description: values.description, + models: values.models ?? [], + max_budget: values.max_budget, + blocked: values.isBlocked ?? false, + ...(Object.keys(modelRpmLimit).length > 0 && { + model_rpm_limit: modelRpmLimit, + }), + ...(Object.keys(modelTpmLimit).length > 0 && { + model_tpm_limit: modelTpmLimit, + }), + ...(Object.keys(metadata).length > 0 && { metadata }), + }; +} diff --git a/ui/litellm-dashboard/src/components/Projects/ProjectsPage.tsx b/ui/litellm-dashboard/src/components/Projects/ProjectsPage.tsx index 40ab5045703..f0b593c2e49 100644 --- a/ui/litellm-dashboard/src/components/Projects/ProjectsPage.tsx +++ b/ui/litellm-dashboard/src/components/Projects/ProjectsPage.tsx @@ -1,13 +1,15 @@ import { useProjects, ProjectResponse } from "@/app/(dashboard)/hooks/projects/useProjects"; import { useTeams } from "@/app/(dashboard)/hooks/teams/useTeams"; -import { PlusOutlined } from "@ant-design/icons"; +import { LoadingOutlined, PlusOutlined } from "@ant-design/icons"; import { Button, Card, Flex, Input, Layout, + Pagination, Space, + Spin, Table, Tag, theme, @@ -18,6 +20,7 @@ import type { ColumnsType } from "antd/es/table"; import { LayersIcon, SearchIcon } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { CreateProjectModal } from "./ProjectModals/CreateProjectModal"; +import { ProjectDetail } from "./ProjectDetailsPage"; const { Title, Text } = Typography; const { Content } = Layout; @@ -25,8 +28,9 @@ const { Content } = Layout; export function ProjectsPage() { const { token } = theme.useToken(); const { data: projects, isLoading } = useProjects(); - const { data: teams } = useTeams(); + const { data: teams, isLoading: isTeamsLoading } = useTeams(); + const [selectedProjectId, setSelectedProjectId] = useState(null); const [isCreateModalVisible, setIsCreateModalVisible] = useState(false); const [searchText, setSearchText] = useState(""); const [currentPage, setCurrentPage] = useState(1); @@ -74,6 +78,7 @@ export function ProjectsPage() { ellipsis className="text-blue-500 bg-blue-50 hover:bg-blue-100 text-xs cursor-pointer" style={{ fontSize: 14, padding: "1px 8px" }} + onClick={() => setSelectedProjectId(id)} > {id} @@ -96,8 +101,11 @@ export function ProjectsPage() { return aAlias.localeCompare(bAlias); }, render: (_: unknown, record: ProjectResponse) => { - const alias = teamAliasMap.get(record.team_id ?? ""); - return alias ?? record.team_id ?? "—"; + if (!record.team_id) return "—"; + const alias = teamAliasMap.get(record.team_id); + if (alias) return alias; + if (isTeamsLoading) return } size="small" />; + return record.team_id; }, }, { @@ -144,6 +152,15 @@ export function ProjectsPage() { }, ]; + if (selectedProjectId) { + return ( + setSelectedProjectId(null)} + /> + ); + } + return ( setSearchText(e.target.value)} allowClear /> + setCurrentPage(page)} + size="small" + showTotal={(total) => `${total} projects`} + showSizeChanger={false} + /> setCurrentPage(page), - size: "small", - showTotal: (total) => `${total} projects`, - showSizeChanger: false, - }} + pagination={false} /> diff --git a/ui/litellm-dashboard/src/components/agents.test.tsx b/ui/litellm-dashboard/src/components/agents.test.tsx new file mode 100644 index 00000000000..2d4c879dec0 --- /dev/null +++ b/ui/litellm-dashboard/src/components/agents.test.tsx @@ -0,0 +1,71 @@ +import React from "react"; +import { render, screen, waitFor } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import AgentsPanel from "./agents"; + +vi.mock("./networking", () => ({ + getAgentsList: vi.fn().mockResolvedValue({ agents: [] }), + deleteAgentCall: vi.fn(), + keyListCall: vi.fn().mockResolvedValue({ keys: [] }), +})); + +vi.mock("./agents/add_agent_form", () => ({ + default: () =>
, +})); + +vi.mock("./agents/agent_card_grid", () => ({ + default: ({ isAdmin }: { isAdmin: boolean }) => ( +
+ ), +})); + +vi.mock("./agents/agent_info", () => ({ + default: () =>
, +})); + +describe("AgentsPanel", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should render the Agents panel title", async () => { + render(); + expect(screen.getByText("Agents")).toBeInTheDocument(); + }); + + it("should show Add New Agent button for admin users", async () => { + render(); + expect(screen.getByText("+ Add New Agent")).toBeInTheDocument(); + }); + + it("should show Add New Agent button for proxy_admin users", async () => { + render(); + expect(screen.getByText("+ Add New Agent")).toBeInTheDocument(); + }); + + it("should not show Add New Agent button for internal_user role", async () => { + render(); + expect(screen.queryByText("+ Add New Agent")).not.toBeInTheDocument(); + }); + + it("should not show Add New Agent button for internal_user_viewer role", async () => { + render(); + expect(screen.queryByText("+ Add New Agent")).not.toBeInTheDocument(); + }); + + it("should pass isAdmin=true to AgentCardGrid for admin role", async () => { + render(); + await waitFor(() => { + const grid = screen.getByTestId("agent-card-grid"); + expect(grid).toHaveAttribute("data-is-admin", "true"); + }); + }); + + it("should pass isAdmin=false to AgentCardGrid for internal user role", async () => { + render(); + await waitFor(() => { + const grid = screen.getByTestId("agent-card-grid"); + expect(grid).toHaveAttribute("data-is-admin", "false"); + }); + }); +}); diff --git a/ui/litellm-dashboard/src/components/agents.tsx b/ui/litellm-dashboard/src/components/agents.tsx index 7ab6a5f0381..8dd9bc7d01c 100644 --- a/ui/litellm-dashboard/src/components/agents.tsx +++ b/ui/litellm-dashboard/src/components/agents.tsx @@ -141,11 +141,13 @@ const AgentsPanel: React.FC = ({ accessToken, userRole }) => { showIcon className="mb-3" /> -
- -
+ {isAdmin && ( +
+ +
+ )}
{selectedAgentId ? ( diff --git a/ui/litellm-dashboard/src/components/agents/agent_card_grid.tsx b/ui/litellm-dashboard/src/components/agents/agent_card_grid.tsx index 0ba7902f8b2..5e984cf220d 100644 --- a/ui/litellm-dashboard/src/components/agents/agent_card_grid.tsx +++ b/ui/litellm-dashboard/src/components/agents/agent_card_grid.tsx @@ -37,7 +37,11 @@ const AgentCardGrid: React.FC = ({ if (!agentsList || agentsList.length === 0) { return (
-

No agents found. Create one to get started.

+

+ {isAdmin + ? "No agents found. Create one to get started." + : "No agents found. Contact an admin to create agents."} +

); } diff --git a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx index 2adfe52f6f8..89aa756cb27 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx @@ -46,7 +46,7 @@ const MCPPermissionManagement: React.FC = ({ } } else { form.setFieldValue("allow_all_keys", false); - form.setFieldValue("available_on_public_internet", false); + form.setFieldValue("available_on_public_internet", true); } }, [mcpServer, form]); @@ -64,6 +64,7 @@ const MCPPermissionManagement: React.FC = ({ } key="permissions" className="border-0" + forceRender >
@@ -89,17 +90,19 @@ const MCPPermissionManagement: React.FC = ({
- Available on Public Internet - + Internal network only + -

Enable if this server should be reachable from the public internet.

+

Turn on to restrict access to callers within your internal network only.

({ checked: !value })} + getValueFromEvent={(checked: boolean) => !checked} + initialValue={true} className="mb-0" > diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_columns.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_columns.tsx index d49949e7e04..c42a8593cf1 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_columns.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_columns.tsx @@ -157,9 +157,9 @@ export const mcpServerColumns = ( cell: ({ row }) => { const isPublic = row.original.available_on_public_internet; return isPublic ? ( - Public + All networks ) : ( - Internal + Internal only ); }, }, diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx index cb7768ab21a..4f8ba4e1307 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx @@ -251,20 +251,20 @@ export const MCPServerView: React.FC = ({
- Available on Public Internet + Network Access
{mcpServer.available_on_public_internet ? ( - Public + All networks ) : ( - - Internal + + Internal only )} - {mcpServer.available_on_public_internet && ( + {!mcpServer.available_on_public_internet && ( - Accessible from external/public IPs + Restricted to internal network )}