diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 847f8623e72..d563adcaa74 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -775,6 +775,10 @@ router_settings: | LITELLM_METER_NAME | Name for OTEL Meter | LITELLM_OTEL_INTEGRATION_ENABLE_EVENTS | Optionally enable semantic logs for OTEL | LITELLM_OTEL_INTEGRATION_ENABLE_METRICS | Optionally enable emantic metrics for OTEL +| LITELLM_ENABLE_PYROSCOPE | If true, enables Pyroscope CPU profiling. Profiles are sent to PYROSCOPE_SERVER_ADDRESS. Off by default. See [Pyroscope profiling](/proxy/pyroscope_profiling). +| PYROSCOPE_APP_NAME | Application name reported to Pyroscope. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. +| PYROSCOPE_SERVER_ADDRESS | Pyroscope server URL to send profiles to. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. +| PYROSCOPE_SAMPLE_RATE | Optional. Sample rate for Pyroscope profiling (integer). No default; when unset, the pyroscope-io library default is used. | LITELLM_MASTER_KEY | Master key for proxy authentication | LITELLM_MODE | Operating mode for LiteLLM (e.g., production, development) | LITELLM_NON_ROOT | Flag to run LiteLLM in non-root mode for enhanced security in Docker containers diff --git a/docs/my-website/docs/proxy/pyroscope_profiling.md b/docs/my-website/docs/proxy/pyroscope_profiling.md new file mode 100644 index 00000000000..fa3db3a8782 --- /dev/null +++ b/docs/my-website/docs/proxy/pyroscope_profiling.md @@ -0,0 +1,43 @@ +# Grafana Pyroscope CPU profiling + +LiteLLM proxy can send continuous CPU profiles to [Grafana Pyroscope](https://grafana.com/docs/pyroscope/latest/) when enabled via environment variables. This is optional and off by default. + +## Quick start + +1. **Install the optional dependency** (required only when enabling Pyroscope): + + ```bash + pip install pyroscope-io + ``` + + Or install the proxy extra: + + ```bash + pip install "litellm[proxy]" + ``` + +2. **Set environment variables** before starting the proxy: + + | Variable | Required | Description | + |----------|----------|-------------| + | `LITELLM_ENABLE_PYROSCOPE` | Yes (to enable) | Set to `true` to enable Pyroscope profiling. | + | `PYROSCOPE_APP_NAME` | Yes (when enabled) | Application name shown in the Pyroscope UI. | + | `PYROSCOPE_SERVER_ADDRESS` | Yes (when enabled) | Pyroscope server URL (e.g. `http://localhost:4040`). | + | `PYROSCOPE_SAMPLE_RATE` | No | Sample rate (integer). If unset, the pyroscope-io library default is used. | + +3. **Start the proxy**; profiling will begin automatically when the proxy starts. + + ```bash + export LITELLM_ENABLE_PYROSCOPE=true + export PYROSCOPE_APP_NAME=litellm-proxy + export PYROSCOPE_SERVER_ADDRESS=http://localhost:4040 + litellm --config config.yaml + ``` + +4. **View profiles** in the Pyroscope (or Grafana) UI and select your `PYROSCOPE_APP_NAME`. + +## Notes + +- **Optional dependency**: `pyroscope-io` is an optional dependency. If it is not installed and `LITELLM_ENABLE_PYROSCOPE=true`, the proxy will log a warning and continue without profiling. +- **Platform support**: The `pyroscope-io` package uses a native extension and is not available on all platforms (e.g. Windows is excluded by the package). +- **Other settings**: See [Configuration settings](/proxy/config_settings) for all proxy environment variables. diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 9e2eb47f4c9..9b3581cce32 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -107,7 +107,8 @@ const sidebars = { items: [ "proxy/alerting", "proxy/pagerduty", - "proxy/prometheus" + "proxy/prometheus", + "proxy/pyroscope_profiling" ] }, { diff --git a/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37-py3-none-any.whl b/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37-py3-none-any.whl new file mode 100644 index 00000000000..695dc102c72 Binary files /dev/null and b/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37-py3-none-any.whl differ diff --git a/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37.tar.gz b/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37.tar.gz new file mode 100644 index 00000000000..d3ecef1752e Binary files /dev/null and b/litellm-proxy-extras/dist/litellm_proxy_extras-0.4.37.tar.gz differ diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260211181323_baseline_diff/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260211181323_baseline_diff/migration.sql deleted file mode 100644 index f3a0821d37f..00000000000 --- a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260211181323_baseline_diff/migration.sql +++ /dev/null @@ -1,3 +0,0 @@ --- AlterTable -ALTER TABLE "LiteLLM_PolicyAttachmentTable" ADD COLUMN "tags" TEXT[] DEFAULT ARRAY[]::TEXT[]; - diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260213170952_access_group_change_to_model_name/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260213170952_access_group_change_to_model_name/migration.sql new file mode 100644 index 00000000000..c940d3aca8b --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260213170952_access_group_change_to_model_name/migration.sql @@ -0,0 +1,3 @@ +-- AlterTable +ALTER TABLE "LiteLLM_AccessGroupTable" DROP COLUMN "access_model_ids", +ADD COLUMN "access_model_names" TEXT[] DEFAULT ARRAY[]::TEXT[]; diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 02dddc74e3e..39359a34547 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -948,7 +948,7 @@ model LiteLLM_AccessGroupTable { description String? // Resource memberships - explicit arrays per type - access_model_ids String[] @default([]) + access_model_names String[] @default([]) access_mcp_server_ids String[] @default([]) access_agent_ids String[] @default([]) diff --git a/litellm-proxy-extras/pyproject.toml b/litellm-proxy-extras/pyproject.toml index eda49bfb9fa..909c5e7e6cd 100644 --- a/litellm-proxy-extras/pyproject.toml +++ b/litellm-proxy-extras/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm-proxy-extras" -version = "0.4.36" +version = "0.4.37" description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package." authors = ["BerriAI"] readme = "README.md" @@ -22,7 +22,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "0.4.36" +version = "0.4.37" version_files = [ "pyproject.toml:version", "../requirements.txt:litellm-proxy-extras==", diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index e546a0dbb02..a11341b80aa 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -163,15 +163,23 @@ def convert_chat_completion_messages_to_responses_api( instructions = f"{instructions} {content}" else: instructions = content + elif isinstance(content, list): + # Extract text from content blocks (e.g. [{"type": "text", "text": "..."}]) + text_parts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif isinstance(block, str): + text_parts.append(block) + extracted = " ".join(text_parts) + if instructions: + instructions = f"{instructions} {extracted}" + else: + instructions = extracted else: - input_items.append( - { - "type": "message", - "role": role, - "content": self._convert_content_to_responses_format( - content, role # type: ignore - ), - } + verbose_logger.warning( + "Unexpected system message content type: %s. Skipping.", + type(content), ) elif role == "tool": # Convert tool message to function call output format diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index 6a5ac92816b..5d397297891 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -533,11 +533,12 @@ def _pop_and_merge_extra_body(data: RequestBody, optional_params: dict) -> None: """Pop extra_body from optional_params and shallow-merge into data, deep-merging dict values.""" extra_body: Optional[dict] = optional_params.pop("extra_body", None) if extra_body is not None: + data_dict: dict = data # type: ignore[assignment] for k, v in extra_body.items(): - if k in data and isinstance(data[k], dict) and isinstance(v, dict): - data[k].update(v) + if k in data_dict and isinstance(data_dict[k], dict) and isinstance(v, dict): + data_dict[k].update(v) else: - data[k] = v + data_dict[k] = v def _transform_request_body( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index ec31652aa54..ba107a9dd10 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -2029,7 +2029,7 @@ async def handle_streamable_http_mcp( # Inject masked debug headers when client sends x-litellm-mcp-debug: true _debug_headers = MCPDebug.maybe_build_debug_headers( raw_headers=raw_headers, - scope=scope, + scope=dict(scope), mcp_servers=mcp_servers, mcp_auth_header=mcp_auth_header, mcp_server_auth_headers=mcp_server_auth_headers, diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index ba4e3b42c37..f643f7205bf 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -585,7 +585,20 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if is_proxy_admin: return UserAPIKeyAuth( + api_key=None, user_role=LitellmUserRoles.PROXY_ADMIN, + user_id=user_id, + team_id=team_id, + team_alias=( + team_object.team_alias + if team_object is not None + else None + ), + team_metadata=team_object.metadata + if team_object is not None + else None, + org_id=org_id, + end_user_id=end_user_id, parent_otel_span=parent_otel_span, ) diff --git a/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py b/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py new file mode 100644 index 00000000000..539a520fcef --- /dev/null +++ b/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py @@ -0,0 +1,69 @@ +""" +Test guardrails for pipeline E2E testing. + +- StrictFilter: blocks any message containing "bad" (case-insensitive) +- PermissiveFilter: always passes (simulates an advanced guardrail that is more lenient) +""" + +from typing import Optional, Union + +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.utils import CallTypesLiteral + + +class StrictFilter(CustomGuardrail): + """Blocks any message containing the word 'bad'.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + for msg in data.get("messages", []): + content = msg.get("content", "") + if isinstance(content, str) and "bad" in content.lower(): + verbose_proxy_logger.info("StrictFilter: BLOCKED - found 'bad'") + raise HTTPException( + status_code=400, + detail="StrictFilter: content contains forbidden word 'bad'", + ) + verbose_proxy_logger.info("StrictFilter: PASSED") + return data + + +class PermissiveFilter(CustomGuardrail): + """Always passes - simulates a lenient advanced guardrail.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + verbose_proxy_logger.info("PermissiveFilter: PASSED (always passes)") + return data + + +class AlwaysBlockFilter(CustomGuardrail): + """Always blocks - for testing full escalation->block path.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + verbose_proxy_logger.info("AlwaysBlockFilter: BLOCKED") + raise HTTPException( + status_code=400, + detail="AlwaysBlockFilter: all content blocked", + ) diff --git a/litellm/proxy/example_config_yaml/test_pipeline_config.yaml b/litellm/proxy/example_config_yaml/test_pipeline_config.yaml new file mode 100644 index 00000000000..d3a8c56b48a --- /dev/null +++ b/litellm/proxy/example_config_yaml/test_pipeline_config.yaml @@ -0,0 +1,64 @@ +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/gpt-3.5-turbo + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-blocked-endpoint + litellm_params: + model: openai/gpt-3.5-turbo + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +guardrails: + - guardrail_name: "strict-filter" + litellm_params: + guardrail: pipeline_test_guardrails.StrictFilter + mode: "pre_call" + - guardrail_name: "permissive-filter" + litellm_params: + guardrail: pipeline_test_guardrails.PermissiveFilter + mode: "pre_call" + - guardrail_name: "always-block-filter" + litellm_params: + guardrail: pipeline_test_guardrails.AlwaysBlockFilter + mode: "pre_call" + +policies: + # Pipeline: strict-filter fails -> escalate to permissive-filter + # If strict fails but permissive passes -> allow the request + content-safety-permissive: + description: "Multi-tier: strict filter with permissive fallback" + guardrails: + add: [strict-filter, permissive-filter] + pipeline: + mode: "pre_call" + steps: + - guardrail: strict-filter + on_fail: next # escalate to permissive + on_pass: allow # clean content proceeds + - guardrail: permissive-filter + on_fail: block # hard block + on_pass: allow # permissive says OK + + # Pipeline: strict-filter fails -> escalate to always-block + # Both fail -> block + content-safety-strict: + description: "Multi-tier: strict filter with strict fallback (both block)" + guardrails: + add: [strict-filter, always-block-filter] + pipeline: + mode: "pre_call" + steps: + - guardrail: strict-filter + on_fail: next + on_pass: allow + - guardrail: always-block-filter + on_fail: block + on_pass: allow + +policy_attachments: + - policy: content-safety-permissive + models: [fake-openai-endpoint] + - policy: content-safety-strict + models: [fake-blocked-endpoint] diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 49d31c1efec..fa024cc33d4 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1642,20 +1642,40 @@ def add_guardrails_from_policy_engine( f"Policy engine: resolved guardrails: {resolved_guardrails}" ) - if not resolved_guardrails: - return + # Resolve pipelines from matching policies + pipelines = PolicyResolver.resolve_pipelines_for_context(context=context) # Add resolved guardrails to request metadata if metadata_variable_name not in data: data[metadata_variable_name] = {} + # Track pipeline-managed guardrails to exclude from independent execution + pipeline_managed_guardrails: set = set() + if pipelines: + pipeline_managed_guardrails = PolicyResolver.get_pipeline_managed_guardrails( + pipelines + ) + data[metadata_variable_name]["_guardrail_pipelines"] = pipelines + data[metadata_variable_name]["_pipeline_managed_guardrails"] = ( + pipeline_managed_guardrails + ) + verbose_proxy_logger.debug( + f"Policy engine: resolved {len(pipelines)} pipeline(s), " + f"managed guardrails: {pipeline_managed_guardrails}" + ) + + if not resolved_guardrails and not pipelines: + return + existing_guardrails = data[metadata_variable_name].get("guardrails", []) if not isinstance(existing_guardrails, list): existing_guardrails = [] # Combine existing guardrails with policy-resolved guardrails (no duplicates) + # Exclude pipeline-managed guardrails from the flat list combined = set(existing_guardrails) combined.update(resolved_guardrails) + combined -= pipeline_managed_guardrails data[metadata_variable_name]["guardrails"] = list(combined) verbose_proxy_logger.debug( diff --git a/litellm/proxy/management_endpoints/access_group_endpoints.py b/litellm/proxy/management_endpoints/access_group_endpoints.py index 33b81e85654..100b1d2659b 100644 --- a/litellm/proxy/management_endpoints/access_group_endpoints.py +++ b/litellm/proxy/management_endpoints/access_group_endpoints.py @@ -31,7 +31,7 @@ def _record_to_response(record) -> AccessGroupResponse: access_group_id=record.access_group_id, access_group_name=record.access_group_name, description=record.description, - access_model_ids=record.access_model_ids, + access_model_names=record.access_model_names, access_mcp_server_ids=record.access_mcp_server_ids, access_agent_ids=record.access_agent_ids, assigned_team_ids=record.assigned_team_ids, @@ -69,7 +69,7 @@ async def create_access_group( data={ "access_group_name": data.access_group_name, "description": data.description, - "access_model_ids": data.access_model_ids or [], + "access_model_names": data.access_model_names or [], "access_mcp_server_ids": data.access_mcp_server_ids or [], "access_agent_ids": data.access_agent_ids or [], "assigned_team_ids": data.assigned_team_ids or [], @@ -153,10 +153,19 @@ async def update_access_group( for field, value in data.model_dump(exclude_unset=True).items(): update_data[field] = value - record = await prisma_client.db.litellm_accessgrouptable.update( - where={"access_group_id": access_group_id}, - data=update_data, - ) + try: + record = await prisma_client.db.litellm_accessgrouptable.update( + where={"access_group_id": access_group_id}, + data=update_data, + ) + except Exception as e: + # Unique constraint violation (e.g. access_group_name already exists). + if "unique constraint" in str(e).lower() or "P2002" in str(e): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Access group '{update_data.get('access_group_name', '')}' already exists", + ) + raise return _record_to_response(record) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index a7b60c8b185..56b513554a8 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -1099,6 +1099,7 @@ async def endpoint_func( # type: ignore fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), subpath: str = "", # captures sub-paths when include_subpath=True + custom_body: Optional[dict] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it ): return await chat_completion_pass_through_endpoint( fastapi_response=fastapi_response, @@ -1115,6 +1116,7 @@ async def endpoint_func( # type: ignore fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), subpath: str = "", # captures sub-paths when include_subpath=True + custom_body: Optional[dict] = None, ): from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( InitPassThroughEndpointHelpers, @@ -1189,11 +1191,13 @@ async def endpoint_func( # type: ignore ) if query_params: final_query_params.update(query_params) - final_custom_body = ( - custom_body_data - if isinstance(custom_body_data, dict) or custom_body_data is None - else None - ) + # When a caller (e.g. bedrock_proxy_route) supplies a pre-built + # body, use it instead of the body parsed from the raw request. + final_custom_body: Optional[dict] = None + if custom_body is not None: + final_custom_body = custom_body + elif isinstance(custom_body_data, dict): + final_custom_body = custom_body_data return await pass_through_request( # type: ignore request=request, diff --git a/litellm/proxy/policy_engine/pipeline_executor.py b/litellm/proxy/policy_engine/pipeline_executor.py new file mode 100644 index 00000000000..b3982678d37 --- /dev/null +++ b/litellm/proxy/policy_engine/pipeline_executor.py @@ -0,0 +1,216 @@ +""" +Pipeline Executor - Executes guardrail pipelines with conditional step logic. + +Runs guardrails sequentially per pipeline step definitions, handling +pass/fail actions (allow, block, next, modify_response) and data forwarding. +""" + +import time +from typing import Any, List, Optional + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) +from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, +) +from litellm.types.proxy.policy_engine.pipeline_types import ( + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) + +try: + from fastapi.exceptions import HTTPException +except ImportError: + HTTPException = None # type: ignore + + +class PipelineExecutor: + """Executes guardrail pipelines with ordered, conditional step logic.""" + + @staticmethod + async def execute_steps( + steps: List[PipelineStep], + mode: str, + data: dict, + user_api_key_dict: Any, + call_type: str, + policy_name: str, + ) -> PipelineExecutionResult: + """ + Execute pipeline steps sequentially with conditional actions. + + Args: + steps: Ordered list of pipeline steps + mode: Event hook mode (pre_call, post_call) + data: Request data dict + user_api_key_dict: User API key auth + call_type: Type of call (completion, etc.) + policy_name: Name of the owning policy (for logging) + + Returns: + PipelineExecutionResult with terminal action and step results + """ + step_results: List[PipelineStepResult] = [] + working_data = data.copy() + if "metadata" in working_data: + working_data["metadata"] = working_data["metadata"].copy() + + for i, step in enumerate(steps): + start_time = time.perf_counter() + + outcome, modified_data, error_detail = await PipelineExecutor._run_step( + step=step, + mode=mode, + data=working_data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + ) + + duration = time.perf_counter() - start_time + + action = step.on_pass if outcome == "pass" else step.on_fail + + step_result = PipelineStepResult( + guardrail_name=step.guardrail, + outcome=outcome, + action_taken=action, + modified_data=modified_data, + error_detail=error_detail, + duration_seconds=round(duration, 4), + ) + step_results.append(step_result) + + verbose_proxy_logger.debug( + f"Pipeline '{policy_name}' step {i}: guardrail={step.guardrail}, " + f"outcome={outcome}, action={action}" + ) + + # Forward modified data to next step if pass_data is True + if step.pass_data and modified_data is not None: + working_data = {**working_data, **modified_data} + + # Handle terminal actions + if action == "allow": + return PipelineExecutionResult( + terminal_action="allow", + step_results=step_results, + modified_data=working_data if working_data != data else None, + ) + + if action == "block": + return PipelineExecutionResult( + terminal_action="block", + step_results=step_results, + error_message=error_detail, + ) + + if action == "modify_response": + return PipelineExecutionResult( + terminal_action="modify_response", + step_results=step_results, + modify_response_message=step.modify_response_message or error_detail, + ) + + # action == "next" → continue to next step + + # Ran out of steps without a terminal action → default allow + return PipelineExecutionResult( + terminal_action="allow", + step_results=step_results, + modified_data=working_data if working_data != data else None, + ) + + @staticmethod + async def _run_step( + step: PipelineStep, + mode: str, + data: dict, + user_api_key_dict: Any, + call_type: str, + ) -> tuple: + """ + Run a single pipeline step's guardrail. + + Returns: + Tuple of (outcome, modified_data, error_detail) where: + - outcome: "pass", "fail", or "error" + - modified_data: dict if guardrail returned modified data, else None + - error_detail: error message string if fail/error, else None + """ + callback = PipelineExecutor._find_guardrail_callback(step.guardrail) + if callback is None: + verbose_proxy_logger.warning( + f"Pipeline: guardrail '{step.guardrail}' not found in callbacks" + ) + return ("error", None, f"Guardrail '{step.guardrail}' not found") + + try: + # Inject guardrail name into metadata so should_run_guardrail() allows it + if "metadata" not in data: + data["metadata"] = {} + original_guardrails = data["metadata"].get("guardrails") + data["metadata"]["guardrails"] = [step.guardrail] + + # Use unified_guardrail path if callback implements apply_guardrail + target = callback + use_unified = "apply_guardrail" in type(callback).__dict__ + if use_unified: + data["guardrail_to_apply"] = callback + target = UnifiedLLMGuardrails() + + if mode == "pre_call": + response = await target.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=None, # type: ignore + data=data, + call_type=call_type, # type: ignore + ) + elif mode == "post_call": + response = await target.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=data.get("response"), # type: ignore + ) + else: + return ("error", None, f"Unsupported pipeline mode: {mode}") + + # Normal return means pass + modified_data = None + if response is not None and isinstance(response, dict): + modified_data = response + return ("pass", modified_data, None) + + except Exception as e: + if CustomGuardrail._is_guardrail_intervention(e): + error_msg = _extract_error_message(e) + return ("fail", None, error_msg) + else: + verbose_proxy_logger.error( + f"Pipeline: unexpected error from guardrail '{step.guardrail}': {e}" + ) + return ("error", None, str(e)) + + @staticmethod + def _find_guardrail_callback(guardrail_name: str) -> Optional[CustomGuardrail]: + """Look up an initialized guardrail callback by name from litellm.callbacks.""" + for callback in litellm.callbacks: + if isinstance(callback, CustomGuardrail): + if callback.guardrail_name == guardrail_name: + return callback + return None + + +def _extract_error_message(e: Exception) -> str: + """Extract a human-readable error message from a guardrail exception.""" + if isinstance(e, ModifyResponseException): + return str(e) + if HTTPException is not None and isinstance(e, HTTPException): + detail = getattr(e, "detail", None) + if detail: + return str(detail) + return str(e) diff --git a/litellm/proxy/policy_engine/policy_endpoints.py b/litellm/proxy/policy_engine/policy_endpoints.py index 3bd893b0034..af12a8598f6 100644 --- a/litellm/proxy/policy_engine/policy_endpoints.py +++ b/litellm/proxy/policy_engine/policy_endpoints.py @@ -10,8 +10,11 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, + PipelineTestRequest, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -349,6 +352,69 @@ async def get_resolved_guardrails(policy_id: str): raise HTTPException(status_code=500, detail=str(e)) +# ───────────────────────────────────────────────────────────────────────────── +# Pipeline Test Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/test-pipeline", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], +) +async def test_pipeline( + request: PipelineTestRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Test a guardrail pipeline with sample messages. + + Executes the pipeline steps against the provided test messages and returns + step-by-step results showing which guardrails passed/failed, actions taken, + and timing information. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/test-pipeline" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "pipeline": { + "mode": "pre_call", + "steps": [ + {"guardrail": "pii-guard", "on_pass": "next", "on_fail": "block"} + ] + }, + "test_messages": [{"role": "user", "content": "My SSN is 123-45-6789"}] + }' + ``` + """ + try: + validated_pipeline = GuardrailPipeline(**request.pipeline) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid pipeline: {e}") + + data = { + "messages": request.test_messages, + "model": "test", + "metadata": {}, + } + + try: + result = await PipelineExecutor.execute_steps( + steps=validated_pipeline.steps, + mode=validated_pipeline.mode, + data=data, + user_api_key_dict=user_api_key_dict, + call_type="completion", + policy_name="test-pipeline", + ) + return result.model_dump() + except Exception as e: + verbose_proxy_logger.exception(f"Error testing pipeline: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ───────────────────────────────────────────────────────────────────────────── # Policy Attachment CRUD Endpoints # ───────────────────────────────────────────────────────────────────────────── diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index a2431977b24..50acddd2b9d 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -10,8 +10,12 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional +from prisma import Json as PrismaJson + from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, + PipelineStep, Policy, PolicyCondition, PolicyCreateRequest, @@ -93,11 +97,32 @@ def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy if condition_data: condition = PolicyCondition(model=condition_data.get("model")) + # Parse pipeline (optional ordered guardrail execution) + pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline")) + return Policy( inherit=policy_data.get("inherit"), description=policy_data.get("description"), guardrails=guardrails, condition=condition, + pipeline=pipeline, + ) + + @staticmethod + def _parse_pipeline(pipeline_data: Optional[Dict[str, Any]]) -> Optional[GuardrailPipeline]: + """Parse a pipeline configuration from raw data.""" + if pipeline_data is None: + return None + + steps_data = pipeline_data.get("steps", []) + steps = [ + PipelineStep(**step_data) if isinstance(step_data, dict) else step_data + for step_data in steps_data + ] + + return GuardrailPipeline( + mode=pipeline_data.get("mode", "pre_call"), + steps=steps, ) def get_policy(self, policy_name: str) -> Optional[Policy]: @@ -225,7 +250,10 @@ async def add_policy_to_db( data["created_by"] = created_by data["updated_by"] = created_by if policy_request.condition is not None: - data["condition"] = policy_request.condition.model_dump() + data["condition"] = PrismaJson(policy_request.condition.model_dump()) + if policy_request.pipeline is not None: + validated_pipeline = GuardrailPipeline(**policy_request.pipeline) + data["pipeline"] = PrismaJson(validated_pipeline.model_dump()) created_policy = await prisma_client.db.litellm_policytable.create( data=data @@ -244,6 +272,7 @@ async def add_policy_to_db( "condition": policy_request.condition.model_dump() if policy_request.condition else None, + "pipeline": policy_request.pipeline, }, ) self.add_policy(policy_request.policy_name, policy) @@ -256,6 +285,7 @@ async def add_policy_to_db( guardrails_add=created_policy.guardrails_add or [], guardrails_remove=created_policy.guardrails_remove or [], condition=created_policy.condition, + pipeline=created_policy.pipeline, created_at=created_policy.created_at, updated_at=created_policy.updated_at, created_by=created_policy.created_by, @@ -302,7 +332,10 @@ async def update_policy_in_db( if policy_request.guardrails_remove is not None: update_data["guardrails_remove"] = policy_request.guardrails_remove if policy_request.condition is not None: - update_data["condition"] = policy_request.condition.model_dump() + update_data["condition"] = PrismaJson(policy_request.condition.model_dump()) + if policy_request.pipeline is not None: + validated_pipeline = GuardrailPipeline(**policy_request.pipeline) + update_data["pipeline"] = PrismaJson(validated_pipeline.model_dump()) updated_policy = await prisma_client.db.litellm_policytable.update( where={"policy_id": policy_id}, @@ -320,6 +353,7 @@ async def update_policy_in_db( "remove": updated_policy.guardrails_remove, }, "condition": updated_policy.condition, + "pipeline": updated_policy.pipeline, }, ) self.add_policy(updated_policy.policy_name, policy) @@ -332,6 +366,7 @@ async def update_policy_in_db( guardrails_add=updated_policy.guardrails_add or [], guardrails_remove=updated_policy.guardrails_remove or [], condition=updated_policy.condition, + pipeline=updated_policy.pipeline, created_at=updated_policy.created_at, updated_at=updated_policy.updated_at, created_by=updated_policy.created_by, @@ -409,6 +444,7 @@ async def get_policy_by_id_from_db( guardrails_add=policy.guardrails_add or [], guardrails_remove=policy.guardrails_remove or [], condition=policy.condition, + pipeline=policy.pipeline, created_at=policy.created_at, updated_at=policy.updated_at, created_by=policy.created_by, @@ -445,6 +481,7 @@ async def get_all_policies_from_db( guardrails_add=p.guardrails_add or [], guardrails_remove=p.guardrails_remove or [], condition=p.condition, + pipeline=p.pipeline, created_at=p.created_at, updated_at=p.updated_at, created_by=p.created_by, @@ -480,6 +517,7 @@ async def sync_policies_from_db( "remove": policy_response.guardrails_remove, }, "condition": policy_response.condition, + "pipeline": policy_response.pipeline, }, ) self.add_policy(policy_response.policy_name, policy) @@ -528,6 +566,7 @@ async def resolve_guardrails_from_db( "remove": policy_response.guardrails_remove, }, "condition": policy_response.condition, + "pipeline": policy_response.pipeline, }, ) temp_policies[policy_response.policy_name] = policy diff --git a/litellm/proxy/policy_engine/policy_resolver.py b/litellm/proxy/policy_engine/policy_resolver.py index cfdedc467d8..a8ad78d6491 100644 --- a/litellm/proxy/policy_engine/policy_resolver.py +++ b/litellm/proxy/policy_engine/policy_resolver.py @@ -8,10 +8,11 @@ - Combining guardrails from multiple matching policies """ -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, Policy, PolicyMatchContext, ResolvedPolicy, @@ -190,6 +191,67 @@ def resolve_guardrails_for_context( return result + @staticmethod + def resolve_pipelines_for_context( + context: PolicyMatchContext, + policies: Optional[Dict[str, Policy]] = None, + ) -> List[Tuple[str, GuardrailPipeline]]: + """ + Resolve pipelines from matching policies for a request context. + + Returns (policy_name, pipeline) tuples for policies that have pipelines. + Guardrails managed by pipelines should be excluded from the flat + guardrails list to avoid double execution. + + Args: + context: The request context + policies: Dictionary of all policies (if None, uses global registry) + + Returns: + List of (policy_name, GuardrailPipeline) tuples + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + if policies is None: + registry = get_policy_registry() + if not registry.is_initialized(): + return [] + policies = registry.get_all_policies() + + matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + if not matching_policy_names: + return [] + + pipelines: List[Tuple[str, GuardrailPipeline]] = [] + for policy_name in matching_policy_names: + policy = policies.get(policy_name) + if policy is None: + continue + if policy.pipeline is not None: + pipelines.append((policy_name, policy.pipeline)) + verbose_proxy_logger.debug( + f"Policy '{policy_name}' has pipeline with " + f"{len(policy.pipeline.steps)} steps" + ) + + return pipelines + + @staticmethod + def get_pipeline_managed_guardrails( + pipelines: List[Tuple[str, GuardrailPipeline]], + ) -> Set[str]: + """ + Get the set of guardrail names managed by pipelines. + + These guardrails should be excluded from normal independent execution. + """ + managed: Set[str] = set() + for _policy_name, pipeline in pipelines: + for step in pipeline.steps: + managed.add(step.guardrail) + return managed + @staticmethod def get_all_resolved_policies( policies: Optional[Dict[str, Policy]] = None, diff --git a/litellm/proxy/policy_engine/policy_validator.py b/litellm/proxy/policy_engine/policy_validator.py index 3eaa67a54d3..89c9b0e2e99 100644 --- a/litellm/proxy/policy_engine/policy_validator.py +++ b/litellm/proxy/policy_engine/policy_validator.py @@ -283,8 +283,14 @@ async def validate_policies( ) ) - # Note: Team, key, and model validation is done via policy_attachments - # Policies no longer have scope - attachments define where policies apply + # Validate pipeline if present + if policy.pipeline is not None: + pipeline_errors = PolicyValidator._validate_pipeline( + policy_name=policy_name, + policy=policy, + available_guardrails=available_guardrails, + ) + errors.extend(pipeline_errors) # Validate inheritance inheritance_errors = self._validate_inheritance_chain( @@ -298,6 +304,53 @@ async def validate_policies( warnings=warnings, ) + @staticmethod + def _validate_pipeline( + policy_name: str, + policy: Policy, + available_guardrails: Set[str], + ) -> List[PolicyValidationError]: + """Validate a policy's pipeline configuration.""" + errors: List[PolicyValidationError] = [] + pipeline = policy.pipeline + if pipeline is None: + return errors + + guardrails_add = set(policy.guardrails.get_add()) + + for i, step in enumerate(pipeline.steps): + # Check guardrail is in policy's guardrails.add + if step.guardrail not in guardrails_add: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=( + f"Pipeline step {i} guardrail '{step.guardrail}' " + f"is not in the policy's guardrails.add list" + ), + field="pipeline.steps", + value=step.guardrail, + ) + ) + + # Check guardrail exists in registry + if available_guardrails and step.guardrail not in available_guardrails: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=( + f"Pipeline step {i} guardrail '{step.guardrail}' " + f"not found in guardrail registry" + ), + field="pipeline.steps", + value=step.guardrail, + ) + ) + + return errors + async def validate_policy_config( self, policy_config: Dict[str, Any], diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 45751c5724c..bc2d32f141d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -867,6 +867,9 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 ## [Optional] Initialize dd tracer ProxyStartupEvent._init_dd_tracer() + ## [Optional] Initialize Pyroscope continuous profiling (env: LITELLM_ENABLE_PYROSCOPE=true) + ProxyStartupEvent._init_pyroscope() + ## Initialize shared aiohttp session for connection reuse shared_aiohttp_session = await _initialize_shared_aiohttp_session() @@ -5814,6 +5817,69 @@ def _init_dd_tracer(cls): prof.start() verbose_proxy_logger.debug("Datadog Profiler started......") + @classmethod + def _init_pyroscope(cls): + """ + Optional continuous profiling via Grafana Pyroscope. + + Off by default. Enable with LITELLM_ENABLE_PYROSCOPE=true. + Requires: pip install pyroscope-io (optional dependency). + When enabled, PYROSCOPE_SERVER_ADDRESS and PYROSCOPE_APP_NAME are required (no defaults). + Optional: PYROSCOPE_SAMPLE_RATE (parsed as integer) to set the sample rate. + """ + if not get_secret_bool("LITELLM_ENABLE_PYROSCOPE", False): + verbose_proxy_logger.debug( + "LiteLLM: Pyroscope profiling is disabled (set LITELLM_ENABLE_PYROSCOPE=true to enable)." + ) + try: + import pyroscope + + app_name = os.getenv("PYROSCOPE_APP_NAME") + if not app_name: + raise ValueError( + "LITELLM_ENABLE_PYROSCOPE is true but PYROSCOPE_APP_NAME is not set. " + "Set PYROSCOPE_APP_NAME when enabling Pyroscope." + ) + server_address = os.getenv("PYROSCOPE_SERVER_ADDRESS") + if not server_address: + raise ValueError( + "LITELLM_ENABLE_PYROSCOPE is true but PYROSCOPE_SERVER_ADDRESS is not set. " + "Set PYROSCOPE_SERVER_ADDRESS when enabling Pyroscope." + ) + tags = {} + env_name = os.getenv("OTEL_ENVIRONMENT_NAME") or os.getenv( + "LITELLM_DEPLOYMENT_ENVIRONMENT", + ) + if env_name: + tags["environment"] = env_name + sample_rate_env = os.getenv("PYROSCOPE_SAMPLE_RATE") + configure_kwargs = { + "app_name": app_name, + "server_address": server_address, + "tags": tags if tags else None, + } + if sample_rate_env is not None: + try: + # pyroscope-io expects sample_rate as an integer + configure_kwargs["sample_rate"] = int(float(sample_rate_env)) + except (ValueError, TypeError): + raise ValueError( + "PYROSCOPE_SAMPLE_RATE must be a number, got: " + f"{sample_rate_env!r}" + ) + pyroscope.configure(**configure_kwargs) + msg = ( + f"LiteLLM: Pyroscope profiling started (app_name={app_name}, server_address={server_address}). " + f"View CPU profiles at the Pyroscope UI and select application '{app_name}'." + ) + if "sample_rate" in configure_kwargs: + msg += f" sample_rate={configure_kwargs['sample_rate']}" + verbose_proxy_logger.info(msg) + except ImportError: + verbose_proxy_logger.warning( + "LiteLLM: LITELLM_ENABLE_PYROSCOPE is set but the 'pyroscope-io' package is not installed. " + "Pyroscope profiling will not run. Install with: pip install pyroscope-io" + ) #### API ENDPOINTS #### @router.get( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 390b0415d15..dab0c8237a7 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -917,6 +917,7 @@ model LiteLLM_PolicyTable { guardrails_add String[] @default([]) guardrails_remove String[] @default([]) condition Json? @default("{}") // Policy conditions (e.g., model matching) + pipeline Json? // Optional guardrail pipeline (mode + steps[]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt @@ -945,7 +946,7 @@ model LiteLLM_AccessGroupTable { description String? // Resource memberships - explicit arrays per type - access_model_ids String[] @default([]) + access_model_names String[] @default([]) access_mcp_server_ids String[] @default([]) access_agent_ids String[] @default([]) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d977751004c..66cf95f8e6b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -77,7 +77,10 @@ from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError -from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert @@ -110,6 +113,7 @@ _PROXY_MaxParallelRequestsHandler, ) from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES from litellm.types.mcp import ( @@ -117,6 +121,7 @@ MCPPreCallRequestObject, MCPPreCallResponseObject, ) +from litellm.types.proxy.policy_engine.pipeline_types import PipelineExecutionResult from litellm.types.utils import LLMResponseTypes, LoggedLiteLLMParams if TYPE_CHECKING: @@ -1141,6 +1146,95 @@ def _process_guardrail_metadata(self, data: dict) -> None: request_data=data, guardrail_name=guardrail_name ) + async def _maybe_execute_pipelines( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: str, + event_hook: str, + ) -> dict: + """ + Execute guardrail pipelines if any are configured for this request. + + Checks metadata for pipelines resolved by the policy engine + and executes them. Handles the result (allow/block/modify_response). + + Returns the (possibly modified) data dict. + """ + metadata = data.get("metadata", data.get("litellm_metadata", {})) or {} + pipelines = metadata.get("_guardrail_pipelines") + if not pipelines: + return data + + for policy_name, pipeline in pipelines: + if pipeline.mode != event_hook: + continue + + result: PipelineExecutionResult = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data=data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + policy_name=policy_name, + ) + + data = self._handle_pipeline_result( + result=result, + data=data, + policy_name=policy_name, + ) + + return data + + @staticmethod + def _handle_pipeline_result( + result: Any, + data: dict, + policy_name: str, + ) -> dict: + """ + Handle a PipelineExecutionResult — allow, block, or modify_response. + + Returns data dict if allowed, raises on block/modify_response. + """ + if result.terminal_action == "allow": + if result.modified_data is not None: + data.update(result.modified_data) + return data + + if result.terminal_action == "block": + step_results_serializable = [ + { + "guardrail": sr.guardrail_name, + "outcome": sr.outcome, + "action": sr.action_taken, + } + for sr in result.step_results + ] + error_detail = { + "error": { + "message": f"Content blocked by guardrail pipeline '{policy_name}'", + "type": "guardrail_pipeline_error", + "pipeline_context": { + "policy": policy_name, + "step_results": step_results_serializable, + }, + } + } + raise HTTPException(status_code=400, detail=error_detail) + + if result.terminal_action == "modify_response": + raise ModifyResponseException( + message=result.modify_response_message or "Response modified by pipeline", + model=data.get("model", "unknown"), + request_data=data, + guardrail_name=f"pipeline:{policy_name}", + detection_info=None, + ) + + return data + # The actual implementation of the function @overload async def pre_call_hook( @@ -1203,6 +1297,18 @@ async def pre_call_hook( ) try: + # Execute guardrail pipelines before the normal callback loop + data = await self._maybe_execute_pipelines( + data=data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + event_hook="pre_call", + ) + + # Get pipeline-managed guardrails to skip in normal loop + metadata = data.get("metadata", data.get("litellm_metadata", {})) or {} + pipeline_managed: set = metadata.get("_pipeline_managed_guardrails", set()) + for callback in litellm.callbacks: start_time = time.time() _callback = None @@ -1217,6 +1323,10 @@ async def pre_call_hook( and isinstance(_callback, CustomGuardrail) and data is not None ): + # Skip guardrails managed by a pipeline + if _callback.guardrail_name and _callback.guardrail_name in pipeline_managed: + continue + result = await self._process_guardrail_callback( callback=_callback, data=data, # type: ignore diff --git a/litellm/proxy/vector_store_endpoints/endpoints.py b/litellm/proxy/vector_store_endpoints/endpoints.py index 0775e05f4fa..30cabd3eeff 100644 --- a/litellm/proxy/vector_store_endpoints/endpoints.py +++ b/litellm/proxy/vector_store_endpoints/endpoints.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional from fastapi import APIRouter, Depends, HTTPException, Request, Response @@ -230,7 +230,7 @@ async def vector_store_create( ) # Get managed vector stores hook - managed_vector_stores = proxy_logging_obj.get_proxy_hook("managed_vector_stores") + managed_vector_stores: Any = proxy_logging_obj.get_proxy_hook("managed_vector_stores") if managed_vector_stores is None: raise HTTPException( status_code=500, diff --git a/litellm/rag/ingestion/vertex_ai_ingestion.py b/litellm/rag/ingestion/vertex_ai_ingestion.py index 47a94185d1d..7394ec7a616 100644 --- a/litellm/rag/ingestion/vertex_ai_ingestion.py +++ b/litellm/rag/ingestion/vertex_ai_ingestion.py @@ -10,7 +10,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from litellm._logging import verbose_logger from litellm.llms.custom_httpx.http_handler import ( diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index 08e31c59662..900f56fea26 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -1500,7 +1500,7 @@ def transform_chat_completion_response_to_responses_api_response( previous_response_id=getattr( chat_completion_response, "previous_response_id", None ), - reasoning=Reasoning(), + reasoning=dict(Reasoning()), status=LiteLLMCompletionResponsesConfig._map_chat_completion_finish_reason_to_responses_status( finish_reason ), @@ -1516,7 +1516,7 @@ def transform_chat_completion_response_to_responses_api_response( # Surface provider-specific fields (generic passthrough from any provider) provider_fields = responses_api_response._hidden_params.get("provider_specific_fields") if provider_fields: - responses_api_response.provider_specific_fields = provider_fields + setattr(responses_api_response, "provider_specific_fields", provider_fields) return responses_api_response diff --git a/litellm/types/access_group.py b/litellm/types/access_group.py index 3a6b75768ef..e26ebe00625 100644 --- a/litellm/types/access_group.py +++ b/litellm/types/access_group.py @@ -7,7 +7,7 @@ class AccessGroupCreateRequest(BaseModel): access_group_name: str description: Optional[str] = None - access_model_ids: Optional[List[str]] = None + access_model_names: Optional[List[str]] = None access_mcp_server_ids: Optional[List[str]] = None access_agent_ids: Optional[List[str]] = None assigned_team_ids: Optional[List[str]] = None @@ -15,8 +15,9 @@ class AccessGroupCreateRequest(BaseModel): class AccessGroupUpdateRequest(BaseModel): + access_group_name: Optional[str] = None description: Optional[str] = None - access_model_ids: Optional[List[str]] = None + access_model_names: Optional[List[str]] = None access_mcp_server_ids: Optional[List[str]] = None access_agent_ids: Optional[List[str]] = None assigned_team_ids: Optional[List[str]] = None @@ -27,7 +28,7 @@ class AccessGroupResponse(BaseModel): access_group_id: str access_group_name: str description: Optional[str] = None - access_model_ids: List[str] + access_model_names: List[str] access_mcp_server_ids: List[str] access_agent_ids: List[str] assigned_team_ids: List[str] diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/zscaler_ai_guard.py b/litellm/types/proxy/guardrails/guardrail_hooks/zscaler_ai_guard.py index 7cbdf751e1b..f522f5b470a 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/zscaler_ai_guard.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/zscaler_ai_guard.py @@ -106,6 +106,7 @@ def validate_endpoint_configuration(self) -> "ZscalerAIGuardConfigModel": ) # Check for configuration issues + assert api_base is not None # always set via env default above is_resolve_policy = api_base.endswith("/resolve-and-execute-policy") is_execute_policy = api_base.endswith("/execute-policy") and not is_resolve_policy diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index 42490c2eddc..e0c1d6f30da 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -10,6 +10,12 @@ - `policy_attachments`: Define WHERE policies apply (teams, keys, models) """ +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) from litellm.types.proxy.policy_engine.policy_types import ( Policy, PolicyAttachment, @@ -20,6 +26,7 @@ ) from litellm.types.proxy.policy_engine.resolver_types import ( AttachmentImpactResponse, + PipelineTestRequest, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -48,6 +55,11 @@ ) __all__ = [ + # Pipeline types + "GuardrailPipeline", + "PipelineStep", + "PipelineStepResult", + "PipelineExecutionResult", # Policy types "Policy", "PolicyConfig", @@ -79,6 +91,8 @@ "PolicyAttachmentCreateRequest", "PolicyAttachmentDBResponse", "PolicyAttachmentListResponse", + # Pipeline test types + "PipelineTestRequest", # Resolve types "PolicyResolveRequest", "PolicyResolveResponse", diff --git a/litellm/types/proxy/policy_engine/pipeline_types.py b/litellm/types/proxy/policy_engine/pipeline_types.py new file mode 100644 index 00000000000..29d2e576000 --- /dev/null +++ b/litellm/types/proxy/policy_engine/pipeline_types.py @@ -0,0 +1,98 @@ +""" +Pipeline type definitions for guardrail pipelines. + +Pipelines define ordered, conditional execution of guardrails within a policy. +When a policy has a `pipeline`, its guardrails run in the defined step order +with configurable actions on pass/fail, rather than independently. +""" + +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +VALID_PIPELINE_ACTIONS = {"allow", "block", "next", "modify_response"} +VALID_PIPELINE_MODES = {"pre_call", "post_call"} + + +class PipelineStep(BaseModel): + """ + A single step in a guardrail pipeline. + + Each step runs a guardrail and takes an action based on pass/fail. + """ + + guardrail: str = Field(description="Name of the guardrail to run.") + on_fail: str = Field( + default="block", + description="Action when guardrail rejects: next | block | allow | modify_response", + ) + on_pass: str = Field( + default="allow", + description="Action when guardrail passes: next | block | allow | modify_response", + ) + pass_data: bool = Field( + default=False, + description="Forward modified request data (e.g., PII-masked) to next step.", + ) + modify_response_message: Optional[str] = Field( + default=None, + description="Custom message for modify_response action.", + ) + + model_config = ConfigDict(extra="forbid") + + @field_validator("on_fail", "on_pass") + @classmethod + def validate_action(cls, v: str) -> str: + if v not in VALID_PIPELINE_ACTIONS: + raise ValueError( + f"Invalid action '{v}'. Must be one of: {sorted(VALID_PIPELINE_ACTIONS)}" + ) + return v + + +class GuardrailPipeline(BaseModel): + """ + Defines ordered execution of guardrails with conditional actions. + + When present on a policy, the guardrails in `steps` are executed + sequentially instead of independently. + """ + + mode: str = Field(description="Event hook: pre_call | post_call") + steps: List[PipelineStep] = Field( + description="Ordered list of pipeline steps. Must have at least 1 step.", + min_length=1, + ) + + model_config = ConfigDict(extra="forbid") + + @field_validator("mode") + @classmethod + def validate_mode(cls, v: str) -> str: + if v not in VALID_PIPELINE_MODES: + raise ValueError( + f"Invalid mode '{v}'. Must be one of: {sorted(VALID_PIPELINE_MODES)}" + ) + return v + + +class PipelineStepResult(BaseModel): + """Result of executing a single pipeline step.""" + + guardrail_name: str + outcome: Literal["pass", "fail", "error"] + action_taken: str + modified_data: Optional[Dict[str, Any]] = None + error_detail: Optional[str] = None + duration_seconds: Optional[float] = None + + +class PipelineExecutionResult(BaseModel): + """Result of executing an entire pipeline.""" + + terminal_action: str # block | allow | modify_response + step_results: List[PipelineStepResult] + modified_data: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + modify_response_message: Optional[str] = None diff --git a/litellm/types/proxy/policy_engine/policy_types.py b/litellm/types/proxy/policy_engine/policy_types.py index f221ba7e038..53a74ca6fd8 100644 --- a/litellm/types/proxy/policy_engine/policy_types.py +++ b/litellm/types/proxy/policy_engine/policy_types.py @@ -29,10 +29,12 @@ - `condition`: Optional model condition for when guardrails apply """ -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field +from litellm.types.proxy.policy_engine.pipeline_types import GuardrailPipeline + # ───────────────────────────────────────────────────────────────────────────── # Policy Condition # ───────────────────────────────────────────────────────────────────────────── @@ -231,6 +233,10 @@ class Policy(BaseModel): default=None, description="Optional condition for when this policy's guardrails apply.", ) + pipeline: Optional[GuardrailPipeline] = Field( + default=None, + description="Optional pipeline for ordered, conditional guardrail execution.", + ) model_config = ConfigDict(extra="forbid") diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py index 0c2c7336f8a..a5a2334ae4b 100644 --- a/litellm/types/proxy/policy_engine/resolver_types.py +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -154,6 +154,10 @@ class PolicyCreateRequest(BaseModel): default=None, description="Condition for when this policy applies.", ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional guardrail pipeline for ordered execution. Contains 'mode' and 'steps'.", + ) class PolicyUpdateRequest(BaseModel): @@ -183,6 +187,10 @@ class PolicyUpdateRequest(BaseModel): default=None, description="Condition for when this policy applies.", ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional guardrail pipeline for ordered execution. Contains 'mode' and 'steps'.", + ) class PolicyDBResponse(BaseModel): @@ -201,6 +209,9 @@ class PolicyDBResponse(BaseModel): condition: Optional[Dict[str, Any]] = Field( default=None, description="Policy condition." ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, description="Optional guardrail pipeline." + ) created_at: Optional[datetime] = Field( default=None, description="When the policy was created." ) @@ -291,6 +302,17 @@ class PolicyAttachmentListResponse(BaseModel): # ───────────────────────────────────────────────────────────────────────────── +class PipelineTestRequest(BaseModel): + """Request body for testing a guardrail pipeline with sample messages.""" + + pipeline: Dict[str, Any] = Field( + description="Pipeline definition with 'mode' and 'steps'.", + ) + test_messages: List[Dict[str, str]] = Field( + description="Test messages to run through the pipeline, e.g. [{'role': 'user', 'content': '...'}].", + ) + + class PolicyResolveRequest(BaseModel): """Request body for resolving effective policies/guardrails for a context.""" diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index e6b7cf17297..18d0f0079ba 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -14835,7 +14835,9 @@ "supports_tool_choice": true, "supports_url_context": true, "supports_vision": true, - "supports_web_search": true + "supports_web_search": true, + "tpm": 250000, + "rpm": 10 }, "gemini-2.5-computer-use-preview-10-2025": { "input_cost_per_token": 1.25e-06, @@ -16323,7 +16325,9 @@ "source": "https://ai.google.dev/pricing", "supported_endpoints": [ "/v1/audio/speech" - ] + ], + "tpm": 4000000, + "rpm": 10 }, "gemini/gemini-2.5-pro": { "cache_read_input_token_cost": 1.25e-07, @@ -16821,7 +16825,9 @@ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "supports_function_calling": true, "supports_tool_choice": true, - "supports_vision": true + "supports_vision": true, + "tpm": 250000, + "rpm": 10 }, "gemini/gemini-gemma-2-9b-it": { "input_cost_per_token": 3.5e-07, @@ -16833,7 +16839,9 @@ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "supports_function_calling": true, "supports_tool_choice": true, - "supports_vision": true + "supports_vision": true, + "tpm": 250000, + "rpm": 10 }, "gemini/gemini-pro": { "input_cost_per_token": 3.5e-07, @@ -36495,7 +36503,9 @@ "text", "image" ], - "supports_vision": true + "supports_vision": true, + "tpm": 250000, + "rpm": 10 }, "gemini/gemini-2.0-flash-lite-001": { "cache_read_input_token_cost": 1.875e-08, @@ -36628,7 +36638,9 @@ "audio" ], "supports_audio_input": true, - "supports_audio_output": true + "supports_audio_output": true, + "tpm": 250000, + "rpm": 10 }, "gemini/gemini-2.5-flash-native-audio-preview-09-2025": { "input_cost_per_audio_token": 1e-06, @@ -36652,7 +36664,9 @@ "audio" ], "supports_audio_input": true, - "supports_audio_output": true + "supports_audio_output": true, + "tpm": 250000, + "rpm": 10 }, "gemini/gemini-2.5-flash-native-audio-preview-12-2025": { "input_cost_per_audio_token": 1e-06, @@ -36676,7 +36690,9 @@ "audio" ], "supports_audio_input": true, - "supports_audio_output": true + "supports_audio_output": true, + "tpm": 250000, + "rpm": 10 }, "gemini-2.5-flash-preview-tts": { "input_cost_per_token": 3e-07, diff --git a/poetry.lock b/poetry.lock index d01baa854af..e30857a3b2f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "a2a-sdk" @@ -5659,6 +5659,24 @@ files = [ [package.extras] dev = ["build", "flake8", "mypy", "pytest", "twine"] +[[package]] +name = "pyroscope-io" +version = "0.8.16" +description = "Pyroscope Python integration" +optional = false +python-versions = "*" +groups = ["main"] +markers = "extra == \"proxy\" and sys_platform != \"win32\"" +files = [ + {file = "pyroscope_io-0.8.16-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:e07edcfd59f5bdce42948b92c9b118c824edbd551730305f095a6b9af401a9e8"}, + {file = "pyroscope_io-0.8.16-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:dc98355e27c0b7b61f27066500fe1045b70e9459bb8b9a3082bc4755cb6392b6"}, + {file = "pyroscope_io-0.8.16-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:86f0f047554ff62bd92c3e5a26bc2809ccd467d11fbacb9fef898ba299dbda59"}, + {file = "pyroscope_io-0.8.16-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6b91ce5b240f8de756c16a17022ca8e25ef8a4eed461c7d074b8a0841cf7b445"}, +] + +[package.dependencies] +cffi = ">=1.6.0" + [[package]] name = "pytest" version = "7.4.4" @@ -8516,7 +8534,7 @@ extra-proxy = ["a2a-sdk", "azure-identity", "azure-keyvault-secrets", "google-cl google = ["google-cloud-aiplatform"] grpc = ["grpcio", "grpcio"] mlflow = ["mlflow"] -proxy = ["PyJWT", "apscheduler", "azure-identity", "azure-storage-blob", "backoff", "boto3", "cryptography", "fastapi", "fastapi-sso", "gunicorn", "litellm-enterprise", "litellm-proxy-extras", "mcp", "orjson", "polars", "pynacl", "python-multipart", "pyyaml", "rich", "rq", "soundfile", "uvicorn", "uvloop", "websockets"] +proxy = ["PyJWT", "apscheduler", "azure-identity", "azure-storage-blob", "backoff", "boto3", "cryptography", "fastapi", "fastapi-sso", "gunicorn", "litellm-enterprise", "litellm-proxy-extras", "mcp", "orjson", "polars", "pynacl", "pyroscope-io", "python-multipart", "pyyaml", "rich", "rq", "soundfile", "uvicorn", "uvloop", "websockets"] semantic-router = ["semantic-router"] utils = ["numpydoc"] diff --git a/pyproject.toml b/pyproject.toml index 6ed7618dd26..be15013267b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ boto3 = { version = "1.40.76", optional = true } redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.9' and python_version < '3.14'"} mcp = {version = ">=1.25.0,<2.0.0", optional = true, python = ">=3.10"} a2a-sdk = {version = "^0.3.22", optional = true, python = ">=3.10"} -litellm-proxy-extras = {version = "0.4.36", optional = true} +litellm-proxy-extras = {version = "0.4.37", optional = true} rich = {version = "13.7.1", optional = true} litellm-enterprise = {version = "0.1.31", optional = true} diskcache = {version = "^5.6.1", optional = true} @@ -69,6 +69,7 @@ polars = {version = "^1.31.0", optional = true, python = ">=3.10"} semantic-router = {version = ">=0.1.12", optional = true, python = ">=3.9,<3.14"} mlflow = {version = ">3.1.4", optional = true, python = ">=3.10"} soundfile = {version = "^0.12.1", optional = true} +pyroscope-io = {version = "^0.8", optional = true, markers = "sys_platform != 'win32'"} # grpcio constraints: # - 1.62.3+ required by grpcio-status # - 1.68.0-1.68.1 has reconnect bug (https://github.com/grpc/grpc/issues/38290) @@ -104,6 +105,7 @@ proxy = [ "rich", "polars", "soundfile", + "pyroscope-io", ] extra_proxy = [ @@ -121,6 +123,8 @@ utils = [ "numpydoc", ] + + caching = ["diskcache"] semantic-router = ["semantic-router"] diff --git a/requirements.txt b/requirements.txt index f31730e20f5..18fc39dca3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,7 +55,7 @@ grpcio>=1.75.0; python_version >= "3.14" sentry_sdk==2.21.0 # for sentry error handling detect-secrets==1.5.0 # Enterprise - secret detection / masking in LLM requests tzdata==2025.1 # IANA time zone database -litellm-proxy-extras==0.4.36 # for proxy extras - e.g. prisma migrations +litellm-proxy-extras==0.4.37 # for proxy extras - e.g. prisma migrations llm-sandbox==0.3.31 # for skill execution in sandbox ### LITELLM PACKAGE DEPENDENCIES python-dotenv==1.0.1 # for env diff --git a/schema.prisma b/schema.prisma index 2a11d0028fb..01965eaafc2 100644 --- a/schema.prisma +++ b/schema.prisma @@ -930,7 +930,7 @@ model LiteLLM_AccessGroupTable { description String? // Resource memberships - explicit arrays per type - access_model_ids String[] @default([]) + access_model_names String[] @default([]) access_mcp_server_ids String[] @default([]) access_agent_ids String[] @default([]) diff --git a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py index f8a082ee30c..c35cedd3c7f 100644 --- a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py +++ b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py @@ -1278,3 +1278,86 @@ def test_transform_response_preserves_annotations(): assert result.usage.total_tokens == 30 print("✓ Annotations from Responses API are correctly preserved in Chat Completions format") + + +def test_convert_chat_completion_messages_to_responses_api_system_string(): + """Test that string system content is extracted into instructions.""" + from litellm.completion_extras.litellm_responses_transformation.transformation import ( + LiteLLMResponsesTransformationHandler, + ) + + handler = LiteLLMResponsesTransformationHandler() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + + input_items, instructions = handler.convert_chat_completion_messages_to_responses_api(messages) + + assert instructions == "You are a helpful assistant." + # System message should NOT appear in input items + for item in input_items: + assert item.get("role") != "system" + # User message should be in input items + assert len(input_items) == 1 + assert input_items[0]["role"] == "user" + + +def test_convert_chat_completion_messages_to_responses_api_system_list_content(): + """Test that list-format system content blocks are extracted into instructions. + + This happens when requests arrive via the Anthropic /v1/messages adapter, + which converts system prompts into list-format content blocks. + """ + from litellm.completion_extras.litellm_responses_transformation.transformation import ( + LiteLLMResponsesTransformationHandler, + ) + + handler = LiteLLMResponsesTransformationHandler() + + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise."}, + ], + }, + {"role": "user", "content": "Hello"}, + ] + + input_items, instructions = handler.convert_chat_completion_messages_to_responses_api(messages) + + assert instructions == "You are a helpful assistant. Be concise." + # System message should NOT appear in input items + for item in input_items: + assert item.get("role") != "system" + assert len(input_items) == 1 + assert input_items[0]["role"] == "user" + + +def test_convert_chat_completion_messages_to_responses_api_multiple_system_messages(): + """Test that multiple system messages (string and list) are concatenated.""" + from litellm.completion_extras.litellm_responses_transformation.transformation import ( + LiteLLMResponsesTransformationHandler, + ) + + handler = LiteLLMResponsesTransformationHandler() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "system", + "content": [ + {"type": "text", "text": "Be concise."}, + ], + }, + {"role": "user", "content": "Hello"}, + ] + + input_items, instructions = handler.convert_chat_completion_messages_to_responses_api(messages) + + assert instructions == "You are a helpful assistant. Be concise." + for item in input_items: + assert item.get("role") != "system" diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 9b7b7f46155..00e348b5b7c 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -422,3 +422,77 @@ async def test_return_user_api_key_auth_obj_user_spend_and_budget(): assert result.user_tpm_limit == 1000 assert result.user_rpm_limit == 100 assert result.user_email == "test@example.com" + + +def test_proxy_admin_jwt_auth_includes_identity_fields(): + """ + Test that the proxy admin early-return path in JWT auth populates + user_id, team_id, team_alias, team_metadata, org_id, and end_user_id. + + Regression test: previously the is_proxy_admin branch only set user_role + and parent_otel_span, discarding all identity fields resolved from the JWT. + This caused blank Team Name and Internal User in Request Logs UI. + """ + from litellm.proxy._types import LiteLLM_TeamTable, LitellmUserRoles, UserAPIKeyAuth + + team_object = LiteLLM_TeamTable( + team_id="team-123", + team_alias="my-team", + metadata={"tags": ["prod"], "env": "production"}, + ) + + # Simulate the proxy admin early-return path (user_api_key_auth.py ~line 586) + result = UserAPIKeyAuth( + api_key=None, + user_role=LitellmUserRoles.PROXY_ADMIN, + user_id="user-abc", + team_id="team-123", + team_alias=( + team_object.team_alias if team_object is not None else None + ), + team_metadata=team_object.metadata if team_object is not None else None, + org_id="org-456", + end_user_id="end-user-789", + parent_otel_span=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + assert result.user_id == "user-abc" + assert result.team_id == "team-123" + assert result.team_alias == "my-team" + assert result.team_metadata == {"tags": ["prod"], "env": "production"} + assert result.org_id == "org-456" + assert result.end_user_id == "end-user-789" + assert result.api_key is None + + +def test_proxy_admin_jwt_auth_handles_no_team_object(): + """ + Test that the proxy admin early-return path works correctly when + team_object is None (user has admin role but no team association). + """ + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + + team_object = None + + result = UserAPIKeyAuth( + api_key=None, + user_role=LitellmUserRoles.PROXY_ADMIN, + user_id="admin-user", + team_id=None, + team_alias=( + team_object.team_alias if team_object is not None else None + ), + team_metadata=team_object.metadata if team_object is not None else None, + org_id=None, + end_user_id=None, + parent_otel_span=None, + ) + + assert result.user_role == LitellmUserRoles.PROXY_ADMIN + assert result.user_id == "admin-user" + assert result.team_id is None + assert result.team_alias is None + assert result.team_metadata is None + assert result.org_id is None + assert result.end_user_id is None diff --git a/tests/test_litellm/proxy/management_endpoints/test_access_group_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_access_group_endpoints.py index 5f204918d08..54df8941fa5 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_access_group_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_access_group_endpoints.py @@ -28,7 +28,7 @@ def _make_access_group_record( access_group_id: str = "ag-123", access_group_name: str = "test-group", description: str | None = "Test description", - access_model_ids: list | None = None, + access_model_names: list | None = None, access_mcp_server_ids: list | None = None, access_agent_ids: list | None = None, assigned_team_ids: list | None = None, @@ -41,7 +41,7 @@ def _make_access_group_record( record.access_group_id = access_group_id record.access_group_name = access_group_name record.description = description - record.access_model_ids = access_model_ids or [] + record.access_model_names = access_model_names or [] record.access_mcp_server_ids = access_mcp_server_ids or [] record.access_agent_ids = access_agent_ids or [] record.assigned_team_ids = assigned_team_ids or [] @@ -64,7 +64,7 @@ def _create_side_effect(*, data): access_group_id="ag-new", access_group_name=data.get("access_group_name", "new"), description=data.get("description"), - access_model_ids=data.get("access_model_ids", []), + access_model_names=data.get("access_model_names", []), access_mcp_server_ids=data.get("access_mcp_server_ids", []), access_agent_ids=data.get("access_agent_ids", []), assigned_team_ids=data.get("assigned_team_ids", []), @@ -80,7 +80,7 @@ def _create_side_effect(*, data): access_group_id=where.get("access_group_id", "ag-123"), access_group_name=data.get("access_group_name", "updated"), description=data.get("description"), - access_model_ids=data.get("access_model_ids", []), + access_model_names=data.get("access_model_names", []), access_mcp_server_ids=data.get("access_mcp_server_ids", []), access_agent_ids=data.get("access_agent_ids", []), assigned_team_ids=data.get("assigned_team_ids", []), @@ -147,7 +147,7 @@ async def mock_tx(): { "access_group_name": "group-b", "description": "Group B description", - "access_model_ids": ["model-1"], + "access_model_names": ["model-1"], "access_mcp_server_ids": ["mcp-1"], "assigned_team_ids": ["team-1"], }, @@ -369,7 +369,7 @@ def test_get_access_group_forbidden_non_admin(client_and_mocks, user_role): "update_payload", [ {"description": "Updated description"}, - {"access_model_ids": ["model-1", "model-2"]}, + {"access_model_names": ["model-1", "model-2"]}, {"assigned_team_ids": [], "assigned_key_ids": ["key-1"]}, ], ) @@ -431,6 +431,57 @@ def test_update_access_group_empty_body(client_and_mocks): assert call_kwargs["data"]["updated_by"] == "admin_user" +def test_update_access_group_name_success(client_and_mocks): + """Update access_group_name succeeds when new name is unique.""" + client, _, mock_table = client_and_mocks + + existing = _make_access_group_record(access_group_id="ag-update", access_group_name="old-name") + mock_table.find_unique = AsyncMock(return_value=existing) + + resp = client.put("/v1/access_group/ag-update", json={"access_group_name": "new-name"}) + assert resp.status_code == 200 + mock_table.update.assert_awaited_once() + call_kwargs = mock_table.update.call_args.kwargs + assert call_kwargs["data"]["access_group_name"] == "new-name" + + +def test_update_access_group_name_duplicate_conflict(client_and_mocks): + """Update access_group_name to existing name returns 409 (unique constraint).""" + client, _, mock_table = client_and_mocks + + existing = _make_access_group_record(access_group_id="ag-update", access_group_name="old-name") + mock_table.find_unique = AsyncMock(return_value=existing) + mock_table.update = AsyncMock( + side_effect=Exception("Unique constraint failed on the fields: (`access_group_name`)") + ) + + resp = client.put("/v1/access_group/ag-update", json={"access_group_name": "taken-name"}) + assert resp.status_code == 409 + assert "already exists" in resp.json()["detail"] + mock_table.update.assert_awaited_once() + + +@pytest.mark.parametrize( + "error_message", + [ + "Unique constraint failed on the fields: (`access_group_name`)", + "P2002: Unique constraint failed", + "unique constraint violation", + ], +) +def test_update_access_group_name_unique_constraint_returns_409(client_and_mocks, error_message): + """Update access_group_name: Prisma unique constraint surfaces as 409.""" + client, _, mock_table = client_and_mocks + + existing = _make_access_group_record(access_group_id="ag-update", access_group_name="old-name") + mock_table.find_unique = AsyncMock(return_value=existing) + mock_table.update = AsyncMock(side_effect=Exception(error_message)) + + resp = client.put("/v1/access_group/ag-update", json={"access_group_name": "race-name"}) + assert resp.status_code == 409 + assert "already exists" in resp.json()["detail"] + + # --------------------------------------------------------------------------- # DELETE # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py index e50e10352e2..7ec97ddc185 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py @@ -2087,6 +2087,143 @@ async def test_add_litellm_data_to_request_adds_headers_to_metadata(): assert "headers" in result["proxy_server_request"] +@pytest.mark.asyncio +async def test_create_pass_through_route_custom_body_url_target(): + """ + Test that the URL-based endpoint_func created by create_pass_through_route + accepts a custom_body parameter and forwards it to pass_through_request, + taking precedence over the request-parsed body. + + This verifies the fix for issue #16999 where bedrock_proxy_route passes + custom_body=data to the endpoint function, which previously crashed with: + TypeError: endpoint_func() got an unexpected keyword argument 'custom_body' + """ + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + create_pass_through_route, + ) + + unique_path = "/test/path/unique/custom_body_url" + endpoint_func = create_pass_through_route( + endpoint=unique_path, + target="https://bedrock-agent-runtime.us-east-1.amazonaws.com", + custom_headers={"Content-Type": "application/json"}, + _forward_headers=True, + ) + + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_request" + ) as mock_pass_through, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.is_registered_pass_through_route" + ) as mock_is_registered, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.get_registered_pass_through_route" + ) as mock_get_registered, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints._parse_request_data_by_content_type" + ) as mock_parse_request: + mock_pass_through.return_value = MagicMock() + mock_is_registered.return_value = True + mock_get_registered.return_value = None + # Simulate the request parser returning a different body + mock_parse_request.return_value = ( + {}, # query_params_data + {"parsed_from_request": True}, # custom_body_data (from request) + None, # file_data + False, # stream + ) + + mock_request = MagicMock(spec=Request) + mock_request.url = MagicMock() + mock_request.url.path = unique_path + mock_request.path_params = {} + mock_request.query_params = QueryParams({}) + + mock_user_api_key_dict = MagicMock() + mock_user_api_key_dict.api_key = "test-key" + + # The caller-supplied body (e.g. from bedrock_proxy_route) + bedrock_body = { + "retrievalQuery": {"text": "What is in the knowledge base?"}, + } + + # Call endpoint_func with custom_body — this is the call that + # used to crash with TypeError before the fix + await endpoint_func( + request=mock_request, + fastapi_response=MagicMock(), + user_api_key_dict=mock_user_api_key_dict, + custom_body=bedrock_body, + ) + + mock_pass_through.assert_called_once() + call_kwargs = mock_pass_through.call_args[1] + + # The critical assertion: custom_body takes precedence over + # the body parsed from the raw request + assert call_kwargs["custom_body"] == bedrock_body + + +@pytest.mark.asyncio +async def test_create_pass_through_route_no_custom_body_falls_back(): + """ + Test that the URL-based endpoint_func falls back to the request-parsed body + when custom_body is not provided. + + This ensures the default pass-through behavior is preserved — only the + Bedrock proxy route (and similar callers) supply a pre-built body. + """ + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + create_pass_through_route, + ) + + unique_path = "/test/path/unique/no_custom_body" + endpoint_func = create_pass_through_route( + endpoint=unique_path, + target="http://example.com/api", + custom_headers={}, + ) + + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_request" + ) as mock_pass_through, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.is_registered_pass_through_route" + ) as mock_is_registered, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.get_registered_pass_through_route" + ) as mock_get_registered, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints._parse_request_data_by_content_type" + ) as mock_parse_request: + mock_pass_through.return_value = MagicMock() + mock_is_registered.return_value = True + mock_get_registered.return_value = None + request_parsed_body = {"key": "from_request"} + mock_parse_request.return_value = ( + {}, # query_params_data + request_parsed_body, # custom_body_data + None, # file_data + False, # stream + ) + + mock_request = MagicMock(spec=Request) + mock_request.url = MagicMock() + mock_request.url.path = unique_path + mock_request.path_params = {} + mock_request.query_params = QueryParams({}) + + mock_user_api_key_dict = MagicMock() + mock_user_api_key_dict.api_key = "test-key" + + # Call without custom_body — should use the request-parsed body + await endpoint_func( + request=mock_request, + fastapi_response=MagicMock(), + user_api_key_dict=mock_user_api_key_dict, + ) + + mock_pass_through.assert_called_once() + call_kwargs = mock_pass_through.call_args[1] + + # Should fall back to the body parsed from the request + assert call_kwargs["custom_body"] == request_parsed_body + + def test_build_full_path_with_root_default(): """ Test _build_full_path_with_root with default root path (/) diff --git a/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py b/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py new file mode 100644 index 00000000000..226e88bea3e --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py @@ -0,0 +1,484 @@ +""" +Tests for the pipeline executor. + +Uses mock guardrails to validate pipeline execution without external services. +""" + +from unittest.mock import MagicMock + +import pytest + +import litellm +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineStep, +) + +try: + from fastapi.exceptions import HTTPException +except ImportError: + HTTPException = None + + +# ───────────────────────────────────────────────────────────────────────────── +# Mock Guardrails +# ───────────────────────────────────────────────────────────────────────────── + + +class AlwaysFailGuardrail(CustomGuardrail): + """Mock guardrail that always raises HTTPException(400).""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + raise HTTPException(status_code=400, detail="Content policy violation") + + +class AlwaysPassGuardrail(CustomGuardrail): + """Mock guardrail that always passes.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + return None + + +class PiiMaskingGuardrail(CustomGuardrail): + """Mock guardrail that masks PII in messages and returns modified data.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + self.received_messages = None + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + self.received_messages = data.get("messages", []) + masked_messages = [] + for msg in data.get("messages", []): + masked_msg = dict(msg) + masked_msg["content"] = msg["content"].replace( + "John Smith", "[REDACTED]" + ) + masked_messages.append(masked_msg) + return {"messages": masked_messages} + + +class ContentCheckGuardrail(CustomGuardrail): + """Mock guardrail that records what messages it received.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + self.received_messages = None + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + self.received_messages = data.get("messages", []) + return None + + +# ───────────────────────────────────────────────────────────────────────────── +# Tests +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_escalation_step1_fails_step2_blocks(): + """ + Pipeline: simple-filter (on_fail: next) -> advanced-filter (on_fail: block) + Input: request that fails simple-filter + Expected: simple-filter fails -> escalate -> advanced-filter fails -> block + """ + simple_guard = AlwaysFailGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysFailGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "bad content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 1 + assert result.terminal_action == "block" + assert len(result.step_results) == 2 + assert result.step_results[0].guardrail_name == "simple-filter" + assert result.step_results[0].outcome == "fail" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].guardrail_name == "advanced-filter" + assert result.step_results[1].outcome == "fail" + assert result.step_results[1].action_taken == "block" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_early_allow_step1_passes_step2_skipped(): + """ + Pipeline: simple-filter (on_pass: allow) -> advanced-filter + Input: clean request that passes simple-filter + Expected: simple-filter passes -> allow (advanced-filter never called) + """ + simple_guard = AlwaysPassGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysFailGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "clean content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 0 + assert result.terminal_action == "allow" + assert len(result.step_results) == 1 + assert result.step_results[0].outcome == "pass" + assert result.step_results[0].action_taken == "allow" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_escalation_step1_fails_step2_passes(): + """ + Pipeline: simple-filter (on_fail: next) -> advanced-filter (on_pass: allow) + Input: request that fails simple but passes advanced + Expected: simple-filter fails -> escalate -> advanced-filter passes -> allow + """ + simple_guard = AlwaysFailGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysPassGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "borderline content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 1 + assert result.terminal_action == "allow" + assert len(result.step_results) == 2 + assert result.step_results[0].outcome == "fail" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].outcome == "pass" + assert result.step_results[1].action_taken == "allow" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_data_forwarding_pii_masking(): + """ + Pipeline: pii-masker (pass_data: true, on_pass: next) -> content-check (on_pass: allow) + Input: "Hello John Smith" + Expected: pii-masker masks -> content-check receives "[REDACTED]" -> allow + """ + pii_guard = PiiMaskingGuardrail(guardrail_name="pii-masker") + content_guard = ContentCheckGuardrail(guardrail_name="content-check") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="pii-masker", + on_fail="block", + on_pass="next", + pass_data=True, + ), + PipelineStep( + guardrail="content-check", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [pii_guard, content_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={ + "messages": [{"role": "user", "content": "Hello John Smith"}] + }, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="pii-then-safety", + ) + + assert pii_guard.calls == 1 + assert content_guard.calls == 1 + assert content_guard.received_messages[0]["content"] == "Hello [REDACTED]" + assert result.terminal_action == "allow" + assert result.modified_data is not None + assert result.modified_data["messages"][0]["content"] == "Hello [REDACTED]" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_guardrail_not_found_uses_on_fail(): + """ + If a guardrail is not found, treat as error and use on_fail action. + """ + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="nonexistent-guard", + on_fail="block", + on_pass="allow", + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test-policy", + ) + + assert result.terminal_action == "block" + assert result.step_results[0].outcome == "error" + assert "not found" in result.step_results[0].error_detail + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_guardrail_not_found_with_next_continues(): + """ + If a guardrail is not found and on_fail is 'next', continue to next step. + """ + pass_guard = AlwaysPassGuardrail(guardrail_name="fallback-guard") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="nonexistent-guard", + on_fail="next", + on_pass="allow", + ), + PipelineStep( + guardrail="fallback-guard", + on_fail="block", + on_pass="allow", + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [pass_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test-policy", + ) + + assert result.terminal_action == "allow" + assert len(result.step_results) == 2 + assert result.step_results[0].outcome == "error" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].outcome == "pass" + assert pass_guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_single_step_pipeline_block(): + """Single step pipeline that blocks.""" + guard = AlwaysFailGuardrail(guardrail_name="blocker") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="blocker", on_fail="block")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.terminal_action == "block" + assert guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_single_step_pipeline_allow(): + """Single step pipeline that allows.""" + guard = AlwaysPassGuardrail(guardrail_name="passer") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="passer", on_pass="allow")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.terminal_action == "allow" + assert guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_step_results_include_duration(): + """Step results should include timing information.""" + guard = AlwaysPassGuardrail(guardrail_name="timed") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="timed")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.step_results[0].duration_seconds is not None + assert result.step_results[0].duration_seconds >= 0 + finally: + litellm.callbacks = original_callbacks diff --git a/tests/test_litellm/proxy/test_pyroscope.py b/tests/test_litellm/proxy/test_pyroscope.py new file mode 100644 index 00000000000..6bfdf81ec1a --- /dev/null +++ b/tests/test_litellm/proxy/test_pyroscope.py @@ -0,0 +1,138 @@ +"""Unit tests for ProxyStartupEvent._init_pyroscope (Grafana Pyroscope profiling).""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from litellm.proxy.proxy_server import ProxyStartupEvent + + +def _mock_pyroscope_module(): + """Return a mock module so 'import pyroscope' succeeds in _init_pyroscope.""" + m = MagicMock() + m.configure = MagicMock() + return m + + +def test_init_pyroscope_returns_cleanly_when_disabled(): + """When LITELLM_ENABLE_PYROSCOPE is false, _init_pyroscope returns without error.""" + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=False, + ): + ProxyStartupEvent._init_pyroscope() + + +def test_init_pyroscope_raises_when_enabled_but_missing_app_name(): + """When LITELLM_ENABLE_PYROSCOPE is true but PYROSCOPE_APP_NAME is not set, raises ValueError.""" + mock_pyroscope = _mock_pyroscope_module() + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=True, + ), patch.dict( + sys.modules, + {"pyroscope": mock_pyroscope}, + ), patch.dict( + os.environ, + { + "PYROSCOPE_APP_NAME": "", + "PYROSCOPE_SERVER_ADDRESS": "http://localhost:4040", + }, + clear=False, + ): + with pytest.raises(ValueError, match="PYROSCOPE_APP_NAME"): + ProxyStartupEvent._init_pyroscope() + + +def test_init_pyroscope_raises_when_enabled_but_missing_server_address(): + """When LITELLM_ENABLE_PYROSCOPE is true but PYROSCOPE_SERVER_ADDRESS is not set, raises ValueError.""" + mock_pyroscope = _mock_pyroscope_module() + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=True, + ), patch.dict( + sys.modules, + {"pyroscope": mock_pyroscope}, + ), patch.dict( + os.environ, + { + "PYROSCOPE_APP_NAME": "myapp", + "PYROSCOPE_SERVER_ADDRESS": "", + }, + clear=False, + ): + with pytest.raises(ValueError, match="PYROSCOPE_SERVER_ADDRESS"): + ProxyStartupEvent._init_pyroscope() + + +def test_init_pyroscope_raises_when_sample_rate_invalid(): + """When PYROSCOPE_SAMPLE_RATE is not a number, raises ValueError.""" + mock_pyroscope = _mock_pyroscope_module() + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=True, + ), patch.dict( + sys.modules, + {"pyroscope": mock_pyroscope}, + ), patch.dict( + os.environ, + { + "PYROSCOPE_APP_NAME": "myapp", + "PYROSCOPE_SERVER_ADDRESS": "http://localhost:4040", + "PYROSCOPE_SAMPLE_RATE": "not-a-number", + }, + clear=False, + ): + with pytest.raises(ValueError, match="PYROSCOPE_SAMPLE_RATE"): + ProxyStartupEvent._init_pyroscope() + + +def test_init_pyroscope_accepts_integer_sample_rate(): + """When enabled with valid config and integer sample rate, configures pyroscope.""" + mock_pyroscope = _mock_pyroscope_module() + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=True, + ), patch.dict( + sys.modules, + {"pyroscope": mock_pyroscope}, + ), patch.dict( + os.environ, + { + "PYROSCOPE_APP_NAME": "myapp", + "PYROSCOPE_SERVER_ADDRESS": "http://localhost:4040", + "PYROSCOPE_SAMPLE_RATE": "100", + }, + clear=False, + ): + ProxyStartupEvent._init_pyroscope() + mock_pyroscope.configure.assert_called_once() + call_kw = mock_pyroscope.configure.call_args[1] + assert call_kw["app_name"] == "myapp" + assert call_kw["server_address"] == "http://localhost:4040" + assert call_kw["sample_rate"] == 100 + + +def test_init_pyroscope_accepts_float_sample_rate_parsed_as_int(): + """PYROSCOPE_SAMPLE_RATE can be a float string; it is parsed as integer.""" + mock_pyroscope = _mock_pyroscope_module() + with patch( + "litellm.proxy.proxy_server.get_secret_bool", + return_value=True, + ), patch.dict( + sys.modules, + {"pyroscope": mock_pyroscope}, + ), patch.dict( + os.environ, + { + "PYROSCOPE_APP_NAME": "myapp", + "PYROSCOPE_SERVER_ADDRESS": "http://localhost:4040", + "PYROSCOPE_SAMPLE_RATE": "100.7", + }, + clear=False, + ): + ProxyStartupEvent._init_pyroscope() + call_kw = mock_pyroscope.configure.call_args[1] + assert call_kw["sample_rate"] == 100 diff --git a/tests/test_litellm/types/__init__.py b/tests/test_litellm/types/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/__init__.py b/tests/test_litellm/types/proxy/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/policy_engine/__init__.py b/tests/test_litellm/types/proxy/policy_engine/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py b/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py new file mode 100644 index 00000000000..21fecc015a3 --- /dev/null +++ b/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py @@ -0,0 +1,152 @@ +""" +Tests for pipeline type definitions. +""" + +import pytest +from pydantic import ValidationError + +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) +from litellm.types.proxy.policy_engine.policy_types import ( + Policy, + PolicyGuardrails, +) + + +def test_pipeline_step_defaults(): + step = PipelineStep(guardrail="my-guard") + assert step.on_fail == "block" + assert step.on_pass == "allow" + assert step.pass_data is False + assert step.modify_response_message is None + + +def test_pipeline_step_valid_actions(): + step = PipelineStep(guardrail="my-guard", on_fail="next", on_pass="next") + assert step.on_fail == "next" + assert step.on_pass == "next" + + +def test_pipeline_step_all_action_types(): + for action in ("allow", "block", "next", "modify_response"): + step = PipelineStep(guardrail="g", on_fail=action, on_pass=action) + assert step.on_fail == action + assert step.on_pass == action + + +def test_pipeline_step_invalid_action_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="my-guard", on_fail="invalid_action") + + +def test_pipeline_step_invalid_on_pass_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="my-guard", on_pass="skip") + + +def test_pipeline_requires_at_least_one_step(): + with pytest.raises(ValidationError): + GuardrailPipeline(mode="pre_call", steps=[]) + + +def test_pipeline_invalid_mode_rejected(): + with pytest.raises(ValidationError): + GuardrailPipeline( + mode="during_call", + steps=[PipelineStep(guardrail="g")], + ) + + +def test_pipeline_valid_modes(): + for mode in ("pre_call", "post_call"): + pipeline = GuardrailPipeline( + mode=mode, + steps=[PipelineStep(guardrail="g")], + ) + assert pipeline.mode == mode + + +def test_pipeline_with_multiple_steps(): + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep(guardrail="g1", on_fail="next", on_pass="allow"), + PipelineStep(guardrail="g2", on_fail="block", on_pass="allow"), + ], + ) + assert len(pipeline.steps) == 2 + assert pipeline.steps[0].guardrail == "g1" + assert pipeline.steps[1].guardrail == "g2" + + +def test_policy_with_pipeline_parses(): + policy = Policy( + guardrails=PolicyGuardrails(add=["g1", "g2"]), + pipeline=GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep(guardrail="g1", on_fail="next"), + PipelineStep(guardrail="g2"), + ], + ), + ) + assert policy.pipeline is not None + assert len(policy.pipeline.steps) == 2 + + +def test_policy_without_pipeline(): + policy = Policy( + guardrails=PolicyGuardrails(add=["g1"]), + ) + assert policy.pipeline is None + + +def test_pipeline_step_result(): + result = PipelineStepResult( + guardrail_name="g1", + outcome="fail", + action_taken="next", + error_detail="Content policy violation", + duration_seconds=0.05, + ) + assert result.outcome == "fail" + assert result.action_taken == "next" + + +def test_pipeline_execution_result(): + result = PipelineExecutionResult( + terminal_action="block", + step_results=[ + PipelineStepResult( + guardrail_name="g1", + outcome="fail", + action_taken="next", + ), + PipelineStepResult( + guardrail_name="g2", + outcome="fail", + action_taken="block", + ), + ], + error_message="Content blocked", + ) + assert result.terminal_action == "block" + assert len(result.step_results) == 2 + + +def test_pipeline_step_extra_fields_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="g", unknown_field="value") + + +def test_pipeline_extra_fields_rejected(): + with pytest.raises(ValidationError): + GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="g")], + unknown="value", + ) diff --git a/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py b/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py new file mode 100644 index 00000000000..c23ed5d4319 --- /dev/null +++ b/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py @@ -0,0 +1,102 @@ +""" +Tests for pipeline field on policy CRUD types (resolver_types.py). +""" + +import pytest + +from litellm.types.proxy.policy_engine.resolver_types import ( + PolicyCreateRequest, + PolicyDBResponse, + PolicyUpdateRequest, +) + + +def test_policy_create_request_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "next", "on_pass": "allow"}, + {"guardrail": "g2", "on_fail": "block", "on_pass": "allow"}, + ], + } + req = PolicyCreateRequest( + policy_name="test-policy", + guardrails_add=["g1", "g2"], + pipeline=pipeline_data, + ) + assert req.pipeline is not None + assert req.pipeline["mode"] == "pre_call" + assert len(req.pipeline["steps"]) == 2 + + +def test_policy_create_request_without_pipeline(): + req = PolicyCreateRequest( + policy_name="test-policy", + guardrails_add=["g1"], + ) + assert req.pipeline is None + + +def test_policy_update_request_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "block", "on_pass": "allow"}, + ], + } + req = PolicyUpdateRequest(pipeline=pipeline_data) + assert req.pipeline is not None + assert req.pipeline["steps"][0]["guardrail"] == "g1" + + +def test_policy_db_response_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "next", "on_pass": "allow"}, + {"guardrail": "g2", "on_fail": "block", "on_pass": "allow"}, + ], + } + resp = PolicyDBResponse( + policy_id="test-id", + policy_name="test-policy", + guardrails_add=["g1", "g2"], + pipeline=pipeline_data, + ) + assert resp.pipeline is not None + assert resp.pipeline["mode"] == "pre_call" + dumped = resp.model_dump() + assert dumped["pipeline"]["steps"][0]["guardrail"] == "g1" + + +def test_policy_db_response_without_pipeline(): + resp = PolicyDBResponse( + policy_id="test-id", + policy_name="test-policy", + ) + assert resp.pipeline is None + dumped = resp.model_dump() + assert dumped["pipeline"] is None + + +def test_policy_create_request_roundtrip(): + pipeline_data = { + "mode": "post_call", + "steps": [ + { + "guardrail": "g1", + "on_fail": "modify_response", + "on_pass": "next", + "pass_data": True, + "modify_response_message": "custom msg", + }, + ], + } + req = PolicyCreateRequest( + policy_name="roundtrip-test", + guardrails_add=["g1"], + pipeline=pipeline_data, + ) + dumped = req.model_dump() + restored = PolicyCreateRequest(**dumped) + assert restored.pipeline == pipeline_data diff --git a/ui/litellm-dashboard/e2e_tests/globalSetup.ts b/ui/litellm-dashboard/e2e_tests/globalSetup.ts index a725c58f35b..e37d0bd718b 100644 --- a/ui/litellm-dashboard/e2e_tests/globalSetup.ts +++ b/ui/litellm-dashboard/e2e_tests/globalSetup.ts @@ -8,7 +8,7 @@ async function globalSetup() { await page.goto("http://localhost:4000/ui/login"); await page.getByPlaceholder("Enter your username").fill(users[Role.ProxyAdmin].email); await page.getByPlaceholder("Enter your password").fill(users[Role.ProxyAdmin].password); - const loginButton = page.getByRole("button", { name: "Login" }); + const loginButton = page.getByRole("button", { name: "Login", exact: true }); await loginButton.click(); await page.waitForSelector("text=AI Gateway"); await page.context().storageState({ path: "admin.storageState.json" }); diff --git a/ui/litellm-dashboard/e2e_tests/tests/login/login.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/login/login.spec.ts index 5ac977ff0c8..1d445944712 100644 --- a/ui/litellm-dashboard/e2e_tests/tests/login/login.spec.ts +++ b/ui/litellm-dashboard/e2e_tests/tests/login/login.spec.ts @@ -6,7 +6,7 @@ test("user can log in", async ({ page }) => { await page.goto("http://localhost:4000/ui/login"); await page.getByPlaceholder("Enter your username").fill(users[Role.ProxyAdmin].email); await page.getByPlaceholder("Enter your password").fill(users[Role.ProxyAdmin].password); - const loginButton = page.getByRole("button", { name: "Login" }); + const loginButton = page.getByRole("button", { name: "Login", exact: true }); await expect(loginButton).toBeEnabled(); await loginButton.click(); await expect(page.getByText("AI Gateway")).toBeVisible(); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails.ts new file mode 100644 index 00000000000..c0379b25321 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails.ts @@ -0,0 +1,63 @@ +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import { all_admin_roles } from "@/utils/roles"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { AccessGroupResponse, accessGroupKeys } from "./useAccessGroups"; + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const fetchAccessGroupDetails = async ( + accessToken: string, + accessGroupId: string, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/v1/access_group/${encodeURIComponent(accessGroupId)}`; + + const response = await fetch(url, { + method: "GET", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useAccessGroupDetails = (accessGroupId?: string) => { + const { accessToken, userRole } = useAuthorized(); + const queryClient = useQueryClient(); + + return useQuery({ + queryKey: accessGroupKeys.detail(accessGroupId!), + queryFn: async () => fetchAccessGroupDetails(accessToken!, accessGroupId!), + enabled: + Boolean(accessToken && accessGroupId) && + all_admin_roles.includes(userRole || ""), + + // Seed from the list cache when available + initialData: () => { + if (!accessGroupId) return undefined; + + const groups = queryClient.getQueryData( + accessGroupKeys.list({}), + ); + + return groups?.find((g) => g.access_group_id === accessGroupId); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.test.ts new file mode 100644 index 00000000000..587064353ac --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.test.ts @@ -0,0 +1,242 @@ +/* @vitest-environment jsdom */ +import React from "react"; +import { renderHook, waitFor } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { useAccessGroups, AccessGroupResponse } from "./useAccessGroups"; +import * as networking from "@/components/networking"; + +vi.mock("@/components/networking", () => ({ + getProxyBaseUrl: vi.fn(() => "http://proxy.example"), + getGlobalLitellmHeaderName: vi.fn(() => "Authorization"), + deriveErrorMessage: vi.fn((data: unknown) => (data as { detail?: string })?.detail ?? "Unknown error"), + handleError: vi.fn(), +})); + +vi.mock("@/app/(dashboard)/hooks/useAuthorized", () => ({ + default: vi.fn(() => ({ + accessToken: "test-token-123", + userRole: "Admin", + })), +})); + +const createQueryClient = () => + new QueryClient({ + defaultOptions: { + queries: { + retry: false, + gcTime: 0, + }, + }, + }); + +const wrapper = ({ children }: { children: React.ReactNode }) => { + const queryClient = createQueryClient(); + return React.createElement(QueryClientProvider, { client: queryClient }, children); +}; + +const mockAccessToken = "test-token-123"; +const mockAccessGroups: AccessGroupResponse[] = [ + { + access_group_id: "ag-1", + access_group_name: "Group One", + description: "First group", + access_model_ids: [], + access_mcp_server_ids: [], + access_agent_ids: [], + assigned_team_ids: [], + assigned_key_ids: [], + created_at: "2025-01-01T00:00:00Z", + created_by: "user-1", + updated_at: "2025-01-01T00:00:00Z", + updated_by: "user-1", + }, +]; + +const fetchMock = vi.fn(); + +describe("useAccessGroups", () => { + beforeEach(async () => { + vi.clearAllMocks(); + vi.mocked(networking.getProxyBaseUrl).mockReturnValue("http://proxy.example"); + vi.mocked(networking.getGlobalLitellmHeaderName).mockReturnValue("Authorization"); + + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: mockAccessToken, + userRole: "Admin", + } as any); + + global.fetch = fetchMock; + }); + + it("should return hook result without errors", () => { + fetchMock.mockResolvedValue({ + ok: true, + json: () => Promise.resolve([]), + } as Response); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + expect(result.current).toBeDefined(); + expect(result.current).toHaveProperty("data"); + expect(result.current).toHaveProperty("isSuccess"); + expect(result.current).toHaveProperty("isError"); + expect(result.current).toHaveProperty("status"); + }); + + it("should return access groups when access token and admin role are present", async () => { + fetchMock.mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockAccessGroups), + } as Response); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + await waitFor(() => { + expect(result.current.isSuccess).toBe(true); + }); + + expect(fetchMock).toHaveBeenCalledWith( + "http://proxy.example/v1/access_group", + expect.objectContaining({ + method: "GET", + headers: expect.objectContaining({ + Authorization: `Bearer ${mockAccessToken}`, + "Content-Type": "application/json", + }), + }), + ); + expect(result.current.data).toEqual(mockAccessGroups); + }); + + it("should not fetch when access token is null", async () => { + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: null, + userRole: "Admin", + } as any); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + expect(result.current.isFetching).toBe(false); + expect(result.current.isLoading).toBe(false); + expect(result.current.data).toBeUndefined(); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("should not fetch when access token is empty string", async () => { + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: "", + userRole: "Admin", + } as any); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + expect(result.current.isFetching).toBe(false); + expect(result.current.isLoading).toBe(false); + expect(result.current.data).toBeUndefined(); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("should not fetch when user role is not an admin role", async () => { + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: mockAccessToken, + userRole: "Viewer", + } as any); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + expect(result.current.isFetching).toBe(false); + expect(result.current.isLoading).toBe(false); + expect(result.current.data).toBeUndefined(); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("should not fetch when user role is null", async () => { + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: mockAccessToken, + userRole: null, + } as any); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + expect(result.current.isFetching).toBe(false); + expect(result.current.isLoading).toBe(false); + expect(result.current.data).toBeUndefined(); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("should fetch when user role is proxy_admin", async () => { + const useAuthorizedModule = await import("@/app/(dashboard)/hooks/useAuthorized"); + vi.mocked(useAuthorizedModule.default).mockReturnValue({ + accessToken: mockAccessToken, + userRole: "proxy_admin", + } as any); + + fetchMock.mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockAccessGroups), + } as Response); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + await waitFor(() => { + expect(result.current.isSuccess).toBe(true); + }); + + expect(fetchMock).toHaveBeenCalled(); + expect(result.current.data).toEqual(mockAccessGroups); + }); + + it("should expose error state when fetch fails", async () => { + fetchMock.mockResolvedValue({ + ok: false, + json: () => Promise.resolve({ detail: "Forbidden" }), + } as Response); + vi.mocked(networking.deriveErrorMessage).mockReturnValue("Forbidden"); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + await waitFor(() => { + expect(result.current.isError).toBe(true); + }); + + expect(result.current.error).toBeInstanceOf(Error); + expect((result.current.error as Error).message).toBe("Forbidden"); + expect(result.current.data).toBeUndefined(); + expect(networking.handleError).toHaveBeenCalledWith("Forbidden"); + }); + + it("should return empty array when API returns empty list", async () => { + fetchMock.mockResolvedValue({ + ok: true, + json: () => Promise.resolve([]), + } as Response); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + await waitFor(() => { + expect(result.current.isSuccess).toBe(true); + }); + + expect(result.current.data).toEqual([]); + }); + + it("should propagate network errors", async () => { + const networkError = new Error("Network failure"); + fetchMock.mockRejectedValue(networkError); + + const { result } = renderHook(() => useAccessGroups(), { wrapper }); + + await waitFor(() => { + expect(result.current.isError).toBe(true); + }); + + expect(result.current.error).toEqual(networkError); + expect(result.current.data).toBeUndefined(); + }); +}); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.ts new file mode 100644 index 00000000000..e5d8829278d --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useAccessGroups.ts @@ -0,0 +1,70 @@ +import { useQuery } from "@tanstack/react-query"; +import { createQueryKeys } from "../common/queryKeysFactory"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import { all_admin_roles } from "@/utils/roles"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; + +// ── Types ──────────────────────────────────────────────────────────────────── + +export interface AccessGroupResponse { + access_group_id: string; + access_group_name: string; + description: string | null; + access_model_ids: string[]; + access_mcp_server_ids: string[]; + access_agent_ids: string[]; + assigned_team_ids: string[]; + assigned_key_ids: string[]; + created_at: string; + created_by: string | null; + updated_at: string; + updated_by: string | null; +} + +// ── Query keys (shared across access-group hooks) ──────────────────────────── + +export const accessGroupKeys = createQueryKeys("accessGroups"); + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const fetchAccessGroups = async ( + accessToken: string, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/v1/access_group`; + + const response = await fetch(url, { + method: "GET", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useAccessGroups = () => { + const { accessToken, userRole } = useAuthorized(); + + return useQuery({ + queryKey: accessGroupKeys.list({}), + queryFn: async () => fetchAccessGroups(accessToken!), + enabled: + Boolean(accessToken) && all_admin_roles.includes(userRole || ""), + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useCreateAccessGroup.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useCreateAccessGroup.ts new file mode 100644 index 00000000000..4d71be94455 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useCreateAccessGroup.ts @@ -0,0 +1,68 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { AccessGroupResponse, accessGroupKeys } from "./useAccessGroups"; + +// ── Types ──────────────────────────────────────────────────────────────────── + +export interface AccessGroupCreateParams { + access_group_name: string; + description?: string | null; + access_model_ids?: string[]; + access_mcp_server_ids?: string[]; + access_agent_ids?: string[]; + assigned_team_ids?: string[]; + assigned_key_ids?: string[]; +} + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const createAccessGroup = async ( + accessToken: string, + params: AccessGroupCreateParams, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/v1/access_group`; + + const response = await fetch(url, { + method: "POST", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(params), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useCreateAccessGroup = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (params) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return createAccessGroup(accessToken, params); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: accessGroupKeys.all }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useDeleteAccessGroup.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useDeleteAccessGroup.ts new file mode 100644 index 00000000000..5df5960ce0a --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useDeleteAccessGroup.ts @@ -0,0 +1,55 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { accessGroupKeys } from "./useAccessGroups"; + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const deleteAccessGroup = async ( + accessToken: string, + accessGroupId: string, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/v1/access_group/${encodeURIComponent(accessGroupId)}`; + + const response = await fetch(url, { + method: "DELETE", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + // 204 No Content — nothing to parse +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useDeleteAccessGroup = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (accessGroupId) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return deleteAccessGroup(accessToken, accessGroupId); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: accessGroupKeys.all }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useEditAccessGroup.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useEditAccessGroup.ts new file mode 100644 index 00000000000..1646458c63d --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/accessGroups/useEditAccessGroup.ts @@ -0,0 +1,77 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { + getProxyBaseUrl, + getGlobalLitellmHeaderName, + deriveErrorMessage, + handleError, +} from "@/components/networking"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { AccessGroupResponse, accessGroupKeys } from "./useAccessGroups"; + +// ── Types ──────────────────────────────────────────────────────────────────── + +export interface AccessGroupUpdateParams { + access_group_name?: string; + description?: string | null; + access_model_ids?: string[]; + access_mcp_server_ids?: string[]; + access_agent_ids?: string[]; + assigned_team_ids?: string[]; + assigned_key_ids?: string[]; +} + +export interface EditAccessGroupVariables { + accessGroupId: string; + params: AccessGroupUpdateParams; +} + +// ── Fetch function ─────────────────────────────────────────────────────────── + +const updateAccessGroup = async ( + accessToken: string, + accessGroupId: string, + params: AccessGroupUpdateParams, +): Promise => { + const baseUrl = getProxyBaseUrl(); + const url = `${baseUrl}/v1/access_group/${encodeURIComponent(accessGroupId)}`; + + const response = await fetch(url, { + method: "PUT", + headers: { + [getGlobalLitellmHeaderName()]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(params), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return response.json(); +}; + +// ── Hook ───────────────────────────────────────────────────────────────────── + +export const useEditAccessGroup = () => { + const { accessToken } = useAuthorized(); + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async ({ accessGroupId, params }) => { + if (!accessToken) { + throw new Error("Access token is required"); + } + return updateAccessGroup(accessToken, accessGroupId, params); + }, + onSuccess: (_data, { accessGroupId }) => { + queryClient.invalidateQueries({ queryKey: accessGroupKeys.all }); + queryClient.invalidateQueries({ + queryKey: accessGroupKeys.detail(accessGroupId), + }); + }, + }); +}; diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 28b8d81cc56..ae3bd76e3cf 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -35,6 +35,7 @@ import TransformRequestPanel from "@/components/transform_request"; import UIThemeSettings from "@/components/ui_theme_settings"; import Usage from "@/components/usage"; import UserDashboard from "@/components/user_dashboard"; +import { AccessGroupsPage } from "@/components/AccessGroups/AccessGroupsPage"; import VectorStoreManagement from "@/components/vector_store_management"; import SpendLogsTable from "@/components/view_logs"; import ViewUserDashboard from "@/components/view_users"; @@ -542,6 +543,8 @@ function CreateKeyPageContent() { ) : page == "claude-code-plugins" ? ( + ) : page == "access-groups" ? ( + ) : page == "vector-stores" ? ( ) : page == "new_usage" ? ( diff --git a/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.test.tsx b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.test.tsx new file mode 100644 index 00000000000..db9d25d886f --- /dev/null +++ b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.test.tsx @@ -0,0 +1,384 @@ +import { useAccessGroupDetails } from "@/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails"; +import { AccessGroupResponse } from "@/app/(dashboard)/hooks/accessGroups/useAccessGroups"; +import { screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { renderWithProviders } from "../../../tests/test-utils"; +import { AccessGroupDetail } from "./AccessGroupsDetailsPage"; + +vi.mock("@/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails"); +vi.mock("./AccessGroupsModal/AccessGroupEditModal", () => ({ + AccessGroupEditModal: ({ + visible, + onCancel, + }: { + visible: boolean; + onCancel: () => void; + }) => + visible ? ( +
+ +
+ ) : null, +})); + +const mockUseAccessGroupDetails = vi.mocked(useAccessGroupDetails); + +const baseMockReturnValue = { + data: undefined, + isLoading: false, + isError: false, + error: null, + isFetching: false, + isPending: false, + isSuccess: true, + status: "success" as const, + dataUpdatedAt: 0, + errorUpdatedAt: 0, + failureCount: 0, + failureReason: null, + errorUpdateCount: 0, + isFetched: true, + isFetchedAfterMount: true, + isRefetching: false, + isLoadingError: false, + isPaused: false, + isPlaceholderData: false, + isRefetchError: false, + isStale: false, + fetchStatus: "idle" as const, + refetch: vi.fn(), +} as unknown as ReturnType; + +const createMockAccessGroup = ( + overrides: Partial = {} +): AccessGroupResponse => ({ + access_group_id: "ag-1", + access_group_name: "Test Group", + description: "A test access group", + access_model_ids: ["model-1", "model-2"], + access_mcp_server_ids: ["mcp-1"], + access_agent_ids: ["agent-1"], + assigned_team_ids: ["team-1"], + assigned_key_ids: ["key-1", "key-2"], + created_at: "2025-01-01T00:00:00Z", + created_by: null, + updated_at: "2025-01-02T00:00:00Z", + updated_by: null, + ...overrides, +}); + +describe("AccessGroupDetail", () => { + const mockOnBack = vi.fn(); + const accessGroupId = "ag-1"; + + beforeEach(() => { + vi.clearAllMocks(); + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup(), + } as ReturnType); + }); + + it("should render the component", () => { + renderWithProviders( + + ); + expect(screen.getByRole("heading", { name: "Test Group" })).toBeInTheDocument(); + }); + + it("should not show access group content when loading", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: undefined, + isLoading: true, + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.queryByRole("heading", { name: "Test Group" })).not.toBeInTheDocument(); + }); + + it("should show empty state when access group is not found", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: undefined, + isLoading: false, + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText("Access group not found")).toBeInTheDocument(); + expect(screen.getByRole("button")).toBeInTheDocument(); + }); + + it("should call onBack when back button is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders( + + ); + + const buttons = screen.getAllByRole("button"); + const backButton = buttons.find((btn) => !btn.textContent?.includes("Edit")); + await user.click(backButton!); + + expect(mockOnBack).toHaveBeenCalledTimes(1); + }); + + it("should display access group name and ID", () => { + renderWithProviders( + + ); + + expect(screen.getByRole("heading", { name: "Test Group" })).toBeInTheDocument(); + expect(screen.getByText(/ID:/)).toBeInTheDocument(); + }); + + it("should display description in Group Details", () => { + renderWithProviders( + + ); + + expect(screen.getByText("Group Details")).toBeInTheDocument(); + expect(screen.getByText("A test access group")).toBeInTheDocument(); + }); + + it("should display em dash when description is empty", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ description: null }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText("—")).toBeInTheDocument(); + }); + + it("should open edit modal when Edit Access Group button is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders( + + ); + + expect(screen.queryByRole("dialog", { name: "Edit Access Group" })).not.toBeInTheDocument(); + + const editButton = screen.getByRole("button", { name: /Edit Access Group/i }); + await user.click(editButton); + + expect(screen.getByRole("dialog", { name: "Edit Access Group" })).toBeInTheDocument(); + }); + + it("should close edit modal when Close Modal is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders( + + ); + + await user.click(screen.getByRole("button", { name: /Edit Access Group/i })); + expect(screen.getByRole("dialog", { name: "Edit Access Group" })).toBeInTheDocument(); + + await user.click(screen.getByRole("button", { name: "Close Modal" })); + expect(screen.queryByRole("dialog", { name: "Edit Access Group" })).not.toBeInTheDocument(); + }); + + it("should display attached keys", () => { + renderWithProviders( + + ); + + expect(screen.getByText("Attached Keys")).toBeInTheDocument(); + expect(screen.getByText("key-1")).toBeInTheDocument(); + expect(screen.getByText("key-2")).toBeInTheDocument(); + }); + + it("should display attached teams", () => { + renderWithProviders( + + ); + + expect(screen.getByText("Attached Teams")).toBeInTheDocument(); + expect(screen.getByText("team-1")).toBeInTheDocument(); + }); + + it("should show View All button for keys when more than 5", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ + assigned_key_ids: ["k1", "k2", "k3", "k4", "k5", "k6"], + }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByRole("button", { name: "View All (6)" })).toBeInTheDocument(); + }); + + it("should toggle between View All and Show Less for keys", async () => { + const user = userEvent.setup(); + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ + assigned_key_ids: ["k1", "k2", "k3", "k4", "k5", "k6"], + }), + } as ReturnType); + + renderWithProviders( + + ); + + await user.click(screen.getByRole("button", { name: "View All (6)" })); + expect(screen.getByRole("button", { name: "Show Less" })).toBeInTheDocument(); + + await user.click(screen.getByRole("button", { name: "Show Less" })); + expect(screen.getByRole("button", { name: "View All (6)" })).toBeInTheDocument(); + }); + + it("should show View All button for teams when more than 5", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ + assigned_team_ids: ["t1", "t2", "t3", "t4", "t5", "t6"], + }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByRole("button", { name: "View All (6)" })).toBeInTheDocument(); + }); + + it("should show empty state when no keys attached", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ assigned_key_ids: [] }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText("No keys attached")).toBeInTheDocument(); + }); + + it("should show empty state when no teams attached", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ assigned_team_ids: [] }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText("No teams attached")).toBeInTheDocument(); + }); + + it("should display Models tab with model IDs", () => { + renderWithProviders( + + ); + + expect(screen.getByRole("tab", { name: /Models/i })).toBeInTheDocument(); + expect(screen.getByText("model-1")).toBeInTheDocument(); + expect(screen.getByText("model-2")).toBeInTheDocument(); + }); + + it("should display MCP Servers tab with server IDs", async () => { + const user = userEvent.setup(); + renderWithProviders( + + ); + + const mcpTab = screen.getByRole("tab", { name: /MCP Servers/i }); + expect(mcpTab).toBeInTheDocument(); + await user.click(mcpTab); + expect(screen.getByText("mcp-1")).toBeInTheDocument(); + }); + + it("should display Agents tab with agent IDs", async () => { + const user = userEvent.setup(); + renderWithProviders( + + ); + + const agentsTab = screen.getByRole("tab", { name: /Agents/i }); + expect(agentsTab).toBeInTheDocument(); + await user.click(agentsTab); + expect(screen.getByText("agent-1")).toBeInTheDocument(); + }); + + it("should show empty state in Models tab when no models assigned", () => { + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ access_model_ids: [] }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText("No models assigned to this group")).toBeInTheDocument(); + }); + + it("should show empty state in MCP Servers tab when none assigned", async () => { + const user = userEvent.setup(); + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ access_mcp_server_ids: [] }), + } as ReturnType); + + renderWithProviders( + + ); + + await user.click(screen.getByRole("tab", { name: /MCP Servers/i })); + expect(screen.getByText("No MCP servers assigned to this group")).toBeInTheDocument(); + }); + + it("should show empty state in Agents tab when none assigned", async () => { + const user = userEvent.setup(); + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ access_agent_ids: [] }), + } as ReturnType); + + renderWithProviders( + + ); + + await user.click(screen.getByRole("tab", { name: /Agents/i })); + expect(screen.getByText("No agents assigned to this group")).toBeInTheDocument(); + }); + + it("should truncate long key IDs with ellipsis", () => { + const longKeyId = "a".repeat(25); + mockUseAccessGroupDetails.mockReturnValue({ + ...baseMockReturnValue, + data: createMockAccessGroup({ assigned_key_ids: [longKeyId] }), + } as ReturnType); + + renderWithProviders( + + ); + + expect(screen.getByText(/a{10}\.\.\.a{6}/)).toBeInTheDocument(); + }); + + it("should display created and last updated timestamps", () => { + renderWithProviders( + + ); + + expect(screen.getByText("Created")).toBeInTheDocument(); + expect(screen.getByText("Last Updated")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.tsx b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.tsx new file mode 100644 index 00000000000..9b794959baa --- /dev/null +++ b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsDetailsPage.tsx @@ -0,0 +1,345 @@ +import { useAccessGroupDetails } from "@/app/(dashboard)/hooks/accessGroups/useAccessGroupDetails"; +import { + Button, + Card, + Col, + Descriptions, + Empty, + Flex, + Layout, + List, + Row, + Spin, + Tabs, + Tag, + theme, + Typography +} from "antd"; +import { + ArrowLeftIcon, + BotIcon, + EditIcon, + KeyIcon, + LayersIcon, + ServerIcon, + UsersIcon, +} from "lucide-react"; +import { useState } from "react"; +import DefaultProxyAdminTag from "../common_components/DefaultProxyAdminTag"; +import { AccessGroupEditModal } from "./AccessGroupsModal/AccessGroupEditModal"; + +const { Title, Text } = Typography; +const { Content } = Layout; + +interface AccessGroupDetailProps { + accessGroupId: string; + onBack: () => void; +} + +export function AccessGroupDetail({ + accessGroupId, + onBack, +}: AccessGroupDetailProps) { + const { data: accessGroup, isLoading } = + useAccessGroupDetails(accessGroupId); + const { token } = theme.useToken(); + const [isEditModalVisible, setIsEditModalVisible] = useState(false); + const [showAllKeys, setShowAllKeys] = useState(false); + const [showAllTeams, setShowAllTeams] = useState(false); + + const MAX_PREVIEW = 5; + + if (isLoading) { + return ( + + + + + + ); + } + + if (!accessGroup) { + return ( + + + + + {/* Group Details */} + + + + + {accessGroup.description || "—"} + + + {new Date(accessGroup.created_at).toLocaleString()} + {accessGroup.created_by && ( + +  {"by"}  + + + )} + + + {new Date(accessGroup.updated_at).toLocaleString()} + {accessGroup.updated_by && ( + +  {"by"}  + + + )} + + + + + + {/* Attached Keys & Teams */} + + + + + Attached Keys + {keyIds.length} + + } + extra={ + keyIds.length > MAX_PREVIEW ? ( + + ) : null + } + > + {keyIds.length > 0 ? ( + + {displayedKeys.map((id) => ( + + + {id.length > 20 + ? `${id.slice(0, 10)}...${id.slice(-6)}` + : id} + + + ))} + + ) : ( + + )} + + + + + + Attached Teams + {teamIds.length} + + } + extra={ + teamIds.length > MAX_PREVIEW ? ( + + ) : null + } + > + {teamIds.length > 0 ? ( + + {displayedTeams.map((id) => ( + + + {id} + + + ))} + + ) : ( + + )} + + + + + {/* Resources Tabs */} + + + + + {/* Edit Modal */} + setIsEditModalVisible(false)} + /> + + ); +} diff --git a/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsModal/AccessGroupBaseForm.tsx b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsModal/AccessGroupBaseForm.tsx new file mode 100644 index 00000000000..df60457571e --- /dev/null +++ b/ui/litellm-dashboard/src/components/AccessGroups/AccessGroupsModal/AccessGroupBaseForm.tsx @@ -0,0 +1,159 @@ +import { useAgents } from "@/app/(dashboard)/hooks/agents/useAgents"; +import { useMCPServers } from "@/app/(dashboard)/hooks/mcpServers/useMCPServers"; +import { ModelSelect } from "@/components/ModelSelect/ModelSelect"; +import type { FormInstance } from "antd"; +import { Form, Input, Select, Space, Tabs } from "antd"; +import { BotIcon, InfoIcon, LayersIcon, ServerIcon } from "lucide-react"; + +const { TextArea } = Input; + +export interface AccessGroupFormValues { + name: string; + description: string; + modelIds: string[]; + mcpServerIds: string[]; + agentIds: string[]; +} + +interface AccessGroupBaseFormProps { + form: FormInstance; + isNameDisabled?: boolean; +} + +export function AccessGroupBaseForm({ + form, + isNameDisabled = false, +}: AccessGroupBaseFormProps) { + const { data: agentsData } = useAgents(); + const { data: mcpServersData } = useMCPServers(); + + const agents = agentsData?.agents ?? []; + const mcpServers = mcpServersData ?? []; + const items = [ + { + key: "1", + label: ( + + + General Info + + ), + children: ( +
+ + + + +