diff --git a/.circleci/config.yml b/.circleci/config.yml index 34c3f05cd25..6c7bbddb9f1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -112,6 +112,24 @@ jobs: python -m mypy . cd .. no_output_timeout: 10m + + semgrep: + docker: + - image: cimg/python:3.12 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + steps: + - checkout + - setup_google_dns + - run: + name: Install Semgrep + command: pip install semgrep + - run: + name: Run Semgrep (custom rules only) + command: semgrep scan --config .semgrep/rules . --error + local_testing_part1: docker: - image: cimg/python:3.12 @@ -3932,6 +3950,9 @@ jobs: image: ubuntu-2204:2023.10.1 resource_class: xlarge working_directory: ~/project + parameters: + browser: + type: string steps: - checkout - setup_google_dns @@ -3961,7 +3982,7 @@ jobs: echo "Expires at: $EXPIRES_AT" neon branches create \ --project-id $NEON_PROJECT_ID \ - --name preview/commit-${CIRCLE_SHA1:0:7} \ + --name preview/commit-${CIRCLE_SHA1:0:7}-<< parameters.browser >> \ --expires-at $EXPIRES_AT \ --parent br-fancy-paper-ad1olsb3 \ --api-key $NEON_API_KEY || true @@ -3971,7 +3992,7 @@ jobs: E2E_UI_TEST_DATABASE_URL=$(neon connection-string \ --project-id $NEON_PROJECT_ID \ --api-key $NEON_API_KEY \ - --branch preview/commit-${CIRCLE_SHA1:0:7} \ + --branch preview/commit-${CIRCLE_SHA1:0:7}-<< parameters.browser >> \ --database-name yuneng-trial-db \ --role neondb_owner) echo $E2E_UI_TEST_DATABASE_URL @@ -3983,7 +4004,7 @@ jobs: -e UI_USERNAME="admin" \ -e UI_PASSWORD="gm" \ -e LITELLM_LICENSE=$LITELLM_LICENSE \ - --name litellm-docker-database \ + --name litellm-docker-database-<< parameters.browser >> \ -v $(pwd)/litellm/proxy/example_config_yaml/simple_config.yaml:/app/config.yaml \ litellm-docker-database:ci \ --config /app/config.yaml \ @@ -3999,7 +4020,7 @@ jobs: sudo rm dockerize-linux-amd64-v0.6.1.tar.gz - run: name: Start outputting logs - command: docker logs -f litellm-docker-database + command: docker logs -f litellm-docker-database-<< parameters.browser >> background: true - run: name: Wait for app to be ready @@ -4008,6 +4029,7 @@ jobs: name: Run Playwright Tests command: | npx playwright test \ + --project << parameters.browser >> \ --config ui/litellm-dashboard/e2e_tests/playwright.config.ts \ --reporter=html \ --output=test-results @@ -4114,6 +4136,12 @@ workflows: only: - main - /litellm_.*/ + - semgrep: + filters: + branches: + only: + - main + - /litellm_.*/ - local_testing_part1: filters: branches: @@ -4213,6 +4241,20 @@ workflows: - main - /litellm_.*/ - e2e_ui_testing: + name: e2e_ui_testing_chromium + browser: chromium + context: e2e_ui_tests + requires: + - ui_build + - build_docker_database_image + filters: + branches: + only: + - main + - /litellm_.*/ + - e2e_ui_testing: + name: e2e_ui_testing_firefox + browser: firefox context: e2e_ui_tests requires: - ui_build @@ -4492,6 +4534,7 @@ workflows: - publish_to_pypi: requires: - mypy_linting + - semgrep - local_testing_part1 - local_testing_part2 - build_and_test @@ -4524,7 +4567,8 @@ workflows: - litellm_assistants_api_testing - auth_ui_unit_tests - db_migration_disable_update_check - - e2e_ui_testing + - e2e_ui_testing_chromium + - e2e_ui_testing_firefox - litellm_proxy_unit_testing_key_generation - litellm_proxy_unit_testing_part1 - litellm_proxy_unit_testing_part2 diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index b91b16c955c..f13039f4516 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -9,6 +9,7 @@ - [ ] I have Added testing in the [`tests/litellm/`](https://github.com/BerriAI/litellm/tree/main/tests/litellm) directory, **Adding at least 1 test is a hard requirement** - [see details](https://docs.litellm.ai/docs/extras/contributing_code) - [ ] My PR passes all unit tests on [`make test-unit`](https://docs.litellm.ai/docs/extras/contributing_code) - [ ] My PR's scope is as isolated as possible, it only solves 1 specific problem +- [ ] I have requested a Greptile review by commenting `@greptileai` and received a **Confidence Score of at least 4/5** before requesting a maintainer review ## CI (LiteLLM team) diff --git a/.semgrep/rules/README.md b/.semgrep/rules/README.md new file mode 100644 index 00000000000..0dbb77cdd48 --- /dev/null +++ b/.semgrep/rules/README.md @@ -0,0 +1,22 @@ +# Custom Semgrep rules for LiteLLM + +Add custom rule YAML files here. Semgrep loads all `.yml`/`.yaml` files under this directory. + +**Run only custom rules (CI / fail on findings):** + +```bash +semgrep scan --config .semgrep/rules . --error +``` + +**Run with registry + custom rules:** + +```bash +semgrep scan --config auto --config .semgrep/rules . +``` + +**Layout:** + +- `python/` – Python-specific rules (security, patterns) +- Add more subdirs as needed (e.g. `generic/` for language-agnostic rules) + +See [Semgrep rule syntax](https://semgrep.dev/docs/writing-rules/rule-syntax/). diff --git a/.semgrep/rules/python/unbounded-memory.yml b/.semgrep/rules/python/unbounded-memory.yml new file mode 100644 index 00000000000..811ef689344 --- /dev/null +++ b/.semgrep/rules/python/unbounded-memory.yml @@ -0,0 +1,14 @@ +# Unbounded memory growth – data structures without a clear max limit +# Can lead to OOM under load. + +rules: + - id: unbounded-asyncio-queue + message: asyncio.Queue() with no maxsize can grow unbounded. Use asyncio.Queue(maxsize=N) for integrations (e.g. log queues). + severity: ERROR + languages: [python] + pattern-either: + - pattern: asyncio.Queue() + - pattern: asyncio.Queue(maxsize=0) + metadata: + category: correctness + cwe: "CWE-400: Uncontrolled Resource Consumption" \ No newline at end of file diff --git a/cookbook/nova_sonic_realtime.py b/cookbook/nova_sonic_realtime.py index 0ea0badfb01..c7a73c1d00f 100644 --- a/cookbook/nova_sonic_realtime.py +++ b/cookbook/nova_sonic_realtime.py @@ -16,10 +16,14 @@ import asyncio import base64 import json +import os import pyaudio import websockets from typing import Optional +# Bounded queue size for audio chunks (configurable via env to avoid unbounded memory) +AUDIO_QUEUE_MAXSIZE = int(os.getenv("LITELLM_ASYNCIO_QUEUE_MAXSIZE", 10_000)) + # Audio configuration (matching Nova Sonic requirements) INPUT_SAMPLE_RATE = 16000 # Nova Sonic expects 16kHz input OUTPUT_SAMPLE_RATE = 24000 # Nova Sonic outputs 24kHz @@ -40,7 +44,7 @@ def __init__(self, url: str, api_key: str): self.api_key = api_key self.ws: Optional[websockets.WebSocketClientProtocol] = None self.is_active = False - self.audio_queue = asyncio.Queue() + self.audio_queue = asyncio.Queue(maxsize=AUDIO_QUEUE_MAXSIZE) self.pyaudio = pyaudio.PyAudio() self.input_stream = None self.output_stream = None diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 5b6c6669b91..a6d4e3102a6 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -395,7 +395,7 @@ router_settings: | ATHINA_API_KEY | API key for Athina service | ATHINA_BASE_URL | Base URL for Athina service (defaults to `https://log.athina.ai`) | AUTH_STRATEGY | Strategy used for authentication (e.g., OAuth, API key) -| AUTO_REDIRECT_UI_LOGIN_TO_SSO | Flag to enable automatic redirect of UI login page to SSO when SSO is configured. Default is **true** +| AUTO_REDIRECT_UI_LOGIN_TO_SSO | Flag to enable automatic redirect of UI login page to SSO when SSO is configured. Default is **false** | AUDIO_SPEECH_CHUNK_SIZE | Chunk size for audio speech processing. Default is 1024 | ANTHROPIC_API_KEY | API key for Anthropic service | ANTHROPIC_API_BASE | Base URL for Anthropic API. Default is https://api.anthropic.com @@ -784,6 +784,7 @@ router_settings: | LITELLM_USER_AGENT | Custom user agent string for LiteLLM API requests. Used for partner telemetry attribution | LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging | LITELM_ENVIRONMENT | Environment for LiteLLM Instance. This is currently only logged to DeepEval to determine the environment for DeepEval integration. +| LITELLM_ASYNCIO_QUEUE_MAXSIZE | Maximum size for asyncio queues (e.g. log queues, spend update queues, and cookbook examples such as realtime audio in `nova_sonic_realtime.py`). Bounds in-memory growth to prevent OOM. Default is 1000. | LOGFIRE_TOKEN | Token for Logfire logging service | LOGFIRE_BASE_URL | Base URL for Logfire logging service (useful for self hosted deployments) | LOGGING_WORKER_CONCURRENCY | Maximum number of concurrent coroutine slots for the logging worker on the asyncio event loop. Default is 100. Setting too high will flood the event loop with logging tasks which will lower the overall latency of the requests. diff --git a/docs/my-website/docs/proxy/guardrails/guardrail_policies.md b/docs/my-website/docs/proxy/guardrails/guardrail_policies.md index 56be11c85a7..e2cb839203e 100644 --- a/docs/my-website/docs/proxy/guardrails/guardrail_policies.md +++ b/docs/my-website/docs/proxy/guardrails/guardrail_policies.md @@ -1,3 +1,7 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # [Beta] Guardrail Policies Use policies to group guardrails and control which ones run for specific teams, keys, or models. @@ -10,6 +14,9 @@ Use policies to group guardrails and control which ones run for specific teams, ## Quick Start + + + ```yaml showLineNumbers title="config.yaml" model_list: - model_name: gpt-4 @@ -43,6 +50,26 @@ policy_attachments: scope: "*" # apply to all requests ``` + + + +**Step 1: Create a Policy** + +Go to **Policies** tab and click **+ Create New Policy**. Fill in the policy name, description, and select guardrails to add. + +![Enter policy name](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/4ba62cc8-d2c4-4af1-a526-686295466928/ascreenshot_401eab3e2081466e8f4d4ffa3bf7bff4_text_export.jpeg) + +![Add a description for the policy](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/51685e47-1d94-4d9c-acb0-3c88dce9f938/ascreenshot_a5cd40066ff34afbb1e4089a3c93d889_text_export.jpeg) + +![Select a parent policy to inherit from](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/1d96c3d3-187a-4f7c-97d2-6ac1f093d51e/ascreenshot_8a3af3b2210547dca3d4709df920d005_text_export.jpeg) + +![Select guardrails to add to the policy](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/23781274-e600-4d5f-a8a6-4a2a977a166c/ascreenshot_a2a45d2c5d064c77ab7cb47b569ad9e9_text_export.jpeg) + +![Click Create Policy to save](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/1d1ae8a8-daa5-451b-9fa2-c5b607ff6220/ascreenshot_218c2dd259714be4aa3c4e1894c96878_text_export.jpeg) + + + + Response headers show what ran: ``` @@ -58,6 +85,9 @@ x-litellm-applied-guardrails: pii_masking,prompt_injection You have a global baseline, but want to add extra guardrails for a specific team. + + + ```yaml showLineNumbers title="config.yaml" policies: global-baseline: @@ -81,6 +111,30 @@ policy_attachments: - finance # team alias from /team/new ``` + + + +**Option 1: Create a team-scoped attachment** + +Go to **Policies** > **Attachments** tab and click **+ Create New Attachment**. Select the policy and the teams to scope it to. + +![Select teams for the attachment](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/50e58f54-3bc3-477e-a106-e58cb65fde7e/ascreenshot_85d2e3d9d8d24842baced92fea170427_text_export.jpeg) + +![Select the teams to attach the policy to](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/f24066bb-0a73-49fb-87b6-c65ad3ca5b2f/ascreenshot_242476fbdac447309f65de78b0ed9fdd_text_export.jpeg) + +**Option 2: Attach from team settings** + +Go to **Teams** > click on a team > **Settings** tab > under **Policies**, select the policies to attach. + +![Open team settings and click Edit Settings](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/c31c3735-4f9d-4c6a-896b-186e97296940/ascreenshot_4749bb24ce5942cca462acc958fd3822_text_export.jpeg) + +![Select policies to attach to this team](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/da8d5d7a-d975-4bfe-acd2-f41dcea29520/ascreenshot_835a33b6cec545cbb2987f017fbaff90_text_export.jpeg) + + + + + + Now the `finance` team gets `pii_masking` + `strict_compliance_check` + `audit_logger`, while everyone else just gets `pii_masking`. ## Remove guardrails for a specific team @@ -201,6 +255,60 @@ policy_attachments: - "test-*" # key alias pattern ``` +**Tag-based** (matches keys/teams by metadata tags, wildcards supported): + +```yaml showLineNumbers title="config.yaml" +policy_attachments: + - policy: hipaa-compliance + tags: + - "healthcare" + - "health-*" # wildcard - matches health-team, health-dev, etc. +``` + +Tags are read from key and team `metadata.tags`. For example, a key created with `metadata: {"tags": ["healthcare"]}` would match the attachment above. + +## Test Policy Matching + +Debug which policies and guardrails apply for a given context. Use this to verify your policy configuration before deploying. + + + + +Go to **Policies** > **Test** tab. Enter a team alias, key alias, model, or tags and click **Test** to see which policies match and what guardrails would be applied. + + + + + + +```bash +curl -X POST "http://localhost:4000/policies/resolve" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "tags": ["healthcare"], + "model": "gpt-4" + }' +``` + +Response: + +```json +{ + "effective_guardrails": ["pii_masking"], + "matched_policies": [ + { + "policy_name": "hipaa-compliance", + "matched_via": "tag:healthcare", + "guardrails_added": ["pii_masking"] + } + ] +} +``` + + + + ## Config Reference ### `policies` @@ -233,14 +341,18 @@ policy_attachments: scope: ... teams: [...] keys: [...] + models: [...] + tags: [...] ``` | Field | Type | Description | |-------|------|-------------| | `policy` | `string` | **Required.** Name of the policy to attach. | | `scope` | `string` | Use `"*"` to apply globally. | -| `teams` | `list[string]` | Team aliases (from `/team/new`). | +| `teams` | `list[string]` | Team aliases (from `/team/new`). Supports `*` wildcard. | | `keys` | `list[string]` | Key aliases (from `/key/generate`). Supports `*` wildcard. | +| `models` | `list[string]` | Model names. Supports `*` wildcard. | +| `tags` | `list[string]` | Tag patterns (from key/team `metadata.tags`). Supports `*` wildcard. | ### Response Headers @@ -248,6 +360,7 @@ policy_attachments: |--------|-------------| | `x-litellm-applied-policies` | Policies that matched this request | | `x-litellm-applied-guardrails` | Guardrails that actually ran | +| `x-litellm-policy-sources` | Why each policy matched (e.g., `hipaa=tag:healthcare; baseline=scope:*`) | ## How it works diff --git a/docs/my-website/docs/proxy/guardrails/policy_tags.md b/docs/my-website/docs/proxy/guardrails/policy_tags.md new file mode 100644 index 00000000000..11840116c31 --- /dev/null +++ b/docs/my-website/docs/proxy/guardrails/policy_tags.md @@ -0,0 +1,139 @@ +# Tag-Based Policy Attachments + +Apply guardrail policies automatically to any key or team that has a specific tag. Instead of attaching policies one-by-one, tag your keys and let the policy engine handle the rest. + +**Example:** Your security team requires all healthcare-related keys to run PII masking and PHI detection. Tag those keys with `health`, create a single tag-based attachment, and every matching key gets the guardrails automatically. + +## 1. Create a Policy with Guardrails + +Navigate to **Policies** in the left sidebar. You'll see a list of existing policies along with their guardrails. + +![Policies list page showing existing policies and the + Add New Policy button](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/d7aa1e1f-011e-40bf-a356-6dfe9d5d54f1/ascreenshot_8db95c231a7f4a79a36c2a98ba127542_text_export.jpeg) + +Click **+ Add New Policy**. In the modal, enter a name for your policy (e.g., `high-risk-policy2`). You can also type to search existing policy names if you want to reference them. + +![Create New Policy modal — enter the policy name and optional description](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/18f1ff69-9b83-4a98-9aad-9892a104d3ff/ascreenshot_1c6b85231cad4ec695750b53bbbda52c_text_export.jpeg) + +Scroll down to **Guardrails to Add**. Click the dropdown to see all available guardrails configured on your proxy — select the ones this policy should enforce. + +![Guardrails to Add dropdown showing available guardrails like OAI-moderation, phi-pre-guard, pii-pre-guard](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/55cedad7-9939-44a1-8644-a184cde82ab7/ascreenshot_eab4e55b82b8411893eccb6234d60b82_text_export.jpeg) + +After selecting your guardrails, they appear as chips in the input field. The **Resolved Guardrails** section below shows the final set that will be applied (including any inherited from a parent policy). + +![Selected guardrails shown as chips: testing-pl, phi-pre-guard, pii-pre-guard. Resolved Guardrails preview below.](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/c06d5b08-1c85-4715-b827-3e6864880428/ascreenshot_7a082e55f3ad425f9009346c68afae23_text_export.jpeg) + +Click **Create Policy** to save. + +![Click Create Policy to save the new policy](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/7e6eae64-4bba-4d72-b226-d1308ac576a8/ascreenshot_22d0ed686c594221bbbd2f40df214d75_text_export.jpeg) + +## 2. Add a Tag Attachment for the Policy + +After creating the policy, switch to the **Attachments** tab. This is where you define *where* the policy applies. + +![Switch to the Attachments tab — shows the attachment table and scope documentation](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/871ae6d9-16d1-44e2-baf2-7bb8a9e72087/ascreenshot_76e124619d70462ea0e2fbb46ded1ac9_text_export.jpeg) + +Click **+ Add New Attachment**. The Attachments page explains the available scopes: Global, Teams, Keys, Models, and **Tags**. + +![Attachments page showing scope types including Tags — click + Add New Attachment](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/d45ab8bc-fc1e-425b-8a3f-44d18df810ec/ascreenshot_425824030f3144b7ab3c0ac570349b00_text_export.jpeg) + +In the **Create Policy Attachment** modal, first select the policy you just created from the dropdown. + +![Select the policy to attach from the dropdown (e.g., high-risk-policy2)](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/e0dcac40-e39c-4a6a-9d9c-4bbb9ec0ee91/ascreenshot_445b19894e0b466196a13e20c8e67f2d_text_export.jpeg) + +Choose **Specific (teams, keys, models, or tags)** as the scope type. This expands the form to show fields for Teams, Keys, Models, and Tags. + +![Select "Specific" scope type to reveal the Tags field](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/f685e02a-e22e-4c6c-9742-d5268746214b/ascreenshot_14d63d9d06dd4fc7854cfeb5e8d9ef85_text_export.jpeg) + +Scroll down to the **Tags** field and type the tag to match — here we enter `health`. You can enter any string, or use a wildcard pattern like `health-*` to match all tags starting with `health-` (e.g., `health-team`, `health-dev`). + +![Tags field with "health" entered. Supports wildcards like prod-* matching prod-us, prod-eu.](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/14581df7-732c-4ea5-b36d-58270b00e92c/ascreenshot_e734c81418f046549b61a84b9d352a29_text_export.jpeg) + +## 3. Check the Impact of the Attachment + +Before creating the attachment, click **Estimate Impact** to preview how many keys and teams would be affected. This is your blast-radius check — make sure the scope is what you expect before applying. + +![Click Estimate Impact — the tag "health" is entered and ready to preview](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/6ccb81d7-3d11-48b0-b634-fc4d738aa530/ascreenshot_2eb89e6ff13a4b12b61004660a36c30c_text_export.jpeg) + +The **Impact Preview** appears inline, showing exactly how many keys and teams would be affected. In this example: "This attachment would affect **1 key** and **0 teams**", with the key alias `hi` listed. + +![Impact Preview showing "This attachment would affect 1 key and 0 teams." Keys: hi](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/8834d85a-2c15-48dd-8d6b-810cf11ee5c4/ascreenshot_d814b42ca9f34c23b0c2269bfa3e64fb_text_export.jpeg) + +Once you're satisfied with the impact, click **Create Attachment** to save. + +![Click Create Attachment to finalize](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/4a8918f2-eedb-4f49-a53b-4e46d0387d2a/ascreenshot_b08d490d836d4f46b4e5cbb14f61377a_text_export.jpeg) + +The attachment now appears in the table with the policy name `high-risk-policy2` and tag `health` visible. + +![Attachments table showing the new attachment with policy high-risk-policy2 and tag "health"](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/45867887-0aec-44a4-963b-b6cc6c302e3e/ascreenshot_981caeff98574ec89a8a53cd295e5043_text_export.jpeg) + +## 4. Create a Key with the Tag + +Navigate to **Virtual Keys** in the left sidebar. Click **+ Create New Key**. + +![Virtual Keys page showing existing keys — click + Create New Key](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/4c1f9448-e590-4546-9357-6f68aa395b27/ascreenshot_4a7bc5be9e4347f3a9fe46f78d938d7c_text_export.jpeg) + +Enter a key name and select a model. Then expand **Optional Settings** and scroll down to the **Tags** field. + +![Create New Key modal — enter the key name](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/f84f7a2b-8057-4926-9f80-d68e437c77cf/ascreenshot_a277c8611b6e41059663b0759cd85cab_text_export.jpeg) + +In the **Tags** field, type `health` and press Enter. This is the tag the policy engine will match against. + +![Tags field in key creation — type "health" to add the tag](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/3ad3bf10-76d2-4f15-9a66-ed6c99bb25c4/ascreenshot_8a8773fb65fc49329cb1716da92b2723_text_export.jpeg) + +The tag `health` now appears as a chip in the Tags field. Confirm your settings look correct. + +![Tags field showing "health" selected with a checkmark](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/de3e58a9-6013-4d0c-882e-5517ea286684/ascreenshot_c7eef1736fce4aa894ac3b118b3800a2_text_export.jpeg) + +Click **Create Key** at the bottom of the form. + +![Click Create Key to generate the new virtual key with the health tag](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/51d419ea-ee80-4e24-8e93-b99a844881bc/ascreenshot_097d4564289943a88e30b5d2e3eab262_text_export.jpeg) + +A dialog appears with your new virtual key. Click **Copy Virtual Key** — you'll need this to test in the next step. + +![Save your Key dialog — click Copy Virtual Key to copy it to clipboard](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/e87a0cc1-4d12-4066-bfa2-973159808fd1/ascreenshot_7b616a7291d0497a9c61bdcdb59394d7_text_export.jpeg) + +## 5. Test the Key and Validate the Policy is Applied + +Navigate to **Playground** in the left sidebar to test the key interactively. + +![Navigate to Playground from the sidebar](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/e6f8a3ee-e9e8-4107-93d1-bfca734c5ce9/ascreenshot_539bde38abe646e49148a912fff2d257_text_export.jpeg) + +Under **Virtual Key Source**, select "Virtual Key" and paste the key you just copied into the input field. + +![Paste the virtual key into the Playground configuration](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/a6612c4a-d499-4e54-8019-f54fde674ad9/ascreenshot_e85ebb9051554594bab0da57823fafad_text_export.jpeg) + +Select a model from the **Select Model** dropdown. + +![Select a model (e.g., bedrock-claude-opus-4.5) from the dropdown](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/325e330f-3eff-4c5e-b177-21916138a2f5/ascreenshot_693478f89c034e949e08f3ed0dd05120_text_export.jpeg) + +Type a message and press Enter. If a guardrail blocks the request, you'll see it in the response. In this example, the `testing-pl` guardrail detected an email pattern and returned a 403 error — confirming the policy is working. + +![Guardrail in action — the request was blocked with "Content blocked: email pattern detected"](https://colony-recorder.s3.amazonaws.com/files/2026-02-11/2cf16809-d2e5-4eae-a7dd-6a16dfcca7ce/ascreenshot_727d7d4ed20b4a52b2b41e39fd36eccb_text_export.jpeg) + +**Using curl:** + +You can also verify via the command line. The response headers confirm which policies and guardrails were applied: + +```bash +curl -v http://localhost:4000/chat/completions \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "say hi"}] + }' +``` + +Check the response headers: + +``` +x-litellm-applied-policies: high-risk-policy2 +x-litellm-applied-guardrails: pii-pre-guard,phi-pre-guard,testing-pl +x-litellm-policy-sources: high-risk-policy2=tag:health +``` + +| Header | What it tells you | +|--------|-------------------| +| `x-litellm-applied-policies` | Which policies matched this request | +| `x-litellm-applied-guardrails` | Which guardrails actually ran | +| `x-litellm-policy-sources` | **Why** each policy matched — `tag:health` confirms it was the tag | diff --git a/docs/my-website/img/policy_team_attach.png b/docs/my-website/img/policy_team_attach.png new file mode 100644 index 00000000000..4e337931ed8 Binary files /dev/null and b/docs/my-website/img/policy_team_attach.png differ diff --git a/docs/my-website/img/policy_test_matching.png b/docs/my-website/img/policy_test_matching.png new file mode 100644 index 00000000000..5d024ae78b4 Binary files /dev/null and b/docs/my-website/img/policy_test_matching.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 28e1724ef82..579b5699101 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -42,49 +42,62 @@ const sidebars = { label: "Guardrails", items: [ "proxy/guardrails/quick_start", - "proxy/guardrails/guardrail_policies", "proxy/guardrails/guardrail_load_balancing", + "proxy/guardrails/test_playground", + "proxy/guardrails/litellm_content_filter", { type: "category", - "label": "Contributing to Guardrails", + label: "Providers", + items: [ + ...[ + "proxy/guardrails/qualifire", + "proxy/guardrails/aim_security", + "proxy/guardrails/onyx_security", + "proxy/guardrails/aporia_api", + "proxy/guardrails/azure_content_guardrail", + "proxy/guardrails/bedrock", + "proxy/guardrails/enkryptai", + "proxy/guardrails/ibm_guardrails", + "proxy/guardrails/grayswan", + "proxy/guardrails/hiddenlayer", + "proxy/guardrails/lasso_security", + "proxy/guardrails/guardrails_ai", + "proxy/guardrails/lakera_ai", + "proxy/guardrails/model_armor", + "proxy/guardrails/noma_security", + "proxy/guardrails/dynamoai", + "proxy/guardrails/openai_moderation", + "proxy/guardrails/pangea", + "proxy/guardrails/pillar_security", + "proxy/guardrails/pii_masking_v2", + "proxy/guardrails/panw_prisma_airs", + "proxy/guardrails/secret_detection", + "proxy/guardrails/custom_guardrail", + "proxy/guardrails/custom_code_guardrail", + "proxy/guardrails/prompt_injection", + "proxy/guardrails/tool_permission", + "proxy/guardrails/zscaler_ai_guard", + "proxy/guardrails/javelin" + ].sort(), + ], + }, + { + type: "category", + label: "Contributing to Guardrails", items: [ "adding_provider/generic_guardrail_api", "adding_provider/simple_guardrail_tutorial", "adding_provider/adding_guardrail_support", ] }, - "proxy/guardrails/test_playground", - "proxy/guardrails/litellm_content_filter", - ...[ - "proxy/guardrails/qualifire", - "proxy/guardrails/aim_security", - "proxy/guardrails/onyx_security", - "proxy/guardrails/aporia_api", - "proxy/guardrails/azure_content_guardrail", - "proxy/guardrails/bedrock", - "proxy/guardrails/enkryptai", - "proxy/guardrails/ibm_guardrails", - "proxy/guardrails/grayswan", - "proxy/guardrails/hiddenlayer", - "proxy/guardrails/lasso_security", - "proxy/guardrails/guardrails_ai", - "proxy/guardrails/lakera_ai", - "proxy/guardrails/model_armor", - "proxy/guardrails/noma_security", - "proxy/guardrails/dynamoai", - "proxy/guardrails/openai_moderation", - "proxy/guardrails/pangea", - "proxy/guardrails/pillar_security", - "proxy/guardrails/pii_masking_v2", - "proxy/guardrails/panw_prisma_airs", - "proxy/guardrails/secret_detection", - "proxy/guardrails/custom_guardrail", - "proxy/guardrails/custom_code_guardrail", - "proxy/guardrails/prompt_injection", - "proxy/guardrails/tool_permission", - "proxy/guardrails/zscaler_ai_guard", - "proxy/guardrails/javelin" - ].sort(), + ], + }, + { + type: "category", + label: "Policies", + items: [ + "proxy/guardrails/guardrail_policies", + "proxy/guardrails/policy_tags", ], }, { @@ -396,6 +409,16 @@ const sidebars = { ], }, "proxy/caching", + { + type: "link", + label: "Guardrails", + href: "https://docs.litellm.ai/docs/proxy/guardrails/quick_start", + }, + { + type: "link", + label: "Policies", + href: "https://docs.litellm.ai/docs/proxy/guardrails/guardrail_policies", + }, { type: "category", label: "Create Custom Plugins", diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index b1ca1f71c9e..558dfcc9517 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -914,6 +914,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/litellm/constants.py b/litellm/constants.py index 180315ace0e..88c57d3ce4c 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -213,6 +213,10 @@ REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer" MAX_REDIS_BUFFER_DEQUEUE_COUNT = int(os.getenv("MAX_REDIS_BUFFER_DEQUEUE_COUNT", 100)) MAX_SIZE_IN_MEMORY_QUEUE = int(os.getenv("MAX_SIZE_IN_MEMORY_QUEUE", 2000)) +# Bounds asyncio.Queue() instances (log queues, spend update queues, etc.) to prevent unbounded memory growth +LITELLM_ASYNCIO_QUEUE_MAXSIZE = int( + os.getenv("LITELLM_ASYNCIO_QUEUE_MAXSIZE", 1000) +) MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = int( os.getenv("MAX_IN_MEMORY_QUEUE_FLUSH_COUNT", 1000) ) @@ -1306,6 +1310,9 @@ os.getenv("DEFAULT_SLACK_ALERTING_THRESHOLD", 300) ) MAX_TEAM_LIST_LIMIT = int(os.getenv("MAX_TEAM_LIST_LIMIT", 20)) +MAX_POLICY_ESTIMATE_IMPACT_ROWS = int( + os.getenv("MAX_POLICY_ESTIMATE_IMPACT_ROWS", 1000) +) DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = float( os.getenv("DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD", 0.7) ) diff --git a/litellm/integrations/cloudzero/transform.py b/litellm/integrations/cloudzero/transform.py index e06b944a419..c36833a6dbf 100644 --- a/litellm/integrations/cloudzero/transform.py +++ b/litellm/integrations/cloudzero/transform.py @@ -103,10 +103,15 @@ def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord: # Use team_alias if available, otherwise team_id, otherwise fallback to 'unknown' entity_id = str(team_alias) if team_alias else (str(team_id) if team_id else 'unknown') + # Get alias fields if they exist + api_key_alias = row.get('api_key_alias') + organization_alias = row.get('organization_alias') + project_alias = row.get('project_alias') + user_alias = row.get('user_alias') + dimensions = { 'entity_type': CZEntityType.TEAM.value, 'entity_id': entity_id, - 'team_id': str(team_id) if team_id else 'unknown', 'team_alias': str(team_alias) if team_alias else 'unknown', 'model': model, 'model_group': str(row.get('model_group', '')), @@ -119,28 +124,37 @@ def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord: 'failed_requests': str(row.get('failed_requests', 0)), 'cache_creation_tokens': str(row.get('cache_creation_input_tokens', 0)), 'cache_read_tokens': str(row.get('cache_read_input_tokens', 0)), + 'organization_alias': str(organization_alias) if organization_alias else '', + 'project_alias': str(project_alias) if project_alias else '', + 'user_alias': str(user_alias) if user_alias else '', } # Extract CZRN components to populate corresponding CBF columns czrn_components = self.czrn_generator.extract_components(resource_id) service_type, provider, region, owner_account_id, resource_type, cloud_local_id = czrn_components + # Build resource/account as concat of api_key_alias and api_key_prefix + resource_account = f"{api_key_alias}|{api_key_hash}" if api_key_alias else api_key_hash + # CloudZero CBF format with proper column names cbf_record = { # Required CBF fields 'time/usage_start': usage_date.isoformat() if usage_date else None, # Required: ISO-formatted UTC datetime 'cost/cost': float(row.get('spend', 0.0)), # Required: billed cost - 'resource/id': resource_id, # Required when resource tags are present + 'resource/id': model, # Send model name # Usage metrics for token consumption 'usage/amount': total_tokens, # Numeric value of tokens consumed 'usage/units': 'tokens', # Description of token units - # CBF fields that correspond to CZRN components - 'resource/service': service_type, # Maps to CZRN service-type (litellm) - 'resource/account': owner_account_id, # Maps to CZRN owner-account-id (entity_id) + # CBF fields - updated per LIT-1907 + 'resource/service': str(row.get('model_group', '')), # Send model_group + 'resource/account': resource_account, # Send api_key_alias|api_key_prefix 'resource/region': region, # Maps to CZRN region (cross-region) - 'resource/usage_family': resource_type, # Maps to CZRN resource-type (llm-usage) + 'resource/usage_family': str(row.get('custom_llm_provider', '')), # Send provider + + # Action field + 'action/operation': str(team_id) if team_id else '', # Send team_id # Line item details 'lineitem/type': 'Usage', # Standard usage line item @@ -155,13 +169,11 @@ def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord: if value and value != 'N/A' and value != 'unknown': # Only add meaningful tags cbf_record[f'resource/tag:{key}'] = str(value) - # Add token breakdown as resource tags for analysis + # Add token breakdown as resource tags for analysis (excluding total_tokens per LIT-1907) if prompt_tokens > 0: cbf_record['resource/tag:prompt_tokens'] = str(prompt_tokens) if completion_tokens > 0: cbf_record['resource/tag:completion_tokens'] = str(completion_tokens) - if total_tokens > 0: - cbf_record['resource/tag:total_tokens'] = str(total_tokens) return CBFRecord(cbf_record) diff --git a/litellm/integrations/gcs_bucket/gcs_bucket.py b/litellm/integrations/gcs_bucket/gcs_bucket.py index 3cb62905531..0f1ba4a4093 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket.py @@ -9,6 +9,7 @@ from urllib.parse import quote from litellm._logging import verbose_logger +from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.proxy._types import CommonProxyErrors @@ -41,7 +42,9 @@ def __init__(self, bucket_name: Optional[str] = None) -> None: batch_size=self.batch_size, flush_interval=self.flush_interval, ) - self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue() # type: ignore[assignment] + self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue( # type: ignore[assignment] + maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE + ) asyncio.create_task(self.periodic_flush()) AdditionalLoggingUtils.__init__(self) @@ -69,6 +72,9 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti ) if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") + # When queue is at maxsize, flush immediately to make room (no blocking, no data dropped) + if self.log_queue.full(): + await self.flush_queue() await self.log_queue.put( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj @@ -91,9 +97,9 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti ) if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") - # Add to logging queue - this will be flushed periodically - # Use asyncio.Queue.put() for thread-safe concurrent access - # If queue is full, this will block until space is available (backpressure) + # When queue is at maxsize, flush immediately to make room (no blocking, no data dropped) + if self.log_queue.full(): + await self.flush_queue() await self.log_queue.put( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index 3ddcae69315..03fbdd463dd 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -98,16 +98,18 @@ def is_azure_content_policy_violation_error(error_str: str) -> bool: """ Check if an error string indicates a content policy violation error. """ + _lower = error_str.lower() known_exception_substrings = [ - "invalid_request_error", "content_policy_violation", + "responsibleaipolicyviolation", "the response was filtered due to the prompt triggering azure openai's content management", "your task failed as a result of our safety system", "the model produced invalid content", "content_filter_policy", + "your request was rejected as a result of our safety system", ] for substring in known_exception_substrings: - if substring in error_str.lower(): + if substring in _lower: return True return False @@ -2060,6 +2062,19 @@ def exception_type( # type: ignore # noqa: PLR0915 if isinstance(body_dict, dict): if isinstance(body_dict.get("error"), dict): azure_error_code = body_dict["error"].get("code") # type: ignore[index] + # Also check inner_error for + # ResponsibleAIPolicyViolation which indicates a + # content policy violation even when the top-level + # code is generic (e.g. "invalid_request_error"). + if azure_error_code != "content_policy_violation": + _inner = ( + body_dict["error"].get("inner_error") # type: ignore[index] + or body_dict["error"].get("innererror") # type: ignore[index] + ) + if isinstance(_inner, dict) and _inner.get( + "code" + ) == "ResponsibleAIPolicyViolation": + azure_error_code = "content_policy_violation" else: azure_error_code = body_dict.get("code") except Exception: diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 82aa7390188..a1e3f59de14 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -834,6 +834,10 @@ def map_openai_params( # noqa: PLR0915 "sonnet-4-5", "opus-4.1", "opus-4-1", + "opus-4.5", + "opus-4-5", + "opus-4.6", + "opus-4-6", } ): _output_format = ( diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 76fa713ca8c..44ee51d14ab 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -901,7 +901,20 @@ async def make_async_azure_httpx_request( if response.json()["status"] == "failed": error_data = response.json() - raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + # Preserve Azure error details (e.g. content_policy_violation, + # inner_error, content_filter_results) as structured body so + # exception_type() can route them correctly. + _error_body = error_data.get("error", error_data) + _error_msg = ( + _error_body.get("message", "Image generation failed") + if isinstance(_error_body, dict) + else json.dumps(error_data) + ) + raise AzureOpenAIError( + status_code=400, + message=_error_msg, + body=error_data, + ) result = response.json()["result"] return httpx.Response( @@ -999,7 +1012,20 @@ def make_sync_azure_httpx_request( if response.json()["status"] == "failed": error_data = response.json() - raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + # Preserve Azure error details (e.g. content_policy_violation, + # inner_error, content_filter_results) as structured body so + # exception_type() can route them correctly. + _error_body = error_data.get("error", error_data) + _error_msg = ( + _error_body.get("message", "Image generation failed") + if isinstance(_error_body, dict) + else json.dumps(error_data) + ) + raise AzureOpenAIError( + status_code=400, + message=_error_msg, + body=error_data, + ) result = response.json()["result"] return httpx.Response( diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index d794aa50d2e..9b5a7b42d0e 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -6115,6 +6115,17 @@ "supports_function_calling": true, "supports_reasoning": true }, + "bedrock/moonshotai.kimi-k2-thinking": { + "input_cost_per_token": 7.3e-07, + "litellm_provider": "bedrock", + "max_input_tokens": 262144, + "max_output_tokens": 262144, + "max_tokens": 262144, + "mode": "chat", + "output_cost_per_token": 3.03e-06, + "supports_function_calling": true, + "supports_reasoning": true + }, "bedrock/moonshotai.kimi-k2.5": { "input_cost_per_token": 7.3e-07, "litellm_provider": "bedrock", @@ -9035,6 +9046,43 @@ } ] }, + "dashscope/qwen3-max": { + "litellm_provider": "dashscope", + "max_input_tokens": 258048, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "source": "https://www.alibabacloud.com/help/en/model-studio/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "tiered_pricing": [ + { + "input_cost_per_token": 1.2e-06, + "output_cost_per_token": 6e-06, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 2.4e-06, + "output_cost_per_token": 1.2e-05, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 3e-06, + "output_cost_per_token": 1.5e-05, + "range": [ + 128000.0, + 252000.0 + ] + } + ] + }, "dashscope/qwq-plus": { "input_cost_per_token": 8e-07, "litellm_provider": "dashscope", diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index faeca9b2aed..62ca6dc2ae2 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -394,6 +394,14 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: _metadata["applied_policies"] ) + if "policy_sources" in _metadata: + sources = _metadata["policy_sources"] + if isinstance(sources, dict) and sources: + # Use ';' as delimiter — matched_via reasons may contain commas + headers["x-litellm-policy-sources"] = "; ".join( + f"{name}={reason}" for name, reason in sources.items() + ) + if "semantic-similarity" in _metadata: headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) @@ -441,6 +449,27 @@ def add_policy_to_applied_policies_header( request_data["metadata"] = _metadata +def add_policy_sources_to_metadata( + request_data: Dict, policy_sources: Dict[str, str] +): + """ + Store policy match reasons in metadata for x-litellm-policy-sources header. + + Args: + request_data: The request data dict + policy_sources: Map of policy_name -> matched_via reason + """ + if not policy_sources: + return + _metadata = request_data.get("metadata", None) or {} + existing = _metadata.get("policy_sources", {}) + if not isinstance(existing, dict): + existing = {} + existing.update(policy_sources) + _metadata["policy_sources"] = existing + request_data["metadata"] = _metadata + + def add_guardrail_response_to_standard_logging_object( litellm_logging_obj: Optional["LiteLLMLogging"], guardrail_response: StandardLoggingGuardrailInformation, diff --git a/litellm/proxy/db/db_transaction_queue/base_update_queue.py b/litellm/proxy/db/db_transaction_queue/base_update_queue.py index 202829b78b6..a5ec1c3eaf4 100644 --- a/litellm/proxy/db/db_transaction_queue/base_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/base_update_queue.py @@ -10,14 +10,18 @@ service_logger_obj = ( ServiceLogging() ) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager -from litellm.constants import MAX_IN_MEMORY_QUEUE_FLUSH_COUNT, MAX_SIZE_IN_MEMORY_QUEUE +from litellm.constants import ( + LITELLM_ASYNCIO_QUEUE_MAXSIZE, + MAX_IN_MEMORY_QUEUE_FLUSH_COUNT, + MAX_SIZE_IN_MEMORY_QUEUE, +) class BaseUpdateQueue: """Base class for in memory buffer for database transactions""" def __init__(self): - self.update_queue = asyncio.Queue() + self.update_queue = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE) self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE async def add_update(self, update): diff --git a/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py index c3074e641b2..5ba8fb13596 100644 --- a/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional from litellm._logging import verbose_proxy_logger +from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE from litellm.proxy._types import BaseDailySpendTransaction from litellm.proxy.db.db_transaction_queue.base_update_queue import ( BaseUpdateQueue, @@ -54,7 +55,7 @@ class DailySpendUpdateQueue(BaseUpdateQueue): def __init__(self): super().__init__() self.update_queue: asyncio.Queue[Dict[str, BaseDailySpendTransaction]] = ( - asyncio.Queue() + asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE) ) async def add_update(self, update: Dict[str, BaseDailySpendTransaction]): diff --git a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py index 9b0449bb9ab..c96564252d0 100644 --- a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional from litellm._logging import verbose_proxy_logger +from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE from litellm.proxy._types import ( DBSpendUpdateTransactions, Litellm_EntityType, @@ -21,7 +22,9 @@ class SpendUpdateQueue(BaseUpdateQueue): def __init__(self): super().__init__() - self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue() + self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue( + maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE + ) async def flush_and_get_aggregated_db_spend_update_transactions( self, diff --git a/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py b/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py index 3cbca27ce0c..955067b486a 100644 --- a/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py +++ b/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py @@ -19,13 +19,15 @@ async def get_ui_config(): from litellm.proxy.utils import get_proxy_base_url, get_server_root_path auto_redirect_ui_login_to_sso = ( - os.getenv("AUTO_REDIRECT_UI_LOGIN_TO_SSO", "true").lower() == "true" + os.getenv("AUTO_REDIRECT_UI_LOGIN_TO_SSO", "false").lower() == "true" ) admin_ui_disabled = os.getenv("DISABLE_ADMIN_UI", "false").lower() == "true" + sso_configured = _has_user_setup_sso() return UiDiscoveryEndpoints( server_root_path=get_server_root_path(), proxy_base_url=get_proxy_base_url(), - auto_redirect_to_sso=_has_user_setup_sso() and auto_redirect_ui_login_to_sso, + auto_redirect_to_sso=sso_configured and auto_redirect_ui_login_to_sso, admin_ui_disabled=admin_ui_disabled, + sso_configured=sso_configured, ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index d71b8449f94..3984384aae4 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -38,7 +38,7 @@ from litellm._uuid import uuid from litellm.caching.caching import DualCache -from litellm.exceptions import BlockedPiiEntityError +from litellm.exceptions import BlockedPiiEntityError, GuardrailRaisedException from litellm.integrations.custom_guardrail import ( CustomGuardrail, log_guardrail_information, @@ -232,6 +232,14 @@ def __del__(self): """Cleanup: we try to close, but doing async cleanup in __del__ is risky.""" pass + def _has_block_action(self) -> bool: + """Return True if pii_entities_config has any BLOCK action (fail-closed on analyzer errors).""" + if not self.pii_entities_config: + return False + return any( + action == PiiAction.BLOCK for action in self.pii_entities_config.values() + ) + def _get_presidio_analyze_request_payload( self, text: str, @@ -316,13 +324,30 @@ async def analyze_text( # Handle error responses from Presidio (e.g., {'error': 'No text provided'}) # Presidio may return a dict instead of a list when errors occur + def _fail_on_invalid_response( + reason: str, + ) -> List[PresidioAnalyzeResponseItem]: + should_fail_closed = ( + bool(self.pii_entities_config) + or self.output_parse_pii + or self.apply_to_output + ) + if should_fail_closed: + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message=f"Presidio analyzer returned invalid response; cannot verify PII when PII protection is configured: {reason}", + should_wrap_with_default_message=False, + ) + verbose_proxy_logger.warning( + "Presidio analyzer %s, returning empty list", reason + ) + return [] + if isinstance(analyze_results, dict): if "error" in analyze_results: - verbose_proxy_logger.warning( - "Presidio analyzer returned error: %s, returning empty list", - analyze_results.get("error"), + return _fail_on_invalid_response( + f"error: {analyze_results.get('error')}" ) - return [] # If it's a dict but not an error, try to process it as a single item verbose_proxy_logger.debug( "Presidio returned dict (not list), attempting to process as single item" @@ -330,23 +355,33 @@ async def analyze_text( try: return [PresidioAnalyzeResponseItem(**analyze_results)] except Exception as e: - verbose_proxy_logger.warning( - "Failed to parse Presidio dict response: %s, returning empty list", - e, + return _fail_on_invalid_response( + f"failed to parse dict response: {e}" ) - return [] + + # Handle unexpected types (str, None, etc.) - e.g. from malformed/error + if not isinstance(analyze_results, list): + return _fail_on_invalid_response( + f"unexpected type {type(analyze_results).__name__} (expected list or dict), response: {str(analyze_results)[:200]}" + ) # Normal case: list of results final_results = [] for item in analyze_results: + if not isinstance(item, dict): + verbose_proxy_logger.warning( + "Skipping invalid Presidio result item (expected dict, got %s): %s", + type(item).__name__, + str(item)[:100], + ) + continue try: final_results.append(PresidioAnalyzeResponseItem(**item)) - except TypeError as te: - # Handle case where item is not a dict (shouldn't happen, but be defensive) + except Exception as e: verbose_proxy_logger.warning( - "Skipping invalid Presidio result item: %s (error: %s)", + "Failed to parse Presidio result item: %s (error: %s)", item, - te, + e, ) continue return final_results diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9be78264e85..49d31c1efec 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1539,8 +1539,15 @@ def add_guardrails_from_policy_engine( """ from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import ( + add_policy_sources_to_metadata, add_policy_to_applied_policies_header, ) + from litellm.proxy.common_utils.http_parsing_utils import ( + get_tags_from_request_body, + ) + from litellm.proxy.policy_engine.attachment_registry import ( + get_attachment_registry, + ) from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.proxy.policy_engine.policy_resolver import PolicyResolver @@ -1561,20 +1568,31 @@ def add_guardrails_from_policy_engine( ) return - # Build context from request + # Extract tags using the shared helper (handles metadata / litellm_metadata, + # top-level tags, deduplication, and type filtering). + + all_tags = get_tags_from_request_body(data) or None + context = PolicyMatchContext( team_alias=user_api_key_dict.team_alias, key_alias=user_api_key_dict.key_alias, model=data.get("model"), + tags=all_tags, ) verbose_proxy_logger.debug( f"Policy engine: matching policies for context team_alias={context.team_alias}, " - f"key_alias={context.key_alias}, model={context.model}" + f"key_alias={context.key_alias}, model={context.model}, tags={context.tags}" ) - # Get matching policies via attachments - matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + # Get matching policies via attachments (with match reasons for attribution) + attachment_registry = get_attachment_registry() + matches_with_reasons = attachment_registry.get_attached_policies_with_reasons( + context + ) + matching_policy_names = [m["policy_name"] for m in matches_with_reasons] + # Build reasons map: {"hipaa-policy": "tag:healthcare", ...} + policy_reasons = {m["policy_name"]: m["matched_via"] for m in matches_with_reasons} verbose_proxy_logger.debug( f"Policy engine: matched policies via attachments: {matching_policy_names}" @@ -1607,6 +1625,16 @@ def add_guardrails_from_policy_engine( request_data=data, policy_name=policy_name ) + # Track policy attribution sources for x-litellm-policy-sources header + applied_reasons = { + name: policy_reasons[name] + for name in applied_policy_names + if name in policy_reasons + } + add_policy_sources_to_metadata( + request_data=data, policy_sources=applied_reasons + ) + # Resolve guardrails from matching policies resolved_guardrails = PolicyResolver.resolve_guardrails_for_context(context=context) diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py index 4a335b54747..69b3b3599f3 100644 --- a/litellm/proxy/policy_engine/attachment_registry.py +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -84,6 +84,7 @@ def _parse_attachment(self, attachment_data: Dict[str, Any]) -> PolicyAttachment teams=attachment_data.get("teams"), keys=attachment_data.get("keys"), models=attachment_data.get("models"), + tags=attachment_data.get("tags"), ) def get_attached_policies(self, context: PolicyMatchContext) -> List[str]: @@ -96,21 +97,68 @@ def get_attached_policies(self, context: PolicyMatchContext) -> List[str]: Returns: List of policy names that are attached to matching scopes """ + return [r["policy_name"] for r in self.get_attached_policies_with_reasons(context)] + + def get_attached_policies_with_reasons( + self, context: PolicyMatchContext + ) -> List[Dict[str, Any]]: + """ + Get list of policy names and match reasons for the given context. + + Returns a list of dicts with 'policy_name' and 'matched_via' keys. + The 'matched_via' describes which dimension caused the match. + """ from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher - attached_policies: List[str] = [] + results: List[Dict[str, Any]] = [] + seen_policies: set = set() for attachment in self._attachments: scope = attachment.to_policy_scope() if PolicyMatcher.scope_matches(scope=scope, context=context): - if attachment.policy not in attached_policies: - attached_policies.append(attachment.policy) + if attachment.policy not in seen_policies: + seen_policies.add(attachment.policy) + matched_via = self._describe_match_reason(attachment, context) + results.append( + { + "policy_name": attachment.policy, + "matched_via": matched_via, + } + ) verbose_proxy_logger.debug( f"Attachment matched: policy={attachment.policy}, " + f"matched_via={matched_via}, " f"context=(team={context.team_alias}, key={context.key_alias}, model={context.model})" ) - return attached_policies + return results + + @staticmethod + def _describe_match_reason( + attachment: PolicyAttachment, context: PolicyMatchContext + ) -> str: + """Describe why an attachment matched the context.""" + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + + if attachment.is_global(): + return "scope:*" + + reasons = [] + if attachment.tags and context.tags: + matching_tags = [ + t for t in context.tags + if PolicyMatcher.matches_pattern(t, attachment.tags) + ] + if matching_tags: + reasons.append(f"tag:{matching_tags[0]}") + if attachment.teams and context.team_alias: + reasons.append(f"team:{context.team_alias}") + if attachment.keys and context.key_alias: + reasons.append(f"key:{context.key_alias}") + if attachment.models and context.model: + reasons.append(f"model:{context.model}") + + return "+".join(reasons) if reasons else "scope:default" def is_policy_attached( self, policy_name: str, context: PolicyMatchContext @@ -238,6 +286,7 @@ async def add_attachment_to_db( "teams": attachment_request.teams or [], "keys": attachment_request.keys or [], "models": attachment_request.models or [], + "tags": attachment_request.tags or [], "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), "created_by": created_by, @@ -253,6 +302,7 @@ async def add_attachment_to_db( teams=attachment_request.teams, keys=attachment_request.keys, models=attachment_request.models, + tags=attachment_request.tags, ) self.add_attachment(attachment) @@ -263,6 +313,7 @@ async def add_attachment_to_db( teams=created_attachment.teams or [], keys=created_attachment.keys or [], models=created_attachment.models or [], + tags=created_attachment.tags or [], created_at=created_attachment.created_at, updated_at=created_attachment.updated_at, created_by=created_attachment.created_by, @@ -344,6 +395,7 @@ async def get_attachment_by_id_from_db( teams=attachment.teams or [], keys=attachment.keys or [], models=attachment.models or [], + tags=attachment.tags or [], created_at=attachment.created_at, updated_at=attachment.updated_at, created_by=attachment.created_by, @@ -381,6 +433,7 @@ async def get_all_attachments_from_db( teams=a.teams or [], keys=a.keys or [], models=a.models or [], + tags=a.tags or [], created_at=a.created_at, updated_at=a.updated_at, created_by=a.created_by, @@ -415,6 +468,7 @@ async def sync_attachments_from_db( teams=attachment_response.teams if attachment_response.teams else None, keys=attachment_response.keys if attachment_response.keys else None, models=attachment_response.models if attachment_response.models else None, + tags=attachment_response.tags if attachment_response.tags else None, ) self._attachments.append(attachment) diff --git a/litellm/proxy/policy_engine/policy_endpoints.py b/litellm/proxy/policy_engine/policy_endpoints.py index 615e153862a..3bd893b0034 100644 --- a/litellm/proxy/policy_engine/policy_endpoints.py +++ b/litellm/proxy/policy_engine/policy_endpoints.py @@ -23,10 +23,6 @@ router = APIRouter() -# Get singleton instances -POLICY_REGISTRY = get_policy_registry() -ATTACHMENT_REGISTRY = get_attachment_registry() - # ───────────────────────────────────────────────────────────────────────────── # Policy CRUD Endpoints @@ -75,7 +71,7 @@ async def list_policies(): raise HTTPException(status_code=500, detail="Database not connected") try: - policies = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + policies = await get_policy_registry().get_all_policies_from_db(prisma_client) return PolicyListDBResponse(policies=policies, total_count=len(policies)) except Exception as e: verbose_proxy_logger.exception(f"Error listing policies: {e}") @@ -130,7 +126,7 @@ async def create_policy( try: created_by = user_api_key_dict.user_id - result = await POLICY_REGISTRY.add_policy_to_db( + result = await get_policy_registry().add_policy_to_db( policy_request=request, prisma_client=prisma_client, created_by=created_by, @@ -168,7 +164,7 @@ async def get_policy(policy_id: str): raise HTTPException(status_code=500, detail="Database not connected") try: - result = await POLICY_REGISTRY.get_policy_by_id_from_db( + result = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -216,7 +212,7 @@ async def update_policy( try: # Check if policy exists - existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + existing = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -226,7 +222,7 @@ async def update_policy( ) updated_by = user_api_key_dict.user_id - result = await POLICY_REGISTRY.update_policy_in_db( + result = await get_policy_registry().update_policy_in_db( policy_id=policy_id, policy_request=request, prisma_client=prisma_client, @@ -269,7 +265,7 @@ async def delete_policy(policy_id: str): try: # Check if policy exists - existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + existing = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -278,7 +274,7 @@ async def delete_policy(policy_id: str): status_code=404, detail=f"Policy with ID {policy_id} not found" ) - result = await POLICY_REGISTRY.delete_policy_from_db( + result = await get_policy_registry().delete_policy_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -324,7 +320,7 @@ async def get_resolved_guardrails(policy_id: str): try: # Get the policy - policy = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -334,7 +330,7 @@ async def get_resolved_guardrails(policy_id: str): ) # Resolve guardrails - resolved = await POLICY_REGISTRY.resolve_guardrails_from_db( + resolved = await get_policy_registry().resolve_guardrails_from_db( policy_name=policy.policy_name, prisma_client=prisma_client, ) @@ -399,7 +395,7 @@ async def list_policy_attachments(): raise HTTPException(status_code=500, detail="Database not connected") try: - attachments = await ATTACHMENT_REGISTRY.get_all_attachments_from_db( + attachments = await get_attachment_registry().get_all_attachments_from_db( prisma_client ) return PolicyAttachmentListResponse( @@ -466,7 +462,7 @@ async def create_policy_attachment( try: # Verify the policy exists - policy = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + policy = await get_policy_registry().get_all_policies_from_db(prisma_client) policy_names = [p.policy_name for p in policy] if request.policy_name not in policy_names: raise HTTPException( @@ -475,7 +471,7 @@ async def create_policy_attachment( ) created_by = user_api_key_dict.user_id - result = await ATTACHMENT_REGISTRY.add_attachment_to_db( + result = await get_attachment_registry().add_attachment_to_db( attachment_request=request, prisma_client=prisma_client, created_by=created_by, @@ -510,7 +506,7 @@ async def get_policy_attachment(attachment_id: str): raise HTTPException(status_code=500, detail="Database not connected") try: - result = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + result = await get_attachment_registry().get_attachment_by_id_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) @@ -556,7 +552,7 @@ async def delete_policy_attachment(attachment_id: str): try: # Check if attachment exists - existing = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + existing = await get_attachment_registry().get_attachment_by_id_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) @@ -566,7 +562,7 @@ async def delete_policy_attachment(attachment_id: str): detail=f"Attachment with ID {attachment_id} not found", ) - result = await ATTACHMENT_REGISTRY.delete_attachment_from_db( + result = await get_attachment_registry().delete_attachment_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) diff --git a/litellm/proxy/policy_engine/policy_matcher.py b/litellm/proxy/policy_engine/policy_matcher.py index ab73970bfab..888981f85f5 100644 --- a/litellm/proxy/policy_engine/policy_matcher.py +++ b/litellm/proxy/policy_engine/policy_matcher.py @@ -81,6 +81,19 @@ def scope_matches(scope: PolicyScope, context: PolicyMatchContext) -> bool: if not PolicyMatcher.matches_pattern(context.model, scope.get_models()): return False + # Check tags (only if scope specifies tags) + # Unlike teams/keys/models, empty tags means "do not check" rather than "match all" + scope_tags = scope.get_tags() + if scope_tags: + if not context.tags: + return False + # Match if ANY context tag matches ANY scope tag pattern + if not any( + PolicyMatcher.matches_pattern(tag, scope_tags) + for tag in context.tags + ): + return False + return True @staticmethod diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index 5fb5084f648..a2431977b24 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -484,6 +484,7 @@ async def sync_policies_from_db( ) self.add_policy(policy_response.policy_name, policy) + self._initialized = True verbose_proxy_logger.info( f"Synced {len(policies)} policies from DB to in-memory registry" ) diff --git a/litellm/proxy/policy_engine/policy_resolve_endpoints.py b/litellm/proxy/policy_engine/policy_resolve_endpoints.py new file mode 100644 index 00000000000..318e990ff12 --- /dev/null +++ b/litellm/proxy/policy_engine/policy_resolve_endpoints.py @@ -0,0 +1,405 @@ +""" +Policy resolve and attachment impact estimation endpoints. + +- /policies/resolve — debug which guardrails apply for a given context +- /policies/attachments/estimate-impact — preview blast radius before creating an attachment +""" + +import json + +from fastapi import APIRouter, Depends, HTTPException, Query + +from litellm._logging import verbose_proxy_logger +from litellm.constants import MAX_POLICY_ESTIMATE_IMPACT_ROWS +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.policy_registry import get_policy_registry +from litellm.types.proxy.policy_engine import ( + AttachmentImpactResponse, + PolicyAttachmentCreateRequest, + PolicyMatchContext, + PolicyMatchDetail, + PolicyResolveRequest, + PolicyResolveResponse, +) + +router = APIRouter() + + +def _build_alias_where(field: str, patterns: list) -> dict: + """Build a Prisma ``where`` clause for alias patterns. + + Supports exact matches and suffix wildcards (``prefix*``). + Returns something like: + {"OR": [{"field": {"in": ["a","b"]}}, {"field": {"startsWith": "dev-"}}]} + """ + exact: list = [] + prefix_conditions: list = [] + for pat in patterns: + if pat.endswith("*"): + prefix_conditions.append({field: {"startsWith": pat[:-1]}}) + else: + exact.append(pat) + + conditions: list = [] + if exact: + conditions.append({field: {"in": exact}}) + conditions.extend(prefix_conditions) + + if not conditions: + return {field: {"not": None}} + if len(conditions) == 1: + return conditions[0] + return {"OR": conditions} + + +def _parse_metadata(raw_metadata: object) -> dict: + """Parse metadata that may be a dict, JSON string, or None.""" + if raw_metadata is None: + return {} + if isinstance(raw_metadata, str): + try: + return json.loads(raw_metadata) + except (json.JSONDecodeError, TypeError): + return {} + return raw_metadata if isinstance(raw_metadata, dict) else {} + + +def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> list: + """Extract tags list from a metadata field (or metadata_json fallback).""" + raw = json_metadata if json_metadata is not None else metadata + parsed = _parse_metadata(raw) + return parsed.get("tags", []) or [] + + +async def _fetch_all_teams(prisma_client: object) -> list: + """Fetch teams from DB once. Reuse the result across tag and alias lookups.""" + return await prisma_client.db.litellm_teamtable.find_many( # type: ignore + where={}, order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + + +def _filter_keys_by_tags(keys: list, tag_patterns: list) -> tuple: + """Filter key rows whose metadata.tags match any of the given patterns. + + Returns (named_aliases, unnamed_count). + """ + + affected: list = [] + unnamed_count = 0 + for key in keys: + key_alias = key.key_alias or "" + key_tags = _get_tags_from_metadata( + key.metadata, getattr(key, "metadata_json", None) + ) + if key_tags and any( + RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat) + for tag in key_tags + for pat in tag_patterns + ): + if key_alias: + affected.append(key_alias) + else: + unnamed_count += 1 + return affected, unnamed_count + + +def _filter_teams_by_tags(teams: list, tag_patterns: list) -> tuple: + """Filter pre-fetched team rows whose metadata.tags match any patterns. + + Returns (named_aliases, unnamed_count). + """ + + affected: list = [] + unnamed_count = 0 + for team in teams: + team_alias = team.team_alias or "" + team_tags = _get_tags_from_metadata(team.metadata) + if team_tags and any( + RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat) + for tag in team_tags + for pat in tag_patterns + ): + if team_alias: + affected.append(team_alias) + else: + unnamed_count += 1 + return affected, unnamed_count + + +async def _find_affected_by_team_patterns( + prisma_client: object, + all_teams: list, + team_patterns: list, + existing_teams: list, + existing_keys: list, +) -> tuple: + """Filter pre-fetched teams by alias patterns, then fetch their keys. + + Returns (new_teams, new_keys, unnamed_keys_count). + """ + + new_teams: list = [] + matched_team_ids: list = [] + + for team in all_teams: + team_alias = team.team_alias or "" + if team_alias and any( + RouteChecks._route_matches_wildcard_pattern(route=team_alias, pattern=pat) + for pat in team_patterns + ): + if team_alias not in existing_teams: + new_teams.append(team_alias) + matched_team_ids.append(str(team.team_id)) + + new_keys: list = [] + unnamed_keys_count = 0 + if matched_team_ids: + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where={"team_id": {"in": matched_team_ids}}, + order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + for key in keys: + key_alias = key.key_alias or "" + if key_alias: + if key_alias not in existing_keys: + new_keys.append(key_alias) + else: + unnamed_keys_count += 1 + + return new_teams, new_keys, unnamed_keys_count + + +async def _find_affected_keys_by_alias( + prisma_client: object, key_patterns: list, existing_keys: list +) -> list: + """Find keys whose alias matches the given patterns.""" + + affected: list = [] + + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where=_build_alias_where("key_alias", key_patterns), + order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + for key in keys: + key_alias = key.key_alias or "" + if key_alias and any( + RouteChecks._route_matches_wildcard_pattern(route=key_alias, pattern=pat) + for pat in key_patterns + ): + if key_alias not in existing_keys: + affected.append(key_alias) + return affected + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Resolve Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/resolve", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyResolveResponse, +) +async def resolve_policies_for_context( + request: PolicyResolveRequest, + force_sync: bool = Query( + default=False, + description="Force a DB sync before resolving. Default uses in-memory cache.", + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Resolve which policies and guardrails apply for a given context. + + Use this endpoint to debug "what guardrails would apply to a request + with this team/key/model/tags combination?" + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/resolve" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "tags": ["healthcare"], + "model": "gpt-4" + }' + ``` + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Only sync from DB when explicitly requested; otherwise use in-memory cache + if force_sync: + await get_policy_registry().sync_policies_from_db(prisma_client) + await get_attachment_registry().sync_attachments_from_db(prisma_client) + + # Build context from request + context = PolicyMatchContext( + team_alias=request.team_alias, + key_alias=request.key_alias, + model=request.model, + tags=request.tags, + ) + + # Get matching policies with reasons + match_results = get_attachment_registry().get_attached_policies_with_reasons( + context=context + ) + + if not match_results: + return PolicyResolveResponse( + effective_guardrails=[], + matched_policies=[], + ) + + # Filter by conditions + policy_names = [r["policy_name"] for r in match_results] + applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions( + policy_names=policy_names, + context=context, + ) + + # Resolve guardrails for each applied policy + matched_policies = [] + all_guardrails: set = set() + for result in match_results: + pname = result["policy_name"] + if pname not in applied_policy_names: + continue + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name=pname, + policies=get_policy_registry().get_all_policies(), + context=context, + ) + guardrails = resolved.guardrails if resolved else [] + all_guardrails.update(guardrails) + matched_policies.append( + PolicyMatchDetail( + policy_name=pname, + matched_via=result["matched_via"], + guardrails_added=guardrails, + ) + ) + + return PolicyResolveResponse( + effective_guardrails=sorted(all_guardrails), + matched_policies=matched_policies, + ) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error resolving policies: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ───────────────────────────────────────────────────────────────────────────── +# Attachment Impact Estimation Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/attachments/estimate-impact", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=AttachmentImpactResponse, +) +async def estimate_attachment_impact( + request: PolicyAttachmentCreateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Estimate how many keys and teams would be affected by a policy attachment. + + Use this before creating an attachment to preview the blast radius. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/attachments/estimate-impact" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "policy_name": "hipaa-compliance", + "tags": ["healthcare", "health-*"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # If global scope, everything is affected — not useful to enumerate + if request.scope == "*": + return AttachmentImpactResponse( + affected_keys_count=-1, + affected_teams_count=-1, + sample_keys=["(global scope — affects all keys)"], + sample_teams=["(global scope — affects all teams)"], + ) + + affected_keys: list = [] + affected_teams: list = [] + unnamed_keys = 0 + unnamed_teams = 0 + + tag_patterns = request.tags or [] + team_patterns = request.teams or [] + + # Fetch teams once — reused by both tag-based and alias-based lookups + all_teams: list = [] + if tag_patterns or team_patterns: + all_teams = await _fetch_all_teams(prisma_client) + + # Tag-based impact + if tag_patterns: + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where={}, order={"created_at": "desc"}, + take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + affected_keys, unnamed_keys = _filter_keys_by_tags(keys, tag_patterns) + affected_teams, unnamed_teams = _filter_teams_by_tags( + all_teams, tag_patterns, + ) + + # Team-based impact (alias matching + keys belonging to those teams) + if team_patterns: + new_teams, new_keys, new_unnamed = await _find_affected_by_team_patterns( + prisma_client, all_teams, team_patterns, + affected_teams, affected_keys, + ) + affected_teams.extend(new_teams) + affected_keys.extend(new_keys) + unnamed_keys += new_unnamed + + # Key-based impact (direct alias matching) + key_patterns = request.keys or [] + if key_patterns: + new_keys = await _find_affected_keys_by_alias( + prisma_client, key_patterns, affected_keys, + ) + affected_keys.extend(new_keys) + + return AttachmentImpactResponse( + affected_keys_count=len(affected_keys) + unnamed_keys, + affected_teams_count=len(affected_teams) + unnamed_teams, + unnamed_keys_count=unnamed_keys, + unnamed_teams_count=unnamed_teams, + sample_keys=affected_keys[:10], + sample_teams=affected_teams[:10], + ) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error estimating attachment impact: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 294294cdda7..2130aed7770 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -427,6 +427,9 @@ def generate_feedback_box(): router as pass_through_router, ) from litellm.proxy.policy_engine.policy_endpoints import router as policy_crud_router +from litellm.proxy.policy_engine.policy_resolve_endpoints import ( + router as policy_resolve_router, +) from litellm.proxy.prompts.prompt_endpoints import router as prompts_router from litellm.proxy.public_endpoints import router as public_endpoints_router from litellm.proxy.rag_endpoints.endpoints import router as rag_router @@ -11746,6 +11749,7 @@ async def get_routes(): app.include_router(guardrails_router) app.include_router(policy_router) app.include_router(policy_crud_router) +app.include_router(policy_resolve_router) app.include_router(search_tool_management_router) app.include_router(prompts_router) app.include_router(callback_management_endpoints_router) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 1750efed92c..37ed0182663 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -911,6 +911,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/litellm/types/integrations/cloudzero.py b/litellm/types/integrations/cloudzero.py index aeda76aa5f9..e79500e08db 100644 --- a/litellm/types/integrations/cloudzero.py +++ b/litellm/types/integrations/cloudzero.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict class CBFRecord(Dict[str, Any]): @@ -9,19 +9,23 @@ class CBFRecord(Dict[str, Any]): (e.g., 'time/usage_start', 'cost/cost'), we use a Dict base class rather than TypedDict to accommodate the special characters in field names. - Expected CBF fields: + Expected CBF fields (per LIT-1907): - time/usage_start: ISO-formatted UTC datetime (Optional[str]) - cost/cost: Billed cost (float) - - resource/id: CloudZero Resource Name (CZRN) (str) + - resource/id: Model name (str) - usage/amount: Numeric value of tokens consumed (int) - usage/units: Description of units, e.g., 'tokens' (str) - - resource/service: Maps to CZRN service-type, e.g., 'litellm' (str) - - resource/account: Maps to CZRN owner-account-id (entity_id) (str) + - resource/service: Model group (str) + - resource/account: api_key_alias|api_key_prefix (str) - resource/region: Maps to CZRN region, e.g., 'cross-region' (str) - - resource/usage_family: Maps to CZRN resource-type, e.g., 'llm-usage' (str) + - resource/usage_family: Provider (str) + - action/operation: Team ID (str) - lineitem/type: Standard usage line item, e.g., 'Usage' (str) - resource/tag:provider: CZRN provider component (str) - resource/tag:model: CZRN cloud-local-id component (model) (str) + - resource/tag:organization_alias: Organization alias if available (Optional[str]) + - resource/tag:project_alias: Project alias if available (Optional[str]) + - resource/tag:user_alias: User alias if available (Optional[str]) - resource/tag:{key}: Various resource tags for dimensions and metrics (Optional[str]) """ pass diff --git a/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py b/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py index dc167667bc0..4a4cdaa2bae 100644 --- a/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py +++ b/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py @@ -8,3 +8,4 @@ class UiDiscoveryEndpoints(BaseModel): proxy_base_url: Optional[str] auto_redirect_to_sso: bool admin_ui_disabled: bool + sso_configured: bool diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index bc54c3eb36b..42490c2eddc 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -19,6 +19,7 @@ PolicyScope, ) from litellm.types.proxy.policy_engine.resolver_types import ( + AttachmentImpactResponse, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -30,6 +31,9 @@ PolicyListDBResponse, PolicyListResponse, PolicyMatchContext, + PolicyMatchDetail, + PolicyResolveRequest, + PolicyResolveResponse, PolicyScopeResponse, PolicySummaryItem, PolicyTestResponse, @@ -75,4 +79,9 @@ "PolicyAttachmentCreateRequest", "PolicyAttachmentDBResponse", "PolicyAttachmentListResponse", + # Resolve types + "PolicyResolveRequest", + "PolicyResolveResponse", + "PolicyMatchDetail", + "AttachmentImpactResponse", ] diff --git a/litellm/types/proxy/policy_engine/policy_types.py b/litellm/types/proxy/policy_engine/policy_types.py index 1c01f89e8b4..f221ba7e038 100644 --- a/litellm/types/proxy/policy_engine/policy_types.py +++ b/litellm/types/proxy/policy_engine/policy_types.py @@ -73,13 +73,15 @@ class PolicyScope(BaseModel): Used internally by PolicyAttachment to define WHERE a policy applies. Scope Fields: - | Field | What it matches | Wildcard support | - |--------|-----------------|----------------------| - | teams | Team aliases | *, healthcare-* | - | keys | Key aliases | *, dev-key-* | - | models | Model names | *, bedrock/*, gpt-* | - - If a field is None or empty, it defaults to matching everything (["*"]). + | Field | What it matches | Wildcard support | Default behavior | + |--------|-----------------|----------------------|---------------------| + | teams | Team aliases | *, healthcare-* | None → matches all | + | keys | Key aliases | *, dev-key-* | None → matches all | + | models | Model names | *, bedrock/*, gpt-* | None → matches all | + | tags | Key/team tags | *, health-*, prod-* | None → not checked | + + If teams/keys/models is None or empty, it defaults to matching everything (["*"]). + If tags is None or empty, the tag dimension is NOT checked (matches all). A request must match ALL specified scope fields for the attachment to apply. """ @@ -95,6 +97,10 @@ class PolicyScope(BaseModel): default=None, description="Model names or wildcard patterns. Use '*' for all models.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns to match against key/team tags. Supports wildcards (e.g., health-*).", + ) model_config = ConfigDict(extra="forbid") @@ -110,6 +116,14 @@ def get_models(self) -> List[str]: """Returns models list, defaulting to ['*'] if not specified.""" return self.models if self.models else ["*"] + def get_tags(self) -> List[str]: + """Returns tags list, defaulting to empty list if not specified. + + Unlike teams/keys/models, empty tags means 'do not check tags' + rather than 'match all'. This is because tags are opt-in scoping. + """ + return self.tags if self.tags else [] + # ───────────────────────────────────────────────────────────────────────────── # Policy Guardrails @@ -266,6 +280,10 @@ class PolicyAttachment(BaseModel): default=None, description="Model names or patterns this attachment applies to.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns this attachment applies to. Supports wildcards (e.g., health-*).", + ) model_config = ConfigDict(extra="forbid") @@ -281,6 +299,7 @@ def to_policy_scope(self) -> PolicyScope: teams=self.teams, keys=self.keys, models=self.models, + tags=self.tags, ) diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py index 9488b8b0841..0c2c7336f8a 100644 --- a/litellm/types/proxy/policy_engine/resolver_types.py +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -30,6 +30,10 @@ class PolicyMatchContext(BaseModel): default=None, description="Model name from the request.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tags from key/team metadata.", + ) model_config = ConfigDict(extra="forbid") @@ -65,6 +69,7 @@ class PolicyScopeResponse(BaseModel): teams: List[str] = Field(default_factory=list) keys: List[str] = Field(default_factory=list) models: List[str] = Field(default_factory=list) + tags: List[str] = Field(default_factory=list) class PolicyGuardrailsResponse(BaseModel): @@ -242,6 +247,10 @@ class PolicyAttachmentCreateRequest(BaseModel): default=None, description="Model names or patterns this attachment applies to.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns this attachment applies to. Supports wildcards (e.g., health-*).", + ) class PolicyAttachmentDBResponse(BaseModel): @@ -253,6 +262,7 @@ class PolicyAttachmentDBResponse(BaseModel): teams: List[str] = Field(default_factory=list, description="Team patterns.") keys: List[str] = Field(default_factory=list, description="Key patterns.") models: List[str] = Field(default_factory=list, description="Model patterns.") + tags: List[str] = Field(default_factory=list, description="Tag patterns.") created_at: Optional[datetime] = Field( default=None, description="When the attachment was created." ) @@ -274,3 +284,81 @@ class PolicyAttachmentListResponse(BaseModel): default_factory=list, description="List of policy attachments." ) total_count: int = Field(default=0, description="Total number of attachments.") + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Resolve Types +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyResolveRequest(BaseModel): + """Request body for resolving effective policies/guardrails for a context.""" + + team_alias: Optional[str] = Field( + default=None, description="Team alias to resolve for." + ) + key_alias: Optional[str] = Field( + default=None, description="Key alias to resolve for." + ) + model: Optional[str] = Field( + default=None, description="Model name to resolve for." + ) + tags: Optional[List[str]] = Field( + default=None, description="Tags to resolve for." + ) + + +class PolicyMatchDetail(BaseModel): + """Details about why a specific policy matched.""" + + policy_name: str = Field(description="Name of the matched policy.") + matched_via: str = Field( + description="How the policy was matched (e.g., 'tag:healthcare', 'team:health-team', 'scope:*')." + ) + guardrails_added: List[str] = Field( + default_factory=list, + description="Guardrails this policy contributes.", + ) + + +class PolicyResolveResponse(BaseModel): + """Response for resolving effective policies/guardrails for a context.""" + + effective_guardrails: List[str] = Field( + default_factory=list, + description="Final list of guardrails that would be applied.", + ) + matched_policies: List[PolicyMatchDetail] = Field( + default_factory=list, + description="Details about each matched policy and why it matched.", + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Attachment Impact Estimation Types +# ───────────────────────────────────────────────────────────────────────────── + + +class AttachmentImpactResponse(BaseModel): + """Response for estimating the impact of a policy attachment.""" + + affected_keys_count: int = Field( + default=0, description="Number of keys that would be affected (named + unnamed)." + ) + affected_teams_count: int = Field( + default=0, description="Number of teams that would be affected (named + unnamed)." + ) + unnamed_keys_count: int = Field( + default=0, description="Number of affected keys without an alias." + ) + unnamed_teams_count: int = Field( + default=0, description="Number of affected teams without an alias." + ) + sample_keys: List[str] = Field( + default_factory=list, + description="Sample of affected key aliases (up to 10).", + ) + sample_teams: List[str] = Field( + default_factory=list, + description="Sample of affected team aliases (up to 10).", + ) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 35538ab1003..9b5a7b42d0e 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -9046,6 +9046,43 @@ } ] }, + "dashscope/qwen3-max": { + "litellm_provider": "dashscope", + "max_input_tokens": 258048, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "source": "https://www.alibabacloud.com/help/en/model-studio/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "tiered_pricing": [ + { + "input_cost_per_token": 1.2e-06, + "output_cost_per_token": 6e-06, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 2.4e-06, + "output_cost_per_token": 1.2e-05, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 3e-06, + "output_cost_per_token": 1.5e-05, + "range": [ + 128000.0, + 252000.0 + ] + } + ] + }, "dashscope/qwq-plus": { "input_cost_per_token": 8e-07, "litellm_provider": "dashscope", diff --git a/schema.prisma b/schema.prisma index 9a87a491cf7..4329f939a7b 100644 --- a/schema.prisma +++ b/schema.prisma @@ -913,6 +913,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index e3bd7d2bb31..57e0dd494e0 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -3,7 +3,6 @@ import sys import pytest -from fastapi.testclient import TestClient sys.path.insert( 0, os.path.abspath("../../../../..") @@ -855,6 +854,92 @@ def test_anthropic_structured_output_beta_header(): ) +@pytest.mark.parametrize( + "model_name", + [ + "claude-opus-4-6-20250918", + "claude-opus-4.6-20250918", + "claude-opus-4-5-20251101", + "claude-opus-4.5-20251101", + ], +) +def test_opus_uses_native_structured_output(model_name): + """ + Test that Opus 4.5 and 4.6 models use native Anthropic structured outputs + (output_format) rather than the tool-based workaround. + """ + config = AnthropicConfig() + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "test_schema", + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + "additionalProperties": False, + }, + }, + } + + optional_params = config.map_openai_params( + non_default_params={"response_format": response_format}, + optional_params={}, + model=model_name, + drop_params=False, + ) + + # Should use output_format (native structured outputs) + assert "output_format" in optional_params + assert optional_params["output_format"]["type"] == "json_schema" + + # Should NOT create a tool-based workaround + assert "tools" not in optional_params + assert "tool_choice" not in optional_params + + # Should set json_mode + assert optional_params.get("json_mode") is True + + +def test_non_structured_output_model_uses_tool_workaround(): + """ + Test that models NOT in the native structured output list still use the + tool-based workaround for response_format. + """ + config = AnthropicConfig() + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "test_schema", + "schema": { + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + "additionalProperties": False, + }, + }, + } + + optional_params = config.map_openai_params( + non_default_params={"response_format": response_format}, + optional_params={}, + model="claude-3-5-sonnet-20241022", + drop_params=False, + ) + + # Should NOT use output_format + assert "output_format" not in optional_params + + # Should use tool-based workaround + assert "tools" in optional_params + assert "tool_choice" in optional_params + + # ============ Tool Search Tests ============ diff --git a/tests/test_litellm/llms/azure/test_azure_exception_mapping.py b/tests/test_litellm/llms/azure/test_azure_exception_mapping.py index f4abe7f2b9a..495ca958cf5 100644 --- a/tests/test_litellm/llms/azure/test_azure_exception_mapping.py +++ b/tests/test_litellm/llms/azure/test_azure_exception_mapping.py @@ -239,4 +239,149 @@ def test_azure_images_content_policy_violation_preserves_nested_inner_error(self assert e.provider_specific_fields is not None assert e.provider_specific_fields["inner_error"]["code"] == "ResponsibleAIPolicyViolation" assert e.provider_specific_fields["inner_error"]["revised_prompt"] == "revised" - assert e.provider_specific_fields["inner_error"]["content_filter_results"]["violence"]["filtered"] is True \ No newline at end of file + assert e.provider_specific_fields["inner_error"]["content_filter_results"]["violence"]["filtered"] is True + + def test_azure_content_policy_violation_detected_via_inner_error_code(self): + """Regression test for #20811: Azure returns inner_error with + ResponsibleAIPolicyViolation but the top-level error message is + generic. Previously this fell through to the generic + BadRequestError handler and all error details were lost.""" + + mock_exception = Exception("Bad request") + # This body structure mirrors what Azure OpenAI Images API returns + # for DALL-E 3 content policy violations (issue #20811). + mock_exception.body = { + "error": { + "code": "content_policy_violation", + "inner_error": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_results": { + "hate": {"filtered": False, "severity": "safe"}, + "profanity": {"detected": False, "filtered": False}, + "self_harm": {"filtered": False, "severity": "safe"}, + "sexual": {"filtered": False, "severity": "safe"}, + "violence": {"filtered": True, "severity": "low"}, + }, + "revised_prompt": ( + "A dark and intense illustration of a man " + "in a dramatic action scene." + ), + }, + "message": ( + "Your request was rejected as a result of our safety system." + ), + "type": "invalid_request_error", + } + } + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_exception.response = mock_response + + with pytest.raises(ContentPolicyViolationError) as exc_info: + exception_type( + model="azure/dall-e-3", + original_exception=mock_exception, + custom_llm_provider="azure", + ) + + e = exc_info.value + # Must surface as ContentPolicyViolationError, not generic BadRequestError + assert "safety system" in str(e) + assert e.provider_specific_fields is not None + inner = e.provider_specific_fields["inner_error"] + assert inner["code"] == "ResponsibleAIPolicyViolation" + assert inner["content_filter_results"]["violence"]["filtered"] is True + assert inner["revised_prompt"] is not None + + def test_azure_policy_violation_detected_via_inner_error_without_top_code(self): + """When the top-level code is NOT 'content_policy_violation' but + inner_error.code IS 'ResponsibleAIPolicyViolation', the error + should still be recognized as a content policy violation.""" + + mock_exception = Exception("Some error") + mock_exception.body = { + "error": { + "code": "BadRequest", + "inner_error": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_results": { + "violence": {"filtered": True, "severity": "medium"}, + }, + }, + "message": "The request was rejected.", + "type": "invalid_request_error", + } + } + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_exception.response = mock_response + + with pytest.raises(ContentPolicyViolationError) as exc_info: + exception_type( + model="azure/dall-e-3", + original_exception=mock_exception, + custom_llm_provider="azure", + ) + + e = exc_info.value + assert e.provider_specific_fields is not None + assert e.provider_specific_fields["inner_error"]["code"] == "ResponsibleAIPolicyViolation" + + def test_azure_image_polling_error_preserves_body(self): + """Verify that AzureOpenAIError raised from the DALL-E polling path + carries the structured body so exception_type() can inspect it.""" + from litellm.llms.azure.common_utils import AzureOpenAIError + + error_payload = { + "status": "failed", + "error": { + "code": "content_policy_violation", + "message": "Your request was rejected.", + "inner_error": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_results": { + "violence": {"filtered": True, "severity": "low"}, + }, + }, + }, + } + + # Simulate what the fixed polling path now does + _error_body = error_payload.get("error", error_payload) + _error_msg = ( + _error_body.get("message", "Image generation failed") + if isinstance(_error_body, dict) + else json.dumps(error_payload) + ) + exc = AzureOpenAIError( + status_code=400, + message=_error_msg, + body=error_payload, + ) + + assert exc.body is not None + assert isinstance(exc.body, dict) + assert exc.body["error"]["code"] == "content_policy_violation" + assert "Your request was rejected" in exc.message + + def test_azure_safety_system_message_detected_as_policy_violation(self): + """Azure's rejection message 'Your request was rejected as a result + of our safety system' should be detected by string matching even + when the structured body is unavailable.""" + + mock_exception = Exception( + "Your request was rejected as a result of our safety system. " + "The revised prompt may contain text that is not allowed." + ) + mock_response = MagicMock() + mock_response.status_code = 400 + mock_exception.response = mock_response + + with pytest.raises(ContentPolicyViolationError): + exception_type( + model="azure/dall-e-3", + original_exception=mock_exception, + custom_llm_provider="azure", + ) \ No newline at end of file diff --git a/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py b/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py index 88d31e993dd..9d0c771e1d9 100644 --- a/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py +++ b/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py @@ -31,6 +31,7 @@ def test_ui_discovery_endpoints_with_defaults(): assert data["proxy_base_url"] is None assert data["auto_redirect_to_sso"] is False assert data["admin_ui_disabled"] is False + assert data["sso_configured"] is False def test_ui_discovery_endpoints_with_custom_server_root_path(): @@ -50,6 +51,7 @@ def test_ui_discovery_endpoints_with_custom_server_root_path(): assert data["server_root_path"] == "/litellm" assert data["proxy_base_url"] is None assert data["auto_redirect_to_sso"] is False + assert data["sso_configured"] is False def test_ui_discovery_endpoints_with_proxy_base_url_when_set(): @@ -69,6 +71,7 @@ def test_ui_discovery_endpoints_with_proxy_base_url_when_set(): assert data["server_root_path"] == "/" assert data["proxy_base_url"] == "https://proxy.example.com" assert data["auto_redirect_to_sso"] is False + assert data["sso_configured"] is False def test_ui_discovery_endpoints_with_sso_configured_and_auto_redirect_enabled(): @@ -88,6 +91,30 @@ def test_ui_discovery_endpoints_with_sso_configured_and_auto_redirect_enabled(): assert data["server_root_path"] == "/litellm" assert data["proxy_base_url"] == "https://proxy.example.com" assert data["auto_redirect_to_sso"] is True + assert data["sso_configured"] is True + + +def test_ui_discovery_endpoints_with_sso_configured_and_auto_redirect_not_set_defaults_to_false(): + """When SSO is configured but AUTO_REDIRECT_UI_LOGIN_TO_SSO is not set, defaults to False.""" + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + with patch("litellm.proxy.utils.get_server_root_path", return_value="/litellm"), \ + patch("litellm.proxy.utils.get_proxy_base_url", return_value="https://proxy.example.com"), \ + patch("litellm.proxy.auth.auth_utils._has_user_setup_sso", return_value=True), \ + patch.dict(os.environ, {"DISABLE_ADMIN_UI": "false"}, clear=False): + # Ensure AUTO_REDIRECT_UI_LOGIN_TO_SSO is not set (simulate default) + os.environ.pop("AUTO_REDIRECT_UI_LOGIN_TO_SSO", None) + + response = client.get("/.well-known/litellm-ui-config") + + assert response.status_code == 200 + data = response.json() + assert data["server_root_path"] == "/litellm" + assert data["proxy_base_url"] == "https://proxy.example.com" + assert data["auto_redirect_to_sso"] is False + assert data["sso_configured"] is True def test_ui_discovery_endpoints_with_sso_configured_but_auto_redirect_disabled(): @@ -107,6 +134,7 @@ def test_ui_discovery_endpoints_with_sso_configured_but_auto_redirect_disabled() assert data["server_root_path"] == "/litellm" assert data["proxy_base_url"] == "https://proxy.example.com" assert data["auto_redirect_to_sso"] is False + assert data["sso_configured"] is True def test_ui_discovery_endpoints_with_sso_not_configured_but_auto_redirect_enabled(): @@ -126,6 +154,7 @@ def test_ui_discovery_endpoints_with_sso_not_configured_but_auto_redirect_enable assert data["server_root_path"] == "/" assert data["proxy_base_url"] is None assert data["auto_redirect_to_sso"] is False + assert data["sso_configured"] is False def test_ui_discovery_endpoints_both_routes_return_same_data(): @@ -164,6 +193,7 @@ def test_ui_discovery_endpoints_with_admin_ui_disabled(): assert data["proxy_base_url"] is None assert data["auto_redirect_to_sso"] is False assert data["admin_ui_disabled"] is True + assert data["sso_configured"] is False def test_ui_discovery_endpoints_with_admin_ui_enabled(): @@ -184,4 +214,5 @@ def test_ui_discovery_endpoints_with_admin_ui_enabled(): assert data["proxy_base_url"] is None assert data["auto_redirect_to_sso"] is False assert data["admin_ui_disabled"] is False + assert data["sso_configured"] is False diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py index 5ec9b13408e..f01c23f7116 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py @@ -6,6 +6,7 @@ import asyncio import os import sys +from contextlib import asynccontextmanager from unittest.mock import MagicMock, patch import pytest @@ -18,10 +19,41 @@ from litellm.proxy.guardrails.guardrail_hooks.presidio import ( _OPTIONAL_PresidioPIIMasking, ) +from litellm.exceptions import GuardrailRaisedException from litellm.types.guardrails import LitellmParams, PiiAction, PiiEntityType from litellm.types.utils import Choices, Message, ModelResponse +def _make_mock_session_iterator(json_response): + """Create a mock _get_session_iterator that yields a session returning json_response.""" + + @asynccontextmanager + async def mock_iterator(): + class MockResponse: + async def json(self): + return json_response + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + class MockSession: + def post(self, *args, **kwargs): + return MockResponse() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + yield MockSession() + + return mock_iterator + + @pytest.fixture def presidio_guardrail(): """Create a Presidio guardrail instance for testing""" @@ -889,37 +921,132 @@ async def test_analyze_text_error_dict_handling(): output_parse_pii=False, ) - # Mock the HTTP response to return error dict - class MockResponse: - async def json(self): - return {"error": "No text provided"} - - async def __aenter__(self): - return self + with patch.object( + presidio, + "_get_session_iterator", + _make_mock_session_iterator({"error": "No text provided"}), + ): + result = await presidio.analyze_text( + text="some text", + presidio_config=None, + request_data={}, + ) + assert result == [], "Error dict should be handled gracefully" - async def __aexit__(self, *args): - pass + print("✓ analyze_text error dict handling test passed") - class MockSession: - def post(self, *args, **kwargs): - return MockResponse() - async def __aenter__(self): - return self +@pytest.mark.asyncio +async def test_analyze_text_string_response_handling(): + """ + Test that analyze_text handles string responses from Presidio API. - async def __aexit__(self, *args): - pass + When Presidio returns a string (e.g. error message from websearch/hosted models), + should handle gracefully instead of crashing with TypeError about mapping vs str. + """ + presidio = _OPTIONAL_PresidioPIIMasking( + presidio_analyzer_api_base="http://mock-presidio:5002/", + presidio_anonymizer_api_base="http://mock-presidio:5001/", + output_parse_pii=False, + ) - with patch("aiohttp.ClientSession", return_value=MockSession()): + with patch.object( + presidio, + "_get_session_iterator", + _make_mock_session_iterator("Internal Server Error"), + ): result = await presidio.analyze_text( text="some text", presidio_config=None, request_data={}, ) - # Should return empty list when error dict is received - assert result == [], "Error dict should be handled gracefully" + assert result == [], "String response should be handled gracefully" - print("✓ analyze_text error dict handling test passed") + +@pytest.mark.asyncio +async def test_analyze_text_invalid_response_raises_when_block_configured(): + """ + When pii_entities_config has BLOCK and Presidio returns invalid response, + should raise GuardrailRaisedException (fail-closed) rather than silently allowing content. + """ + presidio = _OPTIONAL_PresidioPIIMasking( + presidio_analyzer_api_base="http://mock-presidio:5002/", + presidio_anonymizer_api_base="http://mock-presidio:5001/", + output_parse_pii=False, + pii_entities_config={PiiEntityType.CREDIT_CARD: PiiAction.BLOCK}, + ) + + with patch.object( + presidio, + "_get_session_iterator", + _make_mock_session_iterator("Internal Server Error"), + ): + with pytest.raises(GuardrailRaisedException) as exc_info: + await presidio.analyze_text( + text="some text", + presidio_config=None, + request_data={}, + ) + assert "BLOCK" in str(exc_info.value) or "Presidio" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_analyze_text_invalid_response_raises_when_mask_configured(): + """ + When pii_entities_config has MASK and Presidio returns invalid response, + should raise GuardrailRaisedException (fail-closed) because PII masking is expected. + """ + presidio = _OPTIONAL_PresidioPIIMasking( + presidio_analyzer_api_base="http://mock-presidio:5002/", + presidio_anonymizer_api_base="http://mock-presidio:5001/", + output_parse_pii=False, + pii_entities_config={PiiEntityType.CREDIT_CARD: PiiAction.MASK}, + ) + + with patch.object( + presidio, + "_get_session_iterator", + _make_mock_session_iterator("Internal Server Error"), + ): + with pytest.raises(GuardrailRaisedException) as exc_info: + await presidio.analyze_text( + text="some text", + presidio_config=None, + request_data={}, + ) + assert "PII protection is configured" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_analyze_text_list_with_non_dict_items(): + """ + Test that analyze_text skips non-dict items in the result list. + + When Presidio returns a list containing strings (malformed response), + should skip invalid items and return parsed valid ones. + """ + presidio = _OPTIONAL_PresidioPIIMasking( + presidio_analyzer_api_base="http://mock-presidio:5002/", + presidio_anonymizer_api_base="http://mock-presidio:5001/", + output_parse_pii=False, + ) + + json_response = [ + {"entity_type": "PERSON", "start": 0, "end": 5, "score": 0.9}, + "invalid_string_item", + {"entity_type": "EMAIL", "start": 10, "end": 25, "score": 0.85}, + ] + with patch.object( + presidio, "_get_session_iterator", _make_mock_session_iterator(json_response) + ): + result = await presidio.analyze_text( + text="some text", + presidio_config=None, + request_data={}, + ) + assert len(result) == 2, "Should parse 2 valid dict items and skip the string" + assert result[0].get("entity_type") == "PERSON" + assert result[1].get("entity_type") == "EMAIL" @pytest.mark.asyncio diff --git a/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py index 1ed956fe99f..c853253eedd 100644 --- a/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py +++ b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py @@ -192,6 +192,139 @@ def test_combined_team_and_model_attachment(self): assert "strict-policy" not in registry.get_attached_policies(context_wrong_team) +class TestTagBasedAttachments: + """Test tag-based policy attachment matching.""" + + def test_tag_matching_and_wildcards(self): + """Test tag matching: exact match, wildcard match, and no-match cases.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "hipaa-policy", "tags": ["healthcare"]}, + {"policy": "health-policy", "tags": ["health-*"]}, + ]) + + # Exact tag match + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + attached = registry.get_attached_policies(context) + assert "hipaa-policy" in attached + assert "health-policy" not in attached # "healthcare" doesn't match "health-*" + + # Wildcard tag match + context_wildcard = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["health-prod"], + ) + attached_wildcard = registry.get_attached_policies(context_wildcard) + assert "health-policy" in attached_wildcard + assert "hipaa-policy" not in attached_wildcard + + # No match — wrong tag + context_no_match = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert registry.get_attached_policies(context_no_match) == [] + + # No match — no tags on context + context_no_tags = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=None, + ) + assert registry.get_attached_policies(context_no_tags) == [] + + def test_tag_combined_with_team(self): + """Test attachment with both tags and teams requires BOTH to match (AND logic).""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "strict-policy", "teams": ["team-a"], "tags": ["healthcare"]}, + ]) + + # Match — both team and tag match + context = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert "strict-policy" in registry.get_attached_policies(context) + + # No match — tag matches but team doesn't + context_wrong_team = PolicyMatchContext( + team_alias="team-b", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_team) + + # No match — team matches but tag doesn't + context_wrong_tag = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_tag) + + +class TestMatchAttribution: + """Test get_attached_policies_with_reasons — the attribution logic that + powers response headers and the Policy Simulator UI.""" + + def test_reasons_for_global_tag_team_attachments(self): + """Test that match reasons correctly describe WHY each policy matched.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "global-baseline", "scope": "*"}, + {"policy": "hipaa-policy", "tags": ["healthcare"]}, + {"policy": "team-policy", "teams": ["health-team"]}, + ]) + + context = PolicyMatchContext( + team_alias="health-team", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + results = registry.get_attached_policies_with_reasons(context) + reasons = {r["policy_name"]: r["matched_via"] for r in results} + + assert reasons["global-baseline"] == "scope:*" + assert "tag:healthcare" in reasons["hipaa-policy"] + assert "team:health-team" in reasons["team-policy"] + + def test_tags_only_attachment_matches_any_team_key_model(self): + """Test the primary use case: tags-only attachment with no team/key/model + constraint matches any request that carries the tag.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "hipaa-guardrails", "tags": ["healthcare"]}, + ]) + + # Should match regardless of team/key/model + context = PolicyMatchContext( + team_alias="random-team", key_alias="random-key", model="claude-3", + tags=["healthcare"], + ) + attached = registry.get_attached_policies(context) + assert "hipaa-guardrails" in attached + + # Should not match without the tag + context_no_tag = PolicyMatchContext( + team_alias="random-team", key_alias="random-key", model="claude-3", + ) + assert registry.get_attached_policies(context_no_tag) == [] + + def test_attachment_with_no_scope_matches_everything(self): + """Test that an attachment with no scope/teams/keys/models/tags + matches everything because teams/keys/models default to ['*'].""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "catch-all"}, + ]) + + context = PolicyMatchContext( + team_alias="any-team", key_alias="any-key", model="gpt-4", + ) + attached = registry.get_attached_policies(context) + assert "catch-all" in attached + + class TestAttachmentRegistrySingleton: """Test global singleton behavior.""" diff --git a/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py index c011f31af6a..fccb26496ac 100644 --- a/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py +++ b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py @@ -64,6 +64,70 @@ def test_scope_global_wildcard(self): assert PolicyMatcher.scope_matches(scope, context) is True +class TestPolicyMatcherScopeMatchingWithTags: + """Test scope matching with tag patterns.""" + + def test_scope_tag_matching(self): + """Test scope tag matching: exact, wildcard, no-match, and empty context tags.""" + # Exact match + scope = PolicyScope(teams=["*"], keys=["*"], models=["*"], tags=["healthcare"]) + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["healthcare", "internal"], + ) + assert PolicyMatcher.scope_matches(scope, context) is True + + # Wildcard match + scope_wc = PolicyScope(teams=["*"], keys=["*"], models=["*"], tags=["health-*"]) + context_wc = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["health-prod"], + ) + assert PolicyMatcher.scope_matches(scope_wc, context_wc) is True + + # No match — wrong tag + context_wrong = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong) is False + + # No match — context has no tags + context_none = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", tags=None, + ) + assert PolicyMatcher.scope_matches(scope, context_none) is False + + # Scope without tags matches any context (opt-in semantics) + scope_no_tags = PolicyScope(teams=["*"], keys=["*"], models=["*"]) + assert PolicyMatcher.scope_matches(scope_no_tags, context) is True + + def test_scope_tags_and_team_combined(self): + """Test scope with both tags and team — both must match (AND logic).""" + scope = PolicyScope(teams=["team-a"], keys=["*"], models=["*"], tags=["healthcare"]) + + # Both match + context_both = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert PolicyMatcher.scope_matches(scope, context_both) is True + + # Tag matches, team doesn't + context_wrong_team = PolicyMatchContext( + team_alias="team-b", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong_team) is False + + # Team matches, tag doesn't + context_wrong_tag = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong_tag) is False + + class TestPolicyMatcherWithAttachments: """Test getting matching policies via attachments.""" diff --git a/ui/litellm-dashboard/e2e_tests/constants.ts b/ui/litellm-dashboard/e2e_tests/constants.ts index b07bd68fcf1..58b56af0a2b 100644 --- a/ui/litellm-dashboard/e2e_tests/constants.ts +++ b/ui/litellm-dashboard/e2e_tests/constants.ts @@ -1 +1,6 @@ export const ADMIN_STORAGE_PATH = "admin.storageState.json"; + +export const E2E_UPDATE_LIMITS_KEY_ID_PREFIX = "102c"; +export const E2E_DELETE_KEY_ID_PREFIX = "94a5"; +export const E2E_DELETE_KEY_NAME = "e2eDeleteKey"; +export const E2E_REGENERATE_KEY_ID_PREFIX = "593a"; diff --git a/ui/litellm-dashboard/e2e_tests/tests/keys/deleteKey.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/keys/deleteKey.spec.ts new file mode 100644 index 00000000000..a5841316251 --- /dev/null +++ b/ui/litellm-dashboard/e2e_tests/tests/keys/deleteKey.spec.ts @@ -0,0 +1,25 @@ +import { test, expect } from "@playwright/test"; +import { ADMIN_STORAGE_PATH, E2E_DELETE_KEY_ID_PREFIX, E2E_DELETE_KEY_NAME } from "../../constants"; +import { Page } from "../../fixtures/pages"; +import { navigateToPage } from "../../helpers/navigation"; + +test.describe("Delete Key", () => { + test.use({ storageState: ADMIN_STORAGE_PATH }); + + test("Able to delete a key", async ({ page }) => { + await navigateToPage(page, Page.ApiKeys); + await expect(page.getByRole("button", { name: "Next" })).toBeVisible(); + await page + .locator("button", { + hasText: E2E_DELETE_KEY_ID_PREFIX, + }) + .click(); + await page.getByRole("button", { name: "Delete Key" }).click(); + await page.getByRole("textbox", { name: E2E_DELETE_KEY_NAME }).click(); + await page.getByRole("textbox", { name: E2E_DELETE_KEY_NAME }).fill(E2E_DELETE_KEY_NAME); + const deleteButton = page.getByRole("button", { name: "Delete", exact: true }); + await expect(deleteButton).toBeEnabled(); + await deleteButton.click(); + await expect(page.getByText("Key deleted successfully")).toBeVisible(); + }); +}); diff --git a/ui/litellm-dashboard/e2e_tests/tests/keys/regenerateKey.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/keys/regenerateKey.spec.ts new file mode 100644 index 00000000000..0188a4f81ce --- /dev/null +++ b/ui/litellm-dashboard/e2e_tests/tests/keys/regenerateKey.spec.ts @@ -0,0 +1,21 @@ +import { test, expect } from "@playwright/test"; +import { ADMIN_STORAGE_PATH, E2E_REGENERATE_KEY_ID_PREFIX } from "../../constants"; +import { Page } from "../../fixtures/pages"; +import { navigateToPage } from "../../helpers/navigation"; + +test.describe("Regenerate Key", () => { + test.use({ storageState: ADMIN_STORAGE_PATH }); + + test("Able to regenerate a key", async ({ page }) => { + await navigateToPage(page, Page.ApiKeys); + await expect(page.getByRole("button", { name: "Next" })).toBeVisible(); + await page + .locator("button", { + hasText: E2E_REGENERATE_KEY_ID_PREFIX, + }) + .click(); + await page.getByRole("button", { name: "Regenerate Key" }).click(); + await page.getByRole("button", { name: "Regenerate", exact: true }).click(); + await expect(page.getByText("Virtual Key regenerated")).toBeVisible(); + }); +}); diff --git a/ui/litellm-dashboard/e2e_tests/tests/keys/updateKeyLimits.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/keys/updateKeyLimits.spec.ts new file mode 100644 index 00000000000..6cae36272ab --- /dev/null +++ b/ui/litellm-dashboard/e2e_tests/tests/keys/updateKeyLimits.spec.ts @@ -0,0 +1,27 @@ +import { test, expect } from "@playwright/test"; +import { ADMIN_STORAGE_PATH, E2E_UPDATE_LIMITS_KEY_ID_PREFIX } from "../../constants"; +import { Page } from "../../fixtures/pages"; +import { navigateToPage } from "../../helpers/navigation"; + +test.describe("Update Key TPM and RPM Limits", () => { + test.use({ storageState: ADMIN_STORAGE_PATH }); + + test("Able to update a key's TPM and RPM limits", async ({ page }) => { + await navigateToPage(page, Page.ApiKeys); + await expect(page.getByRole("button", { name: "Next" })).toBeVisible(); + await page + .locator("button", { + hasText: E2E_UPDATE_LIMITS_KEY_ID_PREFIX, + }) + .click(); + await page.getByRole("tab", { name: "Settings" }).click(); + await page.getByRole("button", { name: "Edit Settings" }).click(); + await page.getByRole("spinbutton", { name: "TPM Limit" }).click(); + await page.getByRole("spinbutton", { name: "TPM Limit" }).fill("123"); + await page.getByRole("spinbutton", { name: "RPM Limit" }).click(); + await page.getByRole("spinbutton", { name: "RPM Limit" }).fill("456"); + await page.getByRole("button", { name: "Save Changes" }).click(); + await expect(page.getByRole("paragraph").filter({ hasText: "TPM: 123" })).toBeVisible(); + await expect(page.getByRole("paragraph").filter({ hasText: "RPM: 456" })).toBeVisible(); + }); +}); diff --git a/ui/litellm-dashboard/scripts/e2e_tests/neonHelperScripts.ts b/ui/litellm-dashboard/scripts/e2e_tests/neonHelperScripts.ts index 3078a0d90d2..089ad4e7926 100644 --- a/ui/litellm-dashboard/scripts/e2e_tests/neonHelperScripts.ts +++ b/ui/litellm-dashboard/scripts/e2e_tests/neonHelperScripts.ts @@ -1,4 +1,4 @@ -import { createApiClient } from "@neondatabase/api-client"; +import { createApiClient, EndpointType } from "@neondatabase/api-client"; import { config } from "dotenv"; import { resolve } from "path"; @@ -27,6 +27,13 @@ export async function createNeonE2ETestingBranch(projectId: string, parentBranch parent_id: parentBranchId, expires_at: expireAt ?? new Date(Date.now() + 1000 * 60 * 30).toISOString(), }, + endpoints: [ + { + type: EndpointType.ReadWrite, + autoscaling_limit_min_cu: 0.25, + autoscaling_limit_max_cu: 1, + }, + ], }); return response; } catch (error) { @@ -35,13 +42,15 @@ export async function createNeonE2ETestingBranch(projectId: string, parentBranch } export async function getNeonE2ETestingBranchConnectionString() { - await createNeonE2ETestingBranch(PROJECT_ID, PARENT_BRANCH); - + const createBranchResponse = await createNeonE2ETestingBranch(PROJECT_ID, PARENT_BRANCH); + const projectId = createBranchResponse.data.branch.project_id; const response = await apiClient.getConnectionUri({ database_name: NEON_E2E_UI_TEST_DB_NAME, role_name: "neondb_owner", - projectId: PROJECT_ID, + projectId: projectId, }); console.log("connection string:", response.data.uri); return response.data.uri; } + +getNeonE2ETestingBranchConnectionString(); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx index 405f8329b67..a74d3c108d6 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx @@ -31,7 +31,7 @@ import { import * as React from "react"; import { useRouter, usePathname } from "next/navigation"; import { all_admin_roles, internalUserRoles, isAdminRole, rolesWithWriteAccess } from "@/utils/roles"; -import UsageIndicator from "@/components/usage_indicator"; +import UsageIndicator from "@/components/UsageIndicator"; import { serverRootPath } from "@/components/networking"; const { Sider } = Layout; @@ -64,7 +64,7 @@ const getBasePath = () => { const raw = process.env.NEXT_PUBLIC_BASE_URL ?? ""; const trimmed = raw.replace(/^\/+|\/+$/g, ""); // strip leading/trailing slashes const uiPath = trimmed ? `/${trimmed}/` : "/"; - + // If serverRootPath is set and not "/", prepend it to the UI path if (serverRootPath && serverRootPath !== "/") { // Remove trailing slash from serverRootPath and ensure uiPath has no leading slash for proper joining @@ -72,7 +72,7 @@ const getBasePath = () => { const cleanUiPath = uiPath.replace(/^\/+/, ""); return `${cleanServerRoot}/${cleanUiPath}`; } - + return uiPath; }; @@ -153,170 +153,170 @@ const toHref = (slugOrPath: string) => { // ----- Menu config (unchanged labels/icons; same appearance) ----- const menuItems: MenuItemCfg[] = [ - { key: "1", page: "api-keys", label: "Virtual Keys", icon: }, - { - key: "3", - page: "llm-playground", - label: "Test Key", - icon: , - roles: rolesWithWriteAccess, - }, - { - key: "2", - page: "models", - label: "Models + Endpoints", - icon: , - roles: rolesWithWriteAccess, - }, - { - key: "12", - page: "new_usage", - label: "Usage", - icon: , - roles: [...all_admin_roles, ...internalUserRoles], - }, - { key: "6", page: "teams", label: "Teams", icon: }, - { - key: "17", - page: "organizations", - label: "Organizations", - icon: , - roles: all_admin_roles, - }, - { - key: "5", - page: "users", - label: "Internal Users", - icon: , - roles: all_admin_roles, - }, - { key: "14", page: "api_ref", label: "API Reference", icon: }, - { - key: "16", - page: "model-hub-table", - label: "Model Hub", - icon: , - }, - { key: "15", page: "logs", label: "Logs", icon: }, - { - key: "11", - page: "guardrails", - label: "Guardrails", - icon: , - roles: all_admin_roles, - }, - { - key: "28", - page: "policies", - label: "Policies", - icon: , - roles: all_admin_roles, - }, - { - key: "26", - page: "tools", - label: "Tools", - icon: , - children: [ - { key: "18", page: "mcp-servers", label: "MCP Servers", icon: }, - { - key: "21", - page: "vector-stores", - label: "Vector Stores", - icon: , - roles: all_admin_roles, - }, - ], - }, - { - key: "experimental", - page: "experimental", - label: "Experimental", - icon: , - children: [ - { - key: "9", - page: "caching", - label: "Caching", - icon: , - roles: all_admin_roles, - }, - { - key: "25", - page: "prompts", - label: "Prompts", - icon: , - roles: all_admin_roles, - }, - { - key: "10", - page: "budgets", - label: "Budgets", - icon: , - roles: all_admin_roles, - }, - { - key: "20", - page: "transform-request", - label: "API Playground", - icon: , - roles: [...all_admin_roles, ...internalUserRoles], - }, - { - key: "19", - page: "tag-management", - label: "Tag Management", - icon: , - roles: all_admin_roles, - }, - { - key: "27", - page: "claude-code-plugins", - label: "Claude Code Plugins", - icon: , - roles: all_admin_roles, - }, - { key: "4", page: "usage", label: "Old Usage", icon: }, - ], - }, - { - key: "settings", - page: "settings", - label: "Settings", - icon: , - roles: all_admin_roles, - children: [ - { - key: "11", - page: "general-settings", - label: "Router Settings", - icon: , - roles: all_admin_roles, - }, - { - key: "8", - page: "settings", - label: "Logging & Alerts", - icon: , - roles: all_admin_roles, - }, - { - key: "13", - page: "admin-panel", - label: "Admin Settings", - icon: , - roles: all_admin_roles, - }, - { - key: "14", - page: "ui-theme", - label: "UI Theme", - icon: , - roles: all_admin_roles, - }, - ], - }, - ]; + { key: "1", page: "api-keys", label: "Virtual Keys", icon: }, + { + key: "3", + page: "llm-playground", + label: "Test Key", + icon: , + roles: rolesWithWriteAccess, + }, + { + key: "2", + page: "models", + label: "Models + Endpoints", + icon: , + roles: rolesWithWriteAccess, + }, + { + key: "12", + page: "new_usage", + label: "Usage", + icon: , + roles: [...all_admin_roles, ...internalUserRoles], + }, + { key: "6", page: "teams", label: "Teams", icon: }, + { + key: "17", + page: "organizations", + label: "Organizations", + icon: , + roles: all_admin_roles, + }, + { + key: "5", + page: "users", + label: "Internal Users", + icon: , + roles: all_admin_roles, + }, + { key: "14", page: "api_ref", label: "API Reference", icon: }, + { + key: "16", + page: "model-hub-table", + label: "Model Hub", + icon: , + }, + { key: "15", page: "logs", label: "Logs", icon: }, + { + key: "11", + page: "guardrails", + label: "Guardrails", + icon: , + roles: all_admin_roles, + }, + { + key: "28", + page: "policies", + label: "Policies", + icon: , + roles: all_admin_roles, + }, + { + key: "26", + page: "tools", + label: "Tools", + icon: , + children: [ + { key: "18", page: "mcp-servers", label: "MCP Servers", icon: }, + { + key: "21", + page: "vector-stores", + label: "Vector Stores", + icon: , + roles: all_admin_roles, + }, + ], + }, + { + key: "experimental", + page: "experimental", + label: "Experimental", + icon: , + children: [ + { + key: "9", + page: "caching", + label: "Caching", + icon: , + roles: all_admin_roles, + }, + { + key: "25", + page: "prompts", + label: "Prompts", + icon: , + roles: all_admin_roles, + }, + { + key: "10", + page: "budgets", + label: "Budgets", + icon: , + roles: all_admin_roles, + }, + { + key: "20", + page: "transform-request", + label: "API Playground", + icon: , + roles: [...all_admin_roles, ...internalUserRoles], + }, + { + key: "19", + page: "tag-management", + label: "Tag Management", + icon: , + roles: all_admin_roles, + }, + { + key: "27", + page: "claude-code-plugins", + label: "Claude Code Plugins", + icon: , + roles: all_admin_roles, + }, + { key: "4", page: "usage", label: "Old Usage", icon: }, + ], + }, + { + key: "settings", + page: "settings", + label: "Settings", + icon: , + roles: all_admin_roles, + children: [ + { + key: "11", + page: "general-settings", + label: "Router Settings", + icon: , + roles: all_admin_roles, + }, + { + key: "8", + page: "settings", + label: "Logging & Alerts", + icon: , + roles: all_admin_roles, + }, + { + key: "13", + page: "admin-panel", + label: "Admin Settings", + icon: , + roles: all_admin_roles, + }, + { + key: "14", + page: "ui-theme", + label: "UI Theme", + icon: , + roles: all_admin_roles, + }, + ], + }, +]; const Sidebar2: React.FC = ({ accessToken, userRole, defaultSelectedKey, collapsed = false }) => { const router = useRouter(); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts index 6429aeafb5a..aba5dddf13d 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts @@ -23,6 +23,7 @@ vi.mock("../common/queryKeysFactory", () => ({ // Mock data const mockUIConfig: LiteLLMWellKnownUiConfig = { + sso_configured: true, server_root_path: "/api", proxy_base_url: "https://proxy.example.com", auto_redirect_to_sso: true, @@ -99,6 +100,7 @@ describe("useUIConfig", () => { server_root_path: "/v1", proxy_base_url: null, auto_redirect_to_sso: false, + sso_configured: false, admin_ui_disabled: true, }; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/useAuthorized.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useAuthorized.test.ts index ef4a779b50b..76a3129d6d7 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/useAuthorized.test.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useAuthorized.test.ts @@ -90,6 +90,7 @@ describe("useAuthorized", () => { proxy_base_url: null, auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); const decodedPayload = { @@ -131,6 +132,7 @@ describe("useAuthorized", () => { proxy_base_url: null, auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); decodeTokenMock.mockReturnValue(null); @@ -155,6 +157,7 @@ describe("useAuthorized", () => { proxy_base_url: null, auto_redirect_to_sso: false, admin_ui_disabled: true, + sso_configured: false, }); const decodedPayload = { @@ -190,6 +193,7 @@ describe("useAuthorized", () => { proxy_base_url: null, auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); decodeTokenMock.mockReturnValue(null); @@ -212,6 +216,7 @@ describe("useAuthorized", () => { proxy_base_url: null, auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); const decodedPayload = { diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.test.ts new file mode 100644 index 00000000000..bd0e69c0de3 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.test.ts @@ -0,0 +1,190 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import { act, renderHook, waitFor } from "@testing-library/react"; +import { useDisableUsageIndicator } from "./useDisableUsageIndicator"; +import { LOCAL_STORAGE_EVENT } from "@/utils/localStorageUtils"; + +describe("useDisableUsageIndicator", () => { + const STORAGE_KEY = "disableUsageIndicator"; + + beforeEach(() => { + localStorage.clear(); + vi.clearAllMocks(); + }); + + afterEach(() => { + localStorage.clear(); + }); + + it("should return false when localStorage is empty", () => { + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + }); + + it("should return false when localStorage value is not 'true'", () => { + localStorage.setItem(STORAGE_KEY, "false"); + + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + }); + + it("should return true when localStorage value is 'true'", () => { + localStorage.setItem(STORAGE_KEY, "true"); + + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(true); + }); + + it("should return false when localStorage value is an empty string", () => { + localStorage.setItem(STORAGE_KEY, ""); + + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + }); + + it("should update when storage event fires for the correct key", async () => { + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + + await act(async () => { + localStorage.setItem(STORAGE_KEY, "true"); + const storageEvent = new StorageEvent("storage", { + key: STORAGE_KEY, + newValue: "true", + }); + window.dispatchEvent(storageEvent); + }); + + await waitFor(() => { + expect(result.current).toBe(true); + }); + }); + + it("should not update when storage event fires for a different key", () => { + localStorage.setItem(STORAGE_KEY, "false"); + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + + const storageEvent = new StorageEvent("storage", { + key: "otherKey", + newValue: "true", + }); + window.dispatchEvent(storageEvent); + + expect(result.current).toBe(false); + }); + + it("should update when custom LOCAL_STORAGE_EVENT fires for the correct key", async () => { + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + + await act(async () => { + localStorage.setItem(STORAGE_KEY, "true"); + const customEvent = new CustomEvent(LOCAL_STORAGE_EVENT, { + detail: { key: STORAGE_KEY }, + }); + window.dispatchEvent(customEvent); + }); + + await waitFor(() => { + expect(result.current).toBe(true); + }); + }); + + it("should not update when custom LOCAL_STORAGE_EVENT fires for a different key", () => { + localStorage.setItem(STORAGE_KEY, "false"); + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + + const customEvent = new CustomEvent(LOCAL_STORAGE_EVENT, { + detail: { key: "otherKey" }, + }); + window.dispatchEvent(customEvent); + + expect(result.current).toBe(false); + }); + + it("should update when localStorage changes from false to true via custom event", async () => { + localStorage.setItem(STORAGE_KEY, "false"); + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(false); + + await act(async () => { + localStorage.setItem(STORAGE_KEY, "true"); + const customEvent = new CustomEvent(LOCAL_STORAGE_EVENT, { + detail: { key: STORAGE_KEY }, + }); + window.dispatchEvent(customEvent); + }); + + await waitFor(() => { + expect(result.current).toBe(true); + }); + }); + + it("should update when localStorage changes from true to false via storage event", async () => { + localStorage.setItem(STORAGE_KEY, "true"); + const { result } = renderHook(() => useDisableUsageIndicator()); + + expect(result.current).toBe(true); + + await act(async () => { + localStorage.setItem(STORAGE_KEY, "false"); + const storageEvent = new StorageEvent("storage", { + key: STORAGE_KEY, + newValue: "false", + }); + window.dispatchEvent(storageEvent); + }); + + await waitFor(() => { + expect(result.current).toBe(false); + }); + }); + + it("should cleanup event listeners on unmount", () => { + const addEventListenerSpy = vi.spyOn(window, "addEventListener"); + const removeEventListenerSpy = vi.spyOn(window, "removeEventListener"); + + const { unmount } = renderHook(() => useDisableUsageIndicator()); + + expect(addEventListenerSpy).toHaveBeenCalledTimes(2); + expect(addEventListenerSpy).toHaveBeenCalledWith("storage", expect.any(Function)); + expect(addEventListenerSpy).toHaveBeenCalledWith(LOCAL_STORAGE_EVENT, expect.any(Function)); + + unmount(); + + expect(removeEventListenerSpy).toHaveBeenCalledTimes(2); + expect(removeEventListenerSpy).toHaveBeenCalledWith("storage", expect.any(Function)); + expect(removeEventListenerSpy).toHaveBeenCalledWith(LOCAL_STORAGE_EVENT, expect.any(Function)); + }); + + it("should handle multiple hooks independently", async () => { + const { result: result1 } = renderHook(() => useDisableUsageIndicator()); + const { result: result2 } = renderHook(() => useDisableUsageIndicator()); + + expect(result1.current).toBe(false); + expect(result2.current).toBe(false); + + await act(async () => { + localStorage.setItem(STORAGE_KEY, "true"); + const customEvent = new CustomEvent(LOCAL_STORAGE_EVENT, { + detail: { key: STORAGE_KEY }, + }); + window.dispatchEvent(customEvent); + }); + + await waitFor(() => { + expect(result1.current).toBe(true); + expect(result2.current).toBe(true); + }); + }); +}); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.ts new file mode 100644 index 00000000000..7f4e2295090 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/useDisableUsageIndicator.ts @@ -0,0 +1,33 @@ +import { getLocalStorageItem, LOCAL_STORAGE_EVENT } from "@/utils/localStorageUtils"; +import { useSyncExternalStore } from "react"; + +function subscribe(callback: () => void) { + const onStorage = (e: StorageEvent) => { + if (e.key === "disableUsageIndicator") { + callback(); + } + }; + + const onCustom = (e: Event) => { + const { key } = (e as CustomEvent).detail; + if (key === "disableUsageIndicator") { + callback(); + } + }; + + window.addEventListener("storage", onStorage); + window.addEventListener(LOCAL_STORAGE_EVENT, onCustom); + + return () => { + window.removeEventListener("storage", onStorage); + window.removeEventListener(LOCAL_STORAGE_EVENT, onCustom); + }; +} + +function getSnapshot() { + return getLocalStorageItem("disableUsageIndicator") === "true"; +} + +export function useDisableUsageIndicator() { + return useSyncExternalStore(subscribe, getSnapshot); +} diff --git a/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx b/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx index 79834512605..ad2dde2da83 100644 --- a/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx +++ b/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx @@ -64,7 +64,12 @@ describe("LoginPage", () => { it("should render", async () => { (useUIConfig as ReturnType).mockReturnValue({ - data: { auto_redirect_to_sso: false, server_root_path: "/", proxy_base_url: null }, + data: { + auto_redirect_to_sso: false, + server_root_path: "/", + proxy_base_url: null, + sso_configured: false, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(null); @@ -84,7 +89,12 @@ describe("LoginPage", () => { it("should call router.replace to dashboard when jwt is valid", async () => { const validToken = "valid-token"; (useUIConfig as ReturnType).mockReturnValue({ - data: { auto_redirect_to_sso: false, server_root_path: "/", proxy_base_url: null }, + data: { + auto_redirect_to_sso: false, + server_root_path: "/", + proxy_base_url: null, + sso_configured: false, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(validToken); @@ -105,7 +115,12 @@ describe("LoginPage", () => { it("should call router.push to SSO when jwt is invalid and auto_redirect_to_sso is true", async () => { const invalidToken = "invalid-token"; (useUIConfig as ReturnType).mockReturnValue({ - data: { auto_redirect_to_sso: true, server_root_path: "/", proxy_base_url: null }, + data: { + auto_redirect_to_sso: true, + server_root_path: "/", + proxy_base_url: null, + sso_configured: true, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(invalidToken); @@ -126,7 +141,12 @@ describe("LoginPage", () => { it("should not call router when jwt is invalid and auto_redirect_to_sso is false", async () => { const invalidToken = "invalid-token"; (useUIConfig as ReturnType).mockReturnValue({ - data: { auto_redirect_to_sso: false, server_root_path: "/", proxy_base_url: null }, + data: { + auto_redirect_to_sso: false, + server_root_path: "/", + proxy_base_url: null, + sso_configured: false, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(invalidToken); @@ -150,7 +170,12 @@ describe("LoginPage", () => { it("should send user to dashboard when jwt is valid even if auto_redirect_to_sso is true", async () => { const validToken = "valid-token"; (useUIConfig as ReturnType).mockReturnValue({ - data: { auto_redirect_to_sso: true, server_root_path: "/", proxy_base_url: null }, + data: { + auto_redirect_to_sso: true, + server_root_path: "/", + proxy_base_url: null, + sso_configured: true, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(validToken); @@ -172,7 +197,12 @@ describe("LoginPage", () => { it("should show alert when admin_ui_disabled is true", async () => { (useUIConfig as ReturnType).mockReturnValue({ - data: { admin_ui_disabled: true, server_root_path: "/", proxy_base_url: null }, + data: { + admin_ui_disabled: true, + server_root_path: "/", + proxy_base_url: null, + sso_configured: false, + }, isLoading: false, }); (getCookie as ReturnType).mockReturnValue(null); @@ -192,4 +222,60 @@ describe("LoginPage", () => { expect(mockPush).not.toHaveBeenCalled(); expect(mockReplace).not.toHaveBeenCalled(); }); + + it("should show Login with SSO button when sso_configured is true", async () => { + (useUIConfig as ReturnType).mockReturnValue({ + data: { + auto_redirect_to_sso: false, + server_root_path: "/", + proxy_base_url: null, + sso_configured: true, + }, + isLoading: false, + }); + (getCookie as ReturnType).mockReturnValue(null); + (isJwtExpired as ReturnType).mockReturnValue(true); + + const queryClient = createQueryClient(); + render( + + + , + ); + + await waitFor(() => { + expect(screen.getByRole("heading", { name: "Login" })).toBeInTheDocument(); + }); + + expect(screen.getByRole("button", { name: "Login with SSO" })).toBeInTheDocument(); + }); + + it("should show disabled Login with SSO button with popover when sso_configured is false", async () => { + (useUIConfig as ReturnType).mockReturnValue({ + data: { + auto_redirect_to_sso: false, + server_root_path: "/", + proxy_base_url: null, + sso_configured: false, + }, + isLoading: false, + }); + (getCookie as ReturnType).mockReturnValue(null); + (isJwtExpired as ReturnType).mockReturnValue(true); + + const queryClient = createQueryClient(); + render( + + + , + ); + + await waitFor(() => { + expect(screen.getByRole("heading", { name: "Login" })).toBeInTheDocument(); + }); + + const ssoButton = screen.getByRole("button", { name: "Login with SSO" }); + expect(ssoButton).toBeInTheDocument(); + expect(ssoButton).toBeDisabled(); + }); }); diff --git a/ui/litellm-dashboard/src/app/login/LoginPage.tsx b/ui/litellm-dashboard/src/app/login/LoginPage.tsx index 620cb41dfee..a05fa4e214e 100644 --- a/ui/litellm-dashboard/src/app/login/LoginPage.tsx +++ b/ui/litellm-dashboard/src/app/login/LoginPage.tsx @@ -8,7 +8,7 @@ import { getCookie } from "@/utils/cookieUtils"; import { isJwtExpired } from "@/utils/jwtUtils"; import { InfoCircleOutlined } from "@ant-design/icons"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; -import { Alert, Button, Card, Form, Input, Space, Typography } from "antd"; +import { Alert, Button, Card, Form, Input, Popover, Space, Typography } from "antd"; import { useRouter } from "next/navigation"; import { useEffect, useState } from "react"; @@ -179,8 +179,39 @@ function LoginPageContent() { {isLoginLoading ? "Logging in..." : "Login"} + + {!uiConfig?.sso_configured ? ( + + + + ) : ( + + )} + + {uiConfig?.sso_configured && ( + Single Sign-On (SSO) is enabled. LiteLLM no longer automatically redirects to the SSO login flow upon loading this page. To re-enable auto-redirect-to-SSO, set AUTO_REDIRECT_UI_LOGIN_TO_SSO=true in your environment configuration.} + /> + )} ); diff --git a/ui/litellm-dashboard/src/components/AIHub/ModelHubTable.test.tsx b/ui/litellm-dashboard/src/components/AIHub/ModelHubTable.test.tsx index 0a5cd17e571..ee59ac84ece 100644 --- a/ui/litellm-dashboard/src/components/AIHub/ModelHubTable.test.tsx +++ b/ui/litellm-dashboard/src/components/AIHub/ModelHubTable.test.tsx @@ -71,6 +71,7 @@ describe("ModelHubTable", () => { proxy_base_url: "http://localhost:4000", auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); vi.mocked(networking.modelHubPublicModelsCall).mockResolvedValue([]); vi.mocked(networking.getUiSettings).mockResolvedValue({ @@ -140,6 +141,7 @@ describe("ModelHubTable", () => { proxy_base_url: "http://localhost:4000", auto_redirect_to_sso: false, admin_ui_disabled: false, + sso_configured: false, }); modelHubPublicModelsCallMock.mockResolvedValue([]); vi.mocked(networking.getUiSettings).mockResolvedValue({ diff --git a/ui/litellm-dashboard/src/components/Navbar/UserDropdown/UserDropdown.tsx b/ui/litellm-dashboard/src/components/Navbar/UserDropdown/UserDropdown.tsx index f80af33f9e2..90e02ae447b 100644 --- a/ui/litellm-dashboard/src/components/Navbar/UserDropdown/UserDropdown.tsx +++ b/ui/litellm-dashboard/src/components/Navbar/UserDropdown/UserDropdown.tsx @@ -1,5 +1,6 @@ import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; import { useDisableShowPrompts } from "@/app/(dashboard)/hooks/useDisableShowPrompts"; +import { useDisableUsageIndicator } from "@/app/(dashboard)/hooks/useDisableUsageIndicator"; import { emitLocalStorageChange, getLocalStorageItem, @@ -27,6 +28,7 @@ interface UserDropdownProps { const UserDropdown: React.FC = ({ onLogout }) => { const { userId, userEmail, userRole, premiumUser } = useAuthorized(); const disableShowPrompts = useDisableShowPrompts(); + const disableUsageIndicator = useDisableUsageIndicator(); const [disableShowNewBadge, setDisableShowNewBadge] = useState(false); useEffect(() => { @@ -129,6 +131,23 @@ const UserDropdown: React.FC = ({ onLogout }) => { aria-label="Toggle hide all prompts" /> + + Hide Usage Indicator + { + if (checked) { + setLocalStorageItem("disableUsageIndicator", "true"); + emitLocalStorageChange("disableUsageIndicator"); + } else { + removeLocalStorageItem("disableUsageIndicator"); + emitLocalStorageChange("disableUsageIndicator"); + } + }} + aria-label="Toggle hide usage indicator" + /> + ); diff --git a/ui/litellm-dashboard/src/components/UsageIndicator.test.tsx b/ui/litellm-dashboard/src/components/UsageIndicator.test.tsx new file mode 100644 index 00000000000..71a37263980 --- /dev/null +++ b/ui/litellm-dashboard/src/components/UsageIndicator.test.tsx @@ -0,0 +1,186 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { render, screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import UsageIndicator from "./UsageIndicator"; + +vi.mock("./networking", () => ({ + getRemainingUsers: vi.fn(), +})); + +vi.mock("@/app/(dashboard)/hooks/useDisableUsageIndicator", () => ({ + useDisableUsageIndicator: vi.fn(() => false), +})); + +import { getRemainingUsers } from "./networking"; + +const mockGetRemainingUsers = vi.mocked(getRemainingUsers); + +const DEFAULT_USAGE_DATA = { + total_users: 100, + total_users_used: 1, + total_users_remaining: 99, + total_teams: null, + total_teams_used: 0, + total_teams_remaining: null, +}; + +describe("UsageIndicator", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetRemainingUsers.mockResolvedValue(DEFAULT_USAGE_DATA); + }); + + it("should render when given access token and usage data loads", async () => { + render(); + + await screen.findByText("Usage"); + + expect(screen.getByText("Usage")).toBeInTheDocument(); + }); + + it("should not show Near limit when users usage is below 80% (1/100 -> 1%)", async () => { + render(); + + await screen.findByText("Usage"); + + expect(screen.queryByText("Near limit")).not.toBeInTheDocument(); + }); + + it("should render nothing when both total_users and total_teams are null", async () => { + mockGetRemainingUsers.mockResolvedValue({ + total_users: null, + total_teams: null, + total_users_used: 520, + total_teams_used: 4, + total_teams_remaining: null, + total_users_remaining: null, + }); + + render(); + + await waitFor(() => { + expect(screen.queryByText("Usage")).not.toBeInTheDocument(); + expect(screen.queryByText("Loading...")).not.toBeInTheDocument(); + }); + }); + + it("should show Near limit for Teams when at 80% usage (4/5)", async () => { + mockGetRemainingUsers.mockResolvedValue({ + total_users: null, + total_users_used: 0, + total_users_remaining: null, + total_teams: 5, + total_teams_used: 4, + total_teams_remaining: 1, + }); + + render(); + + await screen.findByText("Usage"); + + expect(screen.getByText("Teams")).toBeInTheDocument(); + expect(screen.getByText("Near limit")).toBeInTheDocument(); + }); + + it("should show Over limit for Users when usage exceeds 100% (105/100)", async () => { + mockGetRemainingUsers.mockResolvedValue({ + total_users: 100, + total_users_used: 105, + total_users_remaining: -5, + total_teams: null, + total_teams_used: 0, + total_teams_remaining: null, + }); + + render(); + + await screen.findByText("Usage"); + + expect(screen.getByText("Users")).toBeInTheDocument(); + expect(screen.getByText("Over limit")).toBeInTheDocument(); + }); + + it("should show Over limit for Teams when usage exceeds 100%", async () => { + mockGetRemainingUsers.mockResolvedValue({ + total_users: null, + total_users_used: 0, + total_users_remaining: null, + total_teams: 10, + total_teams_used: 12, + total_teams_remaining: -2, + }); + + render(); + + await screen.findByText("Usage"); + + expect(screen.getByText("Teams")).toBeInTheDocument(); + expect(screen.getByText("Over limit")).toBeInTheDocument(); + }); + + it("should render nothing when accessToken is null", () => { + render(); + + expect(mockGetRemainingUsers).not.toHaveBeenCalled(); + expect(screen.queryByText("Usage")).not.toBeInTheDocument(); + }); + + it("should render nothing when disableUsageIndicator is true", async () => { + const { useDisableUsageIndicator } = await import("@/app/(dashboard)/hooks/useDisableUsageIndicator"); + (useDisableUsageIndicator as ReturnType).mockReturnValue(true); + + render(); + + await waitFor(() => { + expect(screen.queryByText("Usage")).not.toBeInTheDocument(); + }); + + (useDisableUsageIndicator as ReturnType).mockReturnValue(false); + }); + + it("should show Loading while fetching", () => { + mockGetRemainingUsers.mockImplementation(() => new Promise(() => {})); + + render(); + + expect(screen.getByText("Loading...")).toBeInTheDocument(); + }); + + it("should show error message when fetch fails", async () => { + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + mockGetRemainingUsers.mockRejectedValue(new Error("Network error")); + + render(); + + expect(await screen.findByText("Failed to load usage data")).toBeInTheDocument(); + + consoleSpy.mockRestore(); + }); + + it("should minimize when user clicks minimize button", async () => { + const user = userEvent.setup(); + render(); + + await screen.findByText("Usage"); + + const minimizeButton = screen.getByTitle("Minimize"); + await user.click(minimizeButton); + + expect(screen.queryByText("Users")).not.toBeInTheDocument(); + expect(screen.getByTitle("Show usage details")).toBeInTheDocument(); + }); + + it("should restore from minimized when user clicks restore button", async () => { + const user = userEvent.setup(); + render(); + + await screen.findByText("Usage"); + + await user.click(screen.getByTitle("Minimize")); + await user.click(screen.getByTitle("Show usage details")); + + expect(screen.getByText("Usage")).toBeInTheDocument(); + expect(screen.getByText("Users")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/usage_indicator.tsx b/ui/litellm-dashboard/src/components/UsageIndicator.tsx similarity index 98% rename from ui/litellm-dashboard/src/components/usage_indicator.tsx rename to ui/litellm-dashboard/src/components/UsageIndicator.tsx index ed5e8e07555..3976e4d3d6f 100644 --- a/ui/litellm-dashboard/src/components/usage_indicator.tsx +++ b/ui/litellm-dashboard/src/components/UsageIndicator.tsx @@ -1,3 +1,4 @@ +import { useDisableUsageIndicator } from "@/app/(dashboard)/hooks/useDisableUsageIndicator"; import { Badge } from "@tremor/react"; import { AlertTriangle, ChevronDown, ChevronUp, Loader2, Minus, TrendingUp, UserCheck, Users } from "lucide-react"; import { useEffect, useState } from "react"; @@ -23,7 +24,7 @@ interface UsageData { } export default function UsageIndicator({ accessToken, width = 220 }: UsageIndicatorProps) { - const position = "bottom-left"; + const disableUsageIndicator = useDisableUsageIndicator(); const [isExpanded, setIsExpanded] = useState(false); const [isMinimized, setIsMinimized] = useState(false); const [data, setData] = useState(null); @@ -541,8 +542,8 @@ export default function UsageIndicator({ accessToken, width = 220 }: UsageIndica ); }; - // Don't render anything if no access token or if both total_users and total_teams are null - if (!accessToken || (data?.total_users === null && data?.total_teams === null)) { + // Don't render anything if disabled, no access token, or if both total_users and total_teams are null + if (disableUsageIndicator || !accessToken || (data?.total_users === null && data?.total_teams === null)) { return null; } diff --git a/ui/litellm-dashboard/src/components/leftnav.tsx b/ui/litellm-dashboard/src/components/leftnav.tsx index b26590989d2..e35d70a65bb 100644 --- a/ui/litellm-dashboard/src/components/leftnav.tsx +++ b/ui/litellm-dashboard/src/components/leftnav.tsx @@ -29,9 +29,9 @@ import type { MenuProps } from "antd"; import { ConfigProvider, Layout, Menu } from "antd"; import { useMemo } from "react"; import { all_admin_roles, internalUserRoles, isAdminRole, rolesWithWriteAccess } from "../utils/roles"; -import type { Organization } from "./networking"; -import UsageIndicator from "./usage_indicator"; import NewBadge from "./common_components/NewBadge"; +import type { Organization } from "./networking"; +import UsageIndicator from "./UsageIndicator"; const { Sider } = Layout; // Define the props type diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index aa509c65aa4..bee0930dda3 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -259,6 +259,7 @@ export interface LiteLLMWellKnownUiConfig { proxy_base_url: string | null; auto_redirect_to_sso: boolean; admin_ui_disabled: boolean; + sso_configured: boolean; } export interface CredentialsResponse { @@ -5654,6 +5655,68 @@ export const getResolvedGuardrails = async (accessToken: string, policyId: strin } }; +export const resolvePoliciesCall = async ( + accessToken: string, + context: { team_alias?: string; key_alias?: string; model?: string; tags?: string[] } +) => { + try { + const url = proxyBaseUrl + ? `${proxyBaseUrl}/policies/resolve` + : `/policies/resolve`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(context), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return await response.json(); + } catch (error) { + console.error("Failed to resolve policies:", error); + throw error; + } +}; + +export const estimateAttachmentImpactCall = async ( + accessToken: string, + attachmentData: any +) => { + try { + const url = proxyBaseUrl + ? `${proxyBaseUrl}/policies/attachments/estimate-impact` + : `/policies/attachments/estimate-impact`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(attachmentData), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return await response.json(); + } catch (error) { + console.error("Failed to estimate attachment impact:", error); + throw error; + } +}; + export const getPromptsList = async (accessToken: string): Promise => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/prompts/list` : `/prompts/list`; diff --git a/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx b/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx index 9198eda8a94..7426f4fefa2 100644 --- a/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx +++ b/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx @@ -1,10 +1,12 @@ import React, { useState, useEffect } from "react"; import { Modal, Form, Select, Radio, Divider, Typography } from "antd"; import { Button } from "@tremor/react"; -import { Policy, PolicyAttachmentCreateRequest } from "./types"; -import { teamListCall, keyInfoCall, modelAvailableCall } from "../networking"; +import { Policy } from "./types"; +import { teamListCall, keyListCall, modelAvailableCall, estimateAttachmentImpactCall } from "../networking"; import NotificationsManager from "../molecules/notifications_manager"; import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { buildAttachmentData } from "./build_attachment_data"; +import ImpactPreviewAlert from "./impact_preview_alert"; const { Text } = Typography; @@ -34,6 +36,8 @@ const AddAttachmentForm: React.FC = ({ const [isLoadingTeams, setIsLoadingTeams] = useState(false); const [isLoadingKeys, setIsLoadingKeys] = useState(false); const [isLoadingModels, setIsLoadingModels] = useState(false); + const [isEstimating, setIsEstimating] = useState(false); + const [impactResult, setImpactResult] = useState(null); const { userId, userRole } = useAuthorized(); useEffect(() => { @@ -46,33 +50,30 @@ const AddAttachmentForm: React.FC = ({ const loadTeamsKeysAndModels = async () => { if (!accessToken) return; - // Load teams + // Load teams — teamListCall returns a plain array of team objects setIsLoadingTeams(true); try { - // Pass null for organizationID since we're loading all teams the user has access to const teamsResponse = await teamListCall(accessToken, null, userId); - if (teamsResponse?.data) { - const teamAliases = teamsResponse.data - .map((t: any) => t.team_alias) - .filter(Boolean); - setAvailableTeams(teamAliases); - } + const teamsArray = Array.isArray(teamsResponse) ? teamsResponse : (teamsResponse?.data || []); + const teamAliases = teamsArray + .map((t: any) => t.team_alias) + .filter(Boolean); + setAvailableTeams(teamAliases); } catch (error) { console.error("Failed to load teams:", error); } finally { setIsLoadingTeams(false); } - // Load keys + // Load keys — keyListCall returns {keys: [...], total_count, ...} setIsLoadingKeys(true); try { - const keysResponse = await keyInfoCall(accessToken, []); - if (keysResponse?.data) { - const keyAliases = keysResponse.data - .map((k: any) => k.key_alias) - .filter(Boolean); - setAvailableKeys(keyAliases); - } + const keysResponse = await keyListCall(accessToken, null, null, null, null, null, 1, 100); + const keysArray = keysResponse?.keys || keysResponse?.data || []; + const keyAliases = keysArray + .map((k: any) => k.key_alias) + .filter(Boolean); + setAvailableKeys(keyAliases); } catch (error) { console.error("Failed to load keys:", error); } finally { @@ -83,12 +84,11 @@ const AddAttachmentForm: React.FC = ({ setIsLoadingModels(true); try { const modelsResponse = await modelAvailableCall(accessToken, userId || "", userRole || ""); - if (modelsResponse?.data) { - const modelIds = modelsResponse.data - .map((m: any) => m.id || m.model_name) - .filter(Boolean); - setAvailableModels(modelIds); - } + const modelsArray = modelsResponse?.data || (Array.isArray(modelsResponse) ? modelsResponse : []); + const modelIds = modelsArray + .map((m: any) => m.id || m.model_name) + .filter(Boolean); + setAvailableModels(modelIds); } catch (error) { console.error("Failed to load models:", error); } finally { @@ -99,6 +99,28 @@ const AddAttachmentForm: React.FC = ({ const resetForm = () => { form.resetFields(); setScopeType("global"); + setImpactResult(null); + }; + + const getAttachmentData = () => buildAttachmentData(form.getFieldsValue(true), scopeType); + + const handlePreviewImpact = async () => { + if (!accessToken) return; + try { + await form.validateFields(["policy_name"]); + } catch { + return; + } + setIsEstimating(true); + try { + const data = getAttachmentData(); + const result = await estimateAttachmentImpactCall(accessToken, data); + setImpactResult(result); + } catch (error) { + console.error("Failed to estimate impact:", error); + } finally { + setIsEstimating(false); + } }; const handleClose = () => { @@ -110,30 +132,12 @@ const AddAttachmentForm: React.FC = ({ try { setIsSubmitting(true); await form.validateFields(); - const values = form.getFieldsValue(true); if (!accessToken) { throw new Error("No access token available"); } - const data: PolicyAttachmentCreateRequest = { - policy_name: values.policy_name, - }; - - if (scopeType === "global") { - data.scope = "*"; - } else { - if (values.teams && values.teams.length > 0) { - data.teams = values.teams; - } - if (values.keys && values.keys.length > 0) { - data.keys = values.keys; - } - if (values.models && values.models.length > 0) { - data.models = values.models; - } - } - + const data = getAttachmentData(); await createAttachment(accessToken, data); NotificationsManager.success("Attachment created successfully"); @@ -195,8 +199,8 @@ const AddAttachmentForm: React.FC = ({ value={scopeType} onChange={(e) => setScopeType(e.target.value)} > + Specific (teams, keys, models, or tags) Global (applies to all requests) - Specific (teams, keys, or models) @@ -267,13 +271,41 @@ const AddAttachmentForm: React.FC = ({ style={{ width: "100%" }} /> + + + Matches tags from key/team metadata.tags or tags passed dynamically in the request body. Use * as a suffix wildcard (e.g., prod-* matches prod-us, prod-eu). + + } + > + ({ label: t, value: t }))} + filterOption={(input, option) => + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + /> + + + ({ label: m, value: m }))} + filterOption={(input, option) => + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + /> + + +