diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index b1ca1f71c9e..558dfcc9517 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -914,6 +914,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/litellm/constants.py b/litellm/constants.py index 9c25cf77906..f288b11f2b5 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1298,6 +1298,9 @@ os.getenv("DEFAULT_SLACK_ALERTING_THRESHOLD", 300) ) MAX_TEAM_LIST_LIMIT = int(os.getenv("MAX_TEAM_LIST_LIMIT", 20)) +MAX_POLICY_ESTIMATE_IMPACT_ROWS = int( + os.getenv("MAX_POLICY_ESTIMATE_IMPACT_ROWS", 1000) +) DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = float( os.getenv("DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD", 0.7) ) diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index faeca9b2aed..62ca6dc2ae2 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -394,6 +394,14 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: _metadata["applied_policies"] ) + if "policy_sources" in _metadata: + sources = _metadata["policy_sources"] + if isinstance(sources, dict) and sources: + # Use ';' as delimiter — matched_via reasons may contain commas + headers["x-litellm-policy-sources"] = "; ".join( + f"{name}={reason}" for name, reason in sources.items() + ) + if "semantic-similarity" in _metadata: headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) @@ -441,6 +449,27 @@ def add_policy_to_applied_policies_header( request_data["metadata"] = _metadata +def add_policy_sources_to_metadata( + request_data: Dict, policy_sources: Dict[str, str] +): + """ + Store policy match reasons in metadata for x-litellm-policy-sources header. + + Args: + request_data: The request data dict + policy_sources: Map of policy_name -> matched_via reason + """ + if not policy_sources: + return + _metadata = request_data.get("metadata", None) or {} + existing = _metadata.get("policy_sources", {}) + if not isinstance(existing, dict): + existing = {} + existing.update(policy_sources) + _metadata["policy_sources"] = existing + request_data["metadata"] = _metadata + + def add_guardrail_response_to_standard_logging_object( litellm_logging_obj: Optional["LiteLLMLogging"], guardrail_response: StandardLoggingGuardrailInformation, diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9be78264e85..49d31c1efec 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1539,8 +1539,15 @@ def add_guardrails_from_policy_engine( """ from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import ( + add_policy_sources_to_metadata, add_policy_to_applied_policies_header, ) + from litellm.proxy.common_utils.http_parsing_utils import ( + get_tags_from_request_body, + ) + from litellm.proxy.policy_engine.attachment_registry import ( + get_attachment_registry, + ) from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.proxy.policy_engine.policy_resolver import PolicyResolver @@ -1561,20 +1568,31 @@ def add_guardrails_from_policy_engine( ) return - # Build context from request + # Extract tags using the shared helper (handles metadata / litellm_metadata, + # top-level tags, deduplication, and type filtering). + + all_tags = get_tags_from_request_body(data) or None + context = PolicyMatchContext( team_alias=user_api_key_dict.team_alias, key_alias=user_api_key_dict.key_alias, model=data.get("model"), + tags=all_tags, ) verbose_proxy_logger.debug( f"Policy engine: matching policies for context team_alias={context.team_alias}, " - f"key_alias={context.key_alias}, model={context.model}" + f"key_alias={context.key_alias}, model={context.model}, tags={context.tags}" ) - # Get matching policies via attachments - matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + # Get matching policies via attachments (with match reasons for attribution) + attachment_registry = get_attachment_registry() + matches_with_reasons = attachment_registry.get_attached_policies_with_reasons( + context + ) + matching_policy_names = [m["policy_name"] for m in matches_with_reasons] + # Build reasons map: {"hipaa-policy": "tag:healthcare", ...} + policy_reasons = {m["policy_name"]: m["matched_via"] for m in matches_with_reasons} verbose_proxy_logger.debug( f"Policy engine: matched policies via attachments: {matching_policy_names}" @@ -1607,6 +1625,16 @@ def add_guardrails_from_policy_engine( request_data=data, policy_name=policy_name ) + # Track policy attribution sources for x-litellm-policy-sources header + applied_reasons = { + name: policy_reasons[name] + for name in applied_policy_names + if name in policy_reasons + } + add_policy_sources_to_metadata( + request_data=data, policy_sources=applied_reasons + ) + # Resolve guardrails from matching policies resolved_guardrails = PolicyResolver.resolve_guardrails_for_context(context=context) diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py index 4a335b54747..69b3b3599f3 100644 --- a/litellm/proxy/policy_engine/attachment_registry.py +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -84,6 +84,7 @@ def _parse_attachment(self, attachment_data: Dict[str, Any]) -> PolicyAttachment teams=attachment_data.get("teams"), keys=attachment_data.get("keys"), models=attachment_data.get("models"), + tags=attachment_data.get("tags"), ) def get_attached_policies(self, context: PolicyMatchContext) -> List[str]: @@ -96,21 +97,68 @@ def get_attached_policies(self, context: PolicyMatchContext) -> List[str]: Returns: List of policy names that are attached to matching scopes """ + return [r["policy_name"] for r in self.get_attached_policies_with_reasons(context)] + + def get_attached_policies_with_reasons( + self, context: PolicyMatchContext + ) -> List[Dict[str, Any]]: + """ + Get list of policy names and match reasons for the given context. + + Returns a list of dicts with 'policy_name' and 'matched_via' keys. + The 'matched_via' describes which dimension caused the match. + """ from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher - attached_policies: List[str] = [] + results: List[Dict[str, Any]] = [] + seen_policies: set = set() for attachment in self._attachments: scope = attachment.to_policy_scope() if PolicyMatcher.scope_matches(scope=scope, context=context): - if attachment.policy not in attached_policies: - attached_policies.append(attachment.policy) + if attachment.policy not in seen_policies: + seen_policies.add(attachment.policy) + matched_via = self._describe_match_reason(attachment, context) + results.append( + { + "policy_name": attachment.policy, + "matched_via": matched_via, + } + ) verbose_proxy_logger.debug( f"Attachment matched: policy={attachment.policy}, " + f"matched_via={matched_via}, " f"context=(team={context.team_alias}, key={context.key_alias}, model={context.model})" ) - return attached_policies + return results + + @staticmethod + def _describe_match_reason( + attachment: PolicyAttachment, context: PolicyMatchContext + ) -> str: + """Describe why an attachment matched the context.""" + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + + if attachment.is_global(): + return "scope:*" + + reasons = [] + if attachment.tags and context.tags: + matching_tags = [ + t for t in context.tags + if PolicyMatcher.matches_pattern(t, attachment.tags) + ] + if matching_tags: + reasons.append(f"tag:{matching_tags[0]}") + if attachment.teams and context.team_alias: + reasons.append(f"team:{context.team_alias}") + if attachment.keys and context.key_alias: + reasons.append(f"key:{context.key_alias}") + if attachment.models and context.model: + reasons.append(f"model:{context.model}") + + return "+".join(reasons) if reasons else "scope:default" def is_policy_attached( self, policy_name: str, context: PolicyMatchContext @@ -238,6 +286,7 @@ async def add_attachment_to_db( "teams": attachment_request.teams or [], "keys": attachment_request.keys or [], "models": attachment_request.models or [], + "tags": attachment_request.tags or [], "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), "created_by": created_by, @@ -253,6 +302,7 @@ async def add_attachment_to_db( teams=attachment_request.teams, keys=attachment_request.keys, models=attachment_request.models, + tags=attachment_request.tags, ) self.add_attachment(attachment) @@ -263,6 +313,7 @@ async def add_attachment_to_db( teams=created_attachment.teams or [], keys=created_attachment.keys or [], models=created_attachment.models or [], + tags=created_attachment.tags or [], created_at=created_attachment.created_at, updated_at=created_attachment.updated_at, created_by=created_attachment.created_by, @@ -344,6 +395,7 @@ async def get_attachment_by_id_from_db( teams=attachment.teams or [], keys=attachment.keys or [], models=attachment.models or [], + tags=attachment.tags or [], created_at=attachment.created_at, updated_at=attachment.updated_at, created_by=attachment.created_by, @@ -381,6 +433,7 @@ async def get_all_attachments_from_db( teams=a.teams or [], keys=a.keys or [], models=a.models or [], + tags=a.tags or [], created_at=a.created_at, updated_at=a.updated_at, created_by=a.created_by, @@ -415,6 +468,7 @@ async def sync_attachments_from_db( teams=attachment_response.teams if attachment_response.teams else None, keys=attachment_response.keys if attachment_response.keys else None, models=attachment_response.models if attachment_response.models else None, + tags=attachment_response.tags if attachment_response.tags else None, ) self._attachments.append(attachment) diff --git a/litellm/proxy/policy_engine/policy_endpoints.py b/litellm/proxy/policy_engine/policy_endpoints.py index 615e153862a..3bd893b0034 100644 --- a/litellm/proxy/policy_engine/policy_endpoints.py +++ b/litellm/proxy/policy_engine/policy_endpoints.py @@ -23,10 +23,6 @@ router = APIRouter() -# Get singleton instances -POLICY_REGISTRY = get_policy_registry() -ATTACHMENT_REGISTRY = get_attachment_registry() - # ───────────────────────────────────────────────────────────────────────────── # Policy CRUD Endpoints @@ -75,7 +71,7 @@ async def list_policies(): raise HTTPException(status_code=500, detail="Database not connected") try: - policies = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + policies = await get_policy_registry().get_all_policies_from_db(prisma_client) return PolicyListDBResponse(policies=policies, total_count=len(policies)) except Exception as e: verbose_proxy_logger.exception(f"Error listing policies: {e}") @@ -130,7 +126,7 @@ async def create_policy( try: created_by = user_api_key_dict.user_id - result = await POLICY_REGISTRY.add_policy_to_db( + result = await get_policy_registry().add_policy_to_db( policy_request=request, prisma_client=prisma_client, created_by=created_by, @@ -168,7 +164,7 @@ async def get_policy(policy_id: str): raise HTTPException(status_code=500, detail="Database not connected") try: - result = await POLICY_REGISTRY.get_policy_by_id_from_db( + result = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -216,7 +212,7 @@ async def update_policy( try: # Check if policy exists - existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + existing = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -226,7 +222,7 @@ async def update_policy( ) updated_by = user_api_key_dict.user_id - result = await POLICY_REGISTRY.update_policy_in_db( + result = await get_policy_registry().update_policy_in_db( policy_id=policy_id, policy_request=request, prisma_client=prisma_client, @@ -269,7 +265,7 @@ async def delete_policy(policy_id: str): try: # Check if policy exists - existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + existing = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -278,7 +274,7 @@ async def delete_policy(policy_id: str): status_code=404, detail=f"Policy with ID {policy_id} not found" ) - result = await POLICY_REGISTRY.delete_policy_from_db( + result = await get_policy_registry().delete_policy_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -324,7 +320,7 @@ async def get_resolved_guardrails(policy_id: str): try: # Get the policy - policy = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy = await get_policy_registry().get_policy_by_id_from_db( policy_id=policy_id, prisma_client=prisma_client, ) @@ -334,7 +330,7 @@ async def get_resolved_guardrails(policy_id: str): ) # Resolve guardrails - resolved = await POLICY_REGISTRY.resolve_guardrails_from_db( + resolved = await get_policy_registry().resolve_guardrails_from_db( policy_name=policy.policy_name, prisma_client=prisma_client, ) @@ -399,7 +395,7 @@ async def list_policy_attachments(): raise HTTPException(status_code=500, detail="Database not connected") try: - attachments = await ATTACHMENT_REGISTRY.get_all_attachments_from_db( + attachments = await get_attachment_registry().get_all_attachments_from_db( prisma_client ) return PolicyAttachmentListResponse( @@ -466,7 +462,7 @@ async def create_policy_attachment( try: # Verify the policy exists - policy = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + policy = await get_policy_registry().get_all_policies_from_db(prisma_client) policy_names = [p.policy_name for p in policy] if request.policy_name not in policy_names: raise HTTPException( @@ -475,7 +471,7 @@ async def create_policy_attachment( ) created_by = user_api_key_dict.user_id - result = await ATTACHMENT_REGISTRY.add_attachment_to_db( + result = await get_attachment_registry().add_attachment_to_db( attachment_request=request, prisma_client=prisma_client, created_by=created_by, @@ -510,7 +506,7 @@ async def get_policy_attachment(attachment_id: str): raise HTTPException(status_code=500, detail="Database not connected") try: - result = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + result = await get_attachment_registry().get_attachment_by_id_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) @@ -556,7 +552,7 @@ async def delete_policy_attachment(attachment_id: str): try: # Check if attachment exists - existing = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + existing = await get_attachment_registry().get_attachment_by_id_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) @@ -566,7 +562,7 @@ async def delete_policy_attachment(attachment_id: str): detail=f"Attachment with ID {attachment_id} not found", ) - result = await ATTACHMENT_REGISTRY.delete_attachment_from_db( + result = await get_attachment_registry().delete_attachment_from_db( attachment_id=attachment_id, prisma_client=prisma_client, ) diff --git a/litellm/proxy/policy_engine/policy_matcher.py b/litellm/proxy/policy_engine/policy_matcher.py index ab73970bfab..888981f85f5 100644 --- a/litellm/proxy/policy_engine/policy_matcher.py +++ b/litellm/proxy/policy_engine/policy_matcher.py @@ -81,6 +81,19 @@ def scope_matches(scope: PolicyScope, context: PolicyMatchContext) -> bool: if not PolicyMatcher.matches_pattern(context.model, scope.get_models()): return False + # Check tags (only if scope specifies tags) + # Unlike teams/keys/models, empty tags means "do not check" rather than "match all" + scope_tags = scope.get_tags() + if scope_tags: + if not context.tags: + return False + # Match if ANY context tag matches ANY scope tag pattern + if not any( + PolicyMatcher.matches_pattern(tag, scope_tags) + for tag in context.tags + ): + return False + return True @staticmethod diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index 5fb5084f648..a2431977b24 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -484,6 +484,7 @@ async def sync_policies_from_db( ) self.add_policy(policy_response.policy_name, policy) + self._initialized = True verbose_proxy_logger.info( f"Synced {len(policies)} policies from DB to in-memory registry" ) diff --git a/litellm/proxy/policy_engine/policy_resolve_endpoints.py b/litellm/proxy/policy_engine/policy_resolve_endpoints.py new file mode 100644 index 00000000000..947c659abe4 --- /dev/null +++ b/litellm/proxy/policy_engine/policy_resolve_endpoints.py @@ -0,0 +1,408 @@ +""" +Policy resolve and attachment impact estimation endpoints. + +- /policies/resolve — debug which guardrails apply for a given context +- /policies/attachments/estimate-impact — preview blast radius before creating an attachment +""" + +import json + +from litellm.proxy.auth.route_checks import RouteChecks + +from litellm._logging import verbose_proxy_logger +from litellm.constants import MAX_POLICY_ESTIMATE_IMPACT_ROWS +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.policy_registry import get_policy_registry +from litellm.types.proxy.policy_engine import ( + AttachmentImpactResponse, + PolicyAttachmentCreateRequest, + PolicyMatchContext, + PolicyMatchDetail, + PolicyResolveRequest, + PolicyResolveResponse, +) + +router = APIRouter() + + +def _build_alias_where(field: str, patterns: list) -> dict: + """Build a Prisma ``where`` clause for alias patterns. + + Supports exact matches and suffix wildcards (``prefix*``). + Returns something like: + {"OR": [{"field": {"in": ["a","b"]}}, {"field": {"startsWith": "dev-"}}]} + """ + exact: list = [] + prefix_conditions: list = [] + for pat in patterns: + if pat.endswith("*"): + prefix_conditions.append({field: {"startsWith": pat[:-1]}}) + else: + exact.append(pat) + + conditions: list = [] + if exact: + conditions.append({field: {"in": exact}}) + conditions.extend(prefix_conditions) + + if not conditions: + return {field: {"not": None}} + if len(conditions) == 1: + return conditions[0] + return {"OR": conditions} + + +def _parse_metadata(raw_metadata: object) -> dict: + """Parse metadata that may be a dict, JSON string, or None.""" + if raw_metadata is None: + return {} + if isinstance(raw_metadata, str): + try: + return json.loads(raw_metadata) + except (json.JSONDecodeError, TypeError): + return {} + return raw_metadata if isinstance(raw_metadata, dict) else {} + + +def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> list: + """Extract tags list from a metadata field (or metadata_json fallback).""" + raw = json_metadata if json_metadata is not None else metadata + parsed = _parse_metadata(raw) + return parsed.get("tags", []) or [] + + +async def _fetch_all_teams(prisma_client: object) -> list: + """Fetch teams from DB once. Reuse the result across tag and alias lookups.""" + return await prisma_client.db.litellm_teamtable.find_many( # type: ignore + where={}, order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + + +def _filter_keys_by_tags(keys: list, tag_patterns: list) -> tuple: + """Filter key rows whose metadata.tags match any of the given patterns. + + Returns (named_aliases, unnamed_count). + """ + from litellm.proxy.auth.route_checks import RouteChecks + + affected: list = [] + unnamed_count = 0 + for key in keys: + key_alias = key.key_alias or "" + key_tags = _get_tags_from_metadata( + key.metadata, getattr(key, "metadata_json", None) + ) + if key_tags and any( + RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat) + for tag in key_tags + for pat in tag_patterns + ): + if key_alias: + affected.append(key_alias) + else: + unnamed_count += 1 + return affected, unnamed_count + + +def _filter_teams_by_tags(teams: list, tag_patterns: list) -> tuple: + """Filter pre-fetched team rows whose metadata.tags match any patterns. + + Returns (named_aliases, unnamed_count). + """ + from litellm.proxy.auth.route_checks import RouteChecks + + affected: list = [] + unnamed_count = 0 + for team in teams: + team_alias = team.team_alias or "" + team_tags = _get_tags_from_metadata(team.metadata) + if team_tags and any( + RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat) + for tag in team_tags + for pat in tag_patterns + ): + if team_alias: + affected.append(team_alias) + else: + unnamed_count += 1 + return affected, unnamed_count + + +async def _find_affected_by_team_patterns( + prisma_client: object, + all_teams: list, + team_patterns: list, + existing_teams: list, + existing_keys: list, +) -> tuple: + """Filter pre-fetched teams by alias patterns, then fetch their keys. + + Returns (new_teams, new_keys, unnamed_keys_count). + """ + from litellm.proxy.auth.route_checks import RouteChecks + + new_teams: list = [] + matched_team_ids: list = [] + + for team in all_teams: + team_alias = team.team_alias or "" + if team_alias and any( + RouteChecks._route_matches_wildcard_pattern(route=team_alias, pattern=pat) + for pat in team_patterns + ): + if team_alias not in existing_teams: + new_teams.append(team_alias) + matched_team_ids.append(str(team.team_id)) + + new_keys: list = [] + unnamed_keys_count = 0 + if matched_team_ids: + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where={"team_id": {"in": matched_team_ids}}, + order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + for key in keys: + key_alias = key.key_alias or "" + if key_alias: + if key_alias not in existing_keys: + new_keys.append(key_alias) + else: + unnamed_keys_count += 1 + + return new_teams, new_keys, unnamed_keys_count + + +async def _find_affected_keys_by_alias( + prisma_client: object, key_patterns: list, existing_keys: list +) -> list: + """Find keys whose alias matches the given patterns.""" + from litellm.proxy.auth.route_checks import RouteChecks + + affected: list = [] + + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where=_build_alias_where("key_alias", key_patterns), + order={"created_at": "desc"}, take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + for key in keys: + key_alias = key.key_alias or "" + if key_alias and any( + RouteChecks._route_matches_wildcard_pattern(route=key_alias, pattern=pat) + for pat in key_patterns + ): + if key_alias not in existing_keys: + affected.append(key_alias) + return affected + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Resolve Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/resolve", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyResolveResponse, +) +async def resolve_policies_for_context( + request: PolicyResolveRequest, + force_sync: bool = Query( + default=False, + description="Force a DB sync before resolving. Default uses in-memory cache.", + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Resolve which policies and guardrails apply for a given context. + + Use this endpoint to debug "what guardrails would apply to a request + with this team/key/model/tags combination?" + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/resolve" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "tags": ["healthcare"], + "model": "gpt-4" + }' + ``` + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Only sync from DB when explicitly requested; otherwise use in-memory cache + if force_sync: + await get_policy_registry().sync_policies_from_db(prisma_client) + await get_attachment_registry().sync_attachments_from_db(prisma_client) + + # Build context from request + context = PolicyMatchContext( + team_alias=request.team_alias, + key_alias=request.key_alias, + model=request.model, + tags=request.tags, + ) + + # Get matching policies with reasons + match_results = get_attachment_registry().get_attached_policies_with_reasons( + context=context + ) + + if not match_results: + return PolicyResolveResponse( + effective_guardrails=[], + matched_policies=[], + ) + + # Filter by conditions + policy_names = [r["policy_name"] for r in match_results] + applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions( + policy_names=policy_names, + context=context, + ) + + # Resolve guardrails for each applied policy + matched_policies = [] + all_guardrails: set = set() + for result in match_results: + pname = result["policy_name"] + if pname not in applied_policy_names: + continue + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name=pname, + policies=get_policy_registry().get_all_policies(), + context=context, + ) + guardrails = resolved.guardrails if resolved else [] + all_guardrails.update(guardrails) + matched_policies.append( + PolicyMatchDetail( + policy_name=pname, + matched_via=result["matched_via"], + guardrails_added=guardrails, + ) + ) + + return PolicyResolveResponse( + effective_guardrails=sorted(all_guardrails), + matched_policies=matched_policies, + ) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error resolving policies: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ───────────────────────────────────────────────────────────────────────────── +# Attachment Impact Estimation Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/attachments/estimate-impact", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=AttachmentImpactResponse, +) +async def estimate_attachment_impact( + request: PolicyAttachmentCreateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Estimate how many keys and teams would be affected by a policy attachment. + + Use this before creating an attachment to preview the blast radius. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/attachments/estimate-impact" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "policy_name": "hipaa-compliance", + "tags": ["healthcare", "health-*"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # If global scope, everything is affected — not useful to enumerate + if request.scope == "*": + return AttachmentImpactResponse( + affected_keys_count=-1, + affected_teams_count=-1, + sample_keys=["(global scope — affects all keys)"], + sample_teams=["(global scope — affects all teams)"], + ) + + affected_keys: list = [] + affected_teams: list = [] + unnamed_keys = 0 + unnamed_teams = 0 + + tag_patterns = request.tags or [] + team_patterns = request.teams or [] + + # Fetch teams once — reused by both tag-based and alias-based lookups + all_teams: list = [] + if tag_patterns or team_patterns: + all_teams = await _fetch_all_teams(prisma_client) + + # Tag-based impact + if tag_patterns: + keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore + where={}, order={"created_at": "desc"}, + take=MAX_POLICY_ESTIMATE_IMPACT_ROWS, + ) + affected_keys, unnamed_keys = _filter_keys_by_tags(keys, tag_patterns) + affected_teams, unnamed_teams = _filter_teams_by_tags( + all_teams, tag_patterns, + ) + + # Team-based impact (alias matching + keys belonging to those teams) + if team_patterns: + new_teams, new_keys, new_unnamed = await _find_affected_by_team_patterns( + prisma_client, all_teams, team_patterns, + affected_teams, affected_keys, + ) + affected_teams.extend(new_teams) + affected_keys.extend(new_keys) + unnamed_keys += new_unnamed + + # Key-based impact (direct alias matching) + key_patterns = request.keys or [] + if key_patterns: + new_keys = await _find_affected_keys_by_alias( + prisma_client, key_patterns, affected_keys, + ) + affected_keys.extend(new_keys) + + return AttachmentImpactResponse( + affected_keys_count=len(affected_keys) + unnamed_keys, + affected_teams_count=len(affected_teams) + unnamed_teams, + unnamed_keys_count=unnamed_keys, + unnamed_teams_count=unnamed_teams, + sample_keys=affected_keys[:10], + sample_teams=affected_teams[:10], + ) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error estimating attachment impact: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 294294cdda7..2130aed7770 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -427,6 +427,9 @@ def generate_feedback_box(): router as pass_through_router, ) from litellm.proxy.policy_engine.policy_endpoints import router as policy_crud_router +from litellm.proxy.policy_engine.policy_resolve_endpoints import ( + router as policy_resolve_router, +) from litellm.proxy.prompts.prompt_endpoints import router as prompts_router from litellm.proxy.public_endpoints import router as public_endpoints_router from litellm.proxy.rag_endpoints.endpoints import router as rag_router @@ -11746,6 +11749,7 @@ async def get_routes(): app.include_router(guardrails_router) app.include_router(policy_router) app.include_router(policy_crud_router) +app.include_router(policy_resolve_router) app.include_router(search_tool_management_router) app.include_router(prompts_router) app.include_router(callback_management_endpoints_router) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 1750efed92c..37ed0182663 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -911,6 +911,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index bc54c3eb36b..42490c2eddc 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -19,6 +19,7 @@ PolicyScope, ) from litellm.types.proxy.policy_engine.resolver_types import ( + AttachmentImpactResponse, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -30,6 +31,9 @@ PolicyListDBResponse, PolicyListResponse, PolicyMatchContext, + PolicyMatchDetail, + PolicyResolveRequest, + PolicyResolveResponse, PolicyScopeResponse, PolicySummaryItem, PolicyTestResponse, @@ -75,4 +79,9 @@ "PolicyAttachmentCreateRequest", "PolicyAttachmentDBResponse", "PolicyAttachmentListResponse", + # Resolve types + "PolicyResolveRequest", + "PolicyResolveResponse", + "PolicyMatchDetail", + "AttachmentImpactResponse", ] diff --git a/litellm/types/proxy/policy_engine/policy_types.py b/litellm/types/proxy/policy_engine/policy_types.py index 1c01f89e8b4..f221ba7e038 100644 --- a/litellm/types/proxy/policy_engine/policy_types.py +++ b/litellm/types/proxy/policy_engine/policy_types.py @@ -73,13 +73,15 @@ class PolicyScope(BaseModel): Used internally by PolicyAttachment to define WHERE a policy applies. Scope Fields: - | Field | What it matches | Wildcard support | - |--------|-----------------|----------------------| - | teams | Team aliases | *, healthcare-* | - | keys | Key aliases | *, dev-key-* | - | models | Model names | *, bedrock/*, gpt-* | - - If a field is None or empty, it defaults to matching everything (["*"]). + | Field | What it matches | Wildcard support | Default behavior | + |--------|-----------------|----------------------|---------------------| + | teams | Team aliases | *, healthcare-* | None → matches all | + | keys | Key aliases | *, dev-key-* | None → matches all | + | models | Model names | *, bedrock/*, gpt-* | None → matches all | + | tags | Key/team tags | *, health-*, prod-* | None → not checked | + + If teams/keys/models is None or empty, it defaults to matching everything (["*"]). + If tags is None or empty, the tag dimension is NOT checked (matches all). A request must match ALL specified scope fields for the attachment to apply. """ @@ -95,6 +97,10 @@ class PolicyScope(BaseModel): default=None, description="Model names or wildcard patterns. Use '*' for all models.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns to match against key/team tags. Supports wildcards (e.g., health-*).", + ) model_config = ConfigDict(extra="forbid") @@ -110,6 +116,14 @@ def get_models(self) -> List[str]: """Returns models list, defaulting to ['*'] if not specified.""" return self.models if self.models else ["*"] + def get_tags(self) -> List[str]: + """Returns tags list, defaulting to empty list if not specified. + + Unlike teams/keys/models, empty tags means 'do not check tags' + rather than 'match all'. This is because tags are opt-in scoping. + """ + return self.tags if self.tags else [] + # ───────────────────────────────────────────────────────────────────────────── # Policy Guardrails @@ -266,6 +280,10 @@ class PolicyAttachment(BaseModel): default=None, description="Model names or patterns this attachment applies to.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns this attachment applies to. Supports wildcards (e.g., health-*).", + ) model_config = ConfigDict(extra="forbid") @@ -281,6 +299,7 @@ def to_policy_scope(self) -> PolicyScope: teams=self.teams, keys=self.keys, models=self.models, + tags=self.tags, ) diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py index 9488b8b0841..0c2c7336f8a 100644 --- a/litellm/types/proxy/policy_engine/resolver_types.py +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -30,6 +30,10 @@ class PolicyMatchContext(BaseModel): default=None, description="Model name from the request.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tags from key/team metadata.", + ) model_config = ConfigDict(extra="forbid") @@ -65,6 +69,7 @@ class PolicyScopeResponse(BaseModel): teams: List[str] = Field(default_factory=list) keys: List[str] = Field(default_factory=list) models: List[str] = Field(default_factory=list) + tags: List[str] = Field(default_factory=list) class PolicyGuardrailsResponse(BaseModel): @@ -242,6 +247,10 @@ class PolicyAttachmentCreateRequest(BaseModel): default=None, description="Model names or patterns this attachment applies to.", ) + tags: Optional[List[str]] = Field( + default=None, + description="Tag patterns this attachment applies to. Supports wildcards (e.g., health-*).", + ) class PolicyAttachmentDBResponse(BaseModel): @@ -253,6 +262,7 @@ class PolicyAttachmentDBResponse(BaseModel): teams: List[str] = Field(default_factory=list, description="Team patterns.") keys: List[str] = Field(default_factory=list, description="Key patterns.") models: List[str] = Field(default_factory=list, description="Model patterns.") + tags: List[str] = Field(default_factory=list, description="Tag patterns.") created_at: Optional[datetime] = Field( default=None, description="When the attachment was created." ) @@ -274,3 +284,81 @@ class PolicyAttachmentListResponse(BaseModel): default_factory=list, description="List of policy attachments." ) total_count: int = Field(default=0, description="Total number of attachments.") + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Resolve Types +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyResolveRequest(BaseModel): + """Request body for resolving effective policies/guardrails for a context.""" + + team_alias: Optional[str] = Field( + default=None, description="Team alias to resolve for." + ) + key_alias: Optional[str] = Field( + default=None, description="Key alias to resolve for." + ) + model: Optional[str] = Field( + default=None, description="Model name to resolve for." + ) + tags: Optional[List[str]] = Field( + default=None, description="Tags to resolve for." + ) + + +class PolicyMatchDetail(BaseModel): + """Details about why a specific policy matched.""" + + policy_name: str = Field(description="Name of the matched policy.") + matched_via: str = Field( + description="How the policy was matched (e.g., 'tag:healthcare', 'team:health-team', 'scope:*')." + ) + guardrails_added: List[str] = Field( + default_factory=list, + description="Guardrails this policy contributes.", + ) + + +class PolicyResolveResponse(BaseModel): + """Response for resolving effective policies/guardrails for a context.""" + + effective_guardrails: List[str] = Field( + default_factory=list, + description="Final list of guardrails that would be applied.", + ) + matched_policies: List[PolicyMatchDetail] = Field( + default_factory=list, + description="Details about each matched policy and why it matched.", + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Attachment Impact Estimation Types +# ───────────────────────────────────────────────────────────────────────────── + + +class AttachmentImpactResponse(BaseModel): + """Response for estimating the impact of a policy attachment.""" + + affected_keys_count: int = Field( + default=0, description="Number of keys that would be affected (named + unnamed)." + ) + affected_teams_count: int = Field( + default=0, description="Number of teams that would be affected (named + unnamed)." + ) + unnamed_keys_count: int = Field( + default=0, description="Number of affected keys without an alias." + ) + unnamed_teams_count: int = Field( + default=0, description="Number of affected teams without an alias." + ) + sample_keys: List[str] = Field( + default_factory=list, + description="Sample of affected key aliases (up to 10).", + ) + sample_teams: List[str] = Field( + default_factory=list, + description="Sample of affected team aliases (up to 10).", + ) diff --git a/schema.prisma b/schema.prisma index 9a87a491cf7..4329f939a7b 100644 --- a/schema.prisma +++ b/schema.prisma @@ -913,6 +913,7 @@ model LiteLLM_PolicyAttachmentTable { teams String[] @default([]) // Team aliases or patterns keys String[] @default([]) // Key aliases or patterns models String[] @default([]) // Model names or patterns + tags String[] @default([]) // Tag patterns (e.g., ["healthcare", "prod-*"]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py index 1ed956fe99f..c853253eedd 100644 --- a/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py +++ b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py @@ -192,6 +192,139 @@ def test_combined_team_and_model_attachment(self): assert "strict-policy" not in registry.get_attached_policies(context_wrong_team) +class TestTagBasedAttachments: + """Test tag-based policy attachment matching.""" + + def test_tag_matching_and_wildcards(self): + """Test tag matching: exact match, wildcard match, and no-match cases.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "hipaa-policy", "tags": ["healthcare"]}, + {"policy": "health-policy", "tags": ["health-*"]}, + ]) + + # Exact tag match + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + attached = registry.get_attached_policies(context) + assert "hipaa-policy" in attached + assert "health-policy" not in attached # "healthcare" doesn't match "health-*" + + # Wildcard tag match + context_wildcard = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["health-prod"], + ) + attached_wildcard = registry.get_attached_policies(context_wildcard) + assert "health-policy" in attached_wildcard + assert "hipaa-policy" not in attached_wildcard + + # No match — wrong tag + context_no_match = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert registry.get_attached_policies(context_no_match) == [] + + # No match — no tags on context + context_no_tags = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=None, + ) + assert registry.get_attached_policies(context_no_tags) == [] + + def test_tag_combined_with_team(self): + """Test attachment with both tags and teams requires BOTH to match (AND logic).""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "strict-policy", "teams": ["team-a"], "tags": ["healthcare"]}, + ]) + + # Match — both team and tag match + context = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert "strict-policy" in registry.get_attached_policies(context) + + # No match — tag matches but team doesn't + context_wrong_team = PolicyMatchContext( + team_alias="team-b", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_team) + + # No match — team matches but tag doesn't + context_wrong_tag = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_tag) + + +class TestMatchAttribution: + """Test get_attached_policies_with_reasons — the attribution logic that + powers response headers and the Policy Simulator UI.""" + + def test_reasons_for_global_tag_team_attachments(self): + """Test that match reasons correctly describe WHY each policy matched.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "global-baseline", "scope": "*"}, + {"policy": "hipaa-policy", "tags": ["healthcare"]}, + {"policy": "team-policy", "teams": ["health-team"]}, + ]) + + context = PolicyMatchContext( + team_alias="health-team", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + results = registry.get_attached_policies_with_reasons(context) + reasons = {r["policy_name"]: r["matched_via"] for r in results} + + assert reasons["global-baseline"] == "scope:*" + assert "tag:healthcare" in reasons["hipaa-policy"] + assert "team:health-team" in reasons["team-policy"] + + def test_tags_only_attachment_matches_any_team_key_model(self): + """Test the primary use case: tags-only attachment with no team/key/model + constraint matches any request that carries the tag.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "hipaa-guardrails", "tags": ["healthcare"]}, + ]) + + # Should match regardless of team/key/model + context = PolicyMatchContext( + team_alias="random-team", key_alias="random-key", model="claude-3", + tags=["healthcare"], + ) + attached = registry.get_attached_policies(context) + assert "hipaa-guardrails" in attached + + # Should not match without the tag + context_no_tag = PolicyMatchContext( + team_alias="random-team", key_alias="random-key", model="claude-3", + ) + assert registry.get_attached_policies(context_no_tag) == [] + + def test_attachment_with_no_scope_matches_everything(self): + """Test that an attachment with no scope/teams/keys/models/tags + matches everything because teams/keys/models default to ['*'].""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "catch-all"}, + ]) + + context = PolicyMatchContext( + team_alias="any-team", key_alias="any-key", model="gpt-4", + ) + attached = registry.get_attached_policies(context) + assert "catch-all" in attached + + class TestAttachmentRegistrySingleton: """Test global singleton behavior.""" diff --git a/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py index c011f31af6a..fccb26496ac 100644 --- a/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py +++ b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py @@ -64,6 +64,70 @@ def test_scope_global_wildcard(self): assert PolicyMatcher.scope_matches(scope, context) is True +class TestPolicyMatcherScopeMatchingWithTags: + """Test scope matching with tag patterns.""" + + def test_scope_tag_matching(self): + """Test scope tag matching: exact, wildcard, no-match, and empty context tags.""" + # Exact match + scope = PolicyScope(teams=["*"], keys=["*"], models=["*"], tags=["healthcare"]) + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["healthcare", "internal"], + ) + assert PolicyMatcher.scope_matches(scope, context) is True + + # Wildcard match + scope_wc = PolicyScope(teams=["*"], keys=["*"], models=["*"], tags=["health-*"]) + context_wc = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["health-prod"], + ) + assert PolicyMatcher.scope_matches(scope_wc, context_wc) is True + + # No match — wrong tag + context_wrong = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong) is False + + # No match — context has no tags + context_none = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4", tags=None, + ) + assert PolicyMatcher.scope_matches(scope, context_none) is False + + # Scope without tags matches any context (opt-in semantics) + scope_no_tags = PolicyScope(teams=["*"], keys=["*"], models=["*"]) + assert PolicyMatcher.scope_matches(scope_no_tags, context) is True + + def test_scope_tags_and_team_combined(self): + """Test scope with both tags and team — both must match (AND logic).""" + scope = PolicyScope(teams=["team-a"], keys=["*"], models=["*"], tags=["healthcare"]) + + # Both match + context_both = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert PolicyMatcher.scope_matches(scope, context_both) is True + + # Tag matches, team doesn't + context_wrong_team = PolicyMatchContext( + team_alias="team-b", key_alias="key", model="gpt-4", + tags=["healthcare"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong_team) is False + + # Team matches, tag doesn't + context_wrong_tag = PolicyMatchContext( + team_alias="team-a", key_alias="key", model="gpt-4", + tags=["finance"], + ) + assert PolicyMatcher.scope_matches(scope, context_wrong_tag) is False + + class TestPolicyMatcherWithAttachments: """Test getting matching policies via attachments.""" diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 023c88c5e83..ecf97cea2dc 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5652,6 +5652,68 @@ export const getResolvedGuardrails = async (accessToken: string, policyId: strin } }; +export const resolvePoliciesCall = async ( + accessToken: string, + context: { team_alias?: string; key_alias?: string; model?: string; tags?: string[] } +) => { + try { + const url = proxyBaseUrl + ? `${proxyBaseUrl}/policies/resolve` + : `/policies/resolve`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(context), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return await response.json(); + } catch (error) { + console.error("Failed to resolve policies:", error); + throw error; + } +}; + +export const estimateAttachmentImpactCall = async ( + accessToken: string, + attachmentData: any +) => { + try { + const url = proxyBaseUrl + ? `${proxyBaseUrl}/policies/attachments/estimate-impact` + : `/policies/attachments/estimate-impact`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(attachmentData), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + return await response.json(); + } catch (error) { + console.error("Failed to estimate attachment impact:", error); + throw error; + } +}; + export const getPromptsList = async (accessToken: string): Promise => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/prompts/list` : `/prompts/list`; diff --git a/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx b/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx index 9198eda8a94..7426f4fefa2 100644 --- a/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx +++ b/ui/litellm-dashboard/src/components/policies/add_attachment_form.tsx @@ -1,10 +1,12 @@ import React, { useState, useEffect } from "react"; import { Modal, Form, Select, Radio, Divider, Typography } from "antd"; import { Button } from "@tremor/react"; -import { Policy, PolicyAttachmentCreateRequest } from "./types"; -import { teamListCall, keyInfoCall, modelAvailableCall } from "../networking"; +import { Policy } from "./types"; +import { teamListCall, keyListCall, modelAvailableCall, estimateAttachmentImpactCall } from "../networking"; import NotificationsManager from "../molecules/notifications_manager"; import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { buildAttachmentData } from "./build_attachment_data"; +import ImpactPreviewAlert from "./impact_preview_alert"; const { Text } = Typography; @@ -34,6 +36,8 @@ const AddAttachmentForm: React.FC = ({ const [isLoadingTeams, setIsLoadingTeams] = useState(false); const [isLoadingKeys, setIsLoadingKeys] = useState(false); const [isLoadingModels, setIsLoadingModels] = useState(false); + const [isEstimating, setIsEstimating] = useState(false); + const [impactResult, setImpactResult] = useState(null); const { userId, userRole } = useAuthorized(); useEffect(() => { @@ -46,33 +50,30 @@ const AddAttachmentForm: React.FC = ({ const loadTeamsKeysAndModels = async () => { if (!accessToken) return; - // Load teams + // Load teams — teamListCall returns a plain array of team objects setIsLoadingTeams(true); try { - // Pass null for organizationID since we're loading all teams the user has access to const teamsResponse = await teamListCall(accessToken, null, userId); - if (teamsResponse?.data) { - const teamAliases = teamsResponse.data - .map((t: any) => t.team_alias) - .filter(Boolean); - setAvailableTeams(teamAliases); - } + const teamsArray = Array.isArray(teamsResponse) ? teamsResponse : (teamsResponse?.data || []); + const teamAliases = teamsArray + .map((t: any) => t.team_alias) + .filter(Boolean); + setAvailableTeams(teamAliases); } catch (error) { console.error("Failed to load teams:", error); } finally { setIsLoadingTeams(false); } - // Load keys + // Load keys — keyListCall returns {keys: [...], total_count, ...} setIsLoadingKeys(true); try { - const keysResponse = await keyInfoCall(accessToken, []); - if (keysResponse?.data) { - const keyAliases = keysResponse.data - .map((k: any) => k.key_alias) - .filter(Boolean); - setAvailableKeys(keyAliases); - } + const keysResponse = await keyListCall(accessToken, null, null, null, null, null, 1, 100); + const keysArray = keysResponse?.keys || keysResponse?.data || []; + const keyAliases = keysArray + .map((k: any) => k.key_alias) + .filter(Boolean); + setAvailableKeys(keyAliases); } catch (error) { console.error("Failed to load keys:", error); } finally { @@ -83,12 +84,11 @@ const AddAttachmentForm: React.FC = ({ setIsLoadingModels(true); try { const modelsResponse = await modelAvailableCall(accessToken, userId || "", userRole || ""); - if (modelsResponse?.data) { - const modelIds = modelsResponse.data - .map((m: any) => m.id || m.model_name) - .filter(Boolean); - setAvailableModels(modelIds); - } + const modelsArray = modelsResponse?.data || (Array.isArray(modelsResponse) ? modelsResponse : []); + const modelIds = modelsArray + .map((m: any) => m.id || m.model_name) + .filter(Boolean); + setAvailableModels(modelIds); } catch (error) { console.error("Failed to load models:", error); } finally { @@ -99,6 +99,28 @@ const AddAttachmentForm: React.FC = ({ const resetForm = () => { form.resetFields(); setScopeType("global"); + setImpactResult(null); + }; + + const getAttachmentData = () => buildAttachmentData(form.getFieldsValue(true), scopeType); + + const handlePreviewImpact = async () => { + if (!accessToken) return; + try { + await form.validateFields(["policy_name"]); + } catch { + return; + } + setIsEstimating(true); + try { + const data = getAttachmentData(); + const result = await estimateAttachmentImpactCall(accessToken, data); + setImpactResult(result); + } catch (error) { + console.error("Failed to estimate impact:", error); + } finally { + setIsEstimating(false); + } }; const handleClose = () => { @@ -110,30 +132,12 @@ const AddAttachmentForm: React.FC = ({ try { setIsSubmitting(true); await form.validateFields(); - const values = form.getFieldsValue(true); if (!accessToken) { throw new Error("No access token available"); } - const data: PolicyAttachmentCreateRequest = { - policy_name: values.policy_name, - }; - - if (scopeType === "global") { - data.scope = "*"; - } else { - if (values.teams && values.teams.length > 0) { - data.teams = values.teams; - } - if (values.keys && values.keys.length > 0) { - data.keys = values.keys; - } - if (values.models && values.models.length > 0) { - data.models = values.models; - } - } - + const data = getAttachmentData(); await createAttachment(accessToken, data); NotificationsManager.success("Attachment created successfully"); @@ -195,8 +199,8 @@ const AddAttachmentForm: React.FC = ({ value={scopeType} onChange={(e) => setScopeType(e.target.value)} > + Specific (teams, keys, models, or tags) Global (applies to all requests) - Specific (teams, keys, or models) @@ -267,13 +271,41 @@ const AddAttachmentForm: React.FC = ({ style={{ width: "100%" }} /> + + + Matches tags from key/team metadata.tags or tags passed dynamically in the request body. Use * as a suffix wildcard (e.g., prod-* matches prod-us, prod-eu). + + } + > + ({ label: t, value: t }))} + filterOption={(input, option) => + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + /> + + + ({ label: m, value: m }))} + filterOption={(input, option) => + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + /> + + +