diff --git a/litellm/proxy/policy_engine/pipeline_executor.py b/litellm/proxy/policy_engine/pipeline_executor.py index c015d755a9a..b3982678d37 100644 --- a/litellm/proxy/policy_engine/pipeline_executor.py +++ b/litellm/proxy/policy_engine/pipeline_executor.py @@ -56,7 +56,9 @@ async def execute_steps( PipelineExecutionResult with terminal action and step results """ step_results: List[PipelineStepResult] = [] - working_data = copy.deepcopy(data) + working_data = data.copy() + if "metadata" in working_data: + working_data["metadata"] = working_data["metadata"].copy() for i, step in enumerate(steps): start_time = time.perf_counter() @@ -148,6 +150,12 @@ async def _run_step( return ("error", None, f"Guardrail '{step.guardrail}' not found") try: + # Inject guardrail name into metadata so should_run_guardrail() allows it + if "metadata" not in data: + data["metadata"] = {} + original_guardrails = data["metadata"].get("guardrails") + data["metadata"]["guardrails"] = [step.guardrail] + # Use unified_guardrail path if callback implements apply_guardrail target = callback use_unified = "apply_guardrail" in type(callback).__dict__ diff --git a/litellm/proxy/policy_engine/policy_endpoints.py b/litellm/proxy/policy_engine/policy_endpoints.py index 3bd893b0034..af12a8598f6 100644 --- a/litellm/proxy/policy_engine/policy_endpoints.py +++ b/litellm/proxy/policy_engine/policy_endpoints.py @@ -10,8 +10,11 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, + PipelineTestRequest, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -349,6 +352,69 @@ async def get_resolved_guardrails(policy_id: str): raise HTTPException(status_code=500, detail=str(e)) +# ───────────────────────────────────────────────────────────────────────────── +# Pipeline Test Endpoint +# ───────────────────────────────────────────────────────────────────────────── + + +@router.post( + "/policies/test-pipeline", + tags=["Policies"], + dependencies=[Depends(user_api_key_auth)], +) +async def test_pipeline( + request: PipelineTestRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Test a guardrail pipeline with sample messages. + + Executes the pipeline steps against the provided test messages and returns + step-by-step results showing which guardrails passed/failed, actions taken, + and timing information. + + Example Request: + ```bash + curl -X POST "http://localhost:4000/policies/test-pipeline" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "pipeline": { + "mode": "pre_call", + "steps": [ + {"guardrail": "pii-guard", "on_pass": "next", "on_fail": "block"} + ] + }, + "test_messages": [{"role": "user", "content": "My SSN is 123-45-6789"}] + }' + ``` + """ + try: + validated_pipeline = GuardrailPipeline(**request.pipeline) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid pipeline: {e}") + + data = { + "messages": request.test_messages, + "model": "test", + "metadata": {}, + } + + try: + result = await PipelineExecutor.execute_steps( + steps=validated_pipeline.steps, + mode=validated_pipeline.mode, + data=data, + user_api_key_dict=user_api_key_dict, + call_type="completion", + policy_name="test-pipeline", + ) + return result.model_dump() + except Exception as e: + verbose_proxy_logger.exception(f"Error testing pipeline: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ───────────────────────────────────────────────────────────────────────────── # Policy Attachment CRUD Endpoints # ───────────────────────────────────────────────────────────────────────────── diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index 377b4cd86dd..50acddd2b9d 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -10,6 +10,8 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional +from prisma import Json as PrismaJson + from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( GuardrailPipeline, @@ -248,7 +250,10 @@ async def add_policy_to_db( data["created_by"] = created_by data["updated_by"] = created_by if policy_request.condition is not None: - data["condition"] = policy_request.condition.model_dump() + data["condition"] = PrismaJson(policy_request.condition.model_dump()) + if policy_request.pipeline is not None: + validated_pipeline = GuardrailPipeline(**policy_request.pipeline) + data["pipeline"] = PrismaJson(validated_pipeline.model_dump()) created_policy = await prisma_client.db.litellm_policytable.create( data=data @@ -267,6 +272,7 @@ async def add_policy_to_db( "condition": policy_request.condition.model_dump() if policy_request.condition else None, + "pipeline": policy_request.pipeline, }, ) self.add_policy(policy_request.policy_name, policy) @@ -279,6 +285,7 @@ async def add_policy_to_db( guardrails_add=created_policy.guardrails_add or [], guardrails_remove=created_policy.guardrails_remove or [], condition=created_policy.condition, + pipeline=created_policy.pipeline, created_at=created_policy.created_at, updated_at=created_policy.updated_at, created_by=created_policy.created_by, @@ -325,7 +332,10 @@ async def update_policy_in_db( if policy_request.guardrails_remove is not None: update_data["guardrails_remove"] = policy_request.guardrails_remove if policy_request.condition is not None: - update_data["condition"] = policy_request.condition.model_dump() + update_data["condition"] = PrismaJson(policy_request.condition.model_dump()) + if policy_request.pipeline is not None: + validated_pipeline = GuardrailPipeline(**policy_request.pipeline) + update_data["pipeline"] = PrismaJson(validated_pipeline.model_dump()) updated_policy = await prisma_client.db.litellm_policytable.update( where={"policy_id": policy_id}, @@ -343,6 +353,7 @@ async def update_policy_in_db( "remove": updated_policy.guardrails_remove, }, "condition": updated_policy.condition, + "pipeline": updated_policy.pipeline, }, ) self.add_policy(updated_policy.policy_name, policy) @@ -355,6 +366,7 @@ async def update_policy_in_db( guardrails_add=updated_policy.guardrails_add or [], guardrails_remove=updated_policy.guardrails_remove or [], condition=updated_policy.condition, + pipeline=updated_policy.pipeline, created_at=updated_policy.created_at, updated_at=updated_policy.updated_at, created_by=updated_policy.created_by, @@ -432,6 +444,7 @@ async def get_policy_by_id_from_db( guardrails_add=policy.guardrails_add or [], guardrails_remove=policy.guardrails_remove or [], condition=policy.condition, + pipeline=policy.pipeline, created_at=policy.created_at, updated_at=policy.updated_at, created_by=policy.created_by, @@ -468,6 +481,7 @@ async def get_all_policies_from_db( guardrails_add=p.guardrails_add or [], guardrails_remove=p.guardrails_remove or [], condition=p.condition, + pipeline=p.pipeline, created_at=p.created_at, updated_at=p.updated_at, created_by=p.created_by, @@ -503,6 +517,7 @@ async def sync_policies_from_db( "remove": policy_response.guardrails_remove, }, "condition": policy_response.condition, + "pipeline": policy_response.pipeline, }, ) self.add_policy(policy_response.policy_name, policy) @@ -551,6 +566,7 @@ async def resolve_guardrails_from_db( "remove": policy_response.guardrails_remove, }, "condition": policy_response.condition, + "pipeline": policy_response.pipeline, }, ) temp_policies[policy_response.policy_name] = policy diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index c02c3824d68..dab0c8237a7 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -917,6 +917,7 @@ model LiteLLM_PolicyTable { guardrails_add String[] @default([]) guardrails_remove String[] @default([]) condition Json? @default("{}") // Policy conditions (e.g., model matching) + pipeline Json? // Optional guardrail pipeline (mode + steps[]) created_at DateTime @default(now()) created_by String? updated_at DateTime @default(now()) @updatedAt diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a441b2ae7d1..66cf95f8e6b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1222,10 +1222,7 @@ def _handle_pipeline_result( }, } } - if HTTPException is not None: - raise HTTPException(status_code=400, detail=error_detail) - else: - raise Exception(str(error_detail)) + raise HTTPException(status_code=400, detail=error_detail) if result.terminal_action == "modify_response": raise ModifyResponseException( @@ -1236,9 +1233,6 @@ def _handle_pipeline_result( detection_info=None, ) - verbose_proxy_logger.warning( - f"Pipeline '{policy_name}': unrecognized terminal_action '{result.terminal_action}', defaulting to allow" - ) return data # The actual implementation of the function diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index 6f1a8d27d34..e0c1d6f30da 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -26,6 +26,7 @@ ) from litellm.types.proxy.policy_engine.resolver_types import ( AttachmentImpactResponse, + PipelineTestRequest, PolicyAttachmentCreateRequest, PolicyAttachmentDBResponse, PolicyAttachmentListResponse, @@ -90,6 +91,8 @@ "PolicyAttachmentCreateRequest", "PolicyAttachmentDBResponse", "PolicyAttachmentListResponse", + # Pipeline test types + "PipelineTestRequest", # Resolve types "PolicyResolveRequest", "PolicyResolveResponse", diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py index 0c2c7336f8a..a5a2334ae4b 100644 --- a/litellm/types/proxy/policy_engine/resolver_types.py +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -154,6 +154,10 @@ class PolicyCreateRequest(BaseModel): default=None, description="Condition for when this policy applies.", ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional guardrail pipeline for ordered execution. Contains 'mode' and 'steps'.", + ) class PolicyUpdateRequest(BaseModel): @@ -183,6 +187,10 @@ class PolicyUpdateRequest(BaseModel): default=None, description="Condition for when this policy applies.", ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional guardrail pipeline for ordered execution. Contains 'mode' and 'steps'.", + ) class PolicyDBResponse(BaseModel): @@ -201,6 +209,9 @@ class PolicyDBResponse(BaseModel): condition: Optional[Dict[str, Any]] = Field( default=None, description="Policy condition." ) + pipeline: Optional[Dict[str, Any]] = Field( + default=None, description="Optional guardrail pipeline." + ) created_at: Optional[datetime] = Field( default=None, description="When the policy was created." ) @@ -291,6 +302,17 @@ class PolicyAttachmentListResponse(BaseModel): # ───────────────────────────────────────────────────────────────────────────── +class PipelineTestRequest(BaseModel): + """Request body for testing a guardrail pipeline with sample messages.""" + + pipeline: Dict[str, Any] = Field( + description="Pipeline definition with 'mode' and 'steps'.", + ) + test_messages: List[Dict[str, str]] = Field( + description="Test messages to run through the pipeline, e.g. [{'role': 'user', 'content': '...'}].", + ) + + class PolicyResolveRequest(BaseModel): """Request body for resolving effective policies/guardrails for a context.""" diff --git a/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py b/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py new file mode 100644 index 00000000000..c23ed5d4319 --- /dev/null +++ b/tests/test_litellm/types/proxy/policy_engine/test_resolver_types.py @@ -0,0 +1,102 @@ +""" +Tests for pipeline field on policy CRUD types (resolver_types.py). +""" + +import pytest + +from litellm.types.proxy.policy_engine.resolver_types import ( + PolicyCreateRequest, + PolicyDBResponse, + PolicyUpdateRequest, +) + + +def test_policy_create_request_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "next", "on_pass": "allow"}, + {"guardrail": "g2", "on_fail": "block", "on_pass": "allow"}, + ], + } + req = PolicyCreateRequest( + policy_name="test-policy", + guardrails_add=["g1", "g2"], + pipeline=pipeline_data, + ) + assert req.pipeline is not None + assert req.pipeline["mode"] == "pre_call" + assert len(req.pipeline["steps"]) == 2 + + +def test_policy_create_request_without_pipeline(): + req = PolicyCreateRequest( + policy_name="test-policy", + guardrails_add=["g1"], + ) + assert req.pipeline is None + + +def test_policy_update_request_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "block", "on_pass": "allow"}, + ], + } + req = PolicyUpdateRequest(pipeline=pipeline_data) + assert req.pipeline is not None + assert req.pipeline["steps"][0]["guardrail"] == "g1" + + +def test_policy_db_response_with_pipeline(): + pipeline_data = { + "mode": "pre_call", + "steps": [ + {"guardrail": "g1", "on_fail": "next", "on_pass": "allow"}, + {"guardrail": "g2", "on_fail": "block", "on_pass": "allow"}, + ], + } + resp = PolicyDBResponse( + policy_id="test-id", + policy_name="test-policy", + guardrails_add=["g1", "g2"], + pipeline=pipeline_data, + ) + assert resp.pipeline is not None + assert resp.pipeline["mode"] == "pre_call" + dumped = resp.model_dump() + assert dumped["pipeline"]["steps"][0]["guardrail"] == "g1" + + +def test_policy_db_response_without_pipeline(): + resp = PolicyDBResponse( + policy_id="test-id", + policy_name="test-policy", + ) + assert resp.pipeline is None + dumped = resp.model_dump() + assert dumped["pipeline"] is None + + +def test_policy_create_request_roundtrip(): + pipeline_data = { + "mode": "post_call", + "steps": [ + { + "guardrail": "g1", + "on_fail": "modify_response", + "on_pass": "next", + "pass_data": True, + "modify_response_message": "custom msg", + }, + ], + } + req = PolicyCreateRequest( + policy_name="roundtrip-test", + guardrails_add=["g1"], + pipeline=pipeline_data, + ) + dumped = req.model_dump() + restored = PolicyCreateRequest(**dumped) + assert restored.pipeline == pipeline_data diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 55b97f5252f..885a62fab54 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -5673,6 +5673,37 @@ export const deletePolicyAttachmentCall = async (accessToken: string, attachment } }; +export const testPipelineCall = async ( + accessToken: string, + pipeline: any, + testMessages: Array<{role: string, content: string}> +) => { + try { + const url = proxyBaseUrl ? `${proxyBaseUrl}/policies/test-pipeline` : `/policies/test-pipeline`; + const response = await fetch(url, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ pipeline, test_messages: testMessages }), + }); + + if (!response.ok) { + const errorData = await response.json(); + const errorMessage = deriveErrorMessage(errorData); + handleError(errorMessage); + throw new Error(errorMessage); + } + + const data = await response.json(); + return data; + } catch (error) { + console.error("Failed to test pipeline:", error); + throw error; + } +}; + export const getResolvedGuardrails = async (accessToken: string, policyId: string) => { try { const url = proxyBaseUrl diff --git a/ui/litellm-dashboard/src/components/policies/add_policy_form.tsx b/ui/litellm-dashboard/src/components/policies/add_policy_form.tsx index 99ea09db130..4ccc151bd13 100644 --- a/ui/litellm-dashboard/src/components/policies/add_policy_form.tsx +++ b/ui/litellm-dashboard/src/components/policies/add_policy_form.tsx @@ -14,6 +14,7 @@ interface AddPolicyFormProps { visible: boolean; onClose: () => void; onSuccess: () => void; + onOpenFlowBuilder: () => void; accessToken: string | null; editingPolicy?: Policy | null; existingPolicies: Policy[]; @@ -22,10 +23,117 @@ interface AddPolicyFormProps { updatePolicy: (accessToken: string, policyId: string, policyData: any) => Promise; } +// ───────────────────────────────────────────────────────────────────────────── +// Mode Picker (Step 1) - shown first when creating a new policy +// ───────────────────────────────────────────────────────────────────────────── + +interface ModePicker { + selected: "simple" | "flow_builder"; + onSelect: (mode: "simple" | "flow_builder") => void; +} + +const ModePicker: React.FC = ({ selected, onSelect }) => ( +
+ {/* Simple Mode Card */} +
onSelect("simple")} + style={{ + flex: 1, + padding: "24px 20px", + border: `2px solid ${selected === "simple" ? "#4f46e5" : "#e5e7eb"}`, + borderRadius: 12, + cursor: "pointer", + backgroundColor: selected === "simple" ? "#eef2ff" : "#fff", + transition: "all 0.15s ease", + }} + > +
+ + + + +
+ + Simple Mode + + + Pick guardrails from a list. All run in parallel. + +
+ + {/* Flow Builder Card */} +
onSelect("flow_builder")} + style={{ + flex: 1, + padding: "24px 20px", + border: `2px solid ${selected === "flow_builder" ? "#4f46e5" : "#e5e7eb"}`, + borderRadius: 12, + cursor: "pointer", + backgroundColor: selected === "flow_builder" ? "#eef2ff" : "#fff", + transition: "all 0.15s ease", + position: "relative", + }} + > + + NEW + +
+ + + +
+ + Flow Builder + + + Define steps, conditions, and error responses. + +
+
+); + +// ───────────────────────────────────────────────────────────────────────────── +// Main Component +// ───────────────────────────────────────────────────────────────────────────── + const AddPolicyForm: React.FC = ({ visible, onClose, onSuccess, + onOpenFlowBuilder, accessToken, editingPolicy, existingPolicies, @@ -39,16 +147,16 @@ const AddPolicyForm: React.FC = ({ const [isLoadingResolved, setIsLoadingResolved] = useState(false); const [modelConditionType, setModelConditionType] = useState<"model" | "regex">("model"); const [availableModels, setAvailableModels] = useState([]); + const [step, setStep] = useState<"pick_mode" | "simple_form">("pick_mode"); + const [selectedMode, setSelectedMode] = useState<"simple" | "flow_builder">("simple"); const { userId, userRole } = useAuthorized(); // Only consider it "editing" if editingPolicy has a policy_id (real existing policy) - // If editingPolicy is set but has no policy_id, it's just pre-filled data for a new policy (e.g., from a template) const isEditing = !!editingPolicy?.policy_id; useEffect(() => { if (visible && editingPolicy) { const modelCondition = editingPolicy.condition?.model; - // Detect if it's a regex pattern (contains *, ., [, ], etc.) const isRegex = modelCondition && /[.*+?^${}()|[\]\\]/.test(modelCondition); setModelConditionType(isRegex ? "regex" : "model"); @@ -60,14 +168,25 @@ const AddPolicyForm: React.FC = ({ guardrails_remove: editingPolicy.guardrails_remove || [], model_condition: modelCondition, }); - // Load resolved guardrails for editing + if (editingPolicy.policy_id && accessToken) { loadResolvedGuardrails(editingPolicy.policy_id); } + + // If editing a pipeline policy, go directly to flow builder + if (editingPolicy.pipeline) { + onClose(); + onOpenFlowBuilder(); + return; + } + // If editing a simple policy, skip mode picker + setStep("simple_form"); } else if (visible) { form.resetFields(); setResolvedGuardrails([]); setModelConditionType("model"); + setSelectedMode("simple"); + setStep("pick_mode"); } // eslint-disable-next-line react-hooks/exhaustive-deps }, [visible, editingPolicy, form]); @@ -81,7 +200,6 @@ const AddPolicyForm: React.FC = ({ const loadAvailableModels = async () => { if (!accessToken) return; - try { const response = await modelAvailableCall(accessToken, userId, userRole); if (response?.data) { @@ -95,7 +213,6 @@ const AddPolicyForm: React.FC = ({ const loadResolvedGuardrails = async (policyId: string) => { if (!accessToken) return; - setIsLoadingResolved(true); try { const data = await getResolvedGuardrails(accessToken, policyId); @@ -115,20 +232,15 @@ const AddPolicyForm: React.FC = ({ let resolved = new Set(); - // If inheriting, find parent policy and get its guardrails if (inheritFrom) { const parentPolicy = existingPolicies.find(p => p.policy_name === inheritFrom); if (parentPolicy) { - // Recursively resolve parent's guardrails const parentResolved = resolveParentGuardrails(parentPolicy); parentResolved.forEach(g => resolved.add(g)); } } - // Add guardrails guardrailsAdd.forEach((g: string) => resolved.add(g)); - - // Remove guardrails guardrailsRemove.forEach((g: string) => resolved.delete(g)); return Array.from(resolved).sort(); @@ -137,32 +249,23 @@ const AddPolicyForm: React.FC = ({ const resolveParentGuardrails = (policy: Policy): string[] => { let resolved = new Set(); - // If parent inherits, resolve recursively if (policy.inherit) { const grandparent = existingPolicies.find(p => p.policy_name === policy.inherit); if (grandparent) { - const grandparentResolved = resolveParentGuardrails(grandparent); - grandparentResolved.forEach(g => resolved.add(g)); + resolveParentGuardrails(grandparent).forEach(g => resolved.add(g)); } } - - // Add parent's guardrails if (policy.guardrails_add) { policy.guardrails_add.forEach(g => resolved.add(g)); } - - // Remove parent's removed guardrails if (policy.guardrails_remove) { policy.guardrails_remove.forEach(g => resolved.delete(g)); } - return Array.from(resolved); }; - // Recompute resolved guardrails when form values change const handleFormChange = () => { - const resolved = computeResolvedGuardrails(); - setResolvedGuardrails(resolved); + setResolvedGuardrails(computeResolvedGuardrails()); }; const resetForm = () => { @@ -171,9 +274,20 @@ const AddPolicyForm: React.FC = ({ const handleClose = () => { resetForm(); + setStep("pick_mode"); + setSelectedMode("simple"); onClose(); }; + const handleModeConfirm = () => { + if (selectedMode === "flow_builder") { + onClose(); + onOpenFlowBuilder(); + } else { + setStep("simple_form"); + } + }; + const handleSubmit = async () => { try { setIsSubmitting(true); @@ -228,6 +342,50 @@ const AddPolicyForm: React.FC = ({ value: p.policy_name, })); + // ── Mode Picker Step ────────────────────────────────────────────────────── + if (step === "pick_mode") { + return ( + + + + {selectedMode === "flow_builder" && ( + + )} + +
+ + +
+
+ ); + } + + // ── Simple Form Step ────────────────────────────────────────────────────── return ( = ({ const [selectedTemplate, setSelectedTemplate] = useState(null); const [existingGuardrailNames, setExistingGuardrailNames] = useState>(new Set()); const [isCreatingGuardrails, setIsCreatingGuardrails] = useState(false); + const [showFlowBuilder, setShowFlowBuilder] = useState(false); const isAdmin = userRole ? isAdminRole(userRole) : false; @@ -349,8 +351,12 @@ const PoliciesPanel: React.FC = ({ onClose={() => setSelectedPolicyId(null)} onEdit={(policy) => { setEditingPolicy(policy); - setIsAddPolicyModalVisible(true); setSelectedPolicyId(null); + if (policy.pipeline) { + setShowFlowBuilder(true); + } else { + setIsAddPolicyModalVisible(true); + } }} accessToken={accessToken} isAdmin={isAdmin} @@ -363,7 +369,11 @@ const PoliciesPanel: React.FC = ({ onDeleteClick={handleDeleteClick} onEditClick={(policy) => { setEditingPolicy(policy); - setIsAddPolicyModalVisible(true); + if (policy.pipeline) { + setShowFlowBuilder(true); + } else { + setIsAddPolicyModalVisible(true); + } }} onViewClick={(policyId) => setSelectedPolicyId(policyId)} isAdmin={isAdmin} @@ -374,6 +384,10 @@ const PoliciesPanel: React.FC = ({ visible={isAddPolicyModalVisible} onClose={handleCloseModal} onSuccess={handleSuccess} + onOpenFlowBuilder={() => { + setIsAddPolicyModalVisible(false); + setShowFlowBuilder(true); + }} accessToken={accessToken} editingPolicy={editingPolicy} existingPolicies={policiesList} @@ -473,6 +487,24 @@ const PoliciesPanel: React.FC = ({ + + {showFlowBuilder && ( + { + setShowFlowBuilder(false); + setEditingPolicy(null); + }} + onSuccess={() => { + fetchPolicies(); + setEditingPolicy(null); + }} + accessToken={accessToken} + editingPolicy={editingPolicy} + availableGuardrails={guardrailsList} + createPolicy={createPolicyCall} + updatePolicy={updatePolicyCall} + /> + )} ); }; diff --git a/ui/litellm-dashboard/src/components/policies/pipeline_flow_builder.tsx b/ui/litellm-dashboard/src/components/policies/pipeline_flow_builder.tsx new file mode 100644 index 00000000000..259378f2e19 --- /dev/null +++ b/ui/litellm-dashboard/src/components/policies/pipeline_flow_builder.tsx @@ -0,0 +1,999 @@ +import React, { useState } from "react"; +import { Select, Typography, message } from "antd"; +import { Button, TextInput } from "@tremor/react"; +import { ArrowLeftIcon, PlusIcon } from "@heroicons/react/outline"; +import { DotsVerticalIcon } from "@heroicons/react/solid"; +import { GuardrailPipeline, PipelineStep, PipelineTestResult, PolicyCreateRequest, PolicyUpdateRequest, Policy } from "./types"; +import { Guardrail } from "../guardrails/types"; +import { testPipelineCall } from "../networking"; +import NotificationsManager from "../molecules/notifications_manager"; + +const { Text } = Typography; + +const ACTION_OPTIONS = [ + { label: "Next Step", value: "next" }, + { label: "Allow", value: "allow" }, + { label: "Block", value: "block" }, + { label: "Custom Response", value: "modify_response" }, +]; + +const ACTION_LABELS: Record = { + allow: "Allow", + block: "Block", + next: "Next Step", + modify_response: "Custom Response", +}; + +function createDefaultStep(): PipelineStep { + return { + guardrail: "", + on_pass: "next", + on_fail: "block", + pass_data: false, + modify_response_message: null, + }; +} + +function insertStep(steps: PipelineStep[], atIndex: number): PipelineStep[] { + const newSteps = [...steps]; + newSteps.splice(atIndex, 0, createDefaultStep()); + return newSteps; +} + +function removeStep(steps: PipelineStep[], index: number): PipelineStep[] { + if (steps.length <= 1) return steps; + const newSteps = [...steps]; + newSteps.splice(index, 1); + return newSteps; +} + +function updateStepAtIndex( + steps: PipelineStep[], + index: number, + updated: Partial +): PipelineStep[] { + return steps.map((s, i) => (i === index ? { ...s, ...updated } : s)); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Icons (matching the reference image) +// ───────────────────────────────────────────────────────────────────────────── + +const GuardrailIcon: React.FC = () => ( +
+ + + + +
+); + +const PlayIcon: React.FC = () => ( +
+ + + +
+); + +const PassIcon: React.FC = () => ( + + + + +); + +const FailIcon: React.FC = () => ( + + + +); + +// ───────────────────────────────────────────────────────────────────────────── +// Connector +// ───────────────────────────────────────────────────────────────────────────── + +interface ConnectorProps { + onInsert: () => void; +} + +const Connector: React.FC = ({ onInsert }) => ( +
+
+ +
+
+); + +// ───────────────────────────────────────────────────────────────────────────── +// Step Card (editable) +// ───────────────────────────────────────────────────────────────────────────── + +interface StepCardProps { + step: PipelineStep; + stepIndex: number; + totalSteps: number; + onChange: (updated: Partial) => void; + onDelete: () => void; + availableGuardrails: Guardrail[]; +} + +const StepCard: React.FC = ({ + step, + stepIndex, + totalSteps, + onChange, + onDelete, + availableGuardrails, +}) => { + const guardrailOptions = availableGuardrails.map((g) => ({ + label: g.guardrail_name || g.guardrail_id, + value: g.guardrail_name || g.guardrail_id, + })); + + return ( +
+ {/* Header row */} +
+
+ + + GUARDRAIL + +
+
+ + Step {stepIndex + 1} + + +
+
+ + {/* Guardrail selector */} +
+ + onChange({ on_pass: value as PipelineStep["on_pass"] })} + options={ACTION_OPTIONS} + /> + {step.on_pass === "modify_response" && ( +
+ + onChange({ modify_response_message: e.target.value || null })} + /> +
+ )} +
+ + {/* ON FAIL section */} +
+
+ + ON FAIL +
+ +