diff --git a/CLAUDE.md b/CLAUDE.md index d9061b5e2be..f0478120181 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -140,6 +140,11 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components: - **Check index coverage.** For new or modified queries, check `schema.prisma` for a supporting index. Prefer extending an existing index (e.g. `@@index([a])` → `@@index([a, b])`) over adding a new one, unless it's a `@@unique`. Only add indexes for large/frequent queries. - **Keep schema files in sync.** Apply schema changes to all `schema.prisma` copies (`schema.prisma`, `litellm/proxy/`, `litellm-proxy-extras/`, `litellm-js/spend-logs/` for SpendLogs) with a migration under `litellm-proxy-extras/litellm_proxy_extras/migrations/`. +### Setup Wizard (`litellm/setup_wizard.py`) +- The wizard is implemented as a single `SetupWizard` class with `@staticmethod` methods — keep it that way. No module-level functions except `run_setup_wizard()` (the public entrypoint) and pure helpers (color, ANSI). +- Use `litellm.utils.check_valid_key(model, api_key)` for credential validation — never roll a custom completion call. +- Do not hardcode provider env-key names or model lists that already exist in the codebase. Add a `test_model` field to each provider entry to drive `check_valid_key`; set it to `None` for providers that can't be validated with a single API key (Azure, Bedrock, Ollama). + ### Enterprise Features - Enterprise-specific code in `enterprise/` directory - Optional features enabled via environment variables diff --git a/docs/my-website/docs/proxy/docker_quick_start.md b/docs/my-website/docs/proxy/docker_quick_start.md index 4975c76c61f..58a56604751 100644 --- a/docs/my-website/docs/proxy/docker_quick_start.md +++ b/docs/my-website/docs/proxy/docker_quick_start.md @@ -5,11 +5,76 @@ import Image from '@theme/IdealImage'; # Getting Started Tutorial End-to-End tutorial for LiteLLM Proxy to: -- Add an Azure OpenAI model -- Make a successful /chat/completion call -- Generate a virtual key -- Set RPM limit on virtual key +- Add an Azure OpenAI model +- Make a successful /chat/completion call +- Generate a virtual key +- Set RPM limit on virtual key +## Quick Install (Recommended for local / beginners) + +New to LiteLLM? This is the easiest way to get started locally. One command installs LiteLLM and walks you through setup interactively — no config files to write by hand. + +### 1. Install + +```bash +curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh +``` + +This detects your OS, installs `litellm[proxy]`, and drops you straight into the setup wizard. + +### 2. Follow the wizard + +``` +$ litellm --setup + + Welcome to LiteLLM + + Choose your LLM providers + ○ 1. OpenAI GPT-4o, GPT-4o-mini, o1 + ○ 2. Anthropic Claude Opus, Sonnet, Haiku + ○ 3. Azure OpenAI GPT-4o via Azure + ○ 4. Google Gemini Gemini 2.0 Flash, 1.5 Pro + ○ 5. AWS Bedrock Claude, Llama via AWS + ○ 6. Ollama Local models + + ❯ Provider(s): 1,2 + + ❯ OpenAI API key: sk-... + ❯ Anthropic API key: sk-ant-... + + ❯ Port [4000]: + ❯ Master key [auto-generate]: + + ✔ Config saved → ./litellm_config.yaml + + ❯ Start the proxy now? (Y/n): +``` + +The wizard walks you through: +1. Pick your LLM providers (OpenAI, Anthropic, Azure, Bedrock, Gemini, Ollama) +2. Enter API keys for each provider +3. Set a port and master key (or accept the defaults) +4. Config is saved to `./litellm_config.yaml` and the proxy starts immediately + +### 3. Make a call + +Your proxy is running on `http://0.0.0.0:4000`. Test it: + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer ' \ +-d '{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello!"}] +}' +``` + +:::tip Already have pip installed? +You can skip the curl install and run `litellm --setup` directly after `pip install 'litellm[proxy]'`. +::: + +--- ## Pre-Requisites diff --git a/docs/my-website/img/mcp_zero_trust_gateway.png b/docs/my-website/img/mcp_zero_trust_gateway.png new file mode 100644 index 00000000000..3955cef0553 Binary files /dev/null and b/docs/my-website/img/mcp_zero_trust_gateway.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 071f80de020..7db72da276a 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -672,6 +672,7 @@ const sidebars = { "mcp_control", "mcp_cost", "mcp_guardrail", + "mcp_zero_trust", "mcp_troubleshoot", ] }, diff --git a/litellm/__init__.py b/litellm/__init__.py index 51c66838613..7f72e0b0e89 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1465,9 +1465,15 @@ def set_global_gitlab_config(config: Dict[str, Any]) -> None: from .llms.petals.completion.transformation import PetalsConfig as PetalsConfig from .llms.ollama.chat.transformation import OllamaChatConfig as OllamaChatConfig from .llms.ollama.completion.transformation import OllamaConfig as OllamaConfig - from .llms.sagemaker.completion.transformation import SagemakerConfig as SagemakerConfig - from .llms.sagemaker.chat.transformation import SagemakerChatConfig as SagemakerChatConfig - from .llms.sagemaker.nova.transformation import SagemakerNovaConfig as SagemakerNovaConfig + from .llms.sagemaker.completion.transformation import ( + SagemakerConfig as SagemakerConfig, + ) + from .llms.sagemaker.chat.transformation import ( + SagemakerChatConfig as SagemakerChatConfig, + ) + from .llms.sagemaker.nova.transformation import ( + SagemakerNovaConfig as SagemakerNovaConfig, + ) from .llms.cohere.chat.transformation import CohereChatConfig as CohereChatConfig from .llms.anthropic.experimental_pass_through.messages.transformation import ( AnthropicMessagesConfig as AnthropicMessagesConfig, diff --git a/litellm/_logging.py b/litellm/_logging.py index 5de9fbb3558..18c3bcb7e87 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -17,7 +17,9 @@ "`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs." ) -_ENABLE_SECRET_REDACTION = os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true" +_ENABLE_SECRET_REDACTION = ( + os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true" +) _REDACTED = "REDACTED" @@ -199,7 +201,9 @@ def format(self, record): json_record[key] = value if record.exc_info: - json_record["stacktrace"] = record.exc_text or self.formatException(record.exc_info) + json_record["stacktrace"] = record.exc_text or self.formatException( + record.exc_info + ) return safe_dumps(json_record) diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index e0e1e35b94e..ee3c344169b 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1189,7 +1189,9 @@ def completion_cost( # noqa: PLR0915 and _usage["prompt_tokens_details"] != {} and _usage["prompt_tokens_details"] ): - prompt_tokens_details = _usage.get("prompt_tokens_details") or {} + prompt_tokens_details = ( + _usage.get("prompt_tokens_details") or {} + ) cache_read_input_tokens = prompt_tokens_details.get( "cached_tokens", 0 ) @@ -1515,7 +1517,9 @@ def completion_cost( # noqa: PLR0915 if custom_llm_provider == "azure_ai": model_for_additional_costs = request_model_for_cost if completion_response is not None: - hidden_params = getattr(completion_response, "_hidden_params", None) or {} + hidden_params = ( + getattr(completion_response, "_hidden_params", None) or {} + ) hidden_model = hidden_params.get("model") or hidden_params.get( "litellm_model_name" ) diff --git a/litellm/integrations/focus/destinations/factory.py b/litellm/integrations/focus/destinations/factory.py index 01ea6ca9cb4..706e10624ce 100644 --- a/litellm/integrations/focus/destinations/factory.py +++ b/litellm/integrations/focus/destinations/factory.py @@ -59,17 +59,14 @@ def _resolve_config( return {k: v for k, v in resolved.items() if v is not None} if provider == "vantage": resolved = { - "api_key": overrides.get("api_key") - or os.getenv("VANTAGE_API_KEY"), + "api_key": overrides.get("api_key") or os.getenv("VANTAGE_API_KEY"), "integration_token": overrides.get("integration_token") or os.getenv("VANTAGE_INTEGRATION_TOKEN"), "base_url": overrides.get("base_url") or os.getenv("VANTAGE_BASE_URL", "https://api.vantage.sh"), } if not resolved.get("api_key"): - raise ValueError( - "VANTAGE_API_KEY must be provided for Vantage exports" - ) + raise ValueError("VANTAGE_API_KEY must be provided for Vantage exports") if not resolved.get("integration_token"): raise ValueError( "VANTAGE_INTEGRATION_TOKEN must be provided for Vantage exports" diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index 03a93cd988e..bea027aa63d 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -340,9 +340,9 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti ) status_message = str(kwargs.get("exception", "Unknown error")) if standard_logging_object is not None: - status_message = standard_logging_object.get( - "error_str", None - ) or status_message + status_message = ( + standard_logging_object.get("error_str", None) or status_message + ) langfuse_logger_to_use.log_event_on_langfuse( start_time=start_time, end_time=end_time, diff --git a/litellm/integrations/vantage/vantage_logger.py b/litellm/integrations/vantage/vantage_logger.py index 6689d932749..e0942472bec 100644 --- a/litellm/integrations/vantage/vantage_logger.py +++ b/litellm/integrations/vantage/vantage_logger.py @@ -83,7 +83,9 @@ def __init__( verbose_logger.debug( "VantageLogger initialized (integration_token=%s)", - resolved_token[:4] + "***" if resolved_token and len(resolved_token) > 4 else "***", + resolved_token[:4] + "***" + if resolved_token and len(resolved_token) > 4 + else "***", ) async def initialize_focus_export_job(self) -> None: @@ -128,9 +130,7 @@ async def init_vantage_background_job( callback_type=VantageLogger ) if not vantage_loggers: - verbose_logger.debug( - "No Vantage logger registered; skipping scheduler" - ) + verbose_logger.debug("No Vantage logger registered; skipping scheduler") return vantage_logger = cast(VantageLogger, vantage_loggers[0]) diff --git a/litellm/litellm_core_utils/default_encoding.py b/litellm/litellm_core_utils/default_encoding.py index 24533feeccc..f704ba568de 100644 --- a/litellm/litellm_core_utils/default_encoding.py +++ b/litellm/litellm_core_utils/default_encoding.py @@ -26,7 +26,9 @@ else: cache_dir = filename -os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 +os.environ[ + "TIKTOKEN_CACHE_DIR" +] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 import tiktoken import time @@ -48,4 +50,3 @@ # Exponential backoff with jitter to reduce collision probability delay = _retry_delay * (2**attempt) + random.uniform(0, 0.1) time.sleep(delay) - diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 01565b99478..826396a70d8 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -352,9 +352,9 @@ def __init__( ) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[Any] = ( - [] - ) # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -782,9 +782,9 @@ def _auto_detect_prompt_management_logger( prompt_spec=prompt_spec, dynamic_callback_params=dynamic_callback_params, ): - self.model_call_details["prompt_integration"] = ( - logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = logger.__class__.__name__ return logger except Exception: # If check fails, continue to next logger @@ -852,9 +852,9 @@ def get_custom_logger_for_prompt_management( if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( non_default_params ): - self.model_call_details["prompt_integration"] = ( - anthropic_cache_control_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = anthropic_cache_control_logger.__class__.__name__ return anthropic_cache_control_logger ######################################################### @@ -866,9 +866,9 @@ def get_custom_logger_for_prompt_management( internal_usage_cache=None, llm_router=None, ) - self.model_call_details["prompt_integration"] = ( - vector_store_custom_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = vector_store_custom_logger.__class__.__name__ # Add to global callbacks so post-call hooks are invoked if ( vector_store_custom_logger @@ -928,9 +928,9 @@ def _pre_call(self, input, api_key, model=None, additional_args={}): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"]["api_base"] = ( - self._get_masked_api_base(additional_args.get("api_base", "")) - ) + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -959,10 +959,10 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata["raw_request"] = ( - "redacted by litellm. \ + _metadata[ + "raw_request" + ] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" - ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -973,34 +973,34 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - # NOTE: setting ignore_sensitive_headers to True will cause - # the Authorization header to be leaked when calls to the health - # endpoint are made and fail. - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ), - error=None, - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + # NOTE: setting ignore_sensitive_headers to True will cause + # the Authorization header to be leaked when calls to the health + # endpoint are made and fail. + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ), + error=None, ) except Exception as e: - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - error=str(e), - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), ) - _metadata["raw_request"] = ( - "Unable to Log \ + _metadata[ + "raw_request" + ] = "Unable to Log \ raw request: {}".format( - str(e) - ) + str(e) ) if getattr(self, "logger_fn", None) and callable(self.logger_fn): try: @@ -1301,13 +1301,13 @@ async def async_post_mcp_tool_call_hook( for callback in callbacks: try: if isinstance(callback, CustomLogger): - response: Optional[MCPPostCallResponseObject] = ( - await callback.async_post_mcp_tool_call_hook( - kwargs=kwargs, - response_obj=post_mcp_tool_call_response_obj, - start_time=start_time, - end_time=end_time, - ) + response: Optional[ + MCPPostCallResponseObject + ] = await callback.async_post_mcp_tool_call_hook( + kwargs=kwargs, + response_obj=post_mcp_tool_call_response_obj, + start_time=start_time, + end_time=end_time, ) ###################################################################### # if any of the callbacks modify the response, use the modified response @@ -1502,9 +1502,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None try: @@ -1530,9 +1530,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None @@ -1688,9 +1688,9 @@ def _process_hidden_params_and_response_cost( result=logging_result ) - self.model_call_details["standard_logging_object"] = ( - self._build_standard_logging_payload(logging_result, start_time, end_time) - ) + self.model_call_details[ + "standard_logging_object" + ] = self._build_standard_logging_payload(logging_result, start_time, end_time) if ( standard_logging_payload := self.model_call_details.get( @@ -1768,9 +1768,9 @@ def _success_handler_helper_fn( end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details["completion_start_time"] = ( - self.completion_start_time - ) + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time @@ -1807,10 +1807,10 @@ def _success_handler_helper_fn( end_time=end_time, ) elif isinstance(result, dict) or isinstance(result, list): - self.model_call_details["standard_logging_object"] = ( - self._build_standard_logging_payload( - result, start_time, end_time - ) + self.model_call_details[ + "standard_logging_object" + ] = self._build_standard_logging_payload( + result, start_time, end_time ) if ( standard_logging_payload := self.model_call_details.get( @@ -1819,9 +1819,9 @@ def _success_handler_helper_fn( ) is not None: emit_standard_logging_payload(standard_logging_payload) elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object - ) + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object else: self.model_call_details["response_cost"] = None @@ -1979,17 +1979,17 @@ def success_handler( # noqa: PLR0915 verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details["complete_streaming_response"] = ( - complete_streaming_response - ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time - ) + self.model_call_details[ + "standard_logging_object" + ] = self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time ) if ( standard_logging_payload := self.model_call_details.get( @@ -2323,10 +2323,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -2350,10 +2350,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] @@ -2492,9 +2492,9 @@ async def async_success_handler( # noqa: PLR0915 if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["async_complete_streaming_response"] = ( - complete_streaming_response - ) + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response try: if self.model_call_details.get("cache_hit", False) is True: @@ -2505,10 +2505,10 @@ async def async_success_handler( # noqa: PLR0915 model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details["response_cost"] = ( - self._response_cost_calculator( - result=complete_streaming_response - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response ) verbose_logger.debug( @@ -2521,10 +2521,10 @@ async def async_success_handler( # noqa: PLR0915 self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time - ) + self.model_call_details[ + "standard_logging_object" + ] = self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time ) # print standard logging payload @@ -2551,9 +2551,9 @@ async def async_success_handler( # noqa: PLR0915 # _success_handler_helper_fn if self.model_call_details.get("standard_logging_object") is None: ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - self._build_standard_logging_payload(result, start_time, end_time) - ) + self.model_call_details[ + "standard_logging_object" + ] = self._build_standard_logging_payload(result, start_time, end_time) # print standard logging payload if ( @@ -2796,18 +2796,18 @@ def _failure_handler_helper_fn( ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, ) return start_time, end_time @@ -3774,9 +3774,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 service_name=arize_config.project_name, ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -3802,13 +3802,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" - ) + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" else: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"openinference.project.name={arize_phoenix_config.project_name}" - ) + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"openinference.project.name={arize_phoenix_config.project_name}" # Set Phoenix project name from environment variable phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) @@ -3816,19 +3816,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"{existing_attrs},openinference.project.name={phoenix_project_name}" - ) + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"{existing_attrs},openinference.project.name={phoenix_project_name}" else: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"openinference.project.name={phoenix_project_name}" - ) + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"openinference.project.name={phoenix_project_name}" # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - arize_phoenix_config.otlp_auth_headers - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers for callback in _in_memory_loggers: if ( @@ -3907,7 +3907,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 from litellm.integrations.focus.focus_logger import FocusLogger for callback in _in_memory_loggers: - if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger + if ( + type(callback) is FocusLogger + ): # exact match; exclude subclasses like VantageLogger return callback # type: ignore focus_logger = FocusLogger() _in_memory_loggers.append(focus_logger) @@ -4013,9 +4015,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"api_key={os.getenv('LANGTRACE_API_KEY')}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -4289,7 +4291,9 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 from litellm.integrations.focus.focus_logger import FocusLogger for callback in _in_memory_loggers: - if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger + if ( + type(callback) is FocusLogger + ): # exact match; exclude subclasses like VantageLogger return callback elif logging_integration == "vantage": from litellm.integrations.vantage.vantage_logger import VantageLogger @@ -4937,10 +4941,10 @@ def get_hidden_params( for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params["additional_headers"] = ( - StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -5579,9 +5583,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[k] = ( - "scrubbed_by_litellm_for_sensitive_keys" - ) + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 2b838ad1f80..f6004616712 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -2442,7 +2442,9 @@ def anthropic_messages_pt( # noqa: PLR0915 _document_content_element = cast( AnthropicMessagesDocumentParam, add_cache_control_to_content( - anthropic_content_element=cast(AnthropicMessagesDocumentParam, m), + anthropic_content_element=cast( + AnthropicMessagesDocumentParam, m + ), original_content_element=dict(m), ), ) @@ -2454,10 +2456,18 @@ def anthropic_messages_pt( # noqa: PLR0915 ) ) _file_content_element = add_cache_control_to_content( - anthropic_content_element=cast(AnthropicMessagesDocumentParam, _file_content_element), + anthropic_content_element=cast( + AnthropicMessagesDocumentParam, + _file_content_element, + ), original_content_element=dict(m), ) - user_content.append(cast(AnthropicMessagesDocumentParam,_file_content_element)) + user_content.append( + cast( + AnthropicMessagesDocumentParam, + _file_content_element, + ) + ) elif isinstance(user_message_types_block["content"], str): _anthropic_content_text_element: AnthropicMessagesTextParam = { "type": "text", diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index c1a6bd67501..3fda05172b6 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -780,7 +780,7 @@ def translate_anthropic_tools_to_openai( # Keep Anthropic-native tools in their original format new_tools.append(tool) # type: ignore[arg-type] continue - + original_name = tool["name"] truncated_name = truncate_tool_name(original_name) diff --git a/litellm/llms/base_llm/videos/transformation.py b/litellm/llms/base_llm/videos/transformation.py index a2892e20601..87289ad6a0c 100644 --- a/litellm/llms/base_llm/videos/transformation.py +++ b/litellm/llms/base_llm/videos/transformation.py @@ -336,9 +336,7 @@ def transform_video_edit_request( Returns: Tuple[str, Dict]: (url, data) for the POST request """ - raise NotImplementedError( - "video edit is not supported for this provider" - ) + raise NotImplementedError("video edit is not supported for this provider") def transform_video_edit_response( self, @@ -346,9 +344,7 @@ def transform_video_edit_response( logging_obj: LiteLLMLoggingObj, custom_llm_provider: Optional[str] = None, ) -> VideoObject: - raise NotImplementedError( - "video edit is not supported for this provider" - ) + raise NotImplementedError("video edit is not supported for this provider") def transform_video_extension_request( self, @@ -366,9 +362,7 @@ def transform_video_extension_request( Returns: Tuple[str, Dict]: (url, data) for the POST request """ - raise NotImplementedError( - "video extension is not supported for this provider" - ) + raise NotImplementedError("video extension is not supported for this provider") def transform_video_extension_response( self, @@ -376,9 +370,7 @@ def transform_video_extension_response( logging_obj: LiteLLMLoggingObj, custom_llm_provider: Optional[str] = None, ) -> VideoObject: - raise NotImplementedError( - "video extension is not supported for this provider" - ) + raise NotImplementedError("video extension is not supported for this provider") def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 204fa4d0cca..4c9abaad908 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -6162,7 +6162,10 @@ def video_create_character_handler( litellm_params=dict(litellm_params), ) - url, files_list = video_provider_config.transform_video_create_character_request( + ( + url, + files_list, + ) = video_provider_config.transform_video_create_character_request( name=name, video=video, api_base=api_base, @@ -6230,7 +6233,10 @@ async def async_video_create_character_handler( litellm_params=dict(litellm_params), ) - url, files_list = video_provider_config.transform_video_create_character_request( + ( + url, + files_list, + ) = video_provider_config.transform_video_create_character_request( name=name, video=video, api_base=api_base, @@ -6324,11 +6330,7 @@ def video_get_character_handler( ) try: - response = sync_httpx_client.get( - url=url, - headers=headers, - params=params - ) + response = sync_httpx_client.get(url=url, headers=headers, params=params) response.raise_for_status() return video_provider_config.transform_video_get_character_response( raw_response=response, @@ -6386,9 +6388,7 @@ async def async_video_get_character_handler( try: response = await async_httpx_client.get( - url=url, - headers=headers, - params=params + url=url, headers=headers, params=params ) response.raise_for_status() return video_provider_config.transform_video_get_character_response( diff --git a/litellm/llms/gemini/videos/transformation.py b/litellm/llms/gemini/videos/transformation.py index 0798472310e..122cc954836 100644 --- a/litellm/llms/gemini/videos/transformation.py +++ b/litellm/llms/gemini/videos/transformation.py @@ -525,28 +525,47 @@ def transform_video_delete_response( """Video delete is not supported.""" raise NotImplementedError("Video delete is not supported by Google Veo.") - def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers): + def transform_video_create_character_request( + self, name, video, api_base, litellm_params, headers + ): raise NotImplementedError("video create character is not supported for Gemini") def transform_video_create_character_response(self, raw_response, logging_obj): raise NotImplementedError("video create character is not supported for Gemini") - def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers): + def transform_video_get_character_request( + self, character_id, api_base, litellm_params, headers + ): raise NotImplementedError("video get character is not supported for Gemini") def transform_video_get_character_response(self, raw_response, logging_obj): raise NotImplementedError("video get character is not supported for Gemini") - def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None): + def transform_video_edit_request( + self, prompt, video_id, api_base, litellm_params, headers, extra_body=None + ): raise NotImplementedError("video edit is not supported for Gemini") - def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_edit_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video edit is not supported for Gemini") - def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None): + def transform_video_extension_request( + self, + prompt, + video_id, + seconds, + api_base, + litellm_params, + headers, + extra_body=None, + ): raise NotImplementedError("video extension is not supported for Gemini") - def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_extension_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video extension is not supported for Gemini") def get_error_class( diff --git a/litellm/llms/moonshot/chat/transformation.py b/litellm/llms/moonshot/chat/transformation.py index 40096be05c9..24f852c28ba 100644 --- a/litellm/llms/moonshot/chat/transformation.py +++ b/litellm/llms/moonshot/chat/transformation.py @@ -19,7 +19,8 @@ class MoonshotChatConfig(OpenAIGPTConfig): @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: + ... @overload def _transform_messages( @@ -27,7 +28,8 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: ... + ) -> List[AllMessageValues]: + ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False @@ -53,9 +55,13 @@ def _transform_messages( messages = handle_messages_with_content_list_to_str_conversion(messages) if is_async: - return super()._transform_messages(messages=messages, model=model, is_async=True) + return super()._transform_messages( + messages=messages, model=model, is_async=True + ) else: - return super()._transform_messages(messages=messages, model=model, is_async=False) + return super()._transform_messages( + messages=messages, model=model, is_async=False + ) def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str] @@ -141,7 +147,9 @@ def map_openai_params( optional_params["temperature"] = 0.3 return optional_params - def fill_reasoning_content(self, messages: List[AllMessageValues]) -> List[AllMessageValues]: + def fill_reasoning_content( + self, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: """ Moonshot reasoning models require `reasoning_content` on every assistant message that contains tool_calls (multi-turn tool-calling flows). diff --git a/litellm/llms/runwayml/videos/transformation.py b/litellm/llms/runwayml/videos/transformation.py index 2c29c2e21ee..8377dea952e 100644 --- a/litellm/llms/runwayml/videos/transformation.py +++ b/litellm/llms/runwayml/videos/transformation.py @@ -592,28 +592,51 @@ def transform_video_status_retrieve_response( return video_obj - def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers): - raise NotImplementedError("video create character is not supported for RunwayML") + def transform_video_create_character_request( + self, name, video, api_base, litellm_params, headers + ): + raise NotImplementedError( + "video create character is not supported for RunwayML" + ) def transform_video_create_character_response(self, raw_response, logging_obj): - raise NotImplementedError("video create character is not supported for RunwayML") + raise NotImplementedError( + "video create character is not supported for RunwayML" + ) - def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers): + def transform_video_get_character_request( + self, character_id, api_base, litellm_params, headers + ): raise NotImplementedError("video get character is not supported for RunwayML") def transform_video_get_character_response(self, raw_response, logging_obj): raise NotImplementedError("video get character is not supported for RunwayML") - def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None): + def transform_video_edit_request( + self, prompt, video_id, api_base, litellm_params, headers, extra_body=None + ): raise NotImplementedError("video edit is not supported for RunwayML") - def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_edit_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video edit is not supported for RunwayML") - def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None): + def transform_video_extension_request( + self, + prompt, + video_id, + seconds, + api_base, + litellm_params, + headers, + extra_body=None, + ): raise NotImplementedError("video extension is not supported for RunwayML") - def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_extension_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video extension is not supported for RunwayML") def get_error_class( diff --git a/litellm/llms/sagemaker/chat/transformation.py b/litellm/llms/sagemaker/chat/transformation.py index 60e85c9f93b..3e42c1e8c15 100644 --- a/litellm/llms/sagemaker/chat/transformation.py +++ b/litellm/llms/sagemaker/chat/transformation.py @@ -184,9 +184,7 @@ async def get_async_custom_stream_wrapper( llm_provider = LlmProviders(custom_llm_provider) except ValueError: llm_provider = LlmProviders.SAGEMAKER_CHAT - client = get_async_httpx_client( - llm_provider=llm_provider, params={} - ) + client = get_async_httpx_client(llm_provider=llm_provider, params={}) try: response = await client.post( diff --git a/litellm/llms/vertex_ai/batches/transformation.py b/litellm/llms/vertex_ai/batches/transformation.py index 86bdc2c7b5f..c1144654908 100644 --- a/litellm/llms/vertex_ai/batches/transformation.py +++ b/litellm/llms/vertex_ai/batches/transformation.py @@ -142,8 +142,8 @@ def _get_output_file_id_from_vertex_ai_batch_response( Gets the output file id from the Vertex AI Batch response """ - output_file_id: str = ( - response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "") + output_file_id: str = response.get("outputInfo", OutputInfo()).get( + "gcsOutputDirectory", "" ) if output_file_id: output_file_id = output_file_id.rstrip("/") + "/predictions.jsonl" diff --git a/litellm/llms/vertex_ai/videos/transformation.py b/litellm/llms/vertex_ai/videos/transformation.py index 07b3d6faf70..1c24d657c16 100644 --- a/litellm/llms/vertex_ai/videos/transformation.py +++ b/litellm/llms/vertex_ai/videos/transformation.py @@ -624,28 +624,51 @@ def transform_video_delete_response( """Video delete is not supported.""" raise NotImplementedError("Video delete is not supported by Vertex AI Veo.") - def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers): - raise NotImplementedError("video create character is not supported for Vertex AI") + def transform_video_create_character_request( + self, name, video, api_base, litellm_params, headers + ): + raise NotImplementedError( + "video create character is not supported for Vertex AI" + ) def transform_video_create_character_response(self, raw_response, logging_obj): - raise NotImplementedError("video create character is not supported for Vertex AI") + raise NotImplementedError( + "video create character is not supported for Vertex AI" + ) - def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers): + def transform_video_get_character_request( + self, character_id, api_base, litellm_params, headers + ): raise NotImplementedError("video get character is not supported for Vertex AI") def transform_video_get_character_response(self, raw_response, logging_obj): raise NotImplementedError("video get character is not supported for Vertex AI") - def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None): + def transform_video_edit_request( + self, prompt, video_id, api_base, litellm_params, headers, extra_body=None + ): raise NotImplementedError("video edit is not supported for Vertex AI") - def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_edit_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video edit is not supported for Vertex AI") - def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None): + def transform_video_extension_request( + self, + prompt, + video_id, + seconds, + api_base, + litellm_params, + headers, + extra_body=None, + ): raise NotImplementedError("video extension is not supported for Vertex AI") - def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None): + def transform_video_extension_response( + self, raw_response, logging_obj, custom_llm_provider=None + ): raise NotImplementedError("video extension is not supported for Vertex AI") def get_error_class( diff --git a/litellm/main.py b/litellm/main.py index 81319bc432f..112fef44e55 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7533,9 +7533,7 @@ def stream_chunk_builder( # noqa: PLR0915 # the final chunk. all_annotations: list = [] for ac in annotation_chunks: - all_annotations.extend( - ac["choices"][0]["delta"]["annotations"] - ) + all_annotations.extend(ac["choices"][0]["delta"]["annotations"]) response["choices"][0]["message"]["annotations"] = all_annotations audio_chunks = [ diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 6786fc33595..181045809f8 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -32354,6 +32354,53 @@ "supports_vision": true, "supports_web_search": true }, + "xai/grok-4.20-multi-agent-beta-0309": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "xai/grok-4.20-beta-0309-reasoning": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "xai/grok-4.20-beta-0309-non-reasoning": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, "xai/grok-beta": { "input_cost_per_token": 5e-06, "litellm_provider": "xai", diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index af3a715051b..3385e7feef6 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -677,7 +677,60 @@ async def oauth_authorization_server_mcp( # Alias for standard OpenID discovery @router.get("/.well-known/openid-configuration") async def openid_configuration(request: Request): - return await oauth_authorization_server_mcp(request) + response = await oauth_authorization_server_mcp(request) + + # If MCPJWTSigner is active, augment the discovery doc with JWKS fields so + # MCP servers and gateways (e.g. AWS Bedrock AgentCore Gateway) can resolve + # the signing keys and verify liteLLM-issued tokens. + try: + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = get_mcp_jwt_signer() + if signer is not None: + request_base_url = get_request_base_url(request) + if isinstance(response, dict): + response = { + **response, + "jwks_uri": f"{request_base_url}/.well-known/jwks.json", + "id_token_signing_alg_values_supported": ["RS256"], + } + except ImportError: + pass + + return response + + +@router.get("/.well-known/jwks.json") +async def jwks_json(request: Request): + """ + JSON Web Key Set endpoint. + + Returns the RSA public key used by MCPJWTSigner to sign outbound MCP tokens. + MCP servers and gateways use this endpoint to verify liteLLM-issued JWTs. + + Returns an empty key set if MCPJWTSigner is not configured. + """ + try: + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = get_mcp_jwt_signer() + if signer is not None: + return JSONResponse( + content=signer.get_jwks(), + headers={"Cache-Control": f"public, max-age={signer.jwks_max_age}"}, + ) + except ImportError: + pass + + # No signer active — return empty key set; short cache so activation is picked up quickly. + return JSONResponse( + content={"keys": []}, + headers={"Cache-Control": "public, max-age=60"}, + ) # Additional legacy pattern support diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 43fe54fdfb7..1e9d5c5a529 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1908,7 +1908,15 @@ async def pre_call_tool_check( user_api_key_auth: Optional[UserAPIKeyAuth], proxy_logging_obj: ProxyLogging, server: MCPServer, - ): + raw_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Run pre-call checks and guardrail hooks for an MCP tool call. + + Returns a dict that may contain: + - "arguments": hook-modified tool arguments (only if changed) + - "extra_headers": headers injected by pre_mcp_call guardrail hooks + """ ## check if the tool is allowed or banned for the given server if not self.check_allowed_or_banned_tools(name, server): raise HTTPException( @@ -1932,6 +1940,14 @@ async def pre_call_tool_check( server=server, ) + # Extract incoming Bearer token from raw request headers so + # guardrails like MCPJWTSigner can verify + re-sign it (FR-5). + normalized_raw = {k.lower(): v for k, v in (raw_headers or {}).items()} + incoming_bearer_token: Optional[str] = None + auth_hdr = normalized_raw.get("authorization", "") + if auth_hdr.lower().startswith("bearer "): + incoming_bearer_token = auth_hdr[len("bearer ") :] + pre_hook_kwargs = { "name": name, "arguments": arguments, @@ -1957,6 +1973,7 @@ async def pre_call_tool_check( if user_api_key_auth else None ), + "incoming_bearer_token": incoming_bearer_token, } # Create MCP request object for processing @@ -1969,6 +1986,7 @@ async def pre_call_tool_check( mcp_request_obj, pre_hook_kwargs ) + hook_result: Dict[str, Any] = {} try: # Use standard pre_call_hook modified_data = await proxy_logging_obj.pre_call_hook( @@ -1984,7 +2002,9 @@ async def pre_call_tool_check( ) ) if modified_kwargs.get("arguments") != arguments: - arguments = modified_kwargs["arguments"] + hook_result["arguments"] = modified_kwargs["arguments"] + if modified_kwargs.get("extra_headers"): + hook_result["extra_headers"] = modified_kwargs["extra_headers"] except ( BlockedPiiEntityError, @@ -1995,6 +2015,8 @@ async def pre_call_tool_check( verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}") raise e + return hook_result + def _create_during_hook_task( self, name: str, @@ -2047,6 +2069,7 @@ async def _call_regular_mcp_tool( raw_headers: Optional[Dict[str, str]], proxy_logging_obj: Optional[ProxyLogging], host_progress_callback: Optional[Callable] = None, + hook_extra_headers: Optional[Dict[str, str]] = None, ) -> CallToolResult: """ Call a regular MCP tool using the MCP client. @@ -2061,6 +2084,9 @@ async def _call_regular_mcp_tool( oauth2_headers: Optional OAuth2 headers raw_headers: Optional raw headers from the request proxy_logging_obj: Optional ProxyLogging object for hook integration + host_progress_callback: Optional callback for progress updates + hook_extra_headers: Optional headers injected by pre_mcp_call guardrail + hooks. Merged last (highest priority) into outbound request headers. Returns: CallToolResult from the MCP server @@ -2116,6 +2142,31 @@ async def _call_regular_mcp_tool( extra_headers = {} extra_headers.update(mcp_server.static_headers) + if hook_extra_headers: + if extra_headers is None: + extra_headers = {} + if "Authorization" in hook_extra_headers: + if "Authorization" in extra_headers: + verbose_logger.warning( + "MCPServerManager: hook_extra_headers 'Authorization' will overwrite " + "the existing Authorization header from static_headers. " + "The hook JWT will take precedence." + ) + elif server_auth_header is not None: + # server_auth_header is passed separately to _create_mcp_client as + # auth_value. Both will reach the upstream server — warn so admins + # know two Authorization credentials are being sent. + verbose_logger.warning( + "MCPServerManager: hook_extra_headers injects 'Authorization' while " + "server '%s' already has a configured authentication_token. " + "Both credentials will be sent; the hook header is in extra_headers " + "and the server token is in auth_value — the upstream server decides " + "which one wins. Consider unsetting authentication_token if you want " + "the hook JWT to be the sole credential.", + mcp_server.server_name or mcp_server.name, + ) + extra_headers.update(hook_extra_headers) + stdio_env = self._build_stdio_env(mcp_server, raw_headers) client = await self._create_mcp_client( @@ -2201,15 +2252,19 @@ async def call_tool( # Allow validation and modification of tool calls before execution # Using standard pre_call_hook ######################################################### + hook_result: Dict[str, Any] = {} if proxy_logging_obj: - await self.pre_call_tool_check( + hook_result = await self.pre_call_tool_check( name=name, arguments=arguments, server_name=server_name, user_api_key_auth=user_api_key_auth, proxy_logging_obj=proxy_logging_obj, server=mcp_server, + raw_headers=raw_headers, ) + if "arguments" in hook_result: + arguments = hook_result["arguments"] # Prepare tasks for during hooks tasks = [] @@ -2227,8 +2282,16 @@ async def call_tool( # For OpenAPI servers, call the tool handler directly instead of via MCP client if mcp_server.spec_path: verbose_logger.debug( - f"Calling OpenAPI tool {name} directly via HTTP handler" + "Calling OpenAPI tool %s directly via HTTP handler", name ) + if hook_result.get("extra_headers"): + verbose_logger.warning( + "pre_mcp_call hook returned extra_headers for OpenAPI-backed " + "MCP server '%s' — header injection is not supported for " + "OpenAPI servers; headers will be ignored. Use SSE/HTTP " + "transport to enable hook header injection.", + server_name, + ) tasks.append( asyncio.create_task( self._call_openapi_tool_handler(mcp_server, name, arguments) @@ -2247,6 +2310,7 @@ async def call_tool( raw_headers=raw_headers, proxy_logging_obj=proxy_logging_obj, host_progress_callback=host_progress_callback, + hook_extra_headers=hook_result.get("extra_headers"), ) # For OpenAPI tools, await outside the client context diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index ef01f027d6f..c0151d47e04 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -903,12 +903,12 @@ async def _execute_with_mcp_client( try: client_id, client_secret, scopes = _extract_credentials(request) - _oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = ( - request.oauth2_flow or ( - "client_credentials" - if client_id and client_secret and request.token_url - else None - ) + _oauth2_flow: Optional[ + Literal["client_credentials", "authorization_code"] + ] = request.oauth2_flow or ( + "client_credentials" + if client_id and client_secret and request.token_url + else None ) # client_credentials requires token_url to fetch a token; without it the # incoming auth header would be dropped with nothing to replace it. diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ecbd7314cd7..9e86680e355 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2471,6 +2471,9 @@ class UserAPIKeyAuth( Any ] = None # Expanded created_by user when expand=user is used end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + # Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery + # and forwarded into outbound tokens by guardrails such as MCPJWTSigner. + jwt_claims: Optional[Dict] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 0d3c627446b..a03e1fb94c1 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -680,7 +680,7 @@ def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[list]: if customer_headers_mappings: return customer_headers_mappings - + return None @@ -754,15 +754,11 @@ def get_end_user_id_from_request_body( user_id_str = str(header_value) if user_id_str.strip(): return user_id_str - + elif isinstance(custom_header_name_to_check, str): for header_name, header_value in request_headers.items(): if header_name.lower() == custom_header_name_to_check.lower(): - user_id_str = ( - str(header_value) - if header_value is not None - else "" - ) + user_id_str = str(header_value) if header_value is not None else "" if user_id_str.strip(): return user_id_str diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 376048e7a13..044333ac134 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -685,6 +685,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 do_standard_jwt_auth = True if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None: # Decode JWT to get claims without running full auth_builder + jwt_claims: Optional[dict] if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled: jwt_claims = await jwt_handler.get_oidc_userinfo(token=api_key) else: @@ -700,6 +701,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) if valid_token is not None: api_key = valid_token.token or "" + valid_token.jwt_claims = jwt_claims do_standard_jwt_auth = False # Fall through to virtual key checks @@ -729,6 +731,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_membership: Optional[LiteLLM_TeamMembership] = result.get( "team_membership", None ) + jwt_claims = result.get("jwt_claims", None) global_proxy_spend = await get_global_proxy_spend( litellm_proxy_admin_name=litellm_proxy_admin_name, @@ -757,6 +760,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 org_id=org_id, end_user_id=end_user_id, parent_otel_span=parent_otel_span, + jwt_claims=jwt_claims, ) valid_token = UserAPIKeyAuth( @@ -803,6 +807,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_metadata=( team_object.metadata if team_object is not None else None ), + jwt_claims=jwt_claims, ) # Check if model has zero cost - if so, skip all budget checks diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 740e63b7f17..38e5229eee1 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -537,9 +537,10 @@ async def retrieve_batch( # noqa: PLR0915 ) # Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id - # Resolve raw provider input_file_id to unified ID. + # Resolve raw provider file IDs (input, output, error) to unified IDs. if unified_batch_id: await resolve_input_file_id_to_unified(response, prisma_client) + await resolve_output_file_ids_to_unified(response, prisma_client) ### ALERTING ### asyncio.create_task( diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py new file mode 100644 index 00000000000..abea9014a11 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/__init__.py @@ -0,0 +1,84 @@ +"""MCP JWT Signer guardrail — built-in LiteLLM guardrail for zero trust MCP auth.""" + +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + +from .mcp_jwt_signer import MCPJWTSigner, get_mcp_jwt_signer + +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail( + litellm_params: "LitellmParams", guardrail: "Guardrail" +) -> MCPJWTSigner: + import litellm + + guardrail_name = guardrail.get("guardrail_name") + if not guardrail_name: + raise ValueError("MCPJWTSigner guardrail requires a guardrail_name") + + mode = litellm_params.mode + if mode != "pre_mcp_call": + raise ValueError( + f"MCPJWTSigner guardrail '{guardrail_name}' has mode='{mode}' but must use " + "mode='pre_mcp_call'. JWT injection only fires for MCP tool calls." + ) + + optional_params = getattr(litellm_params, "optional_params", None) + + def _get(key): # type: ignore[no-untyped-def] + if optional_params is not None: + v = getattr(optional_params, key, None) + if v is not None: + return v + return getattr(litellm_params, key, None) + + signer = MCPJWTSigner( + guardrail_name=guardrail_name, + event_hook=litellm_params.mode, + default_on=litellm_params.default_on, + # Core signing + issuer=_get("issuer"), + audience=_get("audience"), + ttl_seconds=_get("ttl_seconds"), + # FR-5: verify + re-sign + access_token_discovery_uri=_get("access_token_discovery_uri"), + token_introspection_endpoint=_get("token_introspection_endpoint"), + verify_issuer=_get("verify_issuer"), + verify_audience=_get("verify_audience"), + # FR-12: end-user identity mapping + end_user_claim_sources=_get("end_user_claim_sources"), + # FR-13: claim operations + add_claims=_get("add_claims"), + set_claims=_get("set_claims"), + remove_claims=_get("remove_claims"), + # FR-14: two-token model + channel_token_audience=_get("channel_token_audience"), + channel_token_ttl=_get("channel_token_ttl"), + # FR-15: incoming claim validation + required_claims=_get("required_claims"), + optional_claims=_get("optional_claims"), + # FR-9: debug headers + debug_headers=_get("debug_headers") or False, + # FR-10: configurable scopes + allowed_scopes=_get("allowed_scopes"), + ) + litellm.logging_callback_manager.add_litellm_callback(signer) + return signer + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: initialize_guardrail, +} + +guardrail_class_registry = { + SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: MCPJWTSigner, +} + +__all__ = [ + "MCPJWTSigner", + "initialize_guardrail", + "get_mcp_jwt_signer", +] diff --git a/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py new file mode 100644 index 00000000000..5502076829f --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/mcp_jwt_signer/mcp_jwt_signer.py @@ -0,0 +1,891 @@ +""" +MCPJWTSigner — Built-in LiteLLM guardrail for zero trust MCP authentication. + +Signs outbound MCP requests with a LiteLLM-issued RS256 JWT so that MCP servers +can trust a single signing authority (liteLLM) instead of every upstream IdP. + +Usage in config.yaml: + + guardrails: + - guardrail_name: "mcp-jwt-signer" + litellm_params: + guardrail: mcp_jwt_signer + mode: "pre_mcp_call" + default_on: true + + # Core signing config + issuer: "https://my-litellm.example.com" # optional + audience: "mcp" # optional + ttl_seconds: 300 # optional + + # FR-5: Verify + re-sign — validate incoming Bearer token before signing + access_token_discovery_uri: "https://idp.example.com/.well-known/openid-configuration" + token_introspection_endpoint: "https://idp.example.com/introspect" # opaque tokens + verify_issuer: "https://idp.example.com" # expected iss in incoming JWT + verify_audience: "api://my-app" # expected aud in incoming JWT + + # FR-12: End-user identity mapping — ordered resolution chain + # Supported: token:, litellm:user_id, litellm:email, + # litellm:end_user_id, litellm:team_id + end_user_claim_sources: + - "token:sub" + - "token:email" + - "litellm:user_id" + + # FR-13: Claim operations + add_claims: # add if key not already present in the JWT + deployment_id: "prod-001" + set_claims: # always set (overrides computed value) + env: "production" + remove_claims: # remove from final JWT + - "nbf" + + # FR-14: Two-token model — issue a second JWT for the MCP transport channel + channel_token_audience: "bedrock-gateway" + channel_token_ttl: 60 + + # FR-15: Incoming claim validation — enforce required IdP claims + required_claims: + - "sub" + - "email" + optional_claims: # pass through from jwt_claims into outbound JWT + - "groups" + - "roles" + + # FR-9: Debug headers + debug_headers: false # emit x-litellm-mcp-debug header when true + + # FR-10: Configurable scopes — explicit list replaces auto-generation + allowed_scopes: + - "mcp:tools/call" + - "mcp:tools/list" + +MCP servers verify tokens via: + GET /.well-known/openid-configuration → { jwks_uri: ".../.well-known/jwks.json" } + GET /.well-known/jwks.json → RSA public key in JWKS format + +Optionally set MCP_JWT_SIGNING_KEY env var (PEM string or file:///path) to use +your own RSA keypair. If unset, an RSA-2048 keypair is auto-generated at startup. +""" + +import base64 +import hashlib +import os +import re +import time +from typing import Any, Dict, List, Optional, Union + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.utils import CallTypesLiteral + +# Module-level singleton for the JWKS discovery endpoint to access. +_mcp_jwt_signer_instance: Optional["MCPJWTSigner"] = None + +# Simple in-memory JWKS cache: keyed by JWKS URI → (keys_list, fetched_at). +_jwks_cache: Dict[str, tuple] = {} +_JWKS_CACHE_TTL = 3600 # 1 hour + + +def get_mcp_jwt_signer() -> Optional["MCPJWTSigner"]: + """Return the active MCPJWTSigner singleton, or None if not initialized.""" + return _mcp_jwt_signer_instance + + +def _load_private_key_from_env(env_var: str) -> RSAPrivateKey: + """Load an RSA private key from an env var (PEM string or file:// path).""" + key_material = os.environ.get(env_var, "") + if not key_material: + raise ValueError( + f"MCPJWTSigner: environment variable '{env_var}' is set but empty." + ) + if key_material.startswith("file://"): + path = key_material[len("file://") :] + with open(path, "rb") as f: + key_bytes = f.read() + else: + key_bytes = key_material.encode("utf-8") + return serialization.load_pem_private_key(key_bytes, password=None) # type: ignore[return-value] + + +def _generate_rsa_key_pair() -> RSAPrivateKey: + """Generate a new RSA-2048 private key.""" + return rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + +def _int_to_base64url(n: int) -> str: + """Encode an integer as a base64url string (no padding).""" + byte_length = (n.bit_length() + 7) // 8 + return ( + base64.urlsafe_b64encode(n.to_bytes(byte_length, byteorder="big")) + .rstrip(b"=") + .decode("ascii") + ) + + +def _compute_kid(public_key: Any) -> str: + """Derive a key ID from the public key's DER encoding (SHA-256, first 16 hex chars).""" + der_bytes = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return hashlib.sha256(der_bytes).hexdigest()[:16] + + +async def _fetch_jwks(jwks_uri: str) -> List[Dict[str, Any]]: + """ + Fetch and cache a JWKS from the given URI. + + Results are cached for _JWKS_CACHE_TTL seconds to avoid hammering the IdP. + """ + now = time.time() + cached = _jwks_cache.get(jwks_uri) + if cached is not None: + keys, fetched_at = cached + if now - fetched_at < _JWKS_CACHE_TTL: + return keys # type: ignore[return-value] + + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.get(jwks_uri, headers={"Accept": "application/json"}) + resp.raise_for_status() + keys = resp.json().get("keys", []) + _jwks_cache[jwks_uri] = (keys, now) + return keys # type: ignore[return-value] + + +async def _fetch_oidc_discovery(discovery_uri: str) -> Dict[str, Any]: + """Fetch an OIDC discovery document and return its parsed JSON.""" + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.get(discovery_uri, headers={"Accept": "application/json"}) + resp.raise_for_status() + return resp.json() # type: ignore[return-value] + + +class MCPJWTSigner(CustomGuardrail): + """ + Built-in LiteLLM guardrail that signs outbound MCP requests with a + LiteLLM-issued RS256 JWT, enabling zero trust authentication. + + MCP servers verify tokens using liteLLM's OIDC discovery endpoint and + JWKS endpoint rather than trusting each upstream IdP directly. + + The signed JWT carries: + - iss: LiteLLM issuer identifier + - aud: MCP audience (configurable) + - sub: End-user identity (resolved via end_user_claim_sources, RFC 8693) + - act: Actor/agent identity (team_id or org_id, RFC 8693 delegation) + - scope: Tool-level access scopes (configurable via allowed_scopes) + - iat, exp, nbf: Standard timing claims + + Feature set: + FR-5: Verify + re-sign (access_token_discovery_uri, token_introspection_endpoint) + FR-9: Debug headers (debug_headers) + FR-10: Configurable scopes (allowed_scopes) + FR-12: Configurable end-user identity mapping (end_user_claim_sources) + FR-13: Claim operations (add_claims, set_claims, remove_claims) + FR-14: Two-token model (channel_token_audience, channel_token_ttl) + FR-15: Incoming claim validation (required_claims, optional_claims) + """ + + ALGORITHM = "RS256" + DEFAULT_TTL = 300 + DEFAULT_AUDIENCE = "mcp" + SIGNING_KEY_ENV = "MCP_JWT_SIGNING_KEY" + + def __init__( + self, + # Core signing config + issuer: Optional[str] = None, + audience: Optional[str] = None, + ttl_seconds: Optional[int] = None, + # FR-5: Verify + re-sign + access_token_discovery_uri: Optional[str] = None, + token_introspection_endpoint: Optional[str] = None, + verify_issuer: Optional[str] = None, + verify_audience: Optional[str] = None, + # FR-12: End-user identity mapping + end_user_claim_sources: Optional[List[str]] = None, + # FR-13: Claim operations + add_claims: Optional[Dict[str, Any]] = None, + set_claims: Optional[Dict[str, Any]] = None, + remove_claims: Optional[List[str]] = None, + # FR-14: Two-token model + channel_token_audience: Optional[str] = None, + channel_token_ttl: Optional[int] = None, + # FR-15: Incoming claim validation + required_claims: Optional[List[str]] = None, + optional_claims: Optional[List[str]] = None, + # FR-9: Debug headers + debug_headers: bool = False, + # FR-10: Configurable scopes + allowed_scopes: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + # --- Signing key setup --- + key_material = os.environ.get(self.SIGNING_KEY_ENV) + if key_material: + self._private_key = _load_private_key_from_env(self.SIGNING_KEY_ENV) + self._persistent_key: bool = True + verbose_proxy_logger.info( + "MCPJWTSigner: loaded RSA key from env var %s", self.SIGNING_KEY_ENV + ) + else: + self._private_key = _generate_rsa_key_pair() + self._persistent_key = False + verbose_proxy_logger.info( + "MCPJWTSigner: auto-generated RSA-2048 keypair (set %s to use your own key)", + self.SIGNING_KEY_ENV, + ) + + self._public_key = self._private_key.public_key() + self._kid = _compute_kid(self._public_key) + + # --- Core config --- + self.issuer: str = ( + issuer + or os.environ.get("MCP_JWT_ISSUER") + or os.environ.get("LITELLM_EXTERNAL_URL") + or "litellm" + ) + self.audience: str = ( + audience or os.environ.get("MCP_JWT_AUDIENCE") or self.DEFAULT_AUDIENCE + ) + resolved_ttl = int( + ttl_seconds + if ttl_seconds is not None + else os.environ.get("MCP_JWT_TTL_SECONDS", str(self.DEFAULT_TTL)) + ) + if resolved_ttl <= 0: + raise ValueError( + f"MCPJWTSigner: ttl_seconds must be > 0, got {resolved_ttl}" + ) + self.ttl_seconds: int = resolved_ttl + + # --- FR-5: Verify + re-sign --- + self.access_token_discovery_uri: Optional[str] = access_token_discovery_uri + self.token_introspection_endpoint: Optional[str] = token_introspection_endpoint + self.verify_issuer: Optional[str] = verify_issuer + self.verify_audience: Optional[str] = verify_audience + # Cached OIDC discovery document (fetched lazily, TTL = 24 h) + self._oidc_discovery_doc: Optional[Dict[str, Any]] = None + self._oidc_discovery_fetched_at: float = 0.0 + + # --- FR-12: End-user identity mapping --- + # Default chain: try incoming JWT sub, fall back to litellm user_id + self.end_user_claim_sources: List[str] = end_user_claim_sources or [ + "token:sub", + "litellm:user_id", + ] + + # --- FR-13: Claim operations --- + self.add_claims: Dict[str, Any] = add_claims or {} + self.set_claims: Dict[str, Any] = set_claims or {} + self.remove_claims: List[str] = remove_claims or [] + + # --- FR-14: Two-token model --- + self.channel_token_audience: Optional[str] = channel_token_audience + self.channel_token_ttl: int = ( + channel_token_ttl if channel_token_ttl is not None else self.ttl_seconds + ) + + # --- FR-15: Incoming claim validation --- + self.required_claims: List[str] = required_claims or [] + self.optional_claims: List[str] = optional_claims or [] + + # --- FR-9: Debug headers --- + self.debug_headers: bool = debug_headers + + # --- FR-10: Configurable scopes --- + self.allowed_scopes: Optional[List[str]] = allowed_scopes + + # Register singleton for JWKS/OIDC discovery endpoints. + global _mcp_jwt_signer_instance + if _mcp_jwt_signer_instance is not None: + verbose_proxy_logger.warning( + "MCPJWTSigner: replacing existing singleton — previously issued tokens " + "signed with the old key will fail JWKS verification. " + "Avoid configuring multiple mcp_jwt_signer guardrails." + ) + _mcp_jwt_signer_instance = self + + verbose_proxy_logger.info( + "MCPJWTSigner initialized: issuer=%s audience=%s ttl=%ds kid=%s " + "verify=%s channel_token=%s debug=%s", + self.issuer, + self.audience, + self.ttl_seconds, + self._kid, + bool(self.access_token_discovery_uri), + bool(self.channel_token_audience), + self.debug_headers, + ) + + # ------------------------------------------------------------------ + # Public helpers (used by /.well-known/jwks.json endpoint) + # ------------------------------------------------------------------ + + @property + def jwks_max_age(self) -> int: + """ + Recommended Cache-Control max-age for the JWKS response (seconds). + + 1 hour for persistent keys; 5 minutes for auto-generated keys so MCP + servers re-fetch quickly after a proxy restart. + """ + return 3600 if self._persistent_key else 300 + + def get_jwks(self) -> Dict[str, Any]: + """ + Return the JWKS for the RSA public key. + Used by GET /.well-known/jwks.json so MCP servers can verify tokens. + """ + public_numbers = self._public_key.public_numbers() + return { + "keys": [ + { + "kty": "RSA", + "alg": self.ALGORITHM, + "use": "sig", + "kid": self._kid, + "n": _int_to_base64url(public_numbers.n), + "e": _int_to_base64url(public_numbers.e), + } + ] + } + + # ------------------------------------------------------------------ + # FR-5: Verify + re-sign helpers + # ------------------------------------------------------------------ + + # 24-hour TTL for the OIDC discovery doc — long enough to avoid hammering + # the IdP, short enough to pick up jwks_uri changes after key rotation. + _OIDC_DISCOVERY_TTL = 86400 + + async def _get_oidc_discovery(self) -> Dict[str, Any]: + """Fetch and cache the OIDC discovery document with a 24-hour TTL. + + Only caches when the doc contains a 'jwks_uri' so that a transient or + malformed response doesn't permanently disable JWT verification. + """ + now = time.time() + cache_expired = ( + now - self._oidc_discovery_fetched_at + ) >= self._OIDC_DISCOVERY_TTL + if ( + self._oidc_discovery_doc is None or cache_expired + ) and self.access_token_discovery_uri: + doc = await _fetch_oidc_discovery(self.access_token_discovery_uri) + if "jwks_uri" in doc: + self._oidc_discovery_doc = doc + self._oidc_discovery_fetched_at = now + else: + return doc + return self._oidc_discovery_doc or {} + + async def _verify_incoming_jwt(self, raw_token: str) -> Dict[str, Any]: + """ + Verify an incoming Bearer JWT against the configured IdP's JWKS. + + Returns the verified payload claims dict. + Raises jwt.PyJWTError (or subclass) if verification fails. + """ + discovery = await self._get_oidc_discovery() + jwks_uri = discovery.get("jwks_uri") + if not jwks_uri: + raise ValueError( + "MCPJWTSigner: access_token_discovery_uri discovery document " + f"at {self.access_token_discovery_uri!r} has no 'jwks_uri'." + ) + + jwks_keys = await _fetch_jwks(jwks_uri) + + # Only read `kid` from the unverified header — never `alg`. + # Reading `alg` from an attacker-controlled header enables algorithm + # confusion attacks (e.g. alg:none, HS256 with the public key as secret). + # The algorithm is determined from the JWKS key entry instead. + unverified_header = jwt.get_unverified_header(raw_token) + kid = unverified_header.get("kid") + + # Build a JWKS object and pick the matching key. + # PyJWT's PyJWKSet handles key-type parsing and kid matching correctly. + from jwt import PyJWKSet + + try: + jwks_set = PyJWKSet.from_dict({"keys": jwks_keys}) + except Exception as exc: + raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined] + f"Failed to parse JWKS from {jwks_uri!r}: {exc}" + ) from exc + + signing_jwk = None + for jwk_obj in jwks_set.keys: + if not kid or jwk_obj.key_id == kid: + signing_jwk = jwk_obj + break + + if signing_jwk is None: + raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined] + f"No JWKS key matching kid={kid!r} at {jwks_uri!r}" + ) + + # Use the algorithm declared by the JWKS key entry, not the token header. + # PyJWT populates algorithm_name from the key's `alg` field; when absent + # it infers from the key type (RSAPublicKey → RS256). + alg = getattr(signing_jwk, "algorithm_name", None) or "RS256" + + decode_options: Dict[str, Any] = {"verify_exp": True} + decode_kwargs: Dict[str, Any] = { + "algorithms": [alg], + "options": decode_options, + } + if self.verify_audience: + decode_kwargs["audience"] = self.verify_audience + else: + decode_options["verify_aud"] = False + + if self.verify_issuer: + decode_kwargs["issuer"] = self.verify_issuer + + payload: Dict[str, Any] = jwt.decode( + raw_token, signing_jwk.key, **decode_kwargs + ) + return payload + + async def _introspect_opaque_token(self, token: str) -> Dict[str, Any]: + """ + Perform RFC 7662 token introspection for opaque (non-JWT) tokens. + + Returns the introspection response dict. Raises on HTTP error or + inactive token. + """ + if not self.token_introspection_endpoint: + raise ValueError( + "MCPJWTSigner: token_introspection_endpoint is required for " + "opaque token verification but is not configured." + ) + + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) + + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) + resp = await client.post( + self.token_introspection_endpoint, + data={"token": token}, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + result: Dict[str, Any] = resp.json() + if not result.get("active", False): + raise jwt.exceptions.ExpiredSignatureError( # type: ignore[attr-defined] + "MCPJWTSigner: incoming token is inactive (introspection returned active=false)" + ) + return result + + # ------------------------------------------------------------------ + # FR-15: Incoming claim validation + # ------------------------------------------------------------------ + + def _validate_required_claims( + self, + jwt_claims: Optional[Dict[str, Any]], + ) -> None: + """ + Raise HTTP 403 if any required_claims are absent from the verified + incoming token claims. + """ + if not self.required_claims: + return + + from fastapi import HTTPException + + missing = [c for c in self.required_claims if not (jwt_claims or {}).get(c)] + if missing: + raise HTTPException( + status_code=403, + detail={ + "error": ( + f"MCPJWTSigner: incoming token is missing required claims: " + f"{missing}. Configure the IdP to include these claims." + ) + }, + ) + + # ------------------------------------------------------------------ + # FR-12: End-user identity mapping + # ------------------------------------------------------------------ + + def _resolve_end_user_identity( + self, + user_api_key_dict: UserAPIKeyAuth, + jwt_claims: Optional[Dict[str, Any]], + ) -> str: + """ + Resolve the outbound JWT 'sub' using the ordered end_user_claim_sources list. + + Supported source prefixes: + token: — from verified incoming JWT / introspection claims + litellm:user_id — from UserAPIKeyAuth.user_id + litellm:email — from UserAPIKeyAuth.user_email + litellm:end_user_id — from UserAPIKeyAuth.end_user_id + litellm:team_id — from UserAPIKeyAuth.team_id + + Falls back to a stable hash of the API token for service-account callers. + """ + for source in self.end_user_claim_sources: + value: Optional[str] = None + + if source.startswith("token:"): + claim_name = source[len("token:") :] + raw = (jwt_claims or {}).get(claim_name) + value = str(raw) if raw else None + + elif source == "litellm:user_id": + uid = getattr(user_api_key_dict, "user_id", None) + value = str(uid) if uid else None + + elif source == "litellm:email": + email = getattr(user_api_key_dict, "user_email", None) + value = str(email) if email else None + + elif source == "litellm:end_user_id": + eid = getattr(user_api_key_dict, "end_user_id", None) + value = str(eid) if eid else None + + elif source == "litellm:team_id": + tid = getattr(user_api_key_dict, "team_id", None) + value = str(tid) if tid else None + + else: + verbose_proxy_logger.warning( + "MCPJWTSigner: unknown end_user_claim_source %r — skipping", source + ) + continue + + if value: + return value + + # Final fallback for service accounts with no user identity + token = getattr(user_api_key_dict, "token", None) or getattr( + user_api_key_dict, "api_key", None + ) + if token: + return "apikey:" + hashlib.sha256(str(token).encode()).hexdigest()[:16] + return "litellm-proxy" + + # ------------------------------------------------------------------ + # FR-10: Scope building + # ------------------------------------------------------------------ + + def _build_scope(self, raw_tool_name: str) -> str: + """ + Build the JWT scope string. + + When allowed_scopes is configured: join them verbatim. + Otherwise auto-generate minimal, least-privilege scopes: + - Tool call → mcp:tools/call mcp:tools/:call + - No tool → mcp:tools/call mcp:tools/list + + NOTE: tools/list is intentionally NOT granted on tool-call JWTs to + prevent callers from enumerating tools they didn't ask to use. + """ + if self.allowed_scopes is not None: + return " ".join(self.allowed_scopes) + + tool_name = ( + re.sub(r"[^a-zA-Z0-9_\-]", "_", raw_tool_name) if raw_tool_name else "" + ) + if tool_name: + scopes = ["mcp:tools/call", f"mcp:tools/{tool_name}:call"] + else: + scopes = ["mcp:tools/call", "mcp:tools/list"] + return " ".join(scopes) + + # ------------------------------------------------------------------ + # FR-13: Claim operations + # ------------------------------------------------------------------ + + def _apply_claim_operations(self, claims: Dict[str, Any]) -> Dict[str, Any]: + """Apply add_claims, set_claims, and remove_claims to the claim dict.""" + # add_claims: insert only when key is absent + for k, v in self.add_claims.items(): + if k not in claims: + claims[k] = v + + # set_claims: always override (highest priority) + claims = {**claims, **self.set_claims} + + # remove_claims: delete listed keys + for k in self.remove_claims: + claims.pop(k, None) + + return claims + + # ------------------------------------------------------------------ + # FR-15: optional_claims passthrough + # ------------------------------------------------------------------ + + def _passthrough_optional_claims( + self, + claims: Dict[str, Any], + jwt_claims: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """Forward optional_claims from verified incoming token into the outbound JWT.""" + if not self.optional_claims or not jwt_claims: + return claims + for claim in self.optional_claims: + if claim in jwt_claims and claim not in claims: + claims[claim] = jwt_claims[claim] + return claims + + # ------------------------------------------------------------------ + # Core JWT builder + # ------------------------------------------------------------------ + + def _build_claims( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + jwt_claims: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Build JWT claims for the outbound MCP access token. + + Args: + user_api_key_dict: LiteLLM auth context for the current request. + data: Pre-call hook data dict (contains mcp_tool_name etc.). + jwt_claims: Verified incoming IdP claims (FR-5), or LiteLLM-decoded + jwt_claims if available. None for pure API-key requests. + """ + now = int(time.time()) + claims: Dict[str, Any] = { + "iss": self.issuer, + "aud": self.audience, + "iat": now, + "exp": now + self.ttl_seconds, + "nbf": now, + } + + # sub — resolved via ordered claim sources (FR-12) + claims["sub"] = self._resolve_end_user_identity(user_api_key_dict, jwt_claims) + + # email passthrough when available from LiteLLM context + user_email = getattr(user_api_key_dict, "user_email", None) + if user_email: + claims["email"] = user_email + + # act — RFC 8693 delegation claim (team/org context) + team_id = getattr(user_api_key_dict, "team_id", None) + org_id = getattr(user_api_key_dict, "org_id", None) + act_sub = team_id or org_id or "litellm-proxy" + claims["act"] = {"sub": act_sub} + + # end_user_id when set separately from user_id + end_user_id = getattr(user_api_key_dict, "end_user_id", None) + if end_user_id: + claims["end_user_id"] = end_user_id + + # scope (FR-10) + raw_tool_name: str = data.get("mcp_tool_name", "") + claims["scope"] = self._build_scope(raw_tool_name) + + # optional_claims passthrough (FR-15) + claims = self._passthrough_optional_claims(claims, jwt_claims) + + # Claim operations — applied last so admin overrides take effect (FR-13) + claims = self._apply_claim_operations(claims) + + return claims + + def _build_channel_token_claims( + self, + base_claims: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Build claims for the channel token (FR-14 two-token model). + + Inherits sub/act/scope from the access token but uses a separate + audience and TTL so the transport layer and resource layer receive + purpose-bound credentials. + """ + now = int(time.time()) + return { + **base_claims, + "aud": self.channel_token_audience, + "iat": now, + "exp": now + self.channel_token_ttl, + "nbf": now, + } + + # ------------------------------------------------------------------ + # FR-9: Debug header + # ------------------------------------------------------------------ + + @staticmethod + def _build_debug_header(claims: Dict[str, Any], kid: str) -> str: + """ + Build the x-litellm-mcp-debug header value. + + Format: v=1; kid=; sub=; iss=; exp=; scope= + Scope is truncated to 80 chars for header safety. + """ + sub = claims.get("sub", "") + iss = claims.get("iss", "") + exp = claims.get("exp", 0) + scope = claims.get("scope", "") + if len(scope) > 80: + scope = scope[:77] + "..." + return f"v=1; kid={kid}; sub={sub}; iss={iss}; exp={exp}; scope={scope}" + + # ------------------------------------------------------------------ + # Guardrail hook + # ------------------------------------------------------------------ + + @log_guardrail_information + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + """ + Verifies the incoming token (when configured), validates required claims, + then signs an outbound JWT and injects it as the Authorization header. + + All non-MCP call types pass through unchanged. + """ + if call_type != "call_mcp_tool": + return data + + # ------------------------------------------------------------------ + # FR-5: Verify incoming token before re-signing + # ------------------------------------------------------------------ + jwt_claims: Optional[Dict[str, Any]] = None + raw_token: Optional[str] = data.get("incoming_bearer_token") + + if self.access_token_discovery_uri and raw_token: + # Three-dot pattern → JWT; otherwise opaque. + is_jwt = raw_token.count(".") == 2 + try: + if is_jwt: + jwt_claims = await self._verify_incoming_jwt(raw_token) + elif self.token_introspection_endpoint: + jwt_claims = await self._introspect_opaque_token(raw_token) + else: + verbose_proxy_logger.warning( + "MCPJWTSigner: access_token_discovery_uri is set but the " + "incoming token appears to be opaque and no " + "token_introspection_endpoint is configured. " + "Proceeding without incoming token verification." + ) + except Exception as exc: + verbose_proxy_logger.error( + "MCPJWTSigner: incoming token verification failed: %s", exc + ) + from fastapi import HTTPException + + raise HTTPException( + status_code=401, + detail={ + "error": ( + f"MCPJWTSigner: incoming token verification failed: {exc}" + ) + }, + ) + elif not raw_token and self.access_token_discovery_uri: + verbose_proxy_logger.debug( + "MCPJWTSigner: access_token_discovery_uri configured but no Bearer " + "token found in request (API-key auth request — skipping verification)." + ) + + # Fall back to LiteLLM-decoded JWT claims (available when proxy uses JWT auth). + if jwt_claims is None: + jwt_claims = getattr(user_api_key_dict, "jwt_claims", None) + + # ------------------------------------------------------------------ + # FR-15: Validate required claims + # ------------------------------------------------------------------ + self._validate_required_claims(jwt_claims) + + # ------------------------------------------------------------------ + # Build outbound access token + # ------------------------------------------------------------------ + claims = self._build_claims(user_api_key_dict, data, jwt_claims) + + signed_token = jwt.encode( + claims, + self._private_key, + algorithm=self.ALGORITHM, + headers={"kid": self._kid}, + ) + + # Merge into existing extra_headers — a prior guardrail in the chain may + # have already injected tracing headers or correlation IDs. + existing_headers: Dict[str, str] = data.get("extra_headers") or {} + new_headers: Dict[str, str] = { + **existing_headers, + "Authorization": f"Bearer {signed_token}", + } + + # ------------------------------------------------------------------ + # FR-14: Two-token model — channel token + # ------------------------------------------------------------------ + if self.channel_token_audience: + channel_claims = self._build_channel_token_claims(claims) + channel_token = jwt.encode( + channel_claims, + self._private_key, + algorithm=self.ALGORITHM, + headers={"kid": self._kid}, + ) + new_headers["x-mcp-channel-token"] = f"Bearer {channel_token}" + + # ------------------------------------------------------------------ + # FR-9: Debug header + # ------------------------------------------------------------------ + if self.debug_headers: + new_headers["x-litellm-mcp-debug"] = self._build_debug_header( + claims, self._kid + ) + + data["extra_headers"] = new_headers + + verbose_proxy_logger.debug( + "MCPJWTSigner: signed JWT sub=%s act=%s tool=%s exp=%d " + "verified=%s channel=%s", + claims.get("sub"), + claims.get("act", {}).get("sub"), + data.get("mcp_tool_name"), + claims["exp"], + jwt_claims is not None, + bool(self.channel_token_audience), + ) + + return data diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 8a71b8d4b59..ca8c345f46c 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -2142,8 +2142,7 @@ async def _resolve_org_filter_for_user_search( member_org_ids: List[str] = [] if caller_user is not None: member_org_ids = [ - m.organization_id - for m in (caller_user.organization_memberships or []) + m.organization_id for m in (caller_user.organization_memberships or []) ] if member_org_ids: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index e09d7607ebe..1016e6b149c 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1863,16 +1863,10 @@ async def _validate_update_key_data( user_api_key_cache: Any, ) -> None: """Validate permissions and constraints for key update.""" - _is_proxy_admin = ( - user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value - ) + _is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value # Prevent non-admin from removing user_id (setting to empty string) (LIT-1884) - if ( - data.user_id is not None - and data.user_id == "" - and not _is_proxy_admin - ): + if data.user_id is not None and data.user_id == "" and not _is_proxy_admin: raise HTTPException( status_code=403, detail="Non-admin users cannot remove the user_id from a key.", diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index d83ceb1b095..ceb3d5d277a 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -857,7 +857,13 @@ async def new_team( # noqa: PLR0915 # Apply defaults from litellm.default_team_params for any fields # not explicitly provided in the request. - for field in ("max_budget", "budget_duration", "tpm_limit", "rpm_limit", "team_member_permissions"): + for field in ( + "max_budget", + "budget_duration", + "tpm_limit", + "rpm_limit", + "team_member_permissions", + ): if getattr(data, field, None) is None: default_value = _get_default_team_param(field) if default_value is not None: diff --git a/litellm/proxy/openai_files_endpoints/common_utils.py b/litellm/proxy/openai_files_endpoints/common_utils.py index 49f17535333..25ad1b25aad 100644 --- a/litellm/proxy/openai_files_endpoints/common_utils.py +++ b/litellm/proxy/openai_files_endpoints/common_utils.py @@ -857,7 +857,10 @@ async def update_batch_in_database( # If the batch_processed column doesn't exist (old schema), # retry without it so the status update still succeeds. err_str = str(col_err).lower() - if "batch_processed" in err_str and update_data.get("batch_processed") is not None: + if ( + "batch_processed" in err_str + and update_data.get("batch_processed") is not None + ): verbose_proxy_logger.warning( f"batch_processed column not found, retrying update without it: {col_err}" ) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index e5a34ae8bdd..97d5de0d53d 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -468,6 +468,12 @@ def _maybe_setup_prometheus_multiproc_dir( type=str, help="Path to the logging configuration file", ) +@click.option( + "--setup", + is_flag=True, + default=False, + help="Run the interactive setup wizard to configure providers and generate a config file", +) @click.option( "--version", "-v", @@ -598,6 +604,7 @@ def run_server( # noqa: PLR0915 num_requests, use_queue, health, + setup, version, run_gunicorn, run_hypercorn, @@ -611,6 +618,12 @@ def run_server( # noqa: PLR0915 max_requests_before_restart, enforce_prisma_migration_check: bool, ): + if setup: + from litellm.setup_wizard import run_setup_wizard + + run_setup_wizard() + return + args = locals() if local: from proxy_server import ( @@ -904,7 +917,7 @@ def run_server( # noqa: PLR0915 # Auto-create PROMETHEUS_MULTIPROC_DIR for multi-worker setups ProxyInitializationHelpers._maybe_setup_prometheus_multiproc_dir( num_workers=num_workers, - litellm_settings=litellm_settings if config else None, + litellm_settings=litellm_settings if config else None, # type: ignore[possibly-unbound] ) # --- SEPARATE HEALTH APP LOGIC --- diff --git a/litellm/proxy/response_polling/background_streaming.py b/litellm/proxy/response_polling/background_streaming.py index b4d51814e5a..7583f30eb2d 100644 --- a/litellm/proxy/response_polling/background_streaming.py +++ b/litellm/proxy/response_polling/background_streaming.py @@ -115,7 +115,9 @@ async def background_streaming_task( # noqa: PLR0915 UPDATE_INTERVAL = 0.150 # 150ms batching interval # Track the terminal event from the stream (may not be "completed") - terminal_status: Optional[ResponsesAPIStatus] = None # Will be set by response.completed/failed/incomplete/cancelled + terminal_status: Optional[ + ResponsesAPIStatus + ] = None # Will be set by response.completed/failed/incomplete/cancelled terminal_error = None _event_to_status = { "response.completed": "completed", @@ -259,7 +261,10 @@ async def flush_state_if_needed(force: bool = False) -> None: ) # Extract error for failed and incomplete responses - if event_type == "response.failed" or event_type == "response.incomplete": + if ( + event_type == "response.failed" + or event_type == "response.incomplete" + ): terminal_error = response_data.get("error") # Core response fields diff --git a/litellm/proxy/spend_tracking/vantage_endpoints.py b/litellm/proxy/spend_tracking/vantage_endpoints.py index 14c0ebcb54d..7d8fbf74615 100644 --- a/litellm/proxy/spend_tracking/vantage_endpoints.py +++ b/litellm/proxy/spend_tracking/vantage_endpoints.py @@ -40,9 +40,7 @@ def _get_registered_vantage_logger(): return None -async def _set_vantage_settings( - api_key: str, integration_token: str, base_url: str -): +async def _set_vantage_settings(api_key: str, integration_token: str, base_url: str): """Store Vantage settings in the database with encrypted API key.""" from litellm.proxy.proxy_server import prisma_client @@ -341,9 +339,7 @@ async def init_vantage_settings( except HTTPException: raise except Exception as e: - verbose_proxy_logger.error( - f"Error initializing Vantage settings: {str(e)}" - ) + verbose_proxy_logger.error(f"Error initializing Vantage settings: {str(e)}") raise HTTPException( status_code=500, detail={"error": f"Failed to initialize Vantage settings: {str(e)}"}, @@ -395,7 +391,8 @@ def _to_json_safe_dicts(frame: pl.DataFrame) -> list: """Cast Decimal columns to Float64 so .to_dicts() produces JSON-serializable float values instead of decimal.Decimal.""" decimal_cols = [ - col for col, dtype in zip(frame.columns, frame.dtypes) + col + for col, dtype in zip(frame.columns, frame.dtypes) if isinstance(dtype, pl.Decimal) ] if decimal_cols: @@ -404,8 +401,16 @@ def _to_json_safe_dicts(frame: pl.DataFrame) -> list: ) return frame.to_dicts() - usage_sample = _to_json_safe_dicts(data.head(min(50, len(data)))) if not data.is_empty() else [] - normalized_sample = _to_json_safe_dicts(normalized.head(min(50, len(normalized)))) if not normalized.is_empty() else [] + usage_sample = ( + _to_json_safe_dicts(data.head(min(50, len(data)))) + if not data.is_empty() + else [] + ) + normalized_sample = ( + _to_json_safe_dicts(normalized.head(min(50, len(normalized)))) + if not normalized.is_empty() + else [] + ) # Use the same pre-transform column names as # FocusExportEngine.dry_run_export_usage_data for consistency. @@ -437,14 +442,10 @@ def _to_json_safe_dicts(frame: pl.DataFrame) -> list: except HTTPException: raise except Exception as e: - verbose_proxy_logger.error( - f"Error performing Vantage dry run export: {str(e)}" - ) + verbose_proxy_logger.error(f"Error performing Vantage dry run export: {str(e)}") raise HTTPException( status_code=500, - detail={ - "error": f"Failed to perform Vantage dry run export: {str(e)}" - }, + detail={"error": f"Failed to perform Vantage dry run export: {str(e)}"}, ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 76662be1753..df527d08af8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -454,8 +454,6 @@ def _add_proxy_hooks(self, llm_router: Optional[Router] = None): for hook in PROXY_HOOKS: proxy_hook = get_proxy_hook(hook) - import inspect - expected_args = inspect.getfullargspec(proxy_hook).args passed_in_args: Dict[str, Any] = {} if "internal_usage_cache" in expected_args: @@ -559,6 +557,10 @@ def _convert_mcp_to_llm_format(self, request_obj, kwargs: dict) -> dict: "user_api_key_request_route": kwargs.get("user_api_key_request_route"), "mcp_tool_name": request_obj.tool_name, # Keep original for reference "mcp_arguments": request_obj.arguments, # Keep original for reference + # Raw Bearer token from the original HTTP request — allows guardrails + # (e.g. MCPJWTSigner) to independently verify the caller's identity + # before re-signing an outbound token (FR-5 verify+re-sign). + "incoming_bearer_token": kwargs.get("incoming_bearer_token"), } return synthetic_data @@ -824,17 +826,30 @@ def _convert_mcp_hook_response_to_kwargs( ) -> dict: """ Helper function to convert pre_call_hook response back to kwargs for MCP usage. + + Supports: + - modified_arguments: Override tool call arguments + - extra_headers: Inject custom headers into the outbound MCP request """ if not response_data: return original_kwargs - # Apply any argument modifications from the hook response modified_kwargs = original_kwargs.copy() - # If the response contains modified arguments, apply them if response_data.get("modified_arguments"): modified_kwargs["arguments"] = response_data["modified_arguments"] + if response_data.get("extra_headers"): + # Merge rather than replace — a prior guardrail in the chain may have + # already injected headers (e.g. tracing IDs). Later guardrails win on + # key collisions so that the most-specific guardrail (e.g. JWT signer) + # takes precedence over earlier ones. + existing = modified_kwargs.get("extra_headers") or {} + modified_kwargs["extra_headers"] = { + **existing, + **response_data["extra_headers"], + } + return modified_kwargs async def process_pre_call_hook_response(self, response, data, call_type): diff --git a/litellm/proxy/video_endpoints/utils.py b/litellm/proxy/video_endpoints/utils.py index 36203bdc77e..689fe4a371b 100644 --- a/litellm/proxy/video_endpoints/utils.py +++ b/litellm/proxy/video_endpoints/utils.py @@ -7,7 +7,9 @@ def extract_model_from_target_model_names(target_model_names: Any) -> Optional[str]: if isinstance(target_model_names, str): - target_model_names = [m.strip() for m in target_model_names.split(",") if m.strip()] + target_model_names = [ + m.strip() for m in target_model_names.split(",") if m.strip() + ] elif not isinstance(target_model_names, list): return None return target_model_names[0] if target_model_names else None diff --git a/litellm/responses/main.py b/litellm/responses/main.py index cd9ce67c26e..2a320517a4c 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -692,11 +692,11 @@ def responses( return run_async_function(aresponses_api_with_mcp, **mcp_call_kwargs) # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=custom_llm_provider, ) local_vars.update(kwargs) @@ -908,11 +908,11 @@ def delete_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, ) if responses_api_provider_config is None: @@ -1089,11 +1089,11 @@ def get_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, ) if responses_api_provider_config is None: @@ -1247,11 +1247,11 @@ def list_input_items( if custom_llm_provider is None: raise ValueError("custom_llm_provider is required but passed as None") - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, ) if responses_api_provider_config is None: @@ -1406,11 +1406,11 @@ def cancel_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, ) if responses_api_provider_config is None: @@ -1594,11 +1594,11 @@ def compact_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=custom_llm_provider, - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=custom_llm_provider, ) if responses_api_provider_config is None: diff --git a/litellm/router.py b/litellm/router.py index f34368172ac..46998abb160 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8611,6 +8611,7 @@ def _pre_call_checks( # noqa: PLR0915 _model_info = deployment.get("model_info", {}) # see if we have the info for this model + _deployment_model = None # per-deployment model name (avoids overwriting the outer `model` group name) try: base_model = _model_info.get("base_model", None) if base_model is None: @@ -8618,7 +8619,7 @@ def _pre_call_checks( # noqa: PLR0915 model_info = self.get_router_model_info( deployment=deployment, received_model_name=model ) - model = base_model or _litellm_params.get("model", None) + _deployment_model = base_model or _litellm_params.get("model", None) if ( isinstance(model_info, dict) @@ -8632,7 +8633,9 @@ def _pre_call_checks( # noqa: PLR0915 _context_window_error = True _potential_error_str += ( "Model={}, Max Input Tokens={}, Got={}".format( - model, model_info["max_input_tokens"], input_tokens + _deployment_model, + model_info["max_input_tokens"], + input_tokens, ) ) continue @@ -8688,13 +8691,21 @@ def _pre_call_checks( # noqa: PLR0915 ## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param if request_kwargs is not None and litellm.drop_params is False: - # get supported params - model, custom_llm_provider, _, _ = litellm.get_llm_provider( - model=model, litellm_params=LiteLLM_Params(**_litellm_params) + # get supported params — use per-deployment model to avoid overwriting the outer model group name + _dep_model_for_params = _deployment_model or model + ( + _dep_model_for_params, + custom_llm_provider, + _, + _, + ) = litellm.get_llm_provider( + model=_dep_model_for_params, + litellm_params=LiteLLM_Params(**_litellm_params), ) supported_openai_params = litellm.get_supported_openai_params( - model=model, custom_llm_provider=custom_llm_provider + model=_dep_model_for_params, + custom_llm_provider=custom_llm_provider, ) if supported_openai_params is None: diff --git a/litellm/setup_wizard.py b/litellm/setup_wizard.py new file mode 100644 index 00000000000..ee5918e1273 --- /dev/null +++ b/litellm/setup_wizard.py @@ -0,0 +1,668 @@ +# ruff: noqa: T201 +# flake8: noqa: T201 +""" +LiteLLM Interactive Setup Wizard + +Guides users through selecting LLM providers, entering API keys, +and generating a proxy config file — mirroring the Claude Code onboarding UX. +""" + +import importlib.metadata +import os +import re +import secrets +import sys +import sysconfig +from pathlib import Path +from typing import Dict, List, Optional, Set + +# termios / tty are Unix-only; fall back gracefully on Windows +try: + import termios + import tty + + _HAS_RAW_TERMINAL: bool = True +except ImportError: + termios = None # type: ignore[assignment] + tty = None # type: ignore[assignment] + _HAS_RAW_TERMINAL = False + +from litellm.utils import check_valid_key + +# --------------------------------------------------------------------------- +# Provider definitions +# --------------------------------------------------------------------------- +# Each entry describes one provider card shown in the wizard. +# `env_key` — primary env var name (None = no key needed, e.g. Ollama) +# `test_model` — model passed to check_valid_key for credential validation +# (None = skip validation, e.g. Azure needs a deployment name) +# `models` — default models written into the generated config +# --------------------------------------------------------------------------- + +PROVIDERS: List[Dict] = [ + { + "id": "openai", + "name": "OpenAI", + "description": "GPT-4o, GPT-4o-mini, o3-mini", + "env_key": "OPENAI_API_KEY", + "key_hint": "sk-...", + "test_model": "gpt-4o-mini", + "models": ["gpt-4o", "gpt-4o-mini"], + }, + { + "id": "anthropic", + "name": "Anthropic", + "description": "Claude Opus 4.6, Sonnet 4.6, Haiku 4.5", + "env_key": "ANTHROPIC_API_KEY", + "key_hint": "sk-ant-...", + "test_model": "claude-haiku-4-5-20251001", + "models": ["claude-opus-4-6", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"], + }, + { + "id": "gemini", + "name": "Google Gemini", + "description": "Gemini 2.0 Flash, Gemini 2.5 Pro", + "env_key": "GEMINI_API_KEY", + "key_hint": "AIza...", + "test_model": "gemini/gemini-2.0-flash", + "models": ["gemini/gemini-2.0-flash", "gemini/gemini-2.5-pro"], + }, + { + "id": "azure", + "name": "Azure OpenAI", + "description": "GPT-4o via Azure", + "env_key": "AZURE_API_KEY", + "key_hint": "your-azure-key", + "test_model": None, # needs deployment name — skip validation + "models": [], + "needs_api_base": True, + "api_base_hint": "https://.openai.azure.com/", + "api_version": "2024-07-01-preview", + }, + { + "id": "bedrock", + "name": "AWS Bedrock", + "description": "Claude 3.5, Llama 3 via AWS", + "env_key": "AWS_ACCESS_KEY_ID", + "key_hint": "AKIA...", + "test_model": None, # multi-key auth — skip validation + "models": ["bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"], + "extra_keys": ["AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"], + "extra_hints": ["your-secret-key", "us-east-1"], + }, + { + "id": "ollama", + "name": "Ollama", + "description": "Local models (llama3.2, mistral, etc.)", + "env_key": None, + "key_hint": None, + "test_model": None, # local — no remote validation + "models": ["ollama/llama3.2", "ollama/mistral"], + "api_base": "http://localhost:11434", + }, +] + + +# --------------------------------------------------------------------------- +# ANSI colour helpers +# --------------------------------------------------------------------------- + +_ANSI_RE = re.compile(r"\033\[[^m]*m") + +_ORANGE = "\033[38;2;215;119;87m" +_DIM = "\033[2m" +_BOLD = "\033[1m" +_GREEN = "\033[38;2;78;186;101m" +_BLUE = "\033[38;2;177;185;249m" +_GREY = "\033[38;2;153;153;153m" +_RESET = "\033[0m" +_CHECK = "✔" +_CROSS = "✘" + +_CURSOR_HIDE = "\033[?25l" +_CURSOR_SHOW = "\033[?25h" +_MOVE_UP = "\033[{}A" + + +def _supports_color() -> bool: + return sys.stdout.isatty() and os.environ.get("NO_COLOR") is None + + +def _c(code: str, text: str) -> str: + return f"{code}{text}{_RESET}" if _supports_color() else text + + +def orange(t: str) -> str: + return _c(_ORANGE, t) + + +def bold(t: str) -> str: + return _c(_BOLD, t) + + +def green(t: str) -> str: + return _c(_GREEN, t) + + +def blue(t: str) -> str: + return _c(_BLUE, t) + + +def grey(t: str) -> str: + return _c(_GREY, t) + + +def dim(t: str) -> str: + return _c(_DIM, t) + + +def _divider() -> str: + """Return a styled divider line (evaluated at call-time, not import-time).""" + return dim(" " + "╌" * 74) + + +def _styled_input(prompt: str) -> str: + """ + Like input() but wraps ANSI sequences in readline ignore markers + (\\001...\\002) so readline correctly tracks the cursor column. + In non-TTY contexts, strips ANSI entirely so no escape codes appear. + """ + if sys.stdout.isatty(): + rl_prompt = _ANSI_RE.sub(lambda m: f"\001{m.group()}\002", prompt) + else: + rl_prompt = _ANSI_RE.sub("", prompt) + return input(rl_prompt).strip() + + +def _yaml_escape(value: str) -> str: + """Escape a string for safe embedding in a double-quoted YAML scalar.""" + return ( + value.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + ) + + +# --------------------------------------------------------------------------- +# Layout constants +# --------------------------------------------------------------------------- + +LITELLM_ASCII = r""" + ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗ + ██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║ + ██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║ + ██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║ + ███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║ + ╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝ +""" + + +# --------------------------------------------------------------------------- +# Setup wizard +# --------------------------------------------------------------------------- + + +class SetupWizard: + """ + Interactive onboarding wizard: provider selection → API keys → config file. + + All methods are static — the class is purely a namespace with clear + single-responsibility sections. Entry point: SetupWizard.run(). + """ + + # ── entry point ───────────────────────────────────────────────────────── + + @staticmethod + def run() -> None: + try: + SetupWizard._wizard() + except (KeyboardInterrupt, EOFError): + print(f"\n\n {grey('Setup cancelled.')}\n") + + # ── wizard steps ──────────────────────────────────────────────────────── + + @staticmethod + def _wizard() -> None: + SetupWizard._print_welcome() + print(f" {bold('Lets get started.')}") + print() + + providers = SetupWizard._select_providers() + env_vars = SetupWizard._collect_keys(providers) + port, master_key = SetupWizard._proxy_settings() + + config_path = Path(os.getcwd()) / "litellm_config.yaml" + try: + config_path.write_text( + SetupWizard._build_config(providers, env_vars, master_key) + ) + except OSError as exc: + print(f"\n {bold(_CROSS + ' Could not write config:')} {exc}") + print(" Try running from a directory you have write access to.\n") + return + + SetupWizard._print_success(config_path, port, master_key) + SetupWizard._offer_start(config_path, port, master_key) + + # ── welcome ───────────────────────────────────────────────────────────── + + @staticmethod + def _print_welcome() -> None: + try: + version = importlib.metadata.version("litellm") + except Exception: + version = "unknown" + print() + print(orange(LITELLM_ASCII.rstrip("\n"))) + print(f" {orange('Welcome')} to {bold('LiteLLM')} {grey('v' + version)}") + print() + print(_divider()) + print() + + # ── provider selector ─────────────────────────────────────────────────── + + @staticmethod + def _select_providers() -> List[Dict]: + """Arrow-key multi-select. Falls back to number input if /dev/tty unavailable.""" + if not _HAS_RAW_TERMINAL: + return SetupWizard._select_fallback() + try: + return SetupWizard._select_interactive() + except OSError: + return SetupWizard._select_fallback() + + @staticmethod + def _read_key() -> str: + """Read one keypress from /dev/tty in raw mode.""" + assert ( + termios is not None and tty is not None + ) # only called when _HAS_RAW_TERMINAL + with open("/dev/tty", "rb") as tty_fh: + fd = tty_fh.fileno() + old = termios.tcgetattr(fd) + try: + tty.setraw(fd) + ch = tty_fh.read(1) + if ch == b"\x1b": + ch2 = tty_fh.read(1) + if ch2 == b"[": + ch3 = tty_fh.read(1) + return "\x1b[" + ch3.decode("utf-8", errors="replace") + return "\x1b" + ch2.decode("utf-8", errors="replace") + return ch.decode("utf-8", errors="replace") + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old) + + @staticmethod + def _render_selector(cursor: int, selected: Set[int], first_render: bool) -> int: + """Draw or redraw the provider list. Returns the number of lines printed.""" + lines = [ + f"\n {bold('Add your first model')}\n", + grey(" ↑↓ to navigate · Space to select · Enter to confirm") + "\n", + "\n", + ] + for i, p in enumerate(PROVIDERS): + arrow = blue("❯") if i == cursor else " " + bullet = green("◉") if i in selected else grey("○") + name_str = bold(p["name"]) if i == cursor else p["name"] + lines.append(f" {arrow} {bullet} {name_str} {grey(p['description'])}\n") + lines.append("\n") + + content = "".join(lines) + if not first_render and _supports_color(): + sys.stdout.write(_MOVE_UP.format(content.count("\n"))) + sys.stdout.write(content) + sys.stdout.flush() + return content.count("\n") + + @staticmethod + def _select_interactive() -> List[Dict]: + cursor = 0 + selected: set[int] = set() + + if _supports_color(): + sys.stdout.write(_CURSOR_HIDE) + sys.stdout.flush() + try: + SetupWizard._render_selector(cursor, selected, first_render=True) + while True: + key = SetupWizard._read_key() + dirty = False + if key == "\x1b[A": + cursor = (cursor - 1) % len(PROVIDERS) + dirty = True + elif key == "\x1b[B": + cursor = (cursor + 1) % len(PROVIDERS) + dirty = True + elif key == " ": + selected.symmetric_difference_update({cursor}) + dirty = True + elif key in ("\r", "\n"): + if not selected: + selected.add(cursor) + break + elif key in ("\x03", "\x04"): + raise KeyboardInterrupt + if dirty: + SetupWizard._render_selector(cursor, selected, first_render=False) + finally: + if _supports_color(): + sys.stdout.write(_CURSOR_SHOW) + sys.stdout.flush() + + return [PROVIDERS[i] for i in sorted(selected)] + + @staticmethod + def _select_fallback() -> List[Dict]: + """Number-based fallback when raw terminal input is unavailable.""" + print() + print(f" {bold('Add your first model')}") + print( + grey( + " Enter numbers separated by commas (e.g. 1,2). Press Enter to confirm." + ) + ) + print() + for i, p in enumerate(PROVIDERS, 1): + print(f" {grey(str(i) + '.')} {bold(p['name'])} {grey(p['description'])}") + print() + + while True: + raw = _styled_input(f" {blue('❯')} Provider(s): ") + if not raw: + print(grey(" Please select at least one provider.")) + continue + try: + nums = [ + int(x.strip()) + for x in raw.replace(" ", ",").split(",") + if x.strip() + ] + valid = sorted({n for n in nums if 1 <= n <= len(PROVIDERS)}) + if not valid: + print(grey(f" Enter numbers between 1 and {len(PROVIDERS)}.")) + continue + return [PROVIDERS[i - 1] for i in valid] + except ValueError: + print(grey(" Enter numbers separated by commas, e.g. 1,3")) + + # ── key collection ─────────────────────────────────────────────────────── + + @staticmethod + def _collect_keys(providers: List[Dict]) -> Dict[str, str]: + env_vars: Dict[str, str] = {} + print() + print(_divider()) + print() + print(f" {bold('Enter your API keys')}") + print(grey(" Keys are stored only in the generated config file.")) + print( + grey( + " Tip: add litellm_config.yaml to .gitignore to avoid committing secrets." + ) + ) + print() + + for p in providers: + if p["env_key"] is None: + print( + f" {green(p['name'])}: {grey('no key needed (uses local Ollama)')}" + ) + continue + + key = SetupWizard._prompt_key(p) + if not key: + continue + + for extra_key, extra_hint in zip( + p.get("extra_keys", []), p.get("extra_hints", []) + ): + val = _styled_input(f" {blue('❯')} {extra_key} {grey(extra_hint)}: ") + if val: + env_vars[extra_key] = val + + if p.get("needs_api_base"): + api_base = _styled_input( + f" {blue('❯')} Azure endpoint URL {grey(p.get('api_base_hint', ''))}: " + ) + if api_base: + env_vars[f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}"] = api_base + deployment = _styled_input( + f" {blue('❯')} Azure deployment name {grey('(e.g. my-gpt4o)')}: " + ) + if deployment: + env_vars[ + f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}" + ] = deployment + + # Store the key returned by validation — may be a re-entered replacement + env_vars[p["env_key"]] = SetupWizard._validate_and_report(p, key) + + return env_vars + + @staticmethod + def _prompt_key(provider: Dict) -> str: + """Prompt for a provider's API key, with skip option. Returns the key or ''.""" + hint = grey(provider.get("key_hint", "")) + while True: + key = _styled_input( + f" {blue('❯')} {bold(provider['name'])} API key {hint}: " + ) + if key: + return key + print(grey(" Key is required. Leave blank to skip this provider.")) + if _styled_input(grey(" Skip? (y/N): ")).lower() == "y": + return "" + + @staticmethod + def _validate_and_report(provider: Dict, api_key: str) -> str: + """ + Validate credentials using litellm.utils.check_valid_key and print result. + Offers a re-entry loop on failure. Returns the final (possibly re-entered) key. + """ + test_model: Optional[str] = provider.get("test_model") + if not test_model: + return api_key # Azure / Bedrock / Ollama — skip validation + + while True: + print( + f" {grey('Testing connection to ' + provider['name'] + '...')}", + flush=True, + ) + valid = check_valid_key(model=test_model, api_key=api_key) + if valid: + print( + f" {green(_CHECK)} {bold(provider['name'])} connected successfully" + ) + return api_key + + print(f" {_CROSS} {bold(provider['name'])} {grey('— invalid API key')}") + if ( + _styled_input(f" {blue('❯')} Re-enter key? {grey('(y/N)')}: ").lower() + != "y" + ): + return api_key + + hint = grey(provider.get("key_hint", "")) + new_key = _styled_input( + f" {blue('❯')} {bold(provider['name'])} API key {hint}: " + ) + if not new_key: + return api_key + api_key = new_key + + # ── proxy settings ─────────────────────────────────────────────────────── + + @staticmethod + def _proxy_settings() -> "tuple[int, str]": + print() + print(_divider()) + print() + print(f" {bold('Proxy settings')}") + print() + port = 4000 + while True: + port_raw = _styled_input(f" {blue('❯')} Port {grey('[4000]')}: ") + if not port_raw: + break + if port_raw.isdigit() and 1 <= int(port_raw) <= 65535: + port = int(port_raw) + break + print(grey(" Enter a valid port number (1–65535).")) + key_raw = _styled_input(f" {blue('❯')} Master key {grey('[auto-generate]')}: ") + master_key = key_raw if key_raw else f"sk-{secrets.token_urlsafe(32)}" + return port, master_key + + # ── config generation ──────────────────────────────────────────────────── + + @staticmethod + def _build_config( + providers: List[Dict], + env_vars: Dict[str, str], + master_key: str, + ) -> str: + env_copy = dict(env_vars) # work on a copy — do not mutate caller's dict + lines = ["model_list:"] + for p in providers: + # Only emit models for providers that actually have credentials + has_creds = p["env_key"] is None or p["env_key"] in env_copy + if not has_creds: + continue + + if p["id"] == "azure": + deployment = env_copy.pop( + f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}", "" + ) + if not deployment: + continue # skip Azure entirely if no deployment name was provided + models = [f"azure/{deployment}"] + else: + models = p["models"] + + for model in models: + raw_display = model.split("/")[-1] if "/" in model else model + # Qualify azure display names to avoid collision with OpenAI model names + display = f"azure-{raw_display}" if p["id"] == "azure" else raw_display + lines += [ + f" - model_name: {display}", + " litellm_params:", + f" model: {model}", + ] + if p["env_key"] and p["env_key"] in env_copy: + lines.append(f" api_key: os.environ/{p['env_key']}") + if p.get("api_base"): + lines.append( + f' api_base: "{_yaml_escape(str(p["api_base"]))}"' + ) + elif p.get("needs_api_base"): + azure_base_key = f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}" + if azure_base_key in env_copy: + lines.append( + f' api_base: "{_yaml_escape(env_copy.pop(azure_base_key))}"' + ) + if p.get("api_version"): + lines.append(f" api_version: {p['api_version']}") + + lines += [ + "", + "general_settings:", + f' master_key: "{_yaml_escape(master_key)}"', + "", + ] + + real_vars = {k: v for k, v in env_copy.items() if not k.startswith("_LITELLM_")} + if real_vars: + lines.append("environment_variables:") + for k, v in real_vars.items(): + lines.append(f' {k}: "{_yaml_escape(v)}"') + lines.append("") + + return "\n".join(lines) + + # ── success + launch ───────────────────────────────────────────────────── + + @staticmethod + def _print_success(config_path: Path, port: int, master_key: str) -> None: + print() + print(_divider()) + print() + print(f" {green(_CHECK + ' Config saved')} → {bold(str(config_path))}") + print() + print(f" {bold('To start your proxy:')}") + print() + print(f" {grey('$')} litellm --config {config_path} --port {port}") + print() + print(f" {bold('Then set your client:')}") + print() + print(f" export OPENAI_BASE_URL=http://localhost:{port}") + print(f" export OPENAI_API_KEY={master_key}") + print() + print(_divider()) + print() + + @staticmethod + def _offer_start(config_path: Path, port: int, master_key: str) -> None: + start = _styled_input( + f" {blue('❯')} Start the proxy now? {grey('(Y/n)')}: " + ).lower() + if start not in ("", "y", "yes"): + print() + print( + f" Run {bold(f'litellm --config {config_path}')} whenever you're ready." + ) + print() + print( + grey(f" Quick test once running: curl http://localhost:{port}/health") + ) + print() + return + + print() + print(_divider()) + print() + print(f" {bold('Proxy is starting on')} http://localhost:{port}") + print() + print(grey(" Your proxy is OpenAI-compatible. Point any OpenAI SDK at it:")) + print() + print(f" export OPENAI_BASE_URL=http://localhost:{port}") + print(f" export OPENAI_API_KEY={master_key}") + print() + print(grey(" Quick test (in another terminal):")) + print() + print(f" curl http://localhost:{port}/health") + print() + print(grey(" Dashboard:")) + print() + print(f" http://localhost:{port}/ui {grey('(login with your master key)')}") + print() + print(_divider()) + print() + print(f" {green(_CHECK)} Starting… {grey('(Ctrl+C to stop)')}") + print() + + scripts_dir = sysconfig.get_path("scripts") + litellm_bin = os.path.join(scripts_dir or "", "litellm") + try: + os.execlp( + litellm_bin, + litellm_bin, + "--config", + str(config_path), + "--port", + str(port), + ) # noqa: S606 + except OSError as exc: + print(f"\n {bold(_CROSS + ' Could not start proxy:')} {exc}") + print(f" Run manually: litellm --config {config_path} --port {port}\n") + + +# --------------------------------------------------------------------------- +# Public entrypoint +# --------------------------------------------------------------------------- + + +def run_setup_wizard() -> None: + """Run the interactive setup wizard. Called by `litellm --setup`.""" + SetupWizard.run() diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 27fa27e6da3..f798f05380d 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -79,6 +79,7 @@ class SupportedGuardrailIntegrations(Enum): SEMANTIC_GUARD = "semantic_guard" MCP_END_USER_PERMISSION = "mcp_end_user_permission" BLOCK_CODE_EXECUTION = "block_code_execution" + MCP_JWT_SIGNER = "mcp_jwt_signer" class Role(Enum): diff --git a/litellm/types/proxy/vantage_endpoints.py b/litellm/types/proxy/vantage_endpoints.py index cf4f0a6685f..84199e78340 100644 --- a/litellm/types/proxy/vantage_endpoints.py +++ b/litellm/types/proxy/vantage_endpoints.py @@ -39,7 +39,8 @@ class VantageExportRequest(BaseModel): """Request model for Vantage export operations (actual export, no default limit)""" limit: Optional[int] = Field( - None, description="Optional limit on number of records to export (default: no limit)" + None, + description="Optional limit on number of records to export (default: no limit)", ) start_time_utc: Optional[datetime] = Field( None, description="Start time for data export in UTC" diff --git a/litellm/types/videos/utils.py b/litellm/types/videos/utils.py index 3a100129bcd..bf51fdda370 100644 --- a/litellm/types/videos/utils.py +++ b/litellm/types/videos/utils.py @@ -195,7 +195,9 @@ def decode_character_id_with_provider(encoded_character_id: str) -> DecodedChara character_id=decoded_character_id, ) except Exception as e: - verbose_logger.debug(f"Error decoding character_id '{encoded_character_id}': {e}") + verbose_logger.debug( + f"Error decoding character_id '{encoded_character_id}': {e}" + ) return DecodedCharacterId( custom_llm_provider=None, model_id=None, diff --git a/litellm/videos/main.py b/litellm/videos/main.py index d32a873e0b7..cd61293cd1c 100644 --- a/litellm/videos/main.py +++ b/litellm/videos/main.py @@ -1186,13 +1186,17 @@ def video_create_character( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config( + provider_config: Optional[ + BaseVideoConfig + ] = ProviderConfigManager.get_provider_video_config( model=None, provider=litellm.LlmProviders(custom_llm_provider), ) if provider_config is None: - raise ValueError(f"video create character is not supported for {custom_llm_provider}") + raise ValueError( + f"video create character is not supported for {custom_llm_provider}" + ) local_vars.update(kwargs) request_params: Dict = {"name": name} @@ -1311,13 +1315,17 @@ def video_get_character( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config( + provider_config: Optional[ + BaseVideoConfig + ] = ProviderConfigManager.get_provider_video_config( model=None, provider=litellm.LlmProviders(custom_llm_provider), ) if provider_config is None: - raise ValueError(f"video get character is not supported for {custom_llm_provider}") + raise ValueError( + f"video get character is not supported for {custom_llm_provider}" + ) local_vars.update(kwargs) request_params: Dict = {"character_id": character_id} @@ -1439,7 +1447,9 @@ def video_edit( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config( + provider_config: Optional[ + BaseVideoConfig + ] = ProviderConfigManager.get_provider_video_config( model=None, provider=litellm.LlmProviders(custom_llm_provider), ) @@ -1572,16 +1582,24 @@ def video_extension( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config( + provider_config: Optional[ + BaseVideoConfig + ] = ProviderConfigManager.get_provider_video_config( model=None, provider=litellm.LlmProviders(custom_llm_provider), ) if provider_config is None: - raise ValueError(f"video extension is not supported for {custom_llm_provider}") + raise ValueError( + f"video extension is not supported for {custom_llm_provider}" + ) local_vars.update(kwargs) - request_params: Dict = {"video_id": video_id, "prompt": prompt, "seconds": seconds} + request_params: Dict = { + "video_id": video_id, + "prompt": prompt, + "seconds": seconds, + } litellm_logging_obj.update_environment_variables( model="", diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 6786fc33595..181045809f8 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -32354,6 +32354,53 @@ "supports_vision": true, "supports_web_search": true }, + "xai/grok-4.20-multi-agent-beta-0309": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "xai/grok-4.20-beta-0309-reasoning": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, + "xai/grok-4.20-beta-0309-non-reasoning": { + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 2e-06, + "litellm_provider": "xai", + "max_input_tokens": 2000000, + "max_output_tokens": 2000000, + "max_tokens": 2000000, + "mode": "chat", + "output_cost_per_token": 6e-06, + "source": "https://docs.x.ai/docs/models", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_web_search": true + }, "xai/grok-beta": { "input_cost_per_token": 5e-06, "litellm_provider": "xai", diff --git a/scripts/install.sh b/scripts/install.sh new file mode 100755 index 00000000000..b9912287b70 --- /dev/null +++ b/scripts/install.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +# LiteLLM Installer +# Usage: curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh +# +# NOTE: set -e without pipefail for POSIX sh compatibility (dash on Ubuntu/Debian +# ignores the shebang when invoked as `sh` and does not support `pipefail`). +set -eu + +MIN_PYTHON_MAJOR=3 +MIN_PYTHON_MINOR=9 + +# NOTE: before merging, this must stay as "litellm[proxy]" to install from PyPI. +LITELLM_PACKAGE="litellm[proxy]" + +# ── colours ──────────────────────────────────────────────────────────────── +if [ -t 1 ]; then + BOLD='\033[1m' + GREEN='\033[38;2;78;186;101m' + GREY='\033[38;2;153;153;153m' + RESET='\033[0m' +else + BOLD='' GREEN='' GREY='' RESET='' +fi + +info() { printf "${GREY} %s${RESET}\n" "$*"; } +success() { printf "${GREEN} ✔ %s${RESET}\n" "$*"; } +header() { printf "${BOLD} %s${RESET}\n" "$*"; } +die() { printf "\n Error: %s\n\n" "$*" >&2; exit 1; } + +# ── banner ───────────────────────────────────────────────────────────────── +echo "" +cat << 'EOF' + ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗ + ██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║ + ██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║ + ██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║ + ███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║ + ╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝ +EOF +printf " ${BOLD}LiteLLM Installer${RESET} ${GREY}— unified gateway for 100+ LLM providers${RESET}\n\n" + +# ── OS detection ─────────────────────────────────────────────────────────── +OS="$(uname -s)" +ARCH="$(uname -m)" + +case "$OS" in + Darwin) PLATFORM="macOS ($ARCH)" ;; + Linux) PLATFORM="Linux ($ARCH)" ;; + *) die "Unsupported OS: $OS. LiteLLM supports macOS and Linux." ;; +esac + +info "Platform: $PLATFORM" + +# ── Python detection ─────────────────────────────────────────────────────── +PYTHON_BIN="" +for candidate in python3 python; do + if command -v "$candidate" >/dev/null 2>&1; then + major="$("$candidate" -c 'import sys; print(sys.version_info.major)' 2>/dev/null || true)" + minor="$("$candidate" -c 'import sys; print(sys.version_info.minor)' 2>/dev/null || true)" + if [ "${major:-0}" -ge "$MIN_PYTHON_MAJOR" ] && [ "${minor:-0}" -ge "$MIN_PYTHON_MINOR" ]; then + PYTHON_BIN="$(command -v "$candidate")" + info "Python: $("$candidate" --version 2>&1)" + break + fi + fi +done + +if [ -z "$PYTHON_BIN" ]; then + die "Python ${MIN_PYTHON_MAJOR}.${MIN_PYTHON_MINOR}+ is required but not found. + Install it from https://python.org/downloads or via your package manager: + macOS: brew install python@3 + Ubuntu: sudo apt install python3 python3-pip" +fi + +# ── pip detection ────────────────────────────────────────────────────────── +if ! "$PYTHON_BIN" -m pip --version >/dev/null 2>&1; then + die "pip is not available. Install it with: + $PYTHON_BIN -m ensurepip --upgrade" +fi + +# ── install ──────────────────────────────────────────────────────────────── +echo "" +header "Installing litellm[proxy]…" +echo "" + +"$PYTHON_BIN" -m pip install --upgrade "${LITELLM_PACKAGE}" \ + || die "pip install failed. Try manually: $PYTHON_BIN -m pip install '${LITELLM_PACKAGE}'" + +# ── find the litellm binary installed by pip for this Python ─────────────── +# sysconfig.get_path('scripts') is where pip puts console scripts — reliable +# even when the Python lives in a libexec/ symlink tree (e.g. Homebrew). +SCRIPTS_DIR="$("$PYTHON_BIN" -c 'import sysconfig; print(sysconfig.get_path("scripts"))')" +LITELLM_BIN="${SCRIPTS_DIR}/litellm" + +if [ ! -x "$LITELLM_BIN" ]; then + # Fall back to user-base bin (pip install --user) + USER_BIN="$("$PYTHON_BIN" -c 'import site; print(site.getuserbase())')/bin" + LITELLM_BIN="${USER_BIN}/litellm" +fi + +if [ ! -x "$LITELLM_BIN" ]; then + die "litellm binary not found after install. Try: $PYTHON_BIN -m pip install --user '${LITELLM_PACKAGE}'" +fi + +# ── success banner ───────────────────────────────────────────────────────── +echo "" +success "LiteLLM installed" + +installed_ver="$("$LITELLM_BIN" --version 2>&1 | grep -oE '[0-9]+\.[0-9]+\.[0-9]+' | head -1 || true)" +[ -n "$installed_ver" ] && info "Version: $installed_ver" + +# ── PATH hint ────────────────────────────────────────────────────────────── +if ! command -v litellm >/dev/null 2>&1; then + info "Note: add litellm to your PATH: export PATH=\"\$PATH:${SCRIPTS_DIR}\"" +fi + +# ── launch setup wizard ──────────────────────────────────────────────────── +echo "" +printf " ${BOLD}Run the interactive setup wizard?${RESET} ${GREY}(Y/n)${RESET}: " +# /dev/tty may be unavailable in Docker/CI — default to yes if it can't be read +answer="" +if [ -r /dev/tty ]; then + read -r answer = 2, ( + f"Expected at least 2 batch items (trace-create + generation-create) " + f"after filtering by trace_id={trace_id}, " + f"but got {len(actual_request_body['batch'])}. " + f"Items: {json.dumps(actual_request_body['batch'], indent=2)}" + ) + # Replace dynamic values in actual request body for item in actual_request_body["batch"]: @@ -150,19 +173,36 @@ async def _verify_langfuse_call( """Helper method to verify Langfuse API calls""" await asyncio.sleep(3) - # Verify the call + # Verify at least one call was made assert mock_post.call_count >= 1 - url = mock_post.call_args[0][0] - request_body = mock_post.call_args[1].get("content") - # Parse the JSON string into a dict for assertions - actual_request_body = json.loads(request_body) + # Aggregate batch items from ALL calls — the Langfuse SDK may split + # trace-create and generation-create across separate HTTP flushes. + langfuse_url = "https://us.cloud.langfuse.com/api/public/ingestion" + all_batch_items: list = [] + metadata: Optional[dict] = None + for call in mock_post.call_args_list: + url = call[0][0] + if url != langfuse_url: + continue + request_body = call[1].get("content") + if request_body: + body = json.loads(request_body) + all_batch_items.extend(body.get("batch", [])) + if metadata is None: + metadata = body.get("metadata") + + assert len(all_batch_items) > 0, "No Langfuse ingestion calls found" + assert metadata is not None, "No metadata found in Langfuse calls" + + actual_request_body = { + "batch": all_batch_items, + "metadata": metadata, + } - print("\nMocked Request Details:") - print(f"URL: {url}") + print("\nMocked Request Details (aggregated from all calls):") print(f"Request Body: {json.dumps(actual_request_body, indent=4)}") - assert url == "https://us.cloud.langfuse.com/api/public/ingestion" assert_langfuse_request_matches_expected( actual_request_body, expected_file_name, @@ -170,6 +210,7 @@ async def _verify_langfuse_call( ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion(self, mock_setup): """Test Langfuse logging for chat completion""" setup = mock_setup @@ -185,6 +226,7 @@ async def test_langfuse_logging_completion(self, mock_setup): ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_tags(self, mock_setup): """Test Langfuse logging for chat completion with tags""" setup = mock_setup @@ -203,6 +245,7 @@ async def test_langfuse_logging_completion_with_tags(self, mock_setup): ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_tags_stream(self, mock_setup): """Test Langfuse logging for chat completion with tags""" setup = mock_setup @@ -223,6 +266,7 @@ async def test_langfuse_logging_completion_with_tags_stream(self, mock_setup): ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_langfuse_metadata(self, mock_setup): """Test Langfuse logging for chat completion with metadata for langfuse""" setup = mock_setup @@ -252,6 +296,7 @@ async def test_langfuse_logging_completion_with_langfuse_metadata(self, mock_set ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_with_non_serializable_metadata(self, mock_setup): """Test Langfuse logging with metadata that requires preparation (Pydantic models, sets, etc)""" from pydantic import BaseModel @@ -358,6 +403,7 @@ async def test_langfuse_logging_with_various_metadata_types( ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_malformed_llm_response( self, mock_setup ): @@ -387,6 +433,7 @@ async def test_langfuse_logging_completion_with_malformed_llm_response( ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_bedrock_llm_response( self, mock_setup ): @@ -418,6 +465,7 @@ async def test_langfuse_logging_completion_with_bedrock_llm_response( setup["mock_post"], "completion_with_bedrock_call.json", setup["trace_id"] ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_completion_with_vertex_llm_response( self, mock_setup ): @@ -449,6 +497,7 @@ async def test_langfuse_logging_completion_with_vertex_llm_response( ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_vllm_embedding(self, mock_setup): """ Test that the request sent to the vllm embedding endpoint is correct. @@ -500,6 +549,7 @@ async def test_langfuse_logging_vllm_embedding(self, mock_setup): ) @pytest.mark.asyncio + @pytest.mark.flaky(retries=3, delay=1) async def test_langfuse_logging_with_router(self, mock_setup): """Test Langfuse logging with router""" litellm._turn_on_debug() diff --git a/tests/spend_tracking_tests/test_spend_accuracy_tests.py b/tests/spend_tracking_tests/test_spend_accuracy_tests.py index 11b8e209dc1..90efe28ab84 100644 --- a/tests/spend_tracking_tests/test_spend_accuracy_tests.py +++ b/tests/spend_tracking_tests/test_spend_accuracy_tests.py @@ -308,16 +308,19 @@ async def test_long_term_spend_accuracy_with_bursts(): response = await chat_completion(session, key) print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed") - # Poll until key spend reflects burst 2 - burst_1_spend = intermediate_key_info["info"]["spend"] + # Poll until key spend reaches expected total (burst 1 + burst 2) start = time.time() while time.time() - start < 120: key_info_check = await get_spend_info(session, "key", key) current_spend = key_info_check["info"]["spend"] - if current_spend > burst_1_spend: - print(f"Key spend increased to {current_spend} after {time.time() - start:.1f}s") + if abs(current_spend - expected_spend) < TOLERANCE: + print( + f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s" + ) break - print(f"Key spend still {current_spend}, waiting for burst 2 flush...") + print( + f"Key spend {current_spend}, expected {expected_spend}, waiting..." + ) await asyncio.sleep(10) # Allow extra time for all entity spend aggregations diff --git a/tests/test_litellm/interactions/test_openapi_compliance.py b/tests/test_litellm/interactions/test_openapi_compliance.py index 5b2edcf1ee7..a22244f3f8c 100644 --- a/tests/test_litellm/interactions/test_openapi_compliance.py +++ b/tests/test_litellm/interactions/test_openapi_compliance.py @@ -77,6 +77,11 @@ def test_input_types_match_spec(self, spec_dict): schema = spec_dict["components"]["schemas"]["CreateModelInteractionParams"] input_schema = schema["properties"]["input"] + # The input property may be inline oneOf or a $ref to InteractionsInput + if "$ref" in input_schema: + ref_name = input_schema["$ref"].split("/")[-1] + input_schema = spec_dict["components"]["schemas"][ref_name] + # Should be oneOf with multiple types assert "oneOf" in input_schema @@ -100,10 +105,21 @@ def test_content_schema_uses_discriminator(self, spec_dict): assert "discriminator" in content_schema assert content_schema["discriminator"]["propertyName"] == "type" - # Check TextContent is an option - mapping = content_schema["discriminator"]["mapping"] - assert "text" in mapping - print(f"Content type discriminator mapping: {list(mapping.keys())}") + # Check TextContent is an option (via mapping if present, or via oneOf refs) + mapping = content_schema["discriminator"].get("mapping") + if mapping: + assert "text" in mapping + print(f"Content type discriminator mapping: {list(mapping.keys())}") + else: + # Discriminator without explicit mapping — verify via oneOf + one_of = content_schema.get("oneOf", []) + ref_names = [ + opt["$ref"].split("/")[-1] for opt in one_of if "$ref" in opt + ] + assert "TextContent" in ref_names, ( + f"TextContent not found in oneOf refs: {ref_names}" + ) + print(f"Content type discriminator (no mapping), oneOf refs: {ref_names}") def test_text_content_schema(self, spec_dict): """Verify TextContent schema.""" diff --git a/tests/test_litellm/llms/xai/test_xai_cost_calculator.py b/tests/test_litellm/llms/xai/test_xai_cost_calculator.py index 1c2de953c2f..00955bc5252 100644 --- a/tests/test_litellm/llms/xai/test_xai_cost_calculator.py +++ b/tests/test_litellm/llms/xai/test_xai_cost_calculator.py @@ -386,8 +386,56 @@ def test_web_search_cost_without_prompt_tokens_details(self): completion_tokens=50, total_tokens=150, ) - + web_search_cost = cost_per_web_search_request(usage=usage, model_info={}) - + # Expected cost: No web search data = $0.0 assert web_search_cost == 0.0 + + def test_grok_4_20_beta_reasoning_cost_calculation(self): + """Test cost calculation for grok-4.20-beta-0309-reasoning model.""" + usage = Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300) + + prompt_cost, completion_cost = cost_per_token( + model="grok-4.20-beta-0309-reasoning", usage=usage + ) + + # Input: 100 tokens * $2e-6 = $0.0002 + # Output: 200 tokens * $6e-6 = $0.0012 + expected_prompt_cost = 100 * 2e-6 + expected_completion_cost = 200 * 6e-6 + + assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10) + assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10) + + def test_grok_4_20_beta_non_reasoning_cost_calculation(self): + """Test cost calculation for grok-4.20-beta-0309-non-reasoning model.""" + usage = Usage(prompt_tokens=50, completion_tokens=100, total_tokens=150) + + prompt_cost, completion_cost = cost_per_token( + model="grok-4.20-beta-0309-non-reasoning", usage=usage + ) + + # Input: 50 tokens * $2e-6 = $0.0001 + # Output: 100 tokens * $6e-6 = $0.0006 + expected_prompt_cost = 50 * 2e-6 + expected_completion_cost = 100 * 6e-6 + + assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10) + assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10) + + def test_grok_4_20_multi_agent_cost_calculation(self): + """Test cost calculation for grok-4.20-multi-agent-beta-0309 model.""" + usage = Usage(prompt_tokens=200, completion_tokens=300, total_tokens=500) + + prompt_cost, completion_cost = cost_per_token( + model="grok-4.20-multi-agent-beta-0309", usage=usage + ) + + # Input: 200 tokens * $2e-6 = $0.0004 + # Output: 300 tokens * $6e-6 = $0.0018 + expected_prompt_cost = 200 * 2e-6 + expected_completion_cost = 300 * 6e-6 + + assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10) + assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py new file mode 100644 index 00000000000..32f3a340855 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py @@ -0,0 +1,707 @@ +""" +Tests for pre_mcp_call guardrail hook header mutation support. + +Validates that: +1. _convert_mcp_hook_response_to_kwargs extracts extra_headers from hook response +2. pre_call_tool_check returns hook-provided extra_headers AND modified arguments +3. call_tool flows hook headers and modified arguments downstream +4. Hook-provided headers take highest priority (merge after static_headers) +5. OpenAPI-backed servers log a warning and continue (skip injection) when hook headers are present +6. JWT claims are propagated in both standard and virtual-key fast paths +7. Backward compatibility: hooks without extra_headers continue to work +""" + +import asyncio +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.utils import ProxyLogging +from litellm.types.mcp import MCPAuth, MCPTransport +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +class TestConvertMcpHookResponseToKwargs: + """Tests for ProxyLogging._convert_mcp_hook_response_to_kwargs""" + + def setup_method(self): + self.proxy_logging = ProxyLogging(user_api_key_cache=MagicMock()) + + def test_returns_original_kwargs_when_response_is_none(self): + original = {"arguments": {"key": "val"}, "name": "tool"} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + None, original + ) + assert result == original + + def test_returns_original_kwargs_when_response_is_empty_dict(self): + original = {"arguments": {"key": "val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs({}, original) + assert result == original + + def test_extracts_modified_arguments(self): + original = {"arguments": {"old": "value"}} + response = {"modified_arguments": {"new": "value"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + + def test_extracts_extra_headers(self): + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {"Authorization": "Bearer signed-jwt"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["extra_headers"] == {"Authorization": "Bearer signed-jwt"} + + def test_extracts_both_arguments_and_headers(self): + original = {"arguments": {"old": "value"}} + response = { + "modified_arguments": {"new": "value"}, + "extra_headers": {"X-Custom": "header-val"}, + } + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + assert result["extra_headers"] == {"X-Custom": "header-val"} + + def test_no_extra_headers_key_preserves_original(self): + """Backward compat: hooks that only return modified_arguments still work.""" + original = {"arguments": {"key": "val"}} + response = {"modified_arguments": {"key": "new_val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + assert result["arguments"] == {"key": "new_val"} + + def test_empty_extra_headers_not_set(self): + """Empty dict for extra_headers is falsy and should not be set.""" + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + + +class TestPreCallToolCheckReturnsHeaders: + """Tests that pre_call_tool_check returns hook-provided headers.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_has_no_headers(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": {"key": "val"}} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_extra_headers_from_hook(self): + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt", "X-Trace-Id": "abc123"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"extra_headers": hook_headers} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["extra_headers"] == hook_headers + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_returns_none(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value=None) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_modified_arguments_from_hook(self): + """Modified arguments from the hook must be returned so the caller can use them.""" + manager = MCPServerManager() + server = self._make_server() + + original_args = {"key": "original"} + modified_args = {"key": "modified", "extra": "added"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": modified_args} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments=original_args, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + + @pytest.mark.asyncio + async def test_returns_both_modified_arguments_and_headers(self): + """Hook can modify both arguments and inject headers simultaneously.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified"} + hook_headers = {"Authorization": "Bearer jwt"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value={"dummy": True}) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "original"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + assert result["extra_headers"] == hook_headers + + +class TestCallToolFlowsHookHeaders: + """Tests that call_tool passes hook_extra_headers to _call_regular_mcp_tool.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_hook_headers_passed_to_call_regular_mcp_tool(self): + """Verify that hook_extra_headers kwarg is forwarded.""" + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": hook_headers}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") == hook_headers + + @pytest.mark.asyncio + async def test_no_hook_headers_when_no_proxy_logging(self): + """Without proxy_logging_obj, no pre_call_tool_check runs.""" + manager = MCPServerManager() + server = self._make_server() + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=None, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") is None + + @pytest.mark.asyncio + async def test_modified_arguments_passed_to_downstream(self): + """Hook-modified arguments must be used for the actual tool call.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified_by_hook"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"arguments": modified_args}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "original"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("arguments") == modified_args + + @pytest.mark.asyncio + async def test_openapi_server_warns_and_continues_on_hook_headers(self): + """OpenAPI-backed servers log a warning and continue when hook injects headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": {"Authorization": "Bearer jwt"}}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_openapi_tool_handler", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + import litellm.proxy._experimental.mcp_server.mcp_server_manager as mgr_mod + + proxy_logging = MagicMock(spec=ProxyLogging) + + with patch.object(mgr_mod, "verbose_logger") as mock_logger: + # Should NOT raise — just warn and proceed + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + mock_logger.warning.assert_called_once() + assert "header injection is not supported" in mock_logger.warning.call_args[0][0] + + @pytest.mark.asyncio + async def test_openapi_server_no_error_without_hook_headers(self): + """No exception when OpenAPI server has no hook-injected headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_openapi_tool_handler", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + + +class TestHookHeaderMergePriority: + """Tests that hook-provided headers have highest priority in _call_regular_mcp_tool.""" + + def _make_server( + self, + static_headers: Optional[Dict[str, str]] = None, + extra_headers_config: Optional[list] = None, + ): + return MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + static_headers=static_headers, + extra_headers=extra_headers_config, + ) + + @pytest.mark.asyncio + async def test_hook_headers_override_static_headers(self): + """Hook headers should take precedence over static_headers.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"Authorization": "Bearer static-token", "X-Static": "yes"} + ) + + hook_headers = {"Authorization": "Bearer hook-signed-jwt"} + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=hook_headers, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-signed-jwt" + assert headers["X-Static"] == "yes" + + @pytest.mark.asyncio + async def test_no_hook_headers_preserves_existing_behavior(self): + """When hook_extra_headers is None, existing header logic is unchanged.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"X-Static": "static-value"} + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=None, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers == {"X-Static": "static-value"} + + @pytest.mark.asyncio + async def test_hook_headers_merge_with_oauth2(self): + """Hook headers merge on top of OAuth2 headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers={ + "Authorization": "Bearer oauth2-token", + "X-OAuth": "yes", + }, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers={ + "Authorization": "Bearer hook-jwt", + "X-Trace-Id": "trace-123", + }, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-jwt" + assert headers["X-OAuth"] == "yes" + assert headers["X-Trace-Id"] == "trace-123" + + +class TestUserAPIKeyAuthJwtClaims: + """Tests that UserAPIKeyAuth correctly carries jwt_claims.""" + + def test_jwt_claims_field_defaults_to_none(self): + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + def test_jwt_claims_field_accepts_dict(self): + claims = {"sub": "user-123", "iss": "litellm", "exp": 9999999999} + auth = UserAPIKeyAuth(api_key="test-key", jwt_claims=claims) + assert auth.jwt_claims == claims + assert auth.jwt_claims["sub"] == "user-123" + + def test_jwt_claims_backward_compatible_without_field(self): + """Existing code that doesn't pass jwt_claims should still work.""" + auth = UserAPIKeyAuth( + api_key="test-key", + user_id="user-1", + team_id="team-1", + ) + assert auth.jwt_claims is None + assert auth.user_id == "user-1" + + def test_jwt_claims_set_after_construction(self): + """Virtual-key fast path sets jwt_claims after the object is created.""" + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + claims = {"sub": "user-456", "iss": "okta", "groups": ["admin"]} + auth.jwt_claims = claims + assert auth.jwt_claims == claims + assert auth.jwt_claims["groups"] == ["admin"] diff --git a/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py b/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py new file mode 100644 index 00000000000..247fe7b5764 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py @@ -0,0 +1,1103 @@ +""" +Tests for the MCPJWTSigner built-in guardrail. + +Tests cover: + - RSA key generation and loading + - JWT signing and JWKS format + - Claim building (sub, act, scope) + - Hook fires for call_mcp_tool, skips other call types + - get_mcp_jwt_signer() singleton pattern +""" + +import base64 +import time +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_user_api_key_dict( + user_id: str = "user-123", + team_id: str = "team-abc", + user_email: str = "user@example.com", + end_user_id: Optional[str] = None, +) -> MagicMock: + mock = MagicMock() + mock.user_id = user_id + mock.team_id = team_id + mock.user_email = user_email + mock.end_user_id = end_user_id + mock.org_id = None + mock.token = None + mock.api_key = None + # Explicit None so MagicMock doesn't auto-create a truthy proxy attribute + mock.jwt_claims = None + return mock + + +def _decode_unverified(token: str) -> Dict[str, Any]: + return jwt.decode(token, options={"verify_signature": False}) + + +# --------------------------------------------------------------------------- +# Import target (inline so we can reset the singleton between tests) +# --------------------------------------------------------------------------- + + +def _make_signer(**kwargs: Any): + # Reset singleton before each signer creation to avoid cross-test pollution + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + MCPJWTSigner, + ) + + return MCPJWTSigner( + guardrail_name="test-jwt-signer", + event_hook="pre_mcp_call", + default_on=True, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Key generation tests +# --------------------------------------------------------------------------- + + +def test_auto_generates_rsa_keypair(): + """MCPJWTSigner auto-generates an RSA-2048 keypair when env var is unset.""" + signer = _make_signer() + assert signer._private_key is not None + assert signer._public_key is not None + assert signer._kid is not None and len(signer._kid) == 16 + + +def test_kid_is_deterministic(): + """Two signers built from the same key have the same kid.""" + signer1 = _make_signer() + private_pem = signer1._private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": private_pem}): + signer2 = _make_signer() + + assert signer1._kid == signer2._kid + + +def test_load_key_from_env_var(): + """MCPJWTSigner loads a user-provided RSA key from the env var.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": pem}): + signer = _make_signer() + + assert signer._kid is not None + + +# --------------------------------------------------------------------------- +# JWKS tests +# --------------------------------------------------------------------------- + + +def test_get_jwks_format(): + """get_jwks() returns a valid JWKS dict with RSA fields.""" + signer = _make_signer() + jwks = signer.get_jwks() + + assert "keys" in jwks + assert len(jwks["keys"]) == 1 + key = jwks["keys"][0] + + assert key["kty"] == "RSA" + assert key["alg"] == "RS256" + assert key["use"] == "sig" + assert key["kid"] == signer._kid + assert "n" in key and len(key["n"]) > 0 + assert "e" in key and key["e"] == "AQAB" # 65537 in base64url + + +def test_jwks_public_key_can_verify_signed_jwt(): + """A JWT signed by MCPJWTSigner can be verified using the JWKS public key.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp") + now = int(time.time()) + claims = {"iss": "https://litellm.example.com", "aud": "mcp", "iat": now, "exp": now + 300} + + token = jwt.encode(claims, signer._private_key, algorithm="RS256") + + # Reconstruct public key from JWKS + jwks = signer.get_jwks() + key_data = jwks["keys"][0] + n = int.from_bytes(base64.urlsafe_b64decode(key_data["n"] + "=="), byteorder="big") + e = int.from_bytes(base64.urlsafe_b64decode(key_data["e"] + "=="), byteorder="big") + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers + pub_key = RSAPublicNumbers(e=e, n=n).public_key() + + decoded = jwt.decode( + token, + pub_key, + algorithms=["RS256"], + audience="mcp", + issuer="https://litellm.example.com", + ) + assert decoded["iss"] == "https://litellm.example.com" + + +# --------------------------------------------------------------------------- +# Claim building tests +# --------------------------------------------------------------------------- + + +def test_build_claims_standard_fields(): + """_build_claims() populates iss, aud, iat, exp, nbf.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp", ttl_seconds=300) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "get_weather"} + + claims = signer._build_claims(user_dict, data) + + assert claims["iss"] == "https://litellm.example.com" + assert claims["aud"] == "mcp" + assert "iat" in claims + assert "exp" in claims + assert claims["exp"] - claims["iat"] == 300 + assert "nbf" in claims + + +def test_build_claims_identity(): + """_build_claims() sets sub from user_id and act from team_id (RFC 8693).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="user-xyz", team_id="team-eng") + data: Dict[str, Any] = {} + + claims = signer._build_claims(user_dict, data) + + assert claims["sub"] == "user-xyz" + assert claims["act"]["sub"] == "team-eng" + assert claims["email"] == "user@example.com" + + +def test_build_claims_scope_with_tool(): + """_build_claims() encodes tool-specific scope when mcp_tool_name is set.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "search_web"} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/call" in scopes + assert "mcp:tools/search_web:call" in scopes + # Tool-call JWTs must NOT carry mcp:tools/list — least-privilege + assert "mcp:tools/list" not in scopes + + +def test_build_claims_scope_without_tool(): + """_build_claims() includes mcp:tools/list when no specific tool is called.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data: Dict[str, Any] = {} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/call" in scopes + assert "mcp:tools/list" in scopes + # No per-tool call scope when no tool name was given + assert not any(s.endswith(":call") and s != "mcp:tools/call" for s in scopes) + + +def test_build_claims_act_fallback_to_litellm_proxy(): + """_build_claims() falls back to 'litellm-proxy' when team_id and org_id are absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + user_dict.team_id = None + user_dict.org_id = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["act"]["sub"] == "litellm-proxy" + + +def test_build_claims_sub_fallback_to_token_hash(): + """_build_claims() sets sub to an apikey: hash when user_id is absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + user_dict.token = "sk-test-api-key-abc123" + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"].startswith("apikey:") + assert len(claims["sub"]) == len("apikey:") + 16 # sha256 hex[:16] + + +def test_build_claims_sub_fallback_to_litellm_proxy_when_no_token(): + """_build_claims() falls back to 'litellm-proxy' when user_id and token are both absent.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + user_dict.token = None + user_dict.api_key = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"] == "litellm-proxy" + + +def test_init_raises_on_zero_ttl(): + """MCPJWTSigner raises ValueError when ttl_seconds is 0.""" + with pytest.raises(ValueError, match="ttl_seconds must be > 0"): + _make_signer(ttl_seconds=0) + + +def test_init_raises_on_negative_ttl(): + """MCPJWTSigner raises ValueError when ttl_seconds is negative.""" + with pytest.raises(ValueError, match="ttl_seconds must be > 0"): + _make_signer(ttl_seconds=-60) + + +def test_jwks_max_age_persistent_key(): + """jwks_max_age is 3600 when key loaded from env var.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa as crsa + + private_key = crsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + with patch.dict("os.environ", {"MCP_JWT_SIGNING_KEY": pem}): + signer = _make_signer() + + assert signer.jwks_max_age == 3600 + + +def test_jwks_max_age_auto_generated_key(): + """jwks_max_age is 300 for auto-generated (ephemeral) keys.""" + signer = _make_signer() + assert signer.jwks_max_age == 300 + + +# --------------------------------------------------------------------------- +# Hook dispatch tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_fires_for_call_mcp_tool(): + """async_pre_call_hook() injects Authorization header for call_mcp_tool.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp") + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "do_thing"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "extra_headers" in result + assert result["extra_headers"]["Authorization"].startswith("Bearer ") + + +@pytest.mark.asyncio +async def test_hook_skips_non_mcp_call_types(): + """async_pre_call_hook() leaves data unchanged for non-MCP call types.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"messages": [{"role": "user", "content": "hello"}]} + + for call_type in ("completion", "acompletion", "embedding", "list_mcp_tools"): + original_data = {**data} + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=original_data, + call_type=call_type, # type: ignore[arg-type] + ) + assert "extra_headers" not in (result or {}), f"extra_headers should not be set for {call_type}" + + +@pytest.mark.asyncio +async def test_signed_token_is_verifiable(): + """The JWT injected by the hook can be verified against the JWKS public key.""" + signer = _make_signer(issuer="https://litellm.example.com", audience="mcp", ttl_seconds=300) + user_dict = _make_user_api_key_dict(user_id="alice", team_id="backend") + data = {"mcp_tool_name": "search"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + token = result["extra_headers"]["Authorization"].removeprefix("Bearer ") + + decoded = _decode_unverified(token) + assert decoded["sub"] == "alice" + assert decoded["act"]["sub"] == "backend" + assert "mcp:tools/search:call" in decoded["scope"] + assert decoded["iss"] == "https://litellm.example.com" + assert decoded["aud"] == "mcp" + + +# --------------------------------------------------------------------------- +# Singleton tests +# --------------------------------------------------------------------------- + + +def test_get_mcp_jwt_signer_returns_none_before_init(): + """get_mcp_jwt_signer() returns None before any MCPJWTSigner is created.""" + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + assert get_mcp_jwt_signer() is None + + +def test_get_mcp_jwt_signer_returns_instance_after_init(): + """get_mcp_jwt_signer() returns the initialized signer instance.""" + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import ( + get_mcp_jwt_signer, + ) + + signer = _make_signer() + assert get_mcp_jwt_signer() is signer + + +# --------------------------------------------------------------------------- +# FR-10: Configurable scopes +# --------------------------------------------------------------------------- + + +def test_allowed_scopes_replaces_auto_generation(): + """When allowed_scopes is set it is used verbatim instead of auto-generating.""" + signer = _make_signer(allowed_scopes=["mcp:admin", "mcp:tools/call"]) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "some_tool"} + + claims = signer._build_claims(user_dict, data) + + assert claims["scope"] == "mcp:admin mcp:tools/call" + + +def test_tool_call_scope_no_list_permission(): + """Tool-call JWTs must NOT carry mcp:tools/list (least-privilege).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "my_tool"} + + claims = signer._build_claims(user_dict, data) + + scopes = set(claims["scope"].split()) + assert "mcp:tools/list" not in scopes + assert "mcp:tools/call" in scopes + assert "mcp:tools/my_tool:call" in scopes + + +# --------------------------------------------------------------------------- +# FR-12: End-user identity mapping +# --------------------------------------------------------------------------- + + +def test_end_user_claim_sources_token_sub(): + """end_user_claim_sources resolves sub from incoming JWT claims.""" + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-user") + jwt_claims = {"sub": "idp-user-123", "email": "idp@example.com"} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "idp-user-123" + + +def test_end_user_claim_sources_falls_back_to_litellm_user_id(): + """Falls back to litellm:user_id when token:sub is absent.""" + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-user") + jwt_claims: Dict[str, Any] = {} # no sub + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "litellm-user" + + +def test_end_user_claim_sources_email_source(): + """token:email resolves correctly.""" + signer = _make_signer(end_user_claim_sources=["token:email"]) + user_dict = _make_user_api_key_dict(user_id="") + user_dict.user_id = None + jwt_claims = {"email": "alice@corp.com"} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["sub"] == "alice@corp.com" + + +def test_end_user_claim_sources_litellm_email(): + """litellm:email resolves from UserAPIKeyAuth.user_email.""" + signer = _make_signer(end_user_claim_sources=["litellm:email"]) + user_dict = _make_user_api_key_dict(user_email="proxy-user@example.com") + user_dict.user_id = None + + claims = signer._build_claims(user_dict, {}) + + assert claims["sub"] == "proxy-user@example.com" + + +# --------------------------------------------------------------------------- +# FR-13: Claim operations +# --------------------------------------------------------------------------- + + +def test_add_claims_inserts_when_absent(): + """add_claims inserts key when it is not already in the JWT.""" + signer = _make_signer(add_claims={"deployment_id": "prod-001"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert claims["deployment_id"] == "prod-001" + + +def test_add_claims_does_not_overwrite_existing(): + """add_claims does NOT overwrite an existing claim (use set_claims for that).""" + signer = _make_signer(add_claims={"iss": "should-not-win"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + # iss should be the configured issuer, not overwritten + assert claims["iss"] != "should-not-win" + + +def test_set_claims_always_overrides(): + """set_claims always overrides computed claims.""" + signer = _make_signer(set_claims={"iss": "override-issuer", "custom": "x"}) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert claims["iss"] == "override-issuer" + assert claims["custom"] == "x" + + +def test_remove_claims_deletes_keys(): + """remove_claims deletes specified keys from the final JWT.""" + signer = _make_signer(remove_claims=["nbf", "email"]) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert "nbf" not in claims + assert "email" not in claims + + +def test_claim_operations_order_add_then_set_then_remove(): + """add → set → remove is applied in order: set wins over add, remove beats both.""" + signer = _make_signer( + add_claims={"x": "from-add"}, + set_claims={"x": "from-set"}, + remove_claims=["x"], + ) + user_dict = _make_user_api_key_dict() + + claims = signer._build_claims(user_dict, {}) + + assert "x" not in claims # remove wins + + +# --------------------------------------------------------------------------- +# FR-14: Two-token model +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_channel_token_injected_when_configured(): + """When channel_token_audience is set, x-mcp-channel-token header is injected.""" + signer = _make_signer( + channel_token_audience="bedrock-gateway", + channel_token_ttl=60, + ) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "list_tables"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-mcp-channel-token" in result["extra_headers"] + channel_token = result["extra_headers"]["x-mcp-channel-token"].removeprefix("Bearer ") + channel_payload = _decode_unverified(channel_token) + assert channel_payload["aud"] == "bedrock-gateway" + + +@pytest.mark.asyncio +async def test_channel_token_absent_when_not_configured(): + """x-mcp-channel-token is not injected when channel_token_audience is unset.""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-mcp-channel-token" not in result["extra_headers"] + + +# --------------------------------------------------------------------------- +# FR-15: Incoming claim validation +# --------------------------------------------------------------------------- + + +def test_required_claims_pass_when_present(): + """_validate_required_claims() passes when all required claims are present.""" + signer = _make_signer(required_claims=["sub", "email"]) + # Should not raise + signer._validate_required_claims({"sub": "user", "email": "u@example.com"}) + + +def test_required_claims_raise_403_when_missing(): + """_validate_required_claims() raises HTTP 403 when a required claim is missing.""" + from fastapi import HTTPException + + signer = _make_signer(required_claims=["sub", "email"]) + with pytest.raises(HTTPException) as exc_info: + signer._validate_required_claims({"sub": "user"}) # email missing + + assert exc_info.value.status_code == 403 + assert "email" in str(exc_info.value.detail) + + +def test_required_claims_raise_when_no_jwt_claims(): + """_validate_required_claims() raises when jwt_claims is None and claims are required.""" + from fastapi import HTTPException + + signer = _make_signer(required_claims=["sub"]) + with pytest.raises(HTTPException): + signer._validate_required_claims(None) + + +def test_optional_claims_passed_through(): + """optional_claims are forwarded from incoming jwt_claims into the outbound JWT.""" + signer = _make_signer(optional_claims=["groups", "roles"]) + user_dict = _make_user_api_key_dict() + jwt_claims = {"sub": "u", "groups": ["admin"], "roles": ["editor"]} + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert claims["groups"] == ["admin"] + assert claims["roles"] == ["editor"] + + +def test_optional_claims_not_injected_if_absent(): + """optional_claims are silently skipped when absent in incoming jwt_claims.""" + signer = _make_signer(optional_claims=["groups"]) + user_dict = _make_user_api_key_dict() + jwt_claims: Dict[str, Any] = {"sub": "u"} # no groups + + claims = signer._build_claims(user_dict, {}, jwt_claims=jwt_claims) + + assert "groups" not in claims + + +# --------------------------------------------------------------------------- +# FR-9: Debug headers +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_debug_header_injected_when_enabled(): + """x-litellm-mcp-debug header is injected when debug_headers=True.""" + signer = _make_signer(debug_headers=True) + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "my_tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-litellm-mcp-debug" in result["extra_headers"] + debug_val = result["extra_headers"]["x-litellm-mcp-debug"] + assert "v=1" in debug_val + assert "kid=" in debug_val + assert "sub=" in debug_val + + +@pytest.mark.asyncio +async def test_debug_header_absent_when_disabled(): + """x-litellm-mcp-debug is NOT injected when debug_headers=False (default).""" + signer = _make_signer() + user_dict = _make_user_api_key_dict() + data = {"mcp_tool_name": "tool"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + assert "x-litellm-mcp-debug" not in result["extra_headers"] + + +# --------------------------------------------------------------------------- +# P1 fix: extra_headers merging (multi-guardrail chains) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extra_headers_are_merged_not_replaced(): + """ + Existing extra_headers from a prior guardrail are preserved — only + Authorization is added/overwritten, other keys survive. + """ + signer = _make_signer() + user_dict = _make_user_api_key_dict() + # Simulate a prior guardrail having injected a tracing header + data = { + "mcp_tool_name": "list", + "extra_headers": {"x-trace-id": "abc123", "x-correlation-id": "xyz"}, + } + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + headers = result["extra_headers"] + # Prior headers preserved + assert headers.get("x-trace-id") == "abc123" + assert headers.get("x-correlation-id") == "xyz" + # Authorization injected + assert "Authorization" in headers + + +# --------------------------------------------------------------------------- +# FR-5: Verify + re-sign — jwt_claims fallback from UserAPIKeyAuth +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sub_resolved_from_user_api_key_dict_jwt_claims(): + """ + When no raw token is present but UserAPIKeyAuth.jwt_claims has a sub, + the guardrail resolves sub from jwt_claims (LiteLLM-decoded JWT path). + """ + signer = _make_signer(end_user_claim_sources=["token:sub", "litellm:user_id"]) + user_dict = _make_user_api_key_dict(user_id="litellm-fallback") + # jwt_claims populated by LiteLLM's JWT auth machinery + user_dict.jwt_claims = {"sub": "idp-alice", "email": "alice@idp.com"} + data = {"mcp_tool_name": "query"} + + result = await signer.async_pre_call_hook( + user_api_key_dict=user_dict, + cache=MagicMock(), + data=data, + call_type="call_mcp_tool", + ) + + assert isinstance(result, dict) + token = result["extra_headers"]["Authorization"].removeprefix("Bearer ") + payload = _decode_unverified(token) + assert payload["sub"] == "idp-alice" + + +# --------------------------------------------------------------------------- +# initialize_guardrail factory — regression test for config.yaml wire-up +# --------------------------------------------------------------------------- + + +def test_initialize_guardrail_passes_all_params(): + """ + initialize_guardrail must wire every documented config.yaml param through + to MCPJWTSigner. Previously only issuer/audience/ttl_seconds were passed; + all FR-5/9/10/12/13/14/15 params were silently dropped. + """ + import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as mod + + mod._mcp_jwt_signer_instance = None + + from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer import ( + initialize_guardrail, + ) + + litellm_params = MagicMock() + litellm_params.mode = "pre_mcp_call" + litellm_params.default_on = True + litellm_params.optional_params = None + # Set every non-default param directly on litellm_params + litellm_params.issuer = "https://litellm.example.com" + litellm_params.audience = "mcp-test" + litellm_params.ttl_seconds = 120 + litellm_params.access_token_discovery_uri = "https://idp.example.com/.well-known/openid-configuration" + litellm_params.token_introspection_endpoint = "https://idp.example.com/introspect" + litellm_params.verify_issuer = "https://idp.example.com" + litellm_params.verify_audience = "api://test" + litellm_params.end_user_claim_sources = ["token:email", "litellm:user_id"] + litellm_params.add_claims = {"deployment_id": "prod"} + litellm_params.set_claims = {"env": "production"} + litellm_params.remove_claims = ["nbf"] + litellm_params.channel_token_audience = "bedrock-gateway" + litellm_params.channel_token_ttl = 60 + litellm_params.required_claims = ["sub", "email"] + litellm_params.optional_claims = ["groups"] + litellm_params.debug_headers = True + litellm_params.allowed_scopes = ["mcp:tools/call"] + + guardrail = {"guardrail_name": "mcp-jwt-signer"} + + with patch("litellm.logging_callback_manager.add_litellm_callback"): + signer = initialize_guardrail(litellm_params, guardrail) + + assert signer.issuer == "https://litellm.example.com" + assert signer.audience == "mcp-test" + assert signer.ttl_seconds == 120 + assert signer.access_token_discovery_uri == "https://idp.example.com/.well-known/openid-configuration" + assert signer.token_introspection_endpoint == "https://idp.example.com/introspect" + assert signer.verify_issuer == "https://idp.example.com" + assert signer.verify_audience == "api://test" + assert signer.end_user_claim_sources == ["token:email", "litellm:user_id"] + assert signer.add_claims == {"deployment_id": "prod"} + assert signer.set_claims == {"env": "production"} + assert signer.remove_claims == ["nbf"] + assert signer.channel_token_audience == "bedrock-gateway" + assert signer.channel_token_ttl == 60 + assert signer.required_claims == ["sub", "email"] + assert signer.optional_claims == ["groups"] + assert signer.debug_headers is True + assert signer.allowed_scopes == ["mcp:tools/call"] + + +# --------------------------------------------------------------------------- +# FR-5: _fetch_jwks, _get_oidc_discovery, _verify_incoming_jwt, +# _introspect_opaque_token +# --------------------------------------------------------------------------- + +import litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer as _signer_mod + + +def _make_httpx_response(json_body: dict, status_code: int = 200): + """Build a minimal fake httpx Response object.""" + mock_resp = MagicMock() + mock_resp.status_code = status_code + mock_resp.json.return_value = json_body + mock_resp.raise_for_status = MagicMock() + if status_code >= 400: + from httpx import HTTPStatusError, Request, Response + + mock_resp.raise_for_status.side_effect = HTTPStatusError( + "error", request=MagicMock(), response=MagicMock() + ) + return mock_resp + + +# --- _fetch_jwks --- + + +@pytest.mark.asyncio +async def test_fetch_jwks_returns_keys_and_caches(): + """_fetch_jwks returns keys from the remote JWKS URI and caches the result.""" + _signer_mod._jwks_cache.clear() + + fake_keys = [{"kty": "RSA", "kid": "k1", "n": "abc", "e": "AQAB"}] + fake_resp = _make_httpx_response({"keys": fake_keys}) + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + keys = await _signer_mod._fetch_jwks("https://idp.example.com/jwks") + + assert keys == fake_keys + assert "https://idp.example.com/jwks" in _signer_mod._jwks_cache + _signer_mod._jwks_cache.clear() + + +@pytest.mark.asyncio +async def test_fetch_jwks_uses_cache_on_second_call(): + """_fetch_jwks returns the cached value without a second HTTP call.""" + _signer_mod._jwks_cache.clear() + fake_keys = [{"kty": "RSA", "kid": "k1"}] + _signer_mod._jwks_cache["https://idp.example.com/jwks"] = ( + fake_keys, + time.time(), + ) + + mock_client = MagicMock() + mock_client.get = AsyncMock() + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + keys = await _signer_mod._fetch_jwks("https://idp.example.com/jwks") + + mock_client.get.assert_not_called() + assert keys == fake_keys + _signer_mod._jwks_cache.clear() + + +# --- _get_oidc_discovery --- + + +@pytest.mark.asyncio +async def test_get_oidc_discovery_caches_when_jwks_uri_present(): + """_get_oidc_discovery caches the doc when jwks_uri is in the response.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + signer._oidc_discovery_doc = None # ensure fresh + + discovery_doc = { + "issuer": "https://idp.example.com", + "jwks_uri": "https://idp.example.com/jwks", + } + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_oidc_discovery", + new_callable=AsyncMock, + return_value=discovery_doc, + ): + result = await signer._get_oidc_discovery() + + assert result["jwks_uri"] == "https://idp.example.com/jwks" + assert signer._oidc_discovery_doc == discovery_doc + + +@pytest.mark.asyncio +async def test_get_oidc_discovery_does_not_cache_when_jwks_uri_absent(): + """_get_oidc_discovery does NOT cache a doc that is missing jwks_uri.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + signer._oidc_discovery_doc = None + + bad_doc = {"issuer": "https://idp.example.com"} # no jwks_uri + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_oidc_discovery", + new_callable=AsyncMock, + return_value=bad_doc, + ) as mock_fetch: + result1 = await signer._get_oidc_discovery() + result2 = await signer._get_oidc_discovery() + + # Returns the bad doc each time without caching it + assert "jwks_uri" not in result1 + assert signer._oidc_discovery_doc is None # never cached + assert mock_fetch.call_count == 2 # retried on second call + + +# --- _verify_incoming_jwt --- + + +@pytest.mark.asyncio +async def test_verify_incoming_jwt_returns_payload_on_valid_token(): + """_verify_incoming_jwt decodes and returns claims from a valid JWT.""" + # Build a signer to get a real RSA key pair; use its key to mint the "incoming" JWT + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration", + verify_audience="api://test", + verify_issuer="https://idp.example.com", + ) + # Mint a JWT with signer's own key — we'll pretend it came from the IdP + now = int(time.time()) + incoming_claims = { + "sub": "idp-user-42", + "iss": "https://idp.example.com", + "aud": "api://test", + "iat": now, + "exp": now + 300, + } + incoming_token = jwt.encode(incoming_claims, signer._private_key, algorithm="RS256", headers={"kid": signer._kid}) + + # Build a JWKS from the same public key so verification passes + jwks = signer.get_jwks() + + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_jwks", + new_callable=AsyncMock, + return_value=jwks["keys"], + ): + payload = await signer._verify_incoming_jwt(incoming_token) + + assert payload["sub"] == "idp-user-42" + + +@pytest.mark.asyncio +async def test_verify_incoming_jwt_raises_on_expired_token(): + """_verify_incoming_jwt raises PyJWTError on an expired token.""" + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration", + ) + expired_claims = { + "sub": "idp-user", + "iss": "https://idp.example.com", + "aud": "api://test", + "iat": int(time.time()) - 600, + "exp": int(time.time()) - 300, # expired + } + expired_token = jwt.encode(expired_claims, signer._private_key, algorithm="RS256") + jwks = signer.get_jwks() + + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer._fetch_jwks", + new_callable=AsyncMock, + return_value=jwks["keys"], + ): + with pytest.raises(jwt.PyJWTError): + await signer._verify_incoming_jwt(expired_token) + + +# --- _introspect_opaque_token --- + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_returns_claims_when_active(): + """_introspect_opaque_token returns the introspection payload for active tokens.""" + signer = _make_signer( + token_introspection_endpoint="https://idp.example.com/introspect" + ) + + introspection_response = { + "active": True, + "sub": "service-account", + "scope": "read write", + } + fake_resp = _make_httpx_response(introspection_response) + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + result = await signer._introspect_opaque_token("opaque-token-abc") + + assert result["sub"] == "service-account" + assert result["active"] is True + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_raises_on_inactive_token(): + """_introspect_opaque_token raises ExpiredSignatureError when active=false.""" + signer = _make_signer( + token_introspection_endpoint="https://idp.example.com/introspect" + ) + + fake_resp = _make_httpx_response({"active": False}) + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=fake_resp) + + with patch( + "litellm.llms.custom_httpx.http_handler.get_async_httpx_client", + return_value=mock_client, + ): + with pytest.raises(jwt.ExpiredSignatureError): + await signer._introspect_opaque_token("opaque-token-xyz") + + +@pytest.mark.asyncio +async def test_introspect_opaque_token_raises_without_endpoint_configured(): + """_introspect_opaque_token raises ValueError when no endpoint is set.""" + signer = _make_signer() # no token_introspection_endpoint + + with pytest.raises(ValueError, match="token_introspection_endpoint"): + await signer._introspect_opaque_token("some-token") + + +# --- FR-5 end-to-end hook path --- + + +@pytest.mark.asyncio +async def test_hook_raises_401_when_jwt_verification_fails(): + """async_pre_call_hook raises HTTP 401 when incoming JWT verification fails.""" + from fastapi import HTTPException + + signer = _make_signer( + access_token_discovery_uri="https://idp.example.com/.well-known/openid-configuration" + ) + + with patch.object( + signer, + "_verify_incoming_jwt", + new_callable=AsyncMock, + side_effect=jwt.InvalidSignatureError("bad signature"), + ): + with patch.object( + signer, + "_get_oidc_discovery", + new_callable=AsyncMock, + return_value={"jwks_uri": "https://idp.example.com/jwks"}, + ): + with pytest.raises(HTTPException) as exc_info: + await signer.async_pre_call_hook( + user_api_key_dict=_make_user_api_key_dict(), + cache=MagicMock(), + data={"mcp_tool_name": "tool", "incoming_bearer_token": "hdr.pld.sig"}, + call_type="call_mcp_tool", + ) + + assert exc_info.value.status_code == 401 diff --git a/tests/test_litellm/test_setup_wizard.py b/tests/test_litellm/test_setup_wizard.py new file mode 100644 index 00000000000..e10bd893e31 --- /dev/null +++ b/tests/test_litellm/test_setup_wizard.py @@ -0,0 +1,188 @@ +"""Unit tests for litellm.setup_wizard — pure functions only, no network calls.""" + +from litellm.setup_wizard import SetupWizard, _yaml_escape + +# --------------------------------------------------------------------------- +# _yaml_escape +# --------------------------------------------------------------------------- + + +def test_yaml_escape_plain(): + assert _yaml_escape("sk-abc123") == "sk-abc123" + + +def test_yaml_escape_double_quote(): + assert _yaml_escape('sk-ab"cd') == 'sk-ab\\"cd' + + +def test_yaml_escape_backslash(): + assert _yaml_escape("sk-ab\\cd") == "sk-ab\\\\cd" + + +def test_yaml_escape_combined(): + assert _yaml_escape('ab\\"cd') == 'ab\\\\\\"cd' + + +def test_yaml_escape_newline(): + assert _yaml_escape("sk-abc\ndef") == "sk-abc\\ndef" + + +def test_yaml_escape_carriage_return(): + assert _yaml_escape("sk-abc\rdef") == "sk-abc\\rdef" + + +def test_yaml_escape_tab(): + assert _yaml_escape("sk-abc\tdef") == "sk-abc\\tdef" + + +# --------------------------------------------------------------------------- +# SetupWizard._build_config +# --------------------------------------------------------------------------- + +_OPENAI = { + "id": "openai", + "name": "OpenAI", + "env_key": "OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "test_model": "gpt-4o-mini", +} + +_ANTHROPIC = { + "id": "anthropic", + "name": "Anthropic", + "env_key": "ANTHROPIC_API_KEY", + "models": ["claude-opus-4-6"], + "test_model": "claude-haiku-4-5-20251001", +} + +_AZURE = { + "id": "azure", + "name": "Azure OpenAI", + "env_key": "AZURE_API_KEY", + "models": [], + "test_model": None, + "needs_api_base": True, + "api_base_hint": "https://.openai.azure.com/", + "api_version": "2024-07-01-preview", +} + +_OLLAMA = { + "id": "ollama", + "name": "Ollama", + "env_key": None, + "models": ["ollama/llama3.2"], + "test_model": None, + "api_base": "http://localhost:11434", +} + + +def test_build_config_basic_openai(): + config = SetupWizard._build_config( + [_OPENAI], + {"OPENAI_API_KEY": "sk-test"}, + "sk-master", + ) + assert "model_list:" in config + assert "model_name: gpt-4o" in config + assert "model: gpt-4o" in config + assert "api_key: os.environ/OPENAI_API_KEY" in config + assert 'master_key: "sk-master"' in config + + +def test_build_config_skipped_provider_omitted(): + """Provider with no key in env_vars should not appear in model_list.""" + config = SetupWizard._build_config( + [_OPENAI, _ANTHROPIC], + {"ANTHROPIC_API_KEY": "sk-ant-test"}, # OpenAI key missing + "sk-master", + ) + assert "gpt-4o" not in config + assert "claude-opus-4-6" in config + + +def test_build_config_env_vars_written_escaped(): + """API keys with special chars should be YAML-escaped.""" + config = SetupWizard._build_config( + [_OPENAI], + {"OPENAI_API_KEY": 'sk-ab"cd'}, + "sk-master", + ) + assert 'OPENAI_API_KEY: "sk-ab\\"cd"' in config + + +def test_build_config_master_key_quoted(): + """master_key must be quoted in YAML to handle special characters.""" + config = SetupWizard._build_config( + [_OPENAI], + {"OPENAI_API_KEY": "sk-test"}, + 'sk-master"special', + ) + assert 'master_key: "sk-master\\"special"' in config + + +def test_build_config_does_not_mutate_env_vars(): + """_build_config must not modify the caller's env_vars dict.""" + env_vars = { + "AZURE_API_KEY": "az-key", + "_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com", + "_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-deployment", + } + original_keys = set(env_vars.keys()) + SetupWizard._build_config([_AZURE], env_vars, "sk-master") + assert set(env_vars.keys()) == original_keys + + +def test_build_config_azure_uses_deployment_name(): + env_vars = { + "AZURE_API_KEY": "az-key", + "_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com", + "_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-gpt4o", + } + config = SetupWizard._build_config([_AZURE], env_vars, "sk-master") + assert "model: azure/my-gpt4o" in config + assert "model_name: azure-my-gpt4o" in config + # api_base must be quoted to survive YAML special chars + assert 'api_base: "https://my.azure.com"' in config + + +def test_build_config_azure_no_deployment_skipped(): + """Azure without a deployment name should emit nothing (not fallback to gpt-4o).""" + env_vars = {"AZURE_API_KEY": "az-key"} # no deployment sentinel + config = SetupWizard._build_config([_AZURE], env_vars, "sk-master") + # No azure model entry should be emitted when deployment name is absent + assert "model: azure/" not in config + + +def test_build_config_no_display_name_collision_openai_and_azure(): + """OpenAI gpt-4o and azure gpt-4o should get distinct model_name values.""" + env_vars = { + "OPENAI_API_KEY": "sk-openai", + "AZURE_API_KEY": "az-key", + "_LITELLM_AZURE_DEPLOYMENT_AZURE": "gpt-4o", + } + config = SetupWizard._build_config([_OPENAI, _AZURE], env_vars, "sk-master") + assert "model_name: gpt-4o" in config # OpenAI + assert "model_name: azure-gpt-4o" in config # Azure — qualified + + +def test_build_config_ollama_no_api_key_line(): + """Ollama has no env_key — config should not contain an api_key line for it.""" + config = SetupWizard._build_config([_OLLAMA], {}, "sk-master") + assert "ollama/llama3.2" in config + assert "api_key:" not in config + + +def test_build_config_master_key_in_general_settings(): + """master_key is written to general_settings.""" + config = SetupWizard._build_config([_OPENAI], {"OPENAI_API_KEY": "k"}, "sk-m") + assert 'master_key: "sk-m"' in config + + +def test_build_config_internal_sentinel_keys_excluded(): + """_LITELLM_ prefixed sentinel keys must not appear in environment_variables.""" + env_vars = { + "OPENAI_API_KEY": "sk-real", + "_LITELLM_AZURE_API_BASE_AZURE": "https://x.azure.com", + } + config = SetupWizard._build_config([_OPENAI], env_vars, "sk-master") + assert "_LITELLM_" not in config diff --git a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx index ceb14864ad9..18ab475f228 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx @@ -20,7 +20,6 @@ import { ToolOutlined, TagsOutlined, AuditOutlined, - MessageOutlined, } from "@ant-design/icons"; // import { // all_admin_roles, @@ -466,41 +465,6 @@ const Sidebar2: React.FC = ({ accessToken, userRole, defaultSelect {isAdminRole(userRole) && !collapsed && } - {/* Pinned "Open Chat" button at bottom */} - ); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/playground/page.tsx b/ui/litellm-dashboard/src/app/(dashboard)/playground/page.tsx index 693dc6b20ee..2b4b4ace491 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/playground/page.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/playground/page.tsx @@ -8,8 +8,6 @@ import ComplianceUI from "@/components/playground/complianceUI/ComplianceUI"; import { TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react"; import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; import { fetchProxySettings } from "@/utils/proxyUtils"; -import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig"; -import { MessageOutlined, CloseOutlined } from "@ant-design/icons"; interface ProxySettings { PROXY_BASE_URL?: string; @@ -19,12 +17,6 @@ interface ProxySettings { export default function PlaygroundPage() { const { accessToken, userRole, userId, disabledPersonalKeyCreation, token } = useAuthorized(); const [proxySettings, setProxySettings] = useState(undefined); - const [chatBannerDismissed, setChatBannerDismissed] = useState(false); - const { data: uiConfig } = useUIConfig(); - const uiRoot = uiConfig?.server_root_path && uiConfig.server_root_path !== "/" - ? uiConfig.server_root_path.replace(/\/+$/, "") - : ""; - const chatHref = `${uiRoot}/ui/chat`; useEffect(() => { const initializeProxySettings = async () => { @@ -44,64 +36,6 @@ export default function PlaygroundPage() { return (
- {!chatBannerDismissed && ( -
- - New - - - Chat UI - {" "}— a ChatGPT-like interface for your users to chat with AI models and MCP tools. Share it with your team. - - - Open Chat UI → - - -
- )} Chat