diff --git a/.gitignore b/.gitignore index 9d9e28dc466..0248d68c1e1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .python-version .venv +.venv_policy_test .env .newenv newenv/* diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 71b398c59a4..52170f2f3e6 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -124,7 +124,7 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") - router_settings Json? @default("{}") + router_settings Json? @default("{}") team_member_permissions String[] @default([]) model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2fd03ac2128..0516a4aaa66 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -73,6 +73,7 @@ class SupportedDBObjectType(str, enum.Enum): MODELS = "models" MCP = "mcp" GUARDRAILS = "guardrails" + POLICIES = "policies" VECTOR_STORES = "vector_stores" PASS_THROUGH_ENDPOINTS = "pass_through_endpoints" PROMPTS = "prompts" @@ -844,6 +845,7 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): model_rpm_limit: Optional[dict] = None model_tpm_limit: Optional[dict] = None guardrails: Optional[List[str]] = None + policies: Optional[List[str]] = None prompts: Optional[List[str]] = None blocked: Optional[bool] = None aliases: Optional[dict] = {} @@ -1477,6 +1479,7 @@ class NewTeamRequest(TeamBase): model_aliases: Optional[dict] = None tags: Optional[list] = None guardrails: Optional[List[str]] = None + policies: Optional[List[str]] = None prompts: Optional[List[str]] = None object_permission: Optional[LiteLLM_ObjectPermissionBase] = None allowed_passthrough_routes: Optional[list] = None @@ -1526,6 +1529,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase): blocked: Optional[bool] = None budget_duration: Optional[str] = None guardrails: Optional[List[str]] = None + policies: Optional[List[str]] = None """ team_id: str # required @@ -1541,6 +1545,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase): tags: Optional[list] = None model_aliases: Optional[dict] = None guardrails: Optional[List[str]] = None + policies: Optional[List[str]] = None object_permission: Optional[LiteLLM_ObjectPermissionBase] = None team_member_budget: Optional[float] = None team_member_budget_duration: Optional[str] = None @@ -3499,6 +3504,7 @@ class PassThroughEndpointLoggingTypedDict(TypedDict): LiteLLM_ManagementEndpoint_MetadataFields_Premium = [ "guardrails", + "policies", "tags", "team_member_key_duration", "prompts", diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 32cddc0ef58..1d3ef2e10c2 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1311,6 +1311,118 @@ def _add_guardrails_from_key_or_team_metadata( data[metadata_variable_name]["guardrails"] = list(combined_guardrails) +def _add_guardrails_from_policies_in_metadata( + key_metadata: Optional[dict], + team_metadata: Optional[dict], + data: dict, + metadata_variable_name: str, +) -> None: + """ + Helper to resolve guardrails from policies attached to key/team metadata. + + This function: + 1. Gets policy names from key and team metadata + 2. Resolves guardrails from those policies (including inheritance) + 3. Adds resolved guardrails to request metadata + + Args: + key_metadata: The key metadata dictionary to check for policies + team_metadata: The team metadata dictionary to check for policies + data: The request data to update + metadata_variable_name: The name of the metadata field in data + """ + from litellm._logging import verbose_proxy_logger + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + from litellm.proxy.utils import _premium_user_check + from litellm.types.proxy.policy_engine import PolicyMatchContext + + # Collect policy names from key and team metadata + policy_names: set = set() + + # Add key-level policies first + if key_metadata and "policies" in key_metadata: + if ( + isinstance(key_metadata["policies"], list) + and len(key_metadata["policies"]) > 0 + ): + _premium_user_check() + policy_names.update(key_metadata["policies"]) + + # Add team-level policies + if team_metadata and "policies" in team_metadata: + if ( + isinstance(team_metadata["policies"], list) + and len(team_metadata["policies"]) > 0 + ): + _premium_user_check() + policy_names.update(team_metadata["policies"]) + + if not policy_names: + return + + verbose_proxy_logger.debug( + f"Policy engine: resolving guardrails from key/team policies: {policy_names}" + ) + + # Check if policy registry is initialized + registry = get_policy_registry() + if not registry.is_initialized(): + verbose_proxy_logger.debug( + "Policy engine not initialized, skipping policy resolution from metadata" + ) + return + + # Build context for policy resolution (model from request data) + context = PolicyMatchContext(model=data.get("model")) + + # Get all policies from registry + all_policies = registry.get_all_policies() + + # Resolve guardrails from the specified policies + resolved_guardrails: set = set() + for policy_name in policy_names: + if registry.has_policy(policy_name): + resolved_policy = PolicyResolver.resolve_policy_guardrails( + policy_name=policy_name, + policies=all_policies, + context=context, + ) + resolved_guardrails.update(resolved_policy.guardrails) + verbose_proxy_logger.debug( + f"Policy engine: resolved guardrails from policy '{policy_name}': {resolved_policy.guardrails}" + ) + else: + verbose_proxy_logger.warning( + f"Policy engine: policy '{policy_name}' not found in registry" + ) + + if not resolved_guardrails: + return + + # Add resolved guardrails to request metadata + if metadata_variable_name not in data: + data[metadata_variable_name] = {} + + existing_guardrails = data[metadata_variable_name].get("guardrails", []) + if not isinstance(existing_guardrails, list): + existing_guardrails = [] + + # Combine existing guardrails with policy-resolved guardrails (no duplicates) + combined = set(existing_guardrails) + combined.update(resolved_guardrails) + data[metadata_variable_name]["guardrails"] = list(combined) + + # Store applied policies in metadata for tracking + if "applied_policies" not in data[metadata_variable_name]: + data[metadata_variable_name]["applied_policies"] = [] + data[metadata_variable_name]["applied_policies"].extend(list(policy_names)) + + verbose_proxy_logger.debug( + f"Policy engine: added guardrails from key/team policies to request metadata: {list(resolved_guardrails)}" + ) + + def move_guardrails_to_metadata( data: dict, _metadata_variable_name: str, @@ -1321,6 +1433,7 @@ def move_guardrails_to_metadata( - If guardrails set on API Key metadata then sets guardrails on request metadata - If guardrails not set on API key, then checks request metadata + - Adds guardrails from policies attached to key/team metadata - Adds guardrails from policy engine based on team/key/model context """ # Check key-level guardrails @@ -1331,6 +1444,16 @@ def move_guardrails_to_metadata( metadata_variable_name=_metadata_variable_name, ) + ######################################################################################### + # Add guardrails from policies attached to key/team metadata + ######################################################################################### + _add_guardrails_from_policies_in_metadata( + key_metadata=user_api_key_dict.metadata, + team_metadata=user_api_key_dict.team_metadata, + data=data, + metadata_variable_name=_metadata_variable_name, + ) + ######################################################################################### # Add guardrails from policy engine based on team/key/model context ######################################################################################### diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 2672c41893d..9a986326c03 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -14,7 +14,6 @@ import asyncio import json import traceback -from litellm._uuid import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Union, cast @@ -23,6 +22,7 @@ import litellm from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.hooks.user_management_event_hooks import UserManagementEventHooks @@ -355,6 +355,7 @@ async def new_user( - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request- - blocked: Optional[bool] - [Not Implemented Yet] Whether the user is blocked. - guardrails: Optional[List[str]] - [Not Implemented Yet] List of active guardrails for the user + - policies: Optional[List[str]] - List of policy names to apply to the user. Policies define guardrails, conditions, and inheritance rules. - permissions: Optional[dict] - [Not Implemented Yet] User-specific permissions, eg. turning off pii masking. - metadata: Optional[dict] - Metadata for user, store information for user. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. @@ -1060,6 +1061,7 @@ async def user_update( - allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request- - blocked: Optional[bool] - [Not Implemented Yet] Whether the user is blocked. - guardrails: Optional[List[str]] - [Not Implemented Yet] List of active guardrails for the user + - policies: Optional[List[str]] - List of policy names to apply to the user. Policies define guardrails, conditions, and inheritance rules. - permissions: Optional[dict] - [Not Implemented Yet] User-specific permissions, eg. turning off pii masking. - metadata: Optional[dict] - Metadata for user, store information for user. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index e40a44edf5c..ab87e862ea6 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -14,11 +14,11 @@ import json import secrets import traceback -import yaml from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Literal, Optional, Tuple, cast -from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + import fastapi +import yaml from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status import litellm @@ -31,6 +31,7 @@ UI_SESSION_TOKEN_TEAM_ID, ) from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._experimental.mcp_server.db import ( rotate_mcp_server_credentials_master_key, ) @@ -1010,6 +1011,7 @@ async def generate_key_fn( - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - guardrails: Optional[List[str]] - List of active guardrails for the key + - policies: Optional[List[str]] - List of policy names to apply to the key. Policies define guardrails, conditions, and inheritance rules. - disable_global_guardrails: Optional[bool] - Whether to disable global guardrails for the key. - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} - model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget. @@ -1480,6 +1482,7 @@ async def update_key_fn( - permissions: Optional[dict] - Key-specific permissions - send_invite_email: Optional[bool] - Send invite email to user_id - guardrails: Optional[List[str]] - List of active guardrails for the key + - policies: Optional[List[str]] - List of policy names to apply to the key. Policies define guardrails, conditions, and inheritance rules. - disable_global_guardrails: Optional[bool] - Whether to disable global guardrails for the key. - prompts: Optional[List[str]] - List of prompts that the key is allowed to use. - blocked: Optional[bool] - Whether the key is blocked @@ -2077,6 +2080,7 @@ async def generate_key_helper_fn( # noqa: PLR0915 model_rpm_limit: Optional[dict] = None, model_tpm_limit: Optional[dict] = None, guardrails: Optional[list] = None, + policies: Optional[list] = None, prompts: Optional[list] = None, teams: Optional[list] = None, organization_id: Optional[str] = None, @@ -2139,6 +2143,9 @@ async def generate_key_helper_fn( # noqa: PLR0915 if guardrails is not None: metadata = metadata or {} metadata["guardrails"] = guardrails + if policies is not None: + metadata = metadata or {} + metadata["policies"] = policies if prompts is not None: metadata = metadata or {} metadata["prompts"] = prompts diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 4d313fb1235..c77b60649ab 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -22,11 +22,13 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( BlockTeamRequest, CommonProxyErrors, DeleteTeamRequest, LiteLLM_AuditLogs, + LiteLLM_DeletedTeamTable, LiteLLM_ManagementEndpoint_MetadataFields, LiteLLM_ManagementEndpoint_MetadataFields_Premium, LiteLLM_ModelTable, @@ -34,7 +36,6 @@ LiteLLM_OrganizationTableWithMembers, LiteLLM_TeamMembership, LiteLLM_TeamTable, - LiteLLM_DeletedTeamTable, LiteLLM_TeamTableCachedObj, LiteLLM_UserTable, LiteLLM_VerificationToken, @@ -102,7 +103,7 @@ TeamMemberAddResult, UpdateTeamMemberPermissionsRequest, ) -from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + router = APIRouter() @@ -689,6 +690,7 @@ async def new_team( # noqa: PLR0915 - organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`. - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) + - policies: Optional[List[str]] - Policies for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails/guardrail_policies) - disable_global_guardrails: Optional[bool] - Whether to disable global guardrails for the key. - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"], "agents": ["agent_1", "agent_2"], "agent_access_groups": ["dev_group"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. @@ -1228,6 +1230,7 @@ async def update_team( # noqa: PLR0915 - organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`. - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) + - policies: Optional[List[str]] - Policies for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails/guardrail_policies) - disable_global_guardrails: Optional[bool] - Whether to disable global guardrails for the key. - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"], "agents": ["agent_1", "agent_2"], "agent_access_groups": ["dev_group"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py index b5d6f2fb745..4a335b54747 100644 --- a/litellm/proxy/policy_engine/attachment_registry.py +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -5,14 +5,20 @@ This allows the same policy to be attached to multiple scopes. """ -from typing import Any, Dict, List, Optional +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( PolicyAttachment, + PolicyAttachmentCreateRequest, + PolicyAttachmentDBResponse, PolicyMatchContext, ) +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient + class AttachmentRegistry: """ @@ -188,6 +194,238 @@ def remove_attachments_for_policy(self, policy_name: str) -> int: ) return removed_count + def remove_attachment_by_id(self, attachment_id: str) -> bool: + """ + Remove an attachment by its ID (for DB-synced attachments). + + Args: + attachment_id: The ID of the attachment to remove + + Returns: + True if removed, False if not found + """ + # Note: In-memory attachments don't have IDs, so this is primarily + # for consistency after DB operations + return False + + # ───────────────────────────────────────────────────────────────────────── + # Database CRUD Methods + # ───────────────────────────────────────────────────────────────────────── + + async def add_attachment_to_db( + self, + attachment_request: PolicyAttachmentCreateRequest, + prisma_client: "PrismaClient", + created_by: Optional[str] = None, + ) -> PolicyAttachmentDBResponse: + """ + Add a policy attachment to the database. + + Args: + attachment_request: The attachment creation request + prisma_client: The Prisma client instance + created_by: User who created the attachment + + Returns: + PolicyAttachmentDBResponse with the created attachment + """ + try: + created_attachment = ( + await prisma_client.db.litellm_policyattachmenttable.create( + data={ + "policy_name": attachment_request.policy_name, + "scope": attachment_request.scope, + "teams": attachment_request.teams or [], + "keys": attachment_request.keys or [], + "models": attachment_request.models or [], + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + "created_by": created_by, + "updated_by": created_by, + } + ) + ) + + # Also add to in-memory registry + attachment = PolicyAttachment( + policy=attachment_request.policy_name, + scope=attachment_request.scope, + teams=attachment_request.teams, + keys=attachment_request.keys, + models=attachment_request.models, + ) + self.add_attachment(attachment) + + return PolicyAttachmentDBResponse( + attachment_id=created_attachment.attachment_id, + policy_name=created_attachment.policy_name, + scope=created_attachment.scope, + teams=created_attachment.teams or [], + keys=created_attachment.keys or [], + models=created_attachment.models or [], + created_at=created_attachment.created_at, + updated_at=created_attachment.updated_at, + created_by=created_attachment.created_by, + updated_by=created_attachment.updated_by, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error adding attachment to DB: {e}") + raise Exception(f"Error adding attachment to DB: {str(e)}") + + async def delete_attachment_from_db( + self, + attachment_id: str, + prisma_client: "PrismaClient", + ) -> Dict[str, str]: + """ + Delete a policy attachment from the database. + + Args: + attachment_id: The ID of the attachment to delete + prisma_client: The Prisma client instance + + Returns: + Dict with success message + """ + try: + # Get attachment before deleting + attachment = ( + await prisma_client.db.litellm_policyattachmenttable.find_unique( + where={"attachment_id": attachment_id} + ) + ) + + if attachment is None: + raise Exception(f"Attachment with ID {attachment_id} not found") + + # Delete from DB + await prisma_client.db.litellm_policyattachmenttable.delete( + where={"attachment_id": attachment_id} + ) + + # Note: In-memory attachments don't have IDs, so we need to sync from DB + # to properly update in-memory state + await self.sync_attachments_from_db(prisma_client) + + return {"message": f"Attachment {attachment_id} deleted successfully"} + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting attachment from DB: {e}") + raise Exception(f"Error deleting attachment from DB: {str(e)}") + + async def get_attachment_by_id_from_db( + self, + attachment_id: str, + prisma_client: "PrismaClient", + ) -> Optional[PolicyAttachmentDBResponse]: + """ + Get a policy attachment by ID from the database. + + Args: + attachment_id: The ID of the attachment to retrieve + prisma_client: The Prisma client instance + + Returns: + PolicyAttachmentDBResponse if found, None otherwise + """ + try: + attachment = ( + await prisma_client.db.litellm_policyattachmenttable.find_unique( + where={"attachment_id": attachment_id} + ) + ) + + if attachment is None: + return None + + return PolicyAttachmentDBResponse( + attachment_id=attachment.attachment_id, + policy_name=attachment.policy_name, + scope=attachment.scope, + teams=attachment.teams or [], + keys=attachment.keys or [], + models=attachment.models or [], + created_at=attachment.created_at, + updated_at=attachment.updated_at, + created_by=attachment.created_by, + updated_by=attachment.updated_by, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error getting attachment from DB: {e}") + raise Exception(f"Error getting attachment from DB: {str(e)}") + + async def get_all_attachments_from_db( + self, + prisma_client: "PrismaClient", + ) -> List[PolicyAttachmentDBResponse]: + """ + Get all policy attachments from the database. + + Args: + prisma_client: The Prisma client instance + + Returns: + List of PolicyAttachmentDBResponse objects + """ + try: + attachments = ( + await prisma_client.db.litellm_policyattachmenttable.find_many( + order={"created_at": "desc"}, + ) + ) + + return [ + PolicyAttachmentDBResponse( + attachment_id=a.attachment_id, + policy_name=a.policy_name, + scope=a.scope, + teams=a.teams or [], + keys=a.keys or [], + models=a.models or [], + created_at=a.created_at, + updated_at=a.updated_at, + created_by=a.created_by, + updated_by=a.updated_by, + ) + for a in attachments + ] + except Exception as e: + verbose_proxy_logger.exception(f"Error getting attachments from DB: {e}") + raise Exception(f"Error getting attachments from DB: {str(e)}") + + async def sync_attachments_from_db( + self, + prisma_client: "PrismaClient", + ) -> None: + """ + Sync policy attachments from the database to in-memory registry. + + Args: + prisma_client: The Prisma client instance + """ + try: + attachments = await self.get_all_attachments_from_db(prisma_client) + + # Clear existing attachments and reload from DB + self._attachments = [] + + for attachment_response in attachments: + attachment = PolicyAttachment( + policy=attachment_response.policy_name, + scope=attachment_response.scope, + teams=attachment_response.teams if attachment_response.teams else None, + keys=attachment_response.keys if attachment_response.keys else None, + models=attachment_response.models if attachment_response.models else None, + ) + self._attachments.append(attachment) + + self._initialized = True + verbose_proxy_logger.info( + f"Synced {len(attachments)} attachments from DB to in-memory registry" + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error syncing attachments from DB: {e}") + raise Exception(f"Error syncing attachments from DB: {str(e)}") + # Global singleton instance _attachment_registry: Optional[AttachmentRegistry] = None diff --git a/litellm/proxy/policy_engine/policy_endpoints.py b/litellm/proxy/policy_engine/policy_endpoints.py new file mode 100644 index 00000000000..615e153862a --- /dev/null +++ b/litellm/proxy/policy_engine/policy_endpoints.py @@ -0,0 +1,578 @@ +""" +CRUD ENDPOINTS FOR POLICIES + +Provides REST API endpoints for managing policies and policy attachments. +""" + +from fastapi import APIRouter, Depends, HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.policy_registry import get_policy_registry +from litellm.types.proxy.policy_engine import ( + PolicyAttachmentCreateRequest, + PolicyAttachmentDBResponse, + PolicyAttachmentListResponse, + PolicyCreateRequest, + PolicyDBResponse, + PolicyListDBResponse, + PolicyUpdateRequest, +) + +router = APIRouter() + +# Get singleton instances +POLICY_REGISTRY = get_policy_registry() +ATTACHMENT_REGISTRY = get_attachment_registry() + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy CRUD Endpoints +# ───────────────────────────────────────────────────────────────────────────── + + +@router.get( + "/policies/list", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyListDBResponse, +) +async def list_policies(): + """ + List all policies from the database. + + Example Request: + ```bash + curl -X GET "http://localhost:4000/policies/list" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "policies": [ + { + "policy_id": "123e4567-e89b-12d3-a456-426614174000", + "policy_name": "global-baseline", + "inherit": null, + "description": "Base guardrails for all requests", + "guardrails_add": ["pii_masking"], + "guardrails_remove": [], + "condition": null, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + } + ], + "total_count": 1 + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + policies = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + return PolicyListDBResponse(policies=policies, total_count=len(policies)) + except Exception as e: + verbose_proxy_logger.exception(f"Error listing policies: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post( + "/policies", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyDBResponse, +) +async def create_policy( + request: PolicyCreateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new policy. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "policy_name": "global-baseline", + "description": "Base guardrails for all requests", + "guardrails_add": ["pii_masking", "prompt_injection"], + "guardrails_remove": [] + }' + ``` + + Example Response: + ```json + { + "policy_id": "123e4567-e89b-12d3-a456-426614174000", + "policy_name": "global-baseline", + "inherit": null, + "description": "Base guardrails for all requests", + "guardrails_add": ["pii_masking", "prompt_injection"], + "guardrails_remove": [], + "condition": null, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + created_by = user_api_key_dict.user_id + result = await POLICY_REGISTRY.add_policy_to_db( + policy_request=request, + prisma_client=prisma_client, + created_by=created_by, + ) + return result + except Exception as e: + verbose_proxy_logger.exception(f"Error creating policy: {e}") + if "unique constraint" in str(e).lower(): + raise HTTPException( + status_code=400, + detail=f"Policy with name '{request.policy_name}' already exists", + ) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get( + "/policies/{policy_id}", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyDBResponse, +) +async def get_policy(policy_id: str): + """ + Get a policy by ID. + + Example Request: + ```bash + curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + result = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy_id=policy_id, + prisma_client=prisma_client, + ) + if result is None: + raise HTTPException( + status_code=404, detail=f"Policy with ID {policy_id} not found" + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error getting policy: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put( + "/policies/{policy_id}", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyDBResponse, +) +async def update_policy( + policy_id: str, + request: PolicyUpdateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update an existing policy. + + Example Request: + ```bash + curl -X PUT "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "description": "Updated description", + "guardrails_add": ["pii_masking", "toxicity_filter"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Check if policy exists + existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy_id=policy_id, + prisma_client=prisma_client, + ) + if existing is None: + raise HTTPException( + status_code=404, detail=f"Policy with ID {policy_id} not found" + ) + + updated_by = user_api_key_dict.user_id + result = await POLICY_REGISTRY.update_policy_in_db( + policy_id=policy_id, + policy_request=request, + prisma_client=prisma_client, + updated_by=updated_by, + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error updating policy: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete( + "/policies/{policy_id}", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_policy(policy_id: str): + """ + Delete a policy. + + Example Request: + ```bash + curl -X DELETE "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "message": "Policy 123e4567-e89b-12d3-a456-426614174000 deleted successfully" + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Check if policy exists + existing = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy_id=policy_id, + prisma_client=prisma_client, + ) + if existing is None: + raise HTTPException( + status_code=404, detail=f"Policy with ID {policy_id} not found" + ) + + result = await POLICY_REGISTRY.delete_policy_from_db( + policy_id=policy_id, + prisma_client=prisma_client, + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting policy: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get( + "/policies/{policy_id}/resolved-guardrails", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], +) +async def get_resolved_guardrails(policy_id: str): + """ + Get the resolved guardrails for a policy (including inherited guardrails). + + This endpoint resolves the full inheritance chain and returns the final + set of guardrails that would be applied for this policy. + + Example Request: + ```bash + curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000/resolved-guardrails" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "policy_id": "123e4567-e89b-12d3-a456-426614174000", + "policy_name": "healthcare-compliance", + "resolved_guardrails": ["pii_masking", "prompt_injection", "toxicity_filter"] + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Get the policy + policy = await POLICY_REGISTRY.get_policy_by_id_from_db( + policy_id=policy_id, + prisma_client=prisma_client, + ) + if policy is None: + raise HTTPException( + status_code=404, detail=f"Policy with ID {policy_id} not found" + ) + + # Resolve guardrails + resolved = await POLICY_REGISTRY.resolve_guardrails_from_db( + policy_name=policy.policy_name, + prisma_client=prisma_client, + ) + + return { + "policy_id": policy.policy_id, + "policy_name": policy.policy_name, + "resolved_guardrails": resolved, + } + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + verbose_proxy_logger.exception(f"Error resolving guardrails: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Attachment CRUD Endpoints +# ───────────────────────────────────────────────────────────────────────────── + + +@router.get( + "/policies/attachments/list", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyAttachmentListResponse, +) +async def list_policy_attachments(): + """ + List all policy attachments from the database. + + Example Request: + ```bash + curl -X GET "http://localhost:4000/policies/attachments/list" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "attachments": [ + { + "attachment_id": "123e4567-e89b-12d3-a456-426614174000", + "policy_name": "global-baseline", + "scope": "*", + "teams": [], + "keys": [], + "models": [], + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + } + ], + "total_count": 1 + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + attachments = await ATTACHMENT_REGISTRY.get_all_attachments_from_db( + prisma_client + ) + return PolicyAttachmentListResponse( + attachments=attachments, total_count=len(attachments) + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error listing policy attachments: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post( + "/policies/attachments", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyAttachmentDBResponse, +) +async def create_policy_attachment( + request: PolicyAttachmentCreateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new policy attachment. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/attachments" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "policy_name": "global-baseline", + "scope": "*" + }' + ``` + + Example with team-specific attachment: + ```bash + curl -X POST "http://localhost:4000/policies/attachments" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "policy_name": "healthcare-compliance", + "teams": ["healthcare-team", "medical-research"] + }' + ``` + + Example Response: + ```json + { + "attachment_id": "123e4567-e89b-12d3-a456-426614174000", + "policy_name": "global-baseline", + "scope": "*", + "teams": [], + "keys": [], + "models": [], + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Verify the policy exists + policy = await POLICY_REGISTRY.get_all_policies_from_db(prisma_client) + policy_names = [p.policy_name for p in policy] + if request.policy_name not in policy_names: + raise HTTPException( + status_code=404, + detail=f"Policy '{request.policy_name}' not found. Create the policy first.", + ) + + created_by = user_api_key_dict.user_id + result = await ATTACHMENT_REGISTRY.add_attachment_to_db( + attachment_request=request, + prisma_client=prisma_client, + created_by=created_by, + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error creating policy attachment: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get( + "/policies/attachments/{attachment_id}", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyAttachmentDBResponse, +) +async def get_policy_attachment(attachment_id: str): + """ + Get a policy attachment by ID. + + Example Request: + ```bash + curl -X GET "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + result = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + attachment_id=attachment_id, + prisma_client=prisma_client, + ) + if result is None: + raise HTTPException( + status_code=404, + detail=f"Attachment with ID {attachment_id} not found", + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error getting policy attachment: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete( + "/policies/attachments/{attachment_id}", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_policy_attachment(attachment_id: str): + """ + Delete a policy attachment. + + Example Request: + ```bash + curl -X DELETE "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "message": "Attachment 123e4567-e89b-12d3-a456-426614174000 deleted successfully" + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Database not connected") + + try: + # Check if attachment exists + existing = await ATTACHMENT_REGISTRY.get_attachment_by_id_from_db( + attachment_id=attachment_id, + prisma_client=prisma_client, + ) + if existing is None: + raise HTTPException( + status_code=404, + detail=f"Attachment with ID {attachment_id} not found", + ) + + result = await ATTACHMENT_REGISTRY.delete_attachment_from_db( + attachment_id=attachment_id, + prisma_client=prisma_client, + ) + return result + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting policy attachment: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index 68485f92489..6a251a584cb 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -7,15 +7,22 @@ by policy_attachments (see AttachmentRegistry). """ -from typing import Any, Dict, List, Optional +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( Policy, PolicyCondition, + PolicyCreateRequest, + PolicyDBResponse, PolicyGuardrails, + PolicyUpdateRequest, ) +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient + class PolicyRegistry: """ @@ -178,6 +185,367 @@ def remove_policy(self, policy_name: str) -> bool: return True return False + # ───────────────────────────────────────────────────────────────────────── + # Database CRUD Methods + # ───────────────────────────────────────────────────────────────────────── + + async def add_policy_to_db( + self, + policy_request: PolicyCreateRequest, + prisma_client: "PrismaClient", + created_by: Optional[str] = None, + ) -> PolicyDBResponse: + """ + Add a policy to the database. + + Args: + policy_request: The policy creation request + prisma_client: The Prisma client instance + created_by: User who created the policy + + Returns: + PolicyDBResponse with the created policy + """ + try: + from prisma import Json + + # Build data dict, only include condition if it's set + data: Dict[str, Any] = { + "policy_name": policy_request.policy_name, + "guardrails_add": policy_request.guardrails_add or [], + "guardrails_remove": policy_request.guardrails_remove or [], + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + + # Only add optional fields if they have values + if policy_request.inherit is not None: + data["inherit"] = policy_request.inherit + if policy_request.description is not None: + data["description"] = policy_request.description + if created_by is not None: + data["created_by"] = created_by + data["updated_by"] = created_by + if policy_request.condition is not None: + data["condition"] = Json(policy_request.condition.model_dump()) + + created_policy = await prisma_client.db.litellm_policytable.create( + data=data + ) + + # Also add to in-memory registry + policy = self._parse_policy( + policy_request.policy_name, + { + "inherit": policy_request.inherit, + "description": policy_request.description, + "guardrails": { + "add": policy_request.guardrails_add, + "remove": policy_request.guardrails_remove, + }, + "condition": policy_request.condition.model_dump() + if policy_request.condition + else None, + }, + ) + self.add_policy(policy_request.policy_name, policy) + + return PolicyDBResponse( + policy_id=created_policy.policy_id, + policy_name=created_policy.policy_name, + inherit=created_policy.inherit, + description=created_policy.description, + guardrails_add=created_policy.guardrails_add or [], + guardrails_remove=created_policy.guardrails_remove or [], + condition=created_policy.condition, + created_at=created_policy.created_at, + updated_at=created_policy.updated_at, + created_by=created_policy.created_by, + updated_by=created_policy.updated_by, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error adding policy to DB: {e}") + raise Exception(f"Error adding policy to DB: {str(e)}") + + async def update_policy_in_db( + self, + policy_id: str, + policy_request: PolicyUpdateRequest, + prisma_client: "PrismaClient", + updated_by: Optional[str] = None, + ) -> PolicyDBResponse: + """ + Update a policy in the database. + + Args: + policy_id: The ID of the policy to update + policy_request: The policy update request + prisma_client: The Prisma client instance + updated_by: User who updated the policy + + Returns: + PolicyDBResponse with the updated policy + """ + try: + # Build update data - only include fields that are set + update_data: Dict[str, Any] = { + "updated_at": datetime.now(timezone.utc), + "updated_by": updated_by, + } + + if policy_request.policy_name is not None: + update_data["policy_name"] = policy_request.policy_name + if policy_request.inherit is not None: + update_data["inherit"] = policy_request.inherit + if policy_request.description is not None: + update_data["description"] = policy_request.description + if policy_request.guardrails_add is not None: + update_data["guardrails_add"] = policy_request.guardrails_add + if policy_request.guardrails_remove is not None: + update_data["guardrails_remove"] = policy_request.guardrails_remove + if policy_request.condition is not None: + from prisma import Json + update_data["condition"] = Json(policy_request.condition.model_dump()) + + updated_policy = await prisma_client.db.litellm_policytable.update( + where={"policy_id": policy_id}, + data=update_data, + ) + + # Update in-memory registry + policy = self._parse_policy( + updated_policy.policy_name, + { + "inherit": updated_policy.inherit, + "description": updated_policy.description, + "guardrails": { + "add": updated_policy.guardrails_add, + "remove": updated_policy.guardrails_remove, + }, + "condition": updated_policy.condition, + }, + ) + self.add_policy(updated_policy.policy_name, policy) + + return PolicyDBResponse( + policy_id=updated_policy.policy_id, + policy_name=updated_policy.policy_name, + inherit=updated_policy.inherit, + description=updated_policy.description, + guardrails_add=updated_policy.guardrails_add or [], + guardrails_remove=updated_policy.guardrails_remove or [], + condition=updated_policy.condition, + created_at=updated_policy.created_at, + updated_at=updated_policy.updated_at, + created_by=updated_policy.created_by, + updated_by=updated_policy.updated_by, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error updating policy in DB: {e}") + raise Exception(f"Error updating policy in DB: {str(e)}") + + async def delete_policy_from_db( + self, + policy_id: str, + prisma_client: "PrismaClient", + ) -> Dict[str, str]: + """ + Delete a policy from the database. + + Args: + policy_id: The ID of the policy to delete + prisma_client: The Prisma client instance + + Returns: + Dict with success message + """ + try: + # Get policy name before deleting + policy = await prisma_client.db.litellm_policytable.find_unique( + where={"policy_id": policy_id} + ) + + if policy is None: + raise Exception(f"Policy with ID {policy_id} not found") + + # Delete from DB + await prisma_client.db.litellm_policytable.delete( + where={"policy_id": policy_id} + ) + + # Remove from in-memory registry + self.remove_policy(policy.policy_name) + + return {"message": f"Policy {policy_id} deleted successfully"} + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting policy from DB: {e}") + raise Exception(f"Error deleting policy from DB: {str(e)}") + + async def get_policy_by_id_from_db( + self, + policy_id: str, + prisma_client: "PrismaClient", + ) -> Optional[PolicyDBResponse]: + """ + Get a policy by ID from the database. + + Args: + policy_id: The ID of the policy to retrieve + prisma_client: The Prisma client instance + + Returns: + PolicyDBResponse if found, None otherwise + """ + try: + policy = await prisma_client.db.litellm_policytable.find_unique( + where={"policy_id": policy_id} + ) + + if policy is None: + return None + + return PolicyDBResponse( + policy_id=policy.policy_id, + policy_name=policy.policy_name, + inherit=policy.inherit, + description=policy.description, + guardrails_add=policy.guardrails_add or [], + guardrails_remove=policy.guardrails_remove or [], + condition=policy.condition, + created_at=policy.created_at, + updated_at=policy.updated_at, + created_by=policy.created_by, + updated_by=policy.updated_by, + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error getting policy from DB: {e}") + raise Exception(f"Error getting policy from DB: {str(e)}") + + async def get_all_policies_from_db( + self, + prisma_client: "PrismaClient", + ) -> List[PolicyDBResponse]: + """ + Get all policies from the database. + + Args: + prisma_client: The Prisma client instance + + Returns: + List of PolicyDBResponse objects + """ + try: + policies = await prisma_client.db.litellm_policytable.find_many( + order={"created_at": "desc"}, + ) + + return [ + PolicyDBResponse( + policy_id=p.policy_id, + policy_name=p.policy_name, + inherit=p.inherit, + description=p.description, + guardrails_add=p.guardrails_add or [], + guardrails_remove=p.guardrails_remove or [], + condition=p.condition, + created_at=p.created_at, + updated_at=p.updated_at, + created_by=p.created_by, + updated_by=p.updated_by, + ) + for p in policies + ] + except Exception as e: + verbose_proxy_logger.exception(f"Error getting policies from DB: {e}") + raise Exception(f"Error getting policies from DB: {str(e)}") + + async def sync_policies_from_db( + self, + prisma_client: "PrismaClient", + ) -> None: + """ + Sync policies from the database to in-memory registry. + + Args: + prisma_client: The Prisma client instance + """ + try: + policies = await self.get_all_policies_from_db(prisma_client) + + for policy_response in policies: + policy = self._parse_policy( + policy_response.policy_name, + { + "inherit": policy_response.inherit, + "description": policy_response.description, + "guardrails": { + "add": policy_response.guardrails_add, + "remove": policy_response.guardrails_remove, + }, + "condition": policy_response.condition, + }, + ) + self.add_policy(policy_response.policy_name, policy) + + verbose_proxy_logger.info( + f"Synced {len(policies)} policies from DB to in-memory registry" + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error syncing policies from DB: {e}") + raise Exception(f"Error syncing policies from DB: {str(e)}") + + async def resolve_guardrails_from_db( + self, + policy_name: str, + prisma_client: "PrismaClient", + ) -> List[str]: + """ + Resolve all guardrails for a policy from the database. + + Uses the existing PolicyResolver to handle inheritance chain resolution. + + Args: + policy_name: Name of the policy to resolve + prisma_client: The Prisma client instance + + Returns: + List of resolved guardrail names + """ + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + + try: + # Load all policies from DB to ensure we have the full inheritance chain + policies = await self.get_all_policies_from_db(prisma_client) + + # Build a temporary in-memory map for resolution + temp_policies = {} + for policy_response in policies: + policy = self._parse_policy( + policy_response.policy_name, + { + "inherit": policy_response.inherit, + "description": policy_response.description, + "guardrails": { + "add": policy_response.guardrails_add, + "remove": policy_response.guardrails_remove, + }, + "condition": policy_response.condition, + }, + ) + temp_policies[policy_response.policy_name] = policy + + # Use the existing PolicyResolver to resolve guardrails + resolved_policy = PolicyResolver.resolve_policy_guardrails( + policy_name=policy_name, + policies=temp_policies, + context=None, # No context needed for simple resolution + ) + + return sorted(resolved_policy.guardrails) + except Exception as e: + verbose_proxy_logger.exception(f"Error resolving guardrails from DB: {e}") + raise Exception(f"Error resolving guardrails from DB: {str(e)}") + # Global singleton instance _policy_registry: Optional[PolicyRegistry] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 994f6ee862c..64dfc6a5d8f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -381,6 +381,7 @@ def generate_feedback_box(): from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( router as pass_through_router, ) +from litellm.proxy.policy_engine.policy_endpoints import router as policy_crud_router from litellm.proxy.prompts.prompt_endpoints import router as prompts_router from litellm.proxy.public_endpoints import router as public_endpoints_router from litellm.proxy.rag_endpoints.endpoints import router as rag_router @@ -3801,6 +3802,9 @@ async def _init_non_llm_objects_in_db(self, prisma_client: PrismaClient): if self._should_load_db_object(object_type="guardrails"): await self._init_guardrails_in_db(prisma_client=prisma_client) + if self._should_load_db_object(object_type="policies"): + await self._init_policies_in_db(prisma_client=prisma_client) + if self._should_load_db_object(object_type="vector_stores"): await self._init_vector_stores_in_db(prisma_client=prisma_client) @@ -4024,6 +4028,36 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient): ) ) + async def _init_policies_in_db(self, prisma_client: PrismaClient): + """ + Initialize policies and policy attachments from database into the in-memory registries. + """ + from litellm.proxy.policy_engine.attachment_registry import ( + get_attachment_registry, + ) + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + try: + # Get the global singleton instances + policy_registry = get_policy_registry() + attachment_registry = get_attachment_registry() + + # Sync policies from DB to in-memory registry + await policy_registry.sync_policies_from_db(prisma_client=prisma_client) + + # Sync attachments from DB to in-memory registry + await attachment_registry.sync_attachments_from_db(prisma_client=prisma_client) + + verbose_proxy_logger.debug( + "Successfully synced policies and attachments from DB" + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.py::ProxyConfig:_init_policies_in_db - {}".format( + str(e) + ) + ) + async def _init_vector_stores_in_db(self, prisma_client: PrismaClient): from litellm.vector_stores.vector_store_registry import VectorStoreRegistry @@ -10702,6 +10736,7 @@ async def get_routes(): app.include_router(analytics_router) app.include_router(guardrails_router) app.include_router(policy_router) +app.include_router(policy_crud_router) app.include_router(search_tool_management_router) app.include_router(prompts_router) app.include_router(callback_management_endpoints_router) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 22888f6d3af..52170f2f3e6 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -124,7 +124,7 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") - router_settings Json? @default("{}") + router_settings Json? @default("{}") team_member_permissions String[] @default([]) model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) @@ -863,20 +863,3 @@ model LiteLLM_SkillsTable { updated_at DateTime @default(now()) @updatedAt updated_by String? } - -// Claude Code Marketplace - stores plugins for Claude Code integration -model LiteLLM_ClaudeCodePluginTable { - id String @id @default(uuid()) - name String @unique // Plugin name (kebab-case) - version String? // Semantic version - description String? // Plugin description - manifest_json String // Full plugin.json as JSON string - files_json String // All files as JSON: {"path": "content"} - enabled Boolean @default(true) - created_at DateTime @default(now()) - updated_at DateTime @default(now()) @updatedAt - created_by String? - - @@index([name]) - @@map("litellm_claudecodeplugin") -} diff --git a/litellm/secret_managers/hashicorp_secret_manager.py b/litellm/secret_managers/hashicorp_secret_manager.py index dac0397dd95..c59f2ef638a 100644 --- a/litellm/secret_managers/hashicorp_secret_manager.py +++ b/litellm/secret_managers/hashicorp_secret_manager.py @@ -563,7 +563,7 @@ async def async_rotate_secret( return create_response - except httpx.TimeoutException as e: + except httpx.TimeoutException: verbose_logger.exception("Timeout error occurred during secret rotation") return {"status": "error", "message": "Timeout error occurred"} except Exception as e: diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index 50ed4581013..bc54c3eb36b 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -19,13 +19,21 @@ PolicyScope, ) from litellm.types.proxy.policy_engine.resolver_types import ( + PolicyAttachmentCreateRequest, + PolicyAttachmentDBResponse, + PolicyAttachmentListResponse, + PolicyConditionRequest, + PolicyCreateRequest, + PolicyDBResponse, PolicyGuardrailsResponse, PolicyInfoResponse, + PolicyListDBResponse, PolicyListResponse, PolicyMatchContext, PolicyScopeResponse, PolicySummaryItem, PolicyTestResponse, + PolicyUpdateRequest, ResolvedPolicy, ) from litellm.types.proxy.policy_engine.validation_types import ( @@ -58,4 +66,13 @@ "PolicyScopeResponse", "PolicySummaryItem", "PolicyTestResponse", + # CRUD Request/Response types + "PolicyConditionRequest", + "PolicyCreateRequest", + "PolicyUpdateRequest", + "PolicyDBResponse", + "PolicyListDBResponse", + "PolicyAttachmentCreateRequest", + "PolicyAttachmentDBResponse", + "PolicyAttachmentListResponse", ] diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py index 81ae248d436..9488b8b0841 100644 --- a/litellm/types/proxy/policy_engine/resolver_types.py +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -5,7 +5,8 @@ the final guardrails list. """ -from typing import Dict, List, Optional +from datetime import datetime +from typing import Any, Dict, List, Optional from pydantic import BaseModel, ConfigDict, Field @@ -108,3 +109,168 @@ class PolicyTestResponse(BaseModel): matching_policies: List[str] resolved_guardrails: List[str] message: Optional[str] = None + + +# ───────────────────────────────────────────────────────────────────────────── +# CRUD Request/Response Types for Policy Endpoints +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyConditionRequest(BaseModel): + """Condition for when a policy applies.""" + + model: Optional[str] = Field( + default=None, + description="Model name pattern (exact match or regex) for when policy applies.", + ) + + +class PolicyCreateRequest(BaseModel): + """Request body for creating a new policy.""" + + policy_name: str = Field(description="Unique name for the policy.") + inherit: Optional[str] = Field( + default=None, + description="Name of parent policy to inherit from.", + ) + description: Optional[str] = Field( + default=None, + description="Human-readable description of the policy.", + ) + guardrails_add: Optional[List[str]] = Field( + default=None, + description="List of guardrail names to add.", + ) + guardrails_remove: Optional[List[str]] = Field( + default=None, + description="List of guardrail names to remove (from inherited).", + ) + condition: Optional[PolicyConditionRequest] = Field( + default=None, + description="Condition for when this policy applies.", + ) + + +class PolicyUpdateRequest(BaseModel): + """Request body for updating a policy.""" + + policy_name: Optional[str] = Field( + default=None, + description="New name for the policy.", + ) + inherit: Optional[str] = Field( + default=None, + description="Name of parent policy to inherit from.", + ) + description: Optional[str] = Field( + default=None, + description="Human-readable description of the policy.", + ) + guardrails_add: Optional[List[str]] = Field( + default=None, + description="List of guardrail names to add.", + ) + guardrails_remove: Optional[List[str]] = Field( + default=None, + description="List of guardrail names to remove (from inherited).", + ) + condition: Optional[PolicyConditionRequest] = Field( + default=None, + description="Condition for when this policy applies.", + ) + + +class PolicyDBResponse(BaseModel): + """Response for a policy from the database.""" + + policy_id: str = Field(description="Unique ID of the policy.") + policy_name: str = Field(description="Name of the policy.") + inherit: Optional[str] = Field(default=None, description="Parent policy name.") + description: Optional[str] = Field(default=None, description="Policy description.") + guardrails_add: List[str] = Field( + default_factory=list, description="Guardrails to add." + ) + guardrails_remove: List[str] = Field( + default_factory=list, description="Guardrails to remove." + ) + condition: Optional[Dict[str, Any]] = Field( + default=None, description="Policy condition." + ) + created_at: Optional[datetime] = Field( + default=None, description="When the policy was created." + ) + updated_at: Optional[datetime] = Field( + default=None, description="When the policy was last updated." + ) + created_by: Optional[str] = Field(default=None, description="Who created the policy.") + updated_by: Optional[str] = Field( + default=None, description="Who last updated the policy." + ) + + +class PolicyListDBResponse(BaseModel): + """Response for listing policies from the database.""" + + policies: List[PolicyDBResponse] = Field( + default_factory=list, description="List of policies." + ) + total_count: int = Field(default=0, description="Total number of policies.") + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Attachment CRUD Types +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyAttachmentCreateRequest(BaseModel): + """Request body for creating a policy attachment.""" + + policy_name: str = Field(description="Name of the policy to attach.") + scope: Optional[str] = Field( + default=None, + description="Use '*' for global scope (applies to all requests).", + ) + teams: Optional[List[str]] = Field( + default=None, + description="Team aliases or patterns this attachment applies to.", + ) + keys: Optional[List[str]] = Field( + default=None, + description="Key aliases or patterns this attachment applies to.", + ) + models: Optional[List[str]] = Field( + default=None, + description="Model names or patterns this attachment applies to.", + ) + + +class PolicyAttachmentDBResponse(BaseModel): + """Response for a policy attachment from the database.""" + + attachment_id: str = Field(description="Unique ID of the attachment.") + policy_name: str = Field(description="Name of the attached policy.") + scope: Optional[str] = Field(default=None, description="Scope of the attachment.") + teams: List[str] = Field(default_factory=list, description="Team patterns.") + keys: List[str] = Field(default_factory=list, description="Key patterns.") + models: List[str] = Field(default_factory=list, description="Model patterns.") + created_at: Optional[datetime] = Field( + default=None, description="When the attachment was created." + ) + updated_at: Optional[datetime] = Field( + default=None, description="When the attachment was last updated." + ) + created_by: Optional[str] = Field( + default=None, description="Who created the attachment." + ) + updated_by: Optional[str] = Field( + default=None, description="Who last updated the attachment." + ) + + +class PolicyAttachmentListResponse(BaseModel): + """Response for listing policy attachments.""" + + attachments: List[PolicyAttachmentDBResponse] = Field( + default_factory=list, description="List of policy attachments." + ) + total_count: int = Field(default=0, description="Total number of attachments.") diff --git a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx index 260cac16e02..405f8329b67 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/components/Sidebar2.tsx @@ -19,6 +19,7 @@ import { ExperimentOutlined, ToolOutlined, TagsOutlined, + AuditOutlined, } from "@ant-design/icons"; // import { // all_admin_roles, @@ -102,6 +103,8 @@ const routeFor = (slug: string): string => { return "logs"; case "guardrails": return "guardrails"; + case "policies": + return "policies"; // tools case "mcp-servers": @@ -202,6 +205,13 @@ const menuItems: MenuItemCfg[] = [ icon: , roles: all_admin_roles, }, + { + key: "28", + page: "policies", + label: "Policies", + icon: , + roles: all_admin_roles, + }, { key: "26", page: "tools", diff --git a/ui/litellm-dashboard/src/app/(dashboard)/policies/page.tsx b/ui/litellm-dashboard/src/app/(dashboard)/policies/page.tsx new file mode 100644 index 00000000000..86e063c7358 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/policies/page.tsx @@ -0,0 +1,37 @@ +"use client"; + +import PoliciesPanel from "@/components/policies"; +import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; +import { + getPoliciesList, + createPolicyCall, + updatePolicyCall, + deletePolicyCall, + getPolicyInfo, + getPolicyAttachmentsList, + createPolicyAttachmentCall, + deletePolicyAttachmentCall, + getGuardrailsList, +} from "@/components/networking"; + +const PoliciesPage = () => { + const { accessToken, userRole } = useAuthorized(); + + return ( + + ); +}; + +export default PoliciesPage; diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 8ca25eb9e4c..8ac1f756f96 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -14,6 +14,7 @@ import LoadingScreen from "@/components/common_components/LoadingScreen"; import { CostTrackingSettings } from "@/components/CostTrackingSettings"; import GeneralSettings from "@/components/general_settings"; import GuardrailsPanel from "@/components/guardrails"; +import PoliciesPanel from "@/components/policies"; import { Team } from "@/components/key_team_helpers/key_list"; import { MCPServers } from "@/components/mcp_tools"; import ModelHubTable from "@/components/AIHub/ModelHubTable"; @@ -472,6 +473,8 @@ export default function CreateKeyPage() { ) : page == "guardrails" ? ( + ) : page == "policies" ? ( + ) : page == "agents" ? ( ) : page == "prompts" ? ( diff --git a/ui/litellm-dashboard/src/components/leftnav.tsx b/ui/litellm-dashboard/src/components/leftnav.tsx index 1aa1df14a10..543d95d2ccd 100644 --- a/ui/litellm-dashboard/src/components/leftnav.tsx +++ b/ui/litellm-dashboard/src/components/leftnav.tsx @@ -3,6 +3,7 @@ import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized"; import { ApiOutlined, AppstoreOutlined, + AuditOutlined, BankOutlined, BarChartOutlined, BgColorsOutlined, @@ -123,6 +124,13 @@ const Sidebar: React.FC = ({ setPage, defaultSelectedKey, collapse icon: , roles: all_admin_roles, }, + { + key: "policies", + page: "policies", + label: "Policies", + icon: , + roles: all_admin_roles, + }, { key: "tools", page: "tools", diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a4f2dfce3bd..ffc0f80911d 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5347,6 +5347,249 @@ export const getGuardrailsList = async (accessToken: string) => { } }; +// ───────────────────────────────────────────────────────────────────────────── +// Policy CRUD API Calls +// ───────────────────────────────────────────────────────────────────────────── + +export const getPoliciesList = async (accessToken: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/list` : `/policies/list`; + const response = await fetch(url, { + method: "GET", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to get policies list:", error); + throw error; + } +}; + +export const createPolicyCall = async (accessToken: string, policyData: any) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies` : `/policies`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(policyData), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to create policy:", error); + throw error; + } +}; + +export const updatePolicyCall = async (accessToken: string, policyId: string, policyData: any) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/${policyId}` : `/policies/${policyId}`; + const response = await fetch(url, { + method: "PUT", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(policyData), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to update policy:", error); + throw error; + } +}; + +export const deletePolicyCall = async (accessToken: string, policyId: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/${policyId}` : `/policies/${policyId}`; + const response = await fetch(url, { + method: "DELETE", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to delete policy:", error); + throw error; + } +}; + +export const getPolicyInfo = async (accessToken: string, policyId: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/${policyId}` : `/policies/${policyId}`; + const response = await fetch(url, { + method: "GET", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to get policy info:", error); + throw error; + } +}; + +// Policy Attachments API Calls + +export const getPolicyAttachmentsList = async (accessToken: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/attachments/list` : `/policies/attachments/list`; + const response = await fetch(url, { + method: "GET", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to get policy attachments list:", error); + throw error; + } +}; + +export const createPolicyAttachmentCall = async (accessToken: string, attachmentData: any) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/attachments` : `/policies/attachments`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(attachmentData), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to create policy attachment:", error); + throw error; + } +}; + +export const deletePolicyAttachmentCall = async (accessToken: string, attachmentId: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/attachments/${attachmentId}` : `/policies/attachments/${attachmentId}`; + const response = await fetch(url, { + method: "DELETE", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to delete policy attachment:", error); + throw error; + } +}; + +export const getResolvedGuardrails = async (accessToken: string, policyId: string) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/${policyId}/resolved-guardrails` : `/policies/${policyId}/resolved-guardrails`; + const response = await fetch(url, { + method: "GET", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to get resolved guardrails:", error); + throw error; + } +}; + export const getPromptsList = async (accessToken: string): Promise => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/prompts/list` : `/prompts/list`; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx index 3a7f8fec650..ff93194c3f1 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx @@ -30,6 +30,7 @@ import { coy } from "react-syntax-highlighter/dist/esm/styles/prism"; import { v4 as uuidv4 } from "uuid"; import { truncateString } from "../../../utils/textUtils"; import GuardrailSelector from "../../guardrails/GuardrailSelector"; +import PolicySelector from "../../policies/PolicySelector"; import { MCPServer } from "../../mcp_tools/types"; import NotificationsManager from "../../molecules/notifications_manager"; import { fetchMCPServers, listMCPTools } from "../../networking"; @@ -188,6 +189,15 @@ const ChatUI: React.FC = ({ return []; } }); + const [selectedPolicies, setSelectedPolicies] = useState(() => { + const saved = sessionStorage.getItem("selectedPolicies"); + try { + return saved ? JSON.parse(saved) : []; + } catch (error) { + console.error("Error parsing selectedPolicies from sessionStorage", error); + return []; + } + }); const [messageTraceId, setMessageTraceId] = useState( () => sessionStorage.getItem("messageTraceId") || null, ); @@ -261,6 +271,7 @@ const ChatUI: React.FC = ({ selectedTags, selectedVectorStores, selectedGuardrails, + selectedPolicies, selectedMCPServers, mcpServers, mcpServerToolRestrictions, @@ -283,6 +294,7 @@ const ChatUI: React.FC = ({ selectedTags, selectedVectorStores, selectedGuardrails, + selectedPolicies, selectedMCPServers, mcpServers, mcpServerToolRestrictions, @@ -308,6 +320,7 @@ const ChatUI: React.FC = ({ sessionStorage.setItem("selectedTags", JSON.stringify(selectedTags)); sessionStorage.setItem("selectedVectorStores", JSON.stringify(selectedVectorStores)); sessionStorage.setItem("selectedGuardrails", JSON.stringify(selectedGuardrails)); + sessionStorage.setItem("selectedPolicies", JSON.stringify(selectedPolicies)); sessionStorage.setItem("selectedMCPServers", JSON.stringify(selectedMCPServers)); sessionStorage.setItem("mcpServerToolRestrictions", JSON.stringify(mcpServerToolRestrictions)); sessionStorage.setItem("selectedVoice", selectedVoice); @@ -338,6 +351,7 @@ const ChatUI: React.FC = ({ selectedTags, selectedVectorStores, selectedGuardrails, + selectedPolicies, messageTraceId, responsesSessionId, useApiSessionManagement, @@ -897,6 +911,7 @@ const ChatUI: React.FC = ({ traceId, selectedVectorStores.length > 0 ? selectedVectorStores : undefined, selectedGuardrails.length > 0 ? selectedGuardrails : undefined, + selectedPolicies.length > 0 ? selectedPolicies : undefined, selectedMCPServers, updateChatImageUI, updateSearchResults, @@ -977,6 +992,7 @@ const ChatUI: React.FC = ({ traceId, selectedVectorStores.length > 0 ? selectedVectorStores : undefined, selectedGuardrails.length > 0 ? selectedGuardrails : undefined, + selectedPolicies.length > 0 ? selectedPolicies : undefined, selectedMCPServers, // Pass the selected servers array useApiSessionManagement ? responsesSessionId : null, // Only pass session ID if API mode is enabled handleResponseId, // Pass callback to capture new response ID @@ -1008,6 +1024,7 @@ const ChatUI: React.FC = ({ traceId, selectedVectorStores.length > 0 ? selectedVectorStores : undefined, selectedGuardrails.length > 0 ? selectedGuardrails : undefined, + selectedPolicies.length > 0 ? selectedPolicies : undefined, selectedMCPServers, // Pass the selected tools array customProxyBaseUrl || undefined, ); @@ -1587,6 +1604,32 @@ const ChatUI: React.FC = ({ /> +
+ + Policies + + Select policy/policies to apply to this LLM API call. Policies define which guardrails are applied based on conditions. You can set up your policies{" "} + + here + + . + + } + > + + + + +
+ {/* Code Interpreter Toggle - Only for Responses endpoint */} {endpointType === EndpointType.RESPONSES && (
diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/CodeSnippets.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/CodeSnippets.tsx index 84f7f6b9e1c..6998d542401 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/CodeSnippets.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/CodeSnippets.tsx @@ -6,6 +6,7 @@ interface CodeGenMetadata { tags?: string[]; vector_stores?: string[]; guardrails?: string[]; + policies?: string[]; } interface GenerateCodeParams { @@ -17,6 +18,7 @@ interface GenerateCodeParams { selectedTags: string[]; selectedVectorStores: string[]; selectedGuardrails: string[]; + selectedPolicies: string[]; selectedMCPServers: string[]; mcpServers?: MCPServer[]; mcpServerToolRestrictions?: Record; @@ -40,6 +42,7 @@ export const generateCodeSnippet = (params: GenerateCodeParams): string => { selectedTags, selectedVectorStores, selectedGuardrails, + selectedPolicies, selectedMCPServers, mcpServers, mcpServerToolRestrictions, @@ -72,6 +75,7 @@ export const generateCodeSnippet = (params: GenerateCodeParams): string => { if (selectedTags.length > 0) metadata.tags = selectedTags; if (selectedVectorStores.length > 0) metadata.vector_stores = selectedVectorStores; if (selectedGuardrails.length > 0) metadata.guardrails = selectedGuardrails; + if (selectedPolicies.length > 0) metadata.policies = selectedPolicies; const modelNameForCode = selectedModel || "your-model-name"; diff --git a/ui/litellm-dashboard/src/components/playground/llm_calls/anthropic_messages.tsx b/ui/litellm-dashboard/src/components/playground/llm_calls/anthropic_messages.tsx index 2e8b0be88bb..5570c7408fa 100644 --- a/ui/litellm-dashboard/src/components/playground/llm_calls/anthropic_messages.tsx +++ b/ui/litellm-dashboard/src/components/playground/llm_calls/anthropic_messages.tsx @@ -17,6 +17,7 @@ export async function makeAnthropicMessagesRequest( traceId?: string, vector_store_ids?: string[], guardrails?: string[], + policies?: string[], selectedMCPTools?: string[], customBaseUrl?: string, ) { @@ -59,6 +60,7 @@ export async function makeAnthropicMessagesRequest( if (vector_store_ids) requestBody.vector_store_ids = vector_store_ids; if (guardrails) requestBody.guardrails = guardrails; + if (policies) requestBody.policies = policies; // Use the streaming helper method for cleaner async iteration // @ts-ignore - The SDK types might not include all litellm-specific parameters const stream = client.messages.stream(requestBody, { signal }); diff --git a/ui/litellm-dashboard/src/components/playground/llm_calls/chat_completion.tsx b/ui/litellm-dashboard/src/components/playground/llm_calls/chat_completion.tsx index c3c623c25c3..61d232082e0 100644 --- a/ui/litellm-dashboard/src/components/playground/llm_calls/chat_completion.tsx +++ b/ui/litellm-dashboard/src/components/playground/llm_calls/chat_completion.tsx @@ -19,6 +19,7 @@ export async function makeOpenAIChatCompletionRequest( traceId?: string, vector_store_ids?: string[], guardrails?: string[], + policies?: string[], selectedMCPServers?: string[], onImageGenerated?: (imageUrl: string, model?: string) => void, onSearchResults?: (searchResults: VectorStoreSearchResponse[]) => void, @@ -110,6 +111,7 @@ export async function makeOpenAIChatCompletionRequest( messages: chatHistory as ChatCompletionMessageParam[], ...(vector_store_ids ? { vector_store_ids } : {}), ...(guardrails ? { guardrails } : {}), + ...(policies ? { policies } : {}), ...(tools.length > 0 ? { tools, tool_choice: "auto" } : {}), ...(temperature !== undefined ? { temperature } : {}), ...(max_tokens !== undefined ? { max_tokens } : {}), diff --git a/ui/litellm-dashboard/src/components/playground/llm_calls/responses_api.tsx b/ui/litellm-dashboard/src/components/playground/llm_calls/responses_api.tsx index b658610a21e..c69f82a37bf 100644 --- a/ui/litellm-dashboard/src/components/playground/llm_calls/responses_api.tsx +++ b/ui/litellm-dashboard/src/components/playground/llm_calls/responses_api.tsx @@ -27,6 +27,7 @@ export async function makeOpenAIResponsesRequest( traceId?: string, vector_store_ids?: string[], guardrails?: string[], + policies?: string[], selectedMCPServers?: string[], previousResponseId?: string | null, onResponseId?: (responseId: string) => void, @@ -137,6 +138,7 @@ export async function makeOpenAIResponsesRequest( ...(previousResponseId ? { previous_response_id: previousResponseId } : {}), ...(vector_store_ids ? { vector_store_ids } : {}), ...(guardrails ? { guardrails } : {}), + ...(policies ? { policies } : {}), ...(tools.length > 0 ? { tools, tool_choice: "auto" } : {}), }, { signal }, diff --git a/ui/litellm-dashboard/src/components/policies/PolicySelector.tsx b/ui/litellm-dashboard/src/components/policies/PolicySelector.tsx new file mode 100644 index 00000000000..fd20d9330a3 --- /dev/null +++ b/ui/litellm-dashboard/src/components/policies/PolicySelector.tsx @@ -0,0 +1,77 @@ +import React, { useEffect, useState } from "react"; +import { Select } from "antd"; +import { Policy } from "./types"; +import { getPoliciesList } from "../networking"; + +interface PolicySelectorProps { + onChange: (selectedPolicies: string[]) => void; + value?: string[]; + className?: string; + accessToken: string; + disabled?: boolean; +} + +const PolicySelector: React.FC = ({ + onChange, + value, + className, + accessToken, + disabled +}) => { + const [policies, setPolicies] = useState([]); + const [loading, setLoading] = useState(false); + + useEffect(() => { + const fetchPolicies = async () => { + if (!accessToken) return; + + setLoading(true); + try { + const response = await getPoliciesList(accessToken); + console.log("Policies response:", response); + if (response.policies) { + console.log("Policies data:", response.policies); + setPolicies(response.policies); + } + } catch (error) { + console.error("Error fetching policies:", error); + } finally { + setLoading(false); + } + }; + + fetchPolicies(); + }, [accessToken]); + + const handlePolicyChange = (selectedValues: string[]) => { + console.log("Selected policies:", selectedValues); + onChange(selectedValues); + }; + + return ( +
+ + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + style={{ width: "100%" }} + /> + + + + Scope + + + + setScopeType(e.target.value)} + > + Global (applies to all requests) + Specific (teams, keys, or models) + + + + {scopeType === "specific" && ( + <> + + ({ + label: key, + value: key, + }))} + tokenSeparators={[","]} + showSearch + filterOption={(input, option) => + (option?.label ?? "").toLowerCase().includes(input.toLowerCase()) + } + style={{ width: "100%" }} + /> + + + +