diff --git a/docs/my-website/blog/claude_code_beta_headers/index.md b/docs/my-website/blog/claude_code_beta_headers/index.md new file mode 100644 index 00000000000..138a85a60c5 --- /dev/null +++ b/docs/my-website/blog/claude_code_beta_headers/index.md @@ -0,0 +1,274 @@ +--- +slug: claude_code_beta_headers +title: "Claude Code - Managing Anthropic Beta Headers" +date: 2026-02-16T10:00:00 +authors: + - name: Sameer Kankute + title: SWE @ LiteLLM (LLM Translation) + url: https://www.linkedin.com/in/sameer-kankute/ + image_url: https://pbs.twimg.com/profile_images/2001352686994907136/ONgNuSk5_400x400.jpg + - name: Ishaan Jaff + title: "CTO, LiteLLM" + url: https://www.linkedin.com/in/reffajnaahsi/ + image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg + - name: Krrish Dholakia + title: "CEO, LiteLLM" + url: https://www.linkedin.com/in/krish-d/ + image_url: https://pbs.twimg.com/profile_images/1298587542745358340/DZv3Oj-h_400x400.jpg +description: "How to manage and configure Anthropic beta headers with Claude Code in LiteLLM: filtering, mapping, and dynamic updates across providers." +tags: [anthropic, claude, beta headers, configuration, liteLLM] +hide_table_of_contents: false + +--- +import Image from '@theme/IdealImage'; + +When using Claude Code with LiteLLM and non-Anthropic providers (Bedrock, Azure AI, Vertex AI), you need to ensure that only supported beta headers are sent to each provider. This guide explains how to add support for new beta headers or fix invalid beta header errors. + +## What Are Beta Headers? + +Anthropic uses beta headers to enable experimental features in Claude. When you use Claude Code, it may send beta headers like: + +``` +anthropic-beta: prompt-caching-scope-2026-01-05,advanced-tool-use-2025-11-20 +``` + +However, not all providers support all Anthropic beta features. LiteLLM uses `anthropic_beta_headers_config.json` to manage which beta headers are supported by each provider. + +## Common Error Message + +```bash +Error: The model returned the following errors: invalid beta flag +``` + +## How LiteLLM Handles Beta Headers + +LiteLLM uses a strict validation approach with a configuration file: + +``` +litellm/litellm/anthropic_beta_headers_config.json +``` + +This JSON file contains a **mapping** of beta headers for each provider: +- **Keys**: Input beta header names (from Anthropic) +- **Values**: Provider-specific header names (or `null` if unsupported) +- **Validation**: Only headers present in the mapping with non-null values are forwarded + +This enforces stricter validation than just filtering unsupported headers - headers must be explicitly defined to be allowed. + +## Adding Support for a New Beta Header + +When Anthropic releases a new beta feature, you need to add it to the configuration file for each provider. + +### Step 1: Add the New Beta Header + +Open `anthropic_beta_headers_config.json` and add the new header to each provider's mapping: + +```json title="anthropic_beta_headers_config.json" +{ + "description": "Mapping of Anthropic beta headers for each provider. Keys are input header names, values are provider-specific header names (or null if unsupported). Only headers present in mapping keys with non-null values can be forwarded.", + "anthropic": { + "advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20", + "new-feature-2026-03-01": "new-feature-2026-03-01", + ... + }, + "azure_ai": { + "advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20", + "new-feature-2026-03-01": "new-feature-2026-03-01", + ... + }, + "bedrock_converse": { + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", + "new-feature-2026-03-01": null, + ... + }, + "bedrock": { + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", + "new-feature-2026-03-01": null, + ... + }, + "vertex_ai": { + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", + "new-feature-2026-03-01": null, + ... + } +} +``` + +**Key Points:** +- **Supported headers**: Set the value to the provider-specific header name (often the same as the key) +- **Unsupported headers**: Set the value to `null` +- **Header transformations**: Some providers use different header names (e.g., Bedrock maps `advanced-tool-use-2025-11-20` to `tool-search-tool-2025-10-19`) +- **Alphabetical order**: Keep headers sorted alphabetically for maintainability + +### Step 2: Reload Configuration (No Restart Required!) + +**Option 1: Dynamic Reload Without Restart** + +Instead of restarting your application, you can dynamically reload the beta headers configuration using environment variables and API endpoints: + +```bash +# Set environment variable to fetch from remote URL (Do this if you want to point it to some other URL) +export LITELLM_ANTHROPIC_BETA_HEADERS_URL="https://raw.githubusercontent.com/BerriAI/litellm/main/litellm/anthropic_beta_headers_config.json" + +# Manually trigger reload via API (no restart needed!) +curl -X POST "https://your-proxy-url/reload/anthropic_beta_headers" \ + -H "Authorization: Bearer YOUR_ADMIN_TOKEN" +``` + +**Option 2: Schedule Automatic Reloads** + +Set up automatic reloading to always stay up-to-date with the latest beta headers: + +```bash +# Reload configuration every 24 hours +curl -X POST "https://your-proxy-url/schedule/anthropic_beta_headers_reload?hours=24" \ + -H "Authorization: Bearer YOUR_ADMIN_TOKEN" +``` + +**Option 3: Traditional Restart** + +If you prefer the traditional approach, restart your LiteLLM proxy or application: + +```bash +# If using LiteLLM proxy +litellm --config config.yaml + +# If using Python SDK +# Just restart your Python application +``` + +:::tip Zero-Downtime Updates +With dynamic reloading, you can fix invalid beta header errors **without restarting your service**! This is especially useful in production environments where downtime is costly. + +See [Auto Sync Anthropic Beta Headers](../proxy/sync_anthropic_beta_headers.md) for complete documentation. +::: + +## Fixing Invalid Beta Header Errors + +If you encounter an "invalid beta flag" error, it means a beta header is being sent that the provider doesn't support. + +### Step 1: Identify the Problematic Header + +Check your logs to see which header is causing the issue: + +```bash +Error: The model returned the following errors: invalid beta flag: new-feature-2026-03-01 +``` + +### Step 2: Update the Config + +Set the header value to `null` for that provider: + +```json title="anthropic_beta_headers_config.json" +{ + "bedrock_converse": { + "new-feature-2026-03-01": null + } +} +``` + +### Step 3: Restart and Test + +Restart your application and verify the header is now filtered out. + +## Contributing a Fix to LiteLLM + +Help the community by contributing your fix! + +### What to Include in Your PR + +1. **Update the config file**: Add the new beta header to `litellm/anthropic_beta_headers_config.json` +2. **Test your changes**: Verify the header is correctly filtered/mapped for each provider +3. **Documentation**: Include provider documentation links showing which headers are supported + +### Example PR Description + +```markdown +## Add support for new-feature-2026-03-01 beta header + +### Changes +- Added `new-feature-2026-03-01` to anthropic_beta_headers_config.json +- Set to `null` for bedrock_converse (unsupported) +- Set to header name for anthropic, azure_ai (supported) + +### Testing +Tested with: +- ✅ Anthropic: Header passed through correctly +- ✅ Azure AI: Header passed through correctly +- ✅ Bedrock Converse: Header filtered out (returns error without fix) + +### References +- Anthropic docs: [link] +- AWS Bedrock docs: [link] +``` + + +## How Beta Header Filtering Works + +When you make a request through LiteLLM: + +```mermaid +sequenceDiagram + participant CC as Claude Code + participant LP as LiteLLM + participant Config as Beta Headers Config + participant Provider as Provider (Bedrock/Azure/etc) + + CC->>LP: Request with beta headers + Note over CC,LP: anthropic-beta: header1,header2,header3 + + LP->>Config: Load header mapping for provider + Config-->>LP: Returns mapping (header→value or null) + + Note over LP: Validate & Transform:
1. Check if header exists in mapping
2. Filter out null values
3. Map to provider-specific names + + LP->>Provider: Request with filtered & mapped headers + Note over LP,Provider: anthropic-beta: mapped-header2
(header1, header3 filtered out) + + Provider-->>LP: Success response + LP-->>CC: Response +``` + +### Filtering Rules + +1. **Header must exist in mapping**: Unknown headers are filtered out +2. **Header must have non-null value**: Headers with `null` values are filtered out +3. **Header transformation**: Headers are mapped to provider-specific names (e.g., `advanced-tool-use-2025-11-20` → `tool-search-tool-2025-10-19` for Bedrock) + +### Example + +Request with headers: +``` +anthropic-beta: advanced-tool-use-2025-11-20,computer-use-2025-01-24,unknown-header +``` + +For Bedrock Converse: +- ✅ `computer-use-2025-01-24` → `computer-use-2025-01-24` (supported, passed through) +- ❌ `advanced-tool-use-2025-11-20` → filtered out (null value in config) +- ❌ `unknown-header` → filtered out (not in config) + +Result sent to Bedrock: +``` +anthropic-beta: computer-use-2025-01-24 +``` + +## Dynamic Configuration Management (No Restart Required!) + +### Environment Variables + +Control how LiteLLM loads the beta headers configuration: + +| Variable | Description | Default | +|----------|-------------|---------| +| `LITELLM_ANTHROPIC_BETA_HEADERS_URL` | URL to fetch config from | GitHub main branch | +| `LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS` | Set to `True` to use local config only | `False` | + +**Example: Use Custom Config URL** +```bash +export LITELLM_ANTHROPIC_BETA_HEADERS_URL="https://your-company.com/custom-beta-headers.json" +``` + +**Example: Use Local Config Only (No Remote Fetching)** +```bash +export LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True +``` diff --git a/docs/my-website/blog/claude_opus_4_6/index.md b/docs/my-website/blog/claude_opus_4_6/index.md index 82320472e13..e44420bd570 100644 --- a/docs/my-website/blog/claude_opus_4_6/index.md +++ b/docs/my-website/blog/claude_opus_4_6/index.md @@ -185,7 +185,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ model_list: - model_name: claude-opus-4-6 litellm_params: - model: bedrock/anthropic.claude-opus-4-6-v1:0 + model: bedrock/anthropic.claude-opus-4-6-v1 aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY aws_region_name: us-east-1 diff --git a/docs/my-website/docs/projects/openai-agents.md b/docs/my-website/docs/projects/openai-agents.md index 95a2191b883..86983e7e510 100644 --- a/docs/my-website/docs/projects/openai-agents.md +++ b/docs/my-website/docs/projects/openai-agents.md @@ -1,22 +1,121 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; # OpenAI Agents SDK -The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) is a lightweight framework for building multi-agent workflows. -It includes an official LiteLLM extension that lets you use any of the 100+ supported providers (Anthropic, Gemini, Mistral, Bedrock, etc.) +Use OpenAI Agents SDK with any LLM provider through LiteLLM Proxy. + +The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) is a lightweight framework for building multi-agent workflows. It includes an official LiteLLM extension that lets you use any of the 100+ supported providers. + +## Quick Start + +### 1. Install Dependencies + +```bash +pip install "openai-agents[litellm]" +``` + +### 2. Add Model to Config + +```yaml title="config.yaml" +model_list: + - model_name: gpt-4o + litellm_params: + model: "openai/gpt-4o" + api_key: "os.environ/OPENAI_API_KEY" + + - model_name: claude-sonnet + litellm_params: + model: "anthropic/claude-3-5-sonnet-20241022" + api_key: "os.environ/ANTHROPIC_API_KEY" + + - model_name: gemini-pro + litellm_params: + model: "gemini/gemini-2.0-flash-exp" + api_key: "os.environ/GEMINI_API_KEY" +``` + +### 3. Start LiteLLM Proxy + +```bash +litellm --config config.yaml +``` + +### 4. Use with Proxy + + + + +```python +from agents import Agent, Runner +from agents.extensions.models.litellm_model import LitellmModel + +# Point to LiteLLM proxy +agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + model=LitellmModel( + model="claude-sonnet", # Model from config.yaml + api_key="sk-1234", # LiteLLM API key + base_url="http://localhost:4000" + ) +) + +result = await Runner.run(agent, "What is LiteLLM?") +print(result.final_output) +``` + + + ```python from agents import Agent, Runner from agents.extensions.models.litellm_model import LitellmModel +# Use any provider directly agent = Agent( name="Assistant", instructions="You are a helpful assistant.", - model=LitellmModel(model="provider/model-name") + model=LitellmModel( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="your-anthropic-key" + ) ) -result = Runner.run_sync(agent, "your_prompt_here") -print("Result:", result.final_output) +result = await Runner.run(agent, "What is LiteLLM?") +print(result.final_output) ``` -- [GitHub](https://github.com/openai/openai-agents-python) -- [LiteLLM Extension Docs](https://openai.github.io/openai-agents-python/ref/extensions/litellm/) + + + +## Track Usage + +Enable usage tracking to monitor token consumption: + +```python +from agents import Agent, ModelSettings +from agents.extensions.models.litellm_model import LitellmModel + +agent = Agent( + name="Assistant", + model=LitellmModel(model="claude-sonnet", api_key="sk-1234"), + model_settings=ModelSettings(include_usage=True) +) + +result = await Runner.run(agent, "Hello") +print(result.context_wrapper.usage) # Token counts +``` + +## Environment Variables + +| Variable | Value | Description | +|----------|-------|-------------| +| `LITELLM_BASE_URL` | `http://localhost:4000` | LiteLLM proxy URL | +| `LITELLM_API_KEY` | `sk-1234` | Your LiteLLM API key | + +## Related Resources + +- [OpenAI Agents SDK Documentation](https://openai.github.io/openai-agents-python/) +- [LiteLLM Extension Docs](https://openai.github.io/openai-agents-python/models/litellm/) +- [LiteLLM Proxy Quick Start](../proxy/quick_start) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index ac554b09174..5e3f56c4206 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -769,6 +769,7 @@ router_settings: | LITELM_ENVIRONMENT | Environment of LiteLLM Instance, used by logging services. Currently only used by DeepEval. | LITELLM_KEY_ROTATION_ENABLED | Enable auto-key rotation for LiteLLM (boolean). Default is false. | LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS | Interval in seconds for how often to run job that auto-rotates keys. Default is 86400 (24 hours). +| LITELLM_KEY_ROTATION_GRACE_PERIOD | Duration to keep old key valid after rotation (e.g. "24h", "2d"). Default is empty (immediate revoke). Used for scheduled rotations and as fallback when not specified in regenerate request. | LITELLM_LICENSE | License key for LiteLLM usage | LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS | Set to `True` to use the local bundled Anthropic beta headers config only, disabling remote fetching. Default is `False` | LITELLM_LOCAL_MODEL_COST_MAP | Local configuration for model cost mapping in LiteLLM diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 56fb420e6cf..1abb127dfda 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1338,6 +1338,7 @@ litellm_settings: s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets + s3_use_virtual_hosted_style: false # [OPTIONAL] use virtual-hosted-style URLs (bucket.endpoint/key) instead of path-style (endpoint/bucket/key). Useful for S3-compatible services like MinIO s3_strip_base64_files: false # [OPTIONAL] remove base64 files before storing in s3 ``` diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 38ff4ede280..c74aa75ff4a 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -549,11 +549,14 @@ curl 'http://localhost:4000/key/sk-1234/regenerate' \ "models": [ "gpt-4", "gpt-3.5-turbo" - ] + ], + "grace_period": "48h" }' ``` +**Grace period (optional)**: Set `grace_period` (e.g. `"24h"`, `"2d"`, `"1w"`) to keep the old key valid for a transitional period. Both old and new keys work until the grace period elapses, enabling seamless cutover without production downtime. Omitted or empty = immediate revoke. Can also be set via `LITELLM_KEY_ROTATION_GRACE_PERIOD` env var for scheduled rotations. + **Read More** - [Write rotated keys to secrets manager](https://docs.litellm.ai/docs/secret#aws-secret-manager) @@ -640,11 +643,13 @@ Set these environment variables when starting the proxy: |----------|-------------|---------| | `LITELLM_KEY_ROTATION_ENABLED` | Enable the rotation worker | `false` | | `LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS` | How often to scan for keys to rotate (in seconds) | `86400` (24 hours) | +| `LITELLM_KEY_ROTATION_GRACE_PERIOD` | Duration to keep old key valid after rotation (e.g. `24h`, `2d`) | `""` (immediate revoke) | **Example:** ```bash export LITELLM_KEY_ROTATION_ENABLED=true export LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS=3600 # Check every hour +export LITELLM_KEY_ROTATION_GRACE_PERIOD=48h # Keep old key valid for 48h during cutover litellm --config config.yaml ``` diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 4efb2475755..42996d1a3e9 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -176,6 +176,7 @@ const sidebars = { "tutorials/copilotkit_sdk", "tutorials/google_adk", "tutorials/livekit_xai_realtime", + "projects/openai-agents" ] }, diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py index bb25e4f0626..b28b4497e7c 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py @@ -4,7 +4,7 @@ from litellm._uuid import uuid from datetime import datetime -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional from litellm._logging import verbose_proxy_logger @@ -35,14 +35,11 @@ async def check_batch_cost(self): - if not, return False - if so, return True """ - from litellm_enterprise.proxy.hooks.managed_files import ( - _PROXY_LiteLLMManagedFiles, - ) - from litellm.batches.batch_utils import ( _get_file_content_as_dictionary, calculate_batch_cost_and_usage, ) + from litellm.files.main import afile_content from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.proxy.openai_files_endpoints.common_utils import ( @@ -102,27 +99,29 @@ async def check_batch_cost(self): continue ## RETRIEVE THE BATCH JOB OUTPUT FILE - managed_files_obj = cast( - Optional[_PROXY_LiteLLMManagedFiles], - self.proxy_logging_obj.get_proxy_hook("managed_files"), - ) if ( response.status == "completed" and response.output_file_id is not None - and managed_files_obj is not None ): verbose_proxy_logger.info( f"Batch ID: {batch_id} is complete, tracking cost and usage" ) - # track cost - model_file_id_mapping = { - response.output_file_id: {model_id: response.output_file_id} - } - _file_content = await managed_files_obj.afile_content( - file_id=response.output_file_id, - litellm_parent_otel_span=None, - llm_router=self.llm_router, - model_file_id_mapping=model_file_id_mapping, + + # This background job runs as default_user_id, so going through the HTTP endpoint + # would trigger check_managed_file_id_access and get 403. Instead, extract the raw + # provider file ID and call afile_content directly with deployment credentials. + raw_output_file_id = response.output_file_id + decoded = _is_base64_encoded_unified_file_id(raw_output_file_id) + if decoded: + try: + raw_output_file_id = decoded.split("llm_output_file_id,")[1].split(";")[0] + except (IndexError, AttributeError): + pass + + credentials = self.llm_router.get_deployment_credentials_with_provider(model_id) or {} + _file_content = await afile_content( + file_id=raw_output_file_id, + **credentials, ) file_content_as_dict = _get_file_content_as_dictionary( @@ -143,11 +142,15 @@ async def check_batch_cost(self): custom_llm_provider=custom_llm_provider, ) + # Pass deployment model_info so custom batch pricing + # (input_cost_per_token_batches etc.) is used for cost calc + deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {} batch_cost, batch_usage, batch_models = ( await calculate_batch_cost_and_usage( file_content_dictionary=file_content_as_dict, custom_llm_provider=llm_provider, # type: ignore model_name=model_name, + model_info=deployment_model_info, ) ) logging_obj = LiteLLMLogging( diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index a41b3f3bf6f..f341a1e9634 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -230,12 +230,14 @@ async def can_user_call_unified_file_id( if managed_file: return managed_file.created_by == user_id - return False + raise HTTPException( + status_code=404, + detail=f"File not found: {unified_file_id}", + ) async def can_user_call_unified_object_id( self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth ) -> bool: - ## check if the user has access to the unified object id ## check if the user has access to the unified object id user_id = user_api_key_dict.user_id managed_object = ( @@ -246,7 +248,10 @@ async def can_user_call_unified_object_id( if managed_object: return managed_object.created_by == user_id - return True # don't raise error if managed object is not found + raise HTTPException( + status_code=404, + detail=f"Object not found: {unified_object_id}", + ) async def list_user_batches( self, @@ -911,15 +916,22 @@ async def async_post_call_success_hook( ) setattr(response, file_attr, unified_file_id) - # Fetch the actual file object from the provider + # Use llm_router credentials when available. Without credentials, + # Azure and other auth-required providers return 500/401. file_object = None try: - # Use litellm to retrieve the file object from the provider - from litellm import afile_retrieve - file_object = await afile_retrieve( - custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", - file_id=original_file_id - ) + from litellm.proxy.proxy_server import llm_router as _llm_router + if _llm_router is not None and model_id: + _creds = _llm_router.get_deployment_credentials_with_provider(model_id) or {} + file_object = await litellm.afile_retrieve( + file_id=original_file_id, + **_creds, + ) + else: + file_object = await litellm.afile_retrieve( + custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", + file_id=original_file_id, + ) verbose_logger.debug( f"Successfully retrieved file object for {file_attr}={original_file_id}" ) @@ -1004,7 +1016,10 @@ async def afile_retrieve( raise Exception(f"LiteLLM Managed File object with id={file_id} not found") # Case 2: Managed file and the file object exists in the database + # The stored file_object has the raw provider ID. Replace with the unified ID + # so callers see a consistent ID (matching Case 3 which does response.id = file_id). if stored_file_object and stored_file_object.file_object: + stored_file_object.file_object.id = file_id return stored_file_object.file_object # Case 3: Managed file exists in the database but not the file object (for. e.g the batch task might not have run) diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260203120000_add_deprecated_verification_token_table/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260203120000_add_deprecated_verification_token_table/migration.sql new file mode 100644 index 00000000000..51d88444191 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260203120000_add_deprecated_verification_token_table/migration.sql @@ -0,0 +1,19 @@ +-- CreateTable +CREATE TABLE "LiteLLM_DeprecatedVerificationToken" ( + "id" TEXT NOT NULL, + "token" TEXT NOT NULL, + "active_token_id" TEXT NOT NULL, + "revoke_at" TIMESTAMP(3) NOT NULL, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "LiteLLM_DeprecatedVerificationToken_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_DeprecatedVerificationToken_token_key" ON "LiteLLM_DeprecatedVerificationToken"("token"); + +-- CreateIndex +CREATE INDEX "LiteLLM_DeprecatedVerificationToken_token_revoke_at_idx" ON "LiteLLM_DeprecatedVerificationToken"("token", "revoke_at"); + +-- CreateIndex +CREATE INDEX "LiteLLM_DeprecatedVerificationToken_revoke_at_idx" ON "LiteLLM_DeprecatedVerificationToken"("revoke_at"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214124140_baseline_diff/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214124140_baseline_diff/migration.sql new file mode 100644 index 00000000000..2f725d83806 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260214124140_baseline_diff/migration.sql @@ -0,0 +1,2 @@ +-- This is an empty migration. + diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index c2fca8705cb..441c2cdf70d 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -325,6 +325,19 @@ model LiteLLM_VerificationToken { @@index([budget_reset_at, expires]) } +// Deprecated keys during grace period - allows old key to work until revoke_at +model LiteLLM_DeprecatedVerificationToken { + id String @id @default(uuid()) + token String // Hashed old key + active_token_id String // Current token hash in LiteLLM_VerificationToken + revoke_at DateTime // When the old key stops working + created_at DateTime @default(now()) @map("created_at") + + @@unique([token]) + @@index([token, revoke_at]) + @@index([revoke_at]) +} + // Audit table for deleted keys - preserves spend and key information for historical tracking model LiteLLM_DeletedVerificationToken { id String @id @default(uuid()) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index b67d0d86063..8f9a3c5083f 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -312,10 +312,12 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti _duration, type(_duration) ) ) # invalid _duration value + # Batch polling callbacks (check_batch_cost) don't include call_type in kwargs. + # Use .get() to avoid KeyError. await self.async_service_success_hook( service=ServiceTypes.LITELLM, duration=_duration, - call_type=kwargs["call_type"], + call_type=kwargs.get("call_type", "unknown") ) except Exception as e: raise e diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index 16a467e00cb..29bd99c2a60 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -8,7 +8,7 @@ from litellm._logging import verbose_logger from litellm._uuid import uuid from litellm.types.llms.openai import Batch -from litellm.types.utils import CallTypes, ModelResponse, Usage +from litellm.types.utils import CallTypes, ModelInfo, ModelResponse, Usage from litellm.utils import token_counter @@ -16,14 +16,22 @@ async def calculate_batch_cost_and_usage( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"], model_name: Optional[str] = None, + model_info: Optional[ModelInfo] = None, ) -> Tuple[float, Usage, List[str]]: """ - Calculate the cost and usage of a batch + Calculate the cost and usage of a batch. + + Args: + model_info: Optional deployment-level model info with custom batch + pricing. Threaded through to batch_cost_calculator so that + deployment-specific pricing (e.g. input_cost_per_token_batches) + is used instead of the global cost map. """ batch_cost = _batch_cost_calculator( custom_llm_provider=custom_llm_provider, file_content_dictionary=file_content_dictionary, model_name=model_name, + model_info=model_info, ) batch_usage = _get_batch_job_total_usage_from_file_content( file_content_dictionary=file_content_dictionary, @@ -94,6 +102,7 @@ def _batch_cost_calculator( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai", model_name: Optional[str] = None, + model_info: Optional[ModelInfo] = None, ) -> float: """ Calculate the cost of a batch based on the output file id @@ -108,6 +117,7 @@ def _batch_cost_calculator( total_cost = _get_batch_job_cost_from_file_content( file_content_dictionary=file_content_dictionary, custom_llm_provider=custom_llm_provider, + model_info=model_info, ) verbose_logger.debug("total_cost=%s", total_cost) return total_cost @@ -290,10 +300,13 @@ def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]: def _get_batch_job_cost_from_file_content( file_content_dictionary: List[dict], custom_llm_provider: Literal["openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"] = "openai", + model_info: Optional[ModelInfo] = None, ) -> float: """ Get the cost of a batch job from the file content """ + from litellm.cost_calculator import batch_cost_calculator + try: total_cost: float = 0.0 # parse the file content as json @@ -303,11 +316,22 @@ def _get_batch_job_cost_from_file_content( for _item in file_content_dictionary: if _batch_response_was_successful(_item): _response_body = _get_response_from_batch_job_output_file(_item) - total_cost += litellm.completion_cost( - completion_response=_response_body, - custom_llm_provider=custom_llm_provider, - call_type=CallTypes.aretrieve_batch.value, - ) + if model_info is not None: + usage = _get_batch_job_usage_from_response_body(_response_body) + model = _response_body.get("model", "") + prompt_cost, completion_cost = batch_cost_calculator( + usage=usage, + model=model, + custom_llm_provider=custom_llm_provider, + model_info=model_info, + ) + total_cost += prompt_cost + completion_cost + else: + total_cost += litellm.completion_cost( + completion_response=_response_body, + custom_llm_provider=custom_llm_provider, + call_type=CallTypes.aretrieve_batch.value, + ) verbose_logger.debug("total_cost=%s", total_cost) return total_cost except Exception as e: diff --git a/litellm/constants.py b/litellm/constants.py index 03f80a8cb78..a4a0e7882ea 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -319,6 +319,9 @@ MAX_EXCEPTION_MESSAGE_LENGTH = int(os.getenv("MAX_EXCEPTION_MESSAGE_LENGTH", 2000)) MAX_STRING_LENGTH_PROMPT_IN_DB = int(os.getenv("MAX_STRING_LENGTH_PROMPT_IN_DB", 2048)) BEDROCK_MAX_POLICY_SIZE = int(os.getenv("BEDROCK_MAX_POLICY_SIZE", 75)) +BEDROCK_MIN_THINKING_BUDGET_TOKENS = int( + os.getenv("BEDROCK_MIN_THINKING_BUDGET_TOKENS", 1024) +) REPLICATE_POLLING_DELAY_SECONDS = float( os.getenv("REPLICATE_POLLING_DELAY_SECONDS", 0.5) ) @@ -1258,6 +1261,9 @@ LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS = int( os.getenv("LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS", 86400) ) # 24 hours default +LITELLM_KEY_ROTATION_GRACE_PERIOD: str = os.getenv( + "LITELLM_KEY_ROTATION_GRACE_PERIOD", "" +) # Duration to keep old key valid after rotation (e.g. "24h", "2d"); empty = immediate revoke (default) UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard" LITELLM_PROXY_ADMIN_NAME = "default_user_id" diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index fe082843306..dae0bb1c2c0 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1896,9 +1896,16 @@ def batch_cost_calculator( usage: Usage, model: str, custom_llm_provider: Optional[str] = None, + model_info: Optional[ModelInfo] = None, ) -> Tuple[float, float]: """ - Calculate the cost of a batch job + Calculate the cost of a batch job. + + Args: + model_info: Optional deployment-level model info containing custom + batch pricing (e.g. input_cost_per_token_batches). When provided, + skips the global litellm.get_model_info() lookup so that + deployment-specific pricing is used. """ _, custom_llm_provider, _, _ = litellm.get_llm_provider( @@ -1911,12 +1918,13 @@ def batch_cost_calculator( custom_llm_provider, ) - try: - model_info: Optional[ModelInfo] = litellm.get_model_info( - model=model, custom_llm_provider=custom_llm_provider - ) - except Exception: - model_info = None + if model_info is None: + try: + model_info = litellm.get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + except Exception: + model_info = None if not model_info: return 0.0, 0.0 diff --git a/litellm/integrations/s3_v2.py b/litellm/integrations/s3_v2.py index 534b85e4752..eddc80dbc1f 100644 --- a/litellm/integrations/s3_v2.py +++ b/litellm/integrations/s3_v2.py @@ -51,6 +51,7 @@ def __init__( s3_use_team_prefix: bool = False, s3_strip_base64_files: bool = False, s3_use_key_prefix: bool = False, + s3_use_virtual_hosted_style: bool = False, **kwargs, ): try: @@ -78,7 +79,8 @@ def __init__( s3_path=s3_path, s3_use_team_prefix=s3_use_team_prefix, s3_strip_base64_files=s3_strip_base64_files, - s3_use_key_prefix=s3_use_key_prefix + s3_use_key_prefix=s3_use_key_prefix, + s3_use_virtual_hosted_style=s3_use_virtual_hosted_style ) verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") @@ -135,6 +137,7 @@ def _init_s3_params( s3_use_team_prefix: bool = False, s3_strip_base64_files: bool = False, s3_use_key_prefix: bool = False, + s3_use_virtual_hosted_style: bool = False, ): """ Initialize the s3 params for this logging callback @@ -217,6 +220,11 @@ def _init_s3_params( or s3_strip_base64_files ) + self.s3_use_virtual_hosted_style = ( + bool(litellm.s3_callback_params.get("s3_use_virtual_hosted_style", False)) + or s3_use_virtual_hosted_style + ) + return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -247,8 +255,14 @@ async def _async_log_event_base(self, kwargs, response_obj, start_time, end_time standard_logging_payload=kwargs.get("standard_logging_object", None), ) + # afile_delete and other non-model call types never produce a standard_logging_object, + # so s3_batch_logging_element is None. Skip gracefully instead of raising ValueError. if s3_batch_logging_element is None: - raise ValueError("s3_batch_logging_element is None") + verbose_logger.debug( + "s3 Logging - skipping event, no standard_logging_object for call_type=%s", + kwargs.get("call_type", "unknown"), + ) + return verbose_logger.debug( "\ns3 Logger - Logging payload = %s", s3_batch_logging_element @@ -302,13 +316,20 @@ async def async_upload_data_to_s3( url = f"https://{self.s3_bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}" if self.s3_endpoint_url and self.s3_bucket_name: - url = ( - self.s3_endpoint_url - + "/" - + self.s3_bucket_name - + "/" - + batch_logging_element.s3_object_key - ) + if self.s3_use_virtual_hosted_style: + # Virtual-hosted-style: bucket.endpoint/key + endpoint_host = self.s3_endpoint_url.replace("https://", "").replace("http://", "") + protocol = "https://" if self.s3_endpoint_url.startswith("https://") else "http://" + url = f"{protocol}{self.s3_bucket_name}.{endpoint_host}/{batch_logging_element.s3_object_key}" + else: + # Path-style: endpoint/bucket/key + url = ( + self.s3_endpoint_url + + "/" + + self.s3_bucket_name + + "/" + + batch_logging_element.s3_object_key + ) # Convert JSON to string json_string = safe_dumps(batch_logging_element.payload) @@ -456,13 +477,20 @@ def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement): url = f"https://{self.s3_bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}" if self.s3_endpoint_url and self.s3_bucket_name: - url = ( - self.s3_endpoint_url - + "/" - + self.s3_bucket_name - + "/" - + batch_logging_element.s3_object_key - ) + if self.s3_use_virtual_hosted_style: + # Virtual-hosted-style: bucket.endpoint/key + endpoint_host = self.s3_endpoint_url.replace("https://", "").replace("http://", "") + protocol = "https://" if self.s3_endpoint_url.startswith("https://") else "http://" + url = f"{protocol}{self.s3_bucket_name}.{endpoint_host}/{batch_logging_element.s3_object_key}" + else: + # Path-style: endpoint/bucket/key + url = ( + self.s3_endpoint_url + + "/" + + self.s3_bucket_name + + "/" + + batch_logging_element.s3_object_key + ) # Convert JSON to string json_string = safe_dumps(batch_logging_element.payload) @@ -550,13 +578,20 @@ async def _download_object_from_s3(self, s3_object_key: str) -> Optional[dict]: url = f"https://{self.s3_bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{s3_object_key}" if self.s3_endpoint_url and self.s3_bucket_name: - url = ( - self.s3_endpoint_url - + "/" - + self.s3_bucket_name - + "/" - + s3_object_key - ) + if self.s3_use_virtual_hosted_style: + # Virtual-hosted-style: bucket.endpoint/key + endpoint_host = self.s3_endpoint_url.replace("https://", "").replace("http://", "") + protocol = "https://" if self.s3_endpoint_url.startswith("https://") else "http://" + url = f"{protocol}{self.s3_bucket_name}.{endpoint_host}/{s3_object_key}" + else: + # Path-style: endpoint/bucket/key + url = ( + self.s3_endpoint_url + + "/" + + self.s3_bucket_name + + "/" + + s3_object_key + ) # Prepare the request for GET operation # For GET requests, we need x-amz-content-sha256 with hash of empty string @@ -618,4 +653,4 @@ async def get_proxy_server_request_from_cold_storage_with_object_key( verbose_logger.exception( f"Error retrieving object {object_key} from cold storage: {str(e)}" ) - return None + return None \ No newline at end of file diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index c2cfff80685..85a4790a9b9 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1282,9 +1282,13 @@ def transform_request( output_config = optional_params.get("output_config") if output_config and isinstance(output_config, dict): effort = output_config.get("effort") - if effort and effort not in ["high", "medium", "low"]: + if effort and effort not in ["high", "medium", "low", "max"]: raise ValueError( - f"Invalid effort value: {effort}. Must be one of: 'high', 'medium', 'low'" + f"Invalid effort value: {effort}. Must be one of: 'high', 'medium', 'low', 'max'" + ) + if effort == "max" and not self._is_claude_opus_4_6(model): + raise ValueError( + f"effort='max' is only supported by Claude Opus 4.6. Got model: {model}" ) data["output_config"] = output_config diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index c6caaddf98b..73e74c228ba 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -19,6 +19,7 @@ AnthropicMessagesResponse, ) from litellm.types.utils import ModelResponse +from litellm.utils import get_model_info if TYPE_CHECKING: pass @@ -63,6 +64,14 @@ def _route_openai_thinking_to_responses_api_if_needed( return model = completion_kwargs.get("model") + try: + model_info = get_model_info(model=cast(str, model), custom_llm_provider=custom_llm_provider) + if model_info and model_info.get("supports_reasoning") is False: + # Model doesn't support reasoning/responses API, don't route + return + except Exception: + pass + if isinstance(model, str) and model and not model.startswith("responses/"): # Prefix model with "responses/" to route to OpenAI Responses API completion_kwargs["model"] = f"responses/{model}" diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py b/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py index a86820f82e8..de634ff9ecf 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py @@ -239,8 +239,13 @@ async def __anext__(self): # noqa: PLR0915 merged_chunk["delta"] = {} # Add usage to the held chunk + uncached_input_tokens = chunk.usage.prompt_tokens or 0 + if hasattr(chunk.usage, "prompt_tokens_details") and chunk.usage.prompt_tokens_details: + cached_tokens = getattr(chunk.usage.prompt_tokens_details, "cached_tokens", 0) or 0 + uncached_input_tokens -= cached_tokens + usage_dict: UsageDelta = { - "input_tokens": chunk.usage.prompt_tokens or 0, + "input_tokens": uncached_input_tokens, "output_tokens": chunk.usage.completion_tokens or 0, } # Add cache tokens if available (for prompt caching support) @@ -412,6 +417,7 @@ def _should_start_new_content_block(self, chunk: "ModelResponseStream") -> bool: if block_type == "tool_use": # Type narrowing: content_block_start is ToolUseBlock when block_type is "tool_use" from typing import cast + from litellm.types.llms.anthropic import ToolUseBlock tool_block = cast(ToolUseBlock, content_block_start) @@ -430,6 +436,7 @@ def _should_start_new_content_block(self, chunk: "ModelResponseStream") -> bool: # if we get a function name since it signals a new tool call if block_type == "tool_use": from typing import cast + from litellm.types.llms.anthropic import ToolUseBlock tool_block = cast(ToolUseBlock, content_block_start) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index 169b138a5f7..efbac13735c 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -1070,8 +1070,13 @@ def translate_openai_response_to_anthropic( ) # extract usage usage: Usage = getattr(response, "usage") + uncached_input_tokens = usage.prompt_tokens or 0 + if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details: + cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0 + uncached_input_tokens -= cached_tokens + anthropic_usage = AnthropicUsage( - input_tokens=usage.prompt_tokens or 0, + input_tokens=uncached_input_tokens, output_tokens=usage.completion_tokens or 0, ) # Add cache tokens if available (for prompt caching support) @@ -1230,8 +1235,13 @@ def translate_streaming_openai_response_to_anthropic( else: litellm_usage_chunk = None if litellm_usage_chunk is not None: + uncached_input_tokens = litellm_usage_chunk.prompt_tokens or 0 + if hasattr(litellm_usage_chunk, "prompt_tokens_details") and litellm_usage_chunk.prompt_tokens_details: + cached_tokens = getattr(litellm_usage_chunk.prompt_tokens_details, "cached_tokens", 0) or 0 + uncached_input_tokens -= cached_tokens + usage_delta = UsageDelta( - input_tokens=litellm_usage_chunk.prompt_tokens or 0, + input_tokens=uncached_input_tokens, output_tokens=litellm_usage_chunk.completion_tokens or 0, ) # Add cache tokens if available (for prompt caching support) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index efa755d515e..5faae07e2b9 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -11,7 +11,10 @@ import litellm from litellm._logging import verbose_logger -from litellm.constants import RESPONSE_FORMAT_TOOL_NAME +from litellm.constants import ( + BEDROCK_MIN_THINKING_BUDGET_TOKENS, + RESPONSE_FORMAT_TOOL_NAME, +) from litellm.litellm_core_utils.core_helpers import ( filter_exceptions_from_params, filter_internal_params, @@ -434,6 +437,25 @@ def _handle_reasoning_effort_parameter( reasoning_effort=reasoning_effort, model=model ) + @staticmethod + def _clamp_thinking_budget_tokens(optional_params: dict) -> None: + """ + Clamp thinking.budget_tokens to the Bedrock minimum (1024). + + Bedrock returns a 400 error if budget_tokens < 1024. + """ + thinking = optional_params.get("thinking") + if isinstance(thinking, dict): + budget = thinking.get("budget_tokens") + if isinstance(budget, int) and budget < BEDROCK_MIN_THINKING_BUDGET_TOKENS: + verbose_logger.debug( + "Bedrock requires thinking.budget_tokens >= %d, got %d. " + "Clamping to minimum.", + BEDROCK_MIN_THINKING_BUDGET_TOKENS, + budget, + ) + thinking["budget_tokens"] = BEDROCK_MIN_THINKING_BUDGET_TOKENS + def get_supported_openai_params(self, model: str) -> List[str]: from litellm.utils import supports_function_calling @@ -871,9 +893,14 @@ def update_optional_params_with_thinking_tokens( Checks 'non_default_params' for 'thinking' and 'max_tokens' if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS + + Also clamps thinking.budget_tokens to the Bedrock minimum (1024) to + prevent 400 errors from the Bedrock API. """ from litellm.constants import DEFAULT_MAX_TOKENS + self._clamp_thinking_budget_tokens(optional_params) + is_thinking_enabled = self.is_thinking_enabled(optional_params) is_max_tokens_in_request = self.is_max_tokens_in_request(non_default_params) if is_thinking_enabled and not is_max_tokens_in_request: diff --git a/litellm/llms/chatgpt/responses/transformation.py b/litellm/llms/chatgpt/responses/transformation.py index 0ce24f63a89..bcb6edd39f9 100644 --- a/litellm/llms/chatgpt/responses/transformation.py +++ b/litellm/llms/chatgpt/responses/transformation.py @@ -73,10 +73,6 @@ def transform_responses_api_request( litellm_params, headers, ) - request.pop("max_output_tokens", None) - request.pop("max_tokens", None) - request.pop("max_completion_tokens", None) - request.pop("metadata", None) base_instructions = get_chatgpt_default_instructions() existing_instructions = request.get("instructions") if existing_instructions: @@ -92,7 +88,22 @@ def transform_responses_api_request( if "reasoning.encrypted_content" not in include: include.append("reasoning.encrypted_content") request["include"] = include - return request + + allowed_keys = { + "model", + "input", + "instructions", + "stream", + "store", + "include", + "tools", + "tool_choice", + "reasoning", + "previous_response_id", + "truncation", + } + + return {k: v for k, v in request.items() if k in allowed_keys} def transform_response_api_response( self, diff --git a/litellm/llms/custom_httpx/aiohttp_transport.py b/litellm/llms/custom_httpx/aiohttp_transport.py index fb98006c7e4..6cec1f4fe16 100644 --- a/litellm/llms/custom_httpx/aiohttp_transport.py +++ b/litellm/llms/custom_httpx/aiohttp_transport.py @@ -119,8 +119,13 @@ async def aclose(self) -> None: class AiohttpTransport(httpx.AsyncBaseTransport): - def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> None: + def __init__( + self, + client: Union[ClientSession, Callable[[], ClientSession]], + owns_session: bool = True, + ) -> None: self.client = client + self._owns_session = owns_session ######################################################### # Class variables for proxy settings @@ -128,7 +133,7 @@ def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> self.proxy_cache: Dict[str, Optional[str]] = {} async def aclose(self) -> None: - if isinstance(self.client, ClientSession): + if self._owns_session and isinstance(self.client, ClientSession): await self.client.close() @@ -144,10 +149,11 @@ def __init__( self, client: Union[ClientSession, Callable[[], ClientSession]], ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None, + owns_session: bool = True, ): self.client = client self._ssl_verify = ssl_verify # Store for per-request SSL override - super().__init__(client=client) + super().__init__(client=client, owns_session=owns_session) # Store the client factory for recreating sessions when needed if callable(client): self._client_factory = client diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 5cf6efe5ba2..328097639e5 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -866,6 +866,7 @@ def _create_aiohttp_transport( return LiteLLMAiohttpTransport( client=shared_session, ssl_verify=ssl_for_transport, + owns_session=False, ) # Create new session only if none provided or existing one is invalid diff --git a/litellm/llms/openai_like/dynamic_config.py b/litellm/llms/openai_like/dynamic_config.py index 1e7866bebbe..a2ce6b9a531 100644 --- a/litellm/llms/openai_like/dynamic_config.py +++ b/litellm/llms/openai_like/dynamic_config.py @@ -4,6 +4,7 @@ from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload +from litellm._logging import verbose_logger from litellm.litellm_core_utils.prompt_templates.common_utils import ( handle_messages_with_content_list_to_str_conversion, ) @@ -96,8 +97,27 @@ def get_complete_url( return api_base def get_supported_openai_params(self, model: str) -> list: - """Get supported OpenAI params from base class""" - return super().get_supported_openai_params(model=model) + """Get supported OpenAI params, excluding tool-related params for models + that don't support function calling.""" + from litellm.utils import supports_function_calling + + supported_params = super().get_supported_openai_params(model=model) + + _supports_fc = supports_function_calling( + model=model, custom_llm_provider=provider.slug + ) + + if not _supports_fc: + tool_params = ["tools", "tool_choice", "function_call", "functions", "parallel_tool_calls"] + for param in tool_params: + if param in supported_params: + supported_params.remove(param) + verbose_logger.debug( + f"Model {model} on provider {provider.slug} does not support " + f"function calling — removed tool-related params from supported params." + ) + + return supported_params def map_openai_params( self, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 95d8ba2ff60..2b6d2800124 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -12456,6 +12456,19 @@ "supports_tool_choice": true, "supports_web_search": true }, + "fireworks_ai/accounts/fireworks/models/kimi-k2p5": { + "input_cost_per_token": 6e-07, + "litellm_provider": "fireworks_ai", + "max_input_tokens": 262144, + "max_output_tokens": 262144, + "max_tokens": 262144, + "mode": "chat", + "output_cost_per_token": 3e-06, + "source": "https://fireworks.ai/pricing", + "supports_function_calling": true, + "supports_response_schema": true, + "supports_tool_choice": true + }, "fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct": { "input_cost_per_token": 3e-06, "litellm_provider": "fireworks_ai", @@ -23759,7 +23772,7 @@ "max_output_tokens": 131072, "max_tokens": 131072, "mode": "chat", - "output_cost_per_token": 1.5e-07, + "output_cost_per_token": 1.5e-05, "source": "https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/pricing", "supports_function_calling": true, "supports_response_schema": false @@ -23807,7 +23820,7 @@ "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", - "output_cost_per_token": 1.5e-07, + "output_cost_per_token": 1.5e-05, "source": "https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/pricing", "supports_function_calling": true, "supports_response_schema": false diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index ba107a9dd10..31836a27509 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -149,7 +149,7 @@ class ListMCPToolsRestAPIResponseObject(MCPTool): app=server, event_store=None, json_response=False, # enables SSE streaming - stateless=False, # enables session state + stateless=True, ) # Create SSE session manager diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d549338972c..c327dd130a5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -854,9 +854,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -995,6 +995,9 @@ class RegenerateKeyRequest(GenerateKeyRequest): spend: Optional[float] = None metadata: Optional[dict] = None new_master_key: Optional[str] = None + grace_period: Optional[ + str + ] = None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke class ResetSpendRequest(LiteLLMPydanticObjectBase): @@ -1406,12 +1409,12 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -1433,12 +1436,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1527,15 +1530,15 @@ class NewTeamRequest(TeamBase): ] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm model_tpm_limit: Optional[Dict[str, int]] = None - team_member_budget: Optional[float] = ( - None # allow user to set a budget for all team members - ) - team_member_rpm_limit: Optional[int] = ( - None # allow user to set RPM limit for all team members - ) - team_member_tpm_limit: Optional[int] = ( - None # allow user to set TPM limit for all team members - ) + team_member_budget: Optional[ + float + ] = None # allow user to set a budget for all team members + team_member_rpm_limit: Optional[ + int + ] = None # allow user to set RPM limit for all team members + team_member_tpm_limit: Optional[ + int + ] = None # allow user to set TPM limit for all team members team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None @@ -1627,9 +1630,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1961,9 +1964,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class UserHeaderMapping(LiteLLMPydanticObjectBase): @@ -2403,9 +2406,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -3396,9 +3399,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -3616,9 +3619,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -3761,9 +3764,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 06800cb4524..143b2607feb 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -29,6 +29,7 @@ get_models_from_unified_file_id, get_original_file_id, prepare_data_with_credentials, + resolve_input_file_id_to_unified, update_batch_in_database, ) from litellm.proxy.utils import handle_exception_on_proxy, is_known_model @@ -305,7 +306,7 @@ async def create_batch( # noqa: PLR0915 dependencies=[Depends(user_api_key_auth)], tags=["batch"], ) -async def retrieve_batch( +async def retrieve_batch( # noqa: PLR0915 request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), @@ -377,6 +378,11 @@ async def retrieve_batch( response = await proxy_logging_obj.post_call_success_hook( data=data, user_api_key_dict=user_api_key_dict, response=response ) + + # async_post_call_success_hook replaces batch.id and output_file_id with unified IDs + # but not input_file_id. Resolve raw provider ID to unified ID. + if unified_batch_id: + await resolve_input_file_id_to_unified(response, prisma_client) asyncio.create_task( proxy_logging_obj.update_request_status( @@ -479,6 +485,11 @@ async def retrieve_batch( data=data, user_api_key_dict=user_api_key_dict, response=response ) + # Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id + # Resolve raw provider input_file_id to unified ID. + if unified_batch_id: + await resolve_input_file_id_to_unified(response, prisma_client) + ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( diff --git a/litellm/proxy/common_utils/key_rotation_manager.py b/litellm/proxy/common_utils/key_rotation_manager.py index 13bbf2272f7..5a0a1fabc7d 100644 --- a/litellm/proxy/common_utils/key_rotation_manager.py +++ b/litellm/proxy/common_utils/key_rotation_manager.py @@ -8,7 +8,10 @@ from typing import List from litellm._logging import verbose_proxy_logger -from litellm.constants import LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME +from litellm.constants import ( + LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME, + LITELLM_KEY_ROTATION_GRACE_PERIOD, +) from litellm.proxy._types import ( GenerateKeyResponse, LiteLLM_VerificationToken, @@ -37,6 +40,9 @@ async def process_rotations(self): try: verbose_proxy_logger.info("Starting scheduled key rotation check...") + # Clean up expired deprecated keys first + await self._cleanup_expired_deprecated_keys() + # Find keys that are due for rotation keys_to_rotate = await self._find_keys_needing_rotation() @@ -97,6 +103,24 @@ async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]: return keys_with_rotation + async def _cleanup_expired_deprecated_keys(self) -> None: + """ + Remove deprecated key entries whose revoke_at has passed. + """ + try: + now = datetime.now(timezone.utc) + result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many( + where={"revoke_at": {"lt": now}} + ) + if result > 0: + verbose_proxy_logger.debug( + "Cleaned up %s expired deprecated key(s)", result + ) + except Exception as e: + verbose_proxy_logger.debug( + "Deprecated key cleanup skipped (table may not exist): %s", e + ) + def _should_rotate_key(self, key: LiteLLM_VerificationToken, now: datetime) -> bool: """ Determine if a key should be rotated based on key_rotation_at timestamp. @@ -115,10 +139,11 @@ async def _rotate_key(self, key: LiteLLM_VerificationToken): """ Rotate a single key using existing regenerate_key_fn and call the rotation hook """ - # Create regenerate request + # Create regenerate request with grace period for seamless cutover regenerate_request = RegenerateKeyRequest( key=key.token or "", key_alias=key.key_alias, # Pass key alias to ensure correct secret is updated in AWS Secrets Manager + grace_period=LITELLM_KEY_ROTATION_GRACE_PERIOD or None, ) # Create a system user for key rotation diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index dc928921425..9675b82b145 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -1725,13 +1725,6 @@ async def add_spend_log_transaction_to_daily_agent_transaction( "prisma_client is None. Skipping writing spend logs to db." ) return - base_daily_transaction = ( - await self._common_add_spend_log_transaction_to_daily_transaction( - payload, prisma_client, "agent" - ) - ) - if base_daily_transaction is None: - return if payload["agent_id"] is None: verbose_proxy_logger.debug( "agent_id is None for request. Skipping incrementing agent spend." diff --git a/litellm/proxy/hooks/batch_rate_limiter.py b/litellm/proxy/hooks/batch_rate_limiter.py index 45b1bd8653f..5bebcc92072 100644 --- a/litellm/proxy/hooks/batch_rate_limiter.py +++ b/litellm/proxy/hooks/batch_rate_limiter.py @@ -259,9 +259,10 @@ async def count_input_file_usage( from litellm.proxy.openai_files_endpoints.common_utils import ( _is_base64_encoded_unified_file_id, ) + # Managed files require bypassing the HTTP endpoint (which runs access-check hooks) + # and calling the managed files hook directly with the user's credentials. is_managed_file = _is_base64_encoded_unified_file_id(file_id) if is_managed_file and user_api_key_dict is not None: - # For managed files, use the managed files hook directly file_content = await self._fetch_managed_file_content( file_id=file_id, user_api_key_dict=user_api_key_dict, diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 37b79e6d065..d903ce0d9d7 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -202,6 +202,14 @@ async def _PROXY_track_cost_callback( max_budget=end_user_max_budget, ) else: + # Non-model call types (health checks, afile_delete) have no model or standard_logging_object. + # Use .get() for "stream" to avoid KeyError on health checks. + if sl_object is None and not kwargs.get("model"): + verbose_proxy_logger.warning( + "Cost tracking - skipping, no standard_logging_object and no model for call_type=%s", + kwargs.get("call_type", "unknown"), + ) + return if kwargs.get("stream") is not True or ( kwargs.get("stream") is True and "complete_streaming_response" in kwargs ): diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 17cabb69cb2..2b5c17e4745 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -12,12 +12,14 @@ import asyncio import copy import json +import os import secrets import traceback from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Literal, Optional, Tuple, cast import fastapi +import prisma import yaml from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status @@ -629,7 +631,11 @@ async def _common_key_generation_helper( # noqa: PLR0915 # Validate user-provided key format if data.key is not None and not data.key.startswith("sk-"): - _masked = "{}****{}".format(data.key[:4], data.key[-4:]) if len(data.key) > 8 else "****" + _masked = ( + "{}****{}".format(data.key[:4], data.key[-4:]) + if len(data.key) > 8 + else "****" + ) raise HTTPException( status_code=400, detail={ @@ -1343,6 +1349,7 @@ async def prepare_key_update_data( data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) data_json.pop("new_key", None) + data_json.pop("grace_period", None) # Request-only param, not a DB column if ( data.metadata is not None and data.metadata.get("service_account_id") is not None @@ -3087,13 +3094,17 @@ async def _rotate_master_key( should_create_model_in_db=False, ) if new_model: - new_models.append(jsonify_object(new_model.model_dump())) + _dumped = new_model.model_dump(exclude_none=True) + _dumped["litellm_params"] = prisma.Json(_dumped["litellm_params"]) + _dumped["model_info"] = prisma.Json(_dumped["model_info"]) + new_models.append(_dumped) verbose_proxy_logger.debug("Resetting proxy model table") - await prisma_client.db.litellm_proxymodeltable.delete_many() - verbose_proxy_logger.debug("Creating %s models", len(new_models)) - await prisma_client.db.litellm_proxymodeltable.create_many( - data=new_models, - ) + async with prisma_client.db.tx() as tx: + await tx.litellm_proxymodeltable.delete_many() + verbose_proxy_logger.debug("Creating %s models", len(new_models)) + await tx.litellm_proxymodeltable.create_many( + data=new_models, + ) # 3. process config table try: config = await prisma_client.db.litellm_config.find_many() @@ -3119,15 +3130,20 @@ async def _rotate_master_key( if encrypted_env_vars: await prisma_client.db.litellm_config.update( where={"param_name": "environment_variables"}, - data={"param_value": jsonify_object(encrypted_env_vars)}, + data={"param_value": prisma.Json(encrypted_env_vars)}, ) # 4. process MCP server table - await rotate_mcp_server_credentials_master_key( - prisma_client=prisma_client, - touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, - new_master_key=new_master_key, - ) + try: + await rotate_mcp_server_credentials_master_key( + prisma_client=prisma_client, + touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, + new_master_key=new_master_key, + ) + except Exception as e: + verbose_proxy_logger.warning( + "Failed to rotate MCP server credentials: %s", str(e) + ) # 5. process credentials table try: @@ -3145,13 +3161,19 @@ async def _rotate_master_key( updated_patch=decrypted_cred, new_encryption_key=new_master_key, ) - credential_object_jsonified = jsonify_object( - encrypted_cred.model_dump() - ) + _cred_data = encrypted_cred.model_dump(exclude_none=True) + if "credential_values" in _cred_data: + _cred_data["credential_values"] = prisma.Json( + _cred_data["credential_values"] + ) + if "credential_info" in _cred_data: + _cred_data["credential_info"] = prisma.Json( + _cred_data["credential_info"] + ) await prisma_client.db.litellm_credentialstable.update( where={"credential_name": cred.credential_name}, data={ - **credential_object_jsonified, + **_cred_data, "updated_by": user_api_key_dict.user_id, }, ) @@ -3181,6 +3203,67 @@ def get_new_token(data: Optional[RegenerateKeyRequest]) -> str: return new_token +async def _insert_deprecated_key( + prisma_client: "PrismaClient", + old_token_hash: str, + new_token_hash: str, + grace_period: Optional[str], +) -> None: + """ + Insert old key into deprecated table so it remains valid during grace period. + + Uses upsert to handle concurrent rotations gracefully. + + Parameters: + prisma_client: DB client + old_token_hash: Hash of the old key being rotated out + new_token_hash: Hash of the new replacement key + grace_period: Duration string (e.g. "24h", "2d") or None/empty for immediate revoke + """ + grace_period_value = grace_period or os.getenv( + "LITELLM_KEY_ROTATION_GRACE_PERIOD", "" + ) + if not grace_period_value: + return + + try: + grace_seconds = duration_in_seconds(grace_period_value) + except ValueError: + verbose_proxy_logger.warning( + "Invalid grace_period format: %s. Expected format like '24h', '2d'.", + grace_period_value, + ) + return + + if grace_seconds <= 0: + return + + try: + revoke_at = datetime.now(timezone.utc) + timedelta(seconds=grace_seconds) + await prisma_client.db.litellm_deprecatedverificationtoken.upsert( + where={"token": old_token_hash}, + data={ + "create": { + "token": old_token_hash, + "active_token_id": new_token_hash, + "revoke_at": revoke_at, + }, + "update": { + "active_token_id": new_token_hash, + "revoke_at": revoke_at, + }, + }, + ) + verbose_proxy_logger.debug( + "Deprecated key retained for %s (revoke_at: %s)", + grace_period_value, + revoke_at, + ) + except Exception as deprecated_err: + verbose_proxy_logger.warning( + "Failed to insert deprecated key for grace period: %s", + deprecated_err, + ) async def _execute_virtual_key_regeneration( *, prisma_client: PrismaClient, @@ -3288,6 +3371,7 @@ async def regenerate_key_fn( # noqa: PLR0915 - permissions: Optional[dict] - Key-specific permissions - guardrails: Optional[List[str]] - List of active guardrails for the key - blocked: Optional[bool] - Whether the key is blocked + - grace_period: Optional[str] - Duration to keep old key valid after rotation (e.g. "24h", "2d"). Omitted = immediate revoke. Env: LITELLM_KEY_ROTATION_GRACE_PERIOD Returns: @@ -3406,6 +3490,58 @@ async def regenerate_key_fn( # noqa: PLR0915 ) verbose_proxy_logger.debug("key_in_db: %s", _key_in_db) + new_token = get_new_token(data=data) + + new_token_hash = hash_token(new_token) + new_token_key_name = f"sk-...{new_token[-4:]}" + + # Prepare the update data + update_data = { + "token": new_token_hash, + "key_name": new_token_key_name, + } + + non_default_values = {} + if data is not None: + # Update with any provided parameters from GenerateKeyRequest + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=_key_in_db + ) + verbose_proxy_logger.debug("non_default_values: %s", non_default_values) + + update_data.update(non_default_values) + update_data = prisma_client.jsonify_object(data=update_data) + + # If grace period set, insert deprecated key so old key remains valid + await _insert_deprecated_key( + prisma_client=prisma_client, + old_token_hash=hashed_api_key, + new_token_hash=new_token_hash, + grace_period=data.grace_period if data else None, + ) + + # Update the token in the database + updated_token = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_api_key}, + data=update_data, # type: ignore + ) + + updated_token_dict = {} + if updated_token is not None: + updated_token_dict = dict(updated_token) + + updated_token_dict["key"] = new_token + updated_token_dict["token_id"] = updated_token_dict.pop("token") + + ### 3. remove existing key entry from cache + ###################################################################### + + if hashed_api_key or key: + await _delete_cache_key_object( + hashed_token=hash_token(key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) # Normalize litellm_changed_by: if it's a Header object or not a string, convert to None if litellm_changed_by is not None and not isinstance(litellm_changed_by, str): litellm_changed_by = None diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 7274b389a92..a57dafd1f00 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -14,7 +14,7 @@ import os import secrets from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse @@ -82,7 +82,15 @@ get_server_root_path, ) from litellm.secret_managers.main import get_secret_bool, str_to_bool -from litellm.types.proxy.management_endpoints.ui_sso import * +from litellm.types.proxy.management_endpoints.ui_sso import ( + DefaultTeamSSOParams, + MicrosoftGraphAPIUserGroupDirectoryObject, + MicrosoftGraphAPIUserGroupResponse, + MicrosoftServicePrincipalTeam, + RoleMappings, + TeamMappings, +) +from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401 from litellm.types.proxy.ui_sso import ParsedOpenIDResult if TYPE_CHECKING: @@ -96,15 +104,15 @@ def normalize_email(email: Optional[str]) -> Optional[str]: """ Normalize email address to lowercase for consistent storage and comparison. - + Email addresses should be treated as case-insensitive for SSO purposes, even though RFC 5321 technically allows case-sensitive local parts. This prevents issues where SSO providers return emails with different casing than what's stored in the database. - + Args: email: Email address to normalize, can be None - + Returns: Lowercased email address, or None if input is None """ @@ -336,7 +344,7 @@ async def google_login( # check if user defined a custom auth sso sign in handler, if yes, use it if user_custom_ui_sso_sign_in_handler is not None: try: - from litellm_enterprise.proxy.auth.custom_sso_handler import ( + from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped] EnterpriseCustomSSOHandler, ) @@ -494,7 +502,9 @@ def generic_response_convertor( display_name=get_nested_value( response, generic_user_display_name_attribute_name ), - email=normalize_email(get_nested_value(response, generic_user_email_attribute_name)), + email=normalize_email( + get_nested_value(response, generic_user_email_attribute_name) + ), first_name=get_nested_value(response, generic_user_first_name_attribute_name), last_name=get_nested_value(response, generic_user_last_name_attribute_name), provider=get_nested_value(response, generic_provider_attribute_name), @@ -584,6 +594,7 @@ async def _setup_team_mappings() -> Optional["TeamMappings"]: if team_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings + if isinstance(team_mappings_data, dict): team_mappings = TeamMappings(**team_mappings_data) elif isinstance(team_mappings_data, TeamMappings): @@ -621,6 +632,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: if role_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings + if isinstance(role_mappings_data, dict): role_mappings = RoleMappings(**role_mappings_data) elif isinstance(role_mappings_data, RoleMappings): @@ -634,7 +646,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: verbose_proxy_logger.debug( f"Could not load role_mappings from database: {e}. Continuing with existing role logic." ) - + generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None) generic_role_mappings_group_claim = os.getenv( "GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None @@ -644,8 +656,8 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: ) if generic_role_mappings is not None: verbose_proxy_logger.debug( - "Found role_mappings for generic provider in environment variables" - ) + "Found role_mappings for generic provider in environment variables" + ) import ast try: @@ -670,7 +682,9 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: ) return role_mappings except TypeError as e: - verbose_proxy_logger.warning(f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic.") + verbose_proxy_logger.warning( + f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic." + ) return role_mappings @@ -747,7 +761,7 @@ def response_convertor(response, client): try: result = await generic_sso.verify_and_process( request, - params=SSOAuthenticationHandler.prepare_token_exchange_parameters( + params=await SSOAuthenticationHandler.prepare_token_exchange_parameters( request=request, generic_include_client_id=generic_include_client_id, ), @@ -942,7 +956,7 @@ def _build_sso_user_update_data( Returns: dict: Update data containing user_email and optionally user_role if valid - """ + """ update_data: dict = {"user_email": normalize_email(user_email)} # Get SSO role from result and include if valid @@ -1740,7 +1754,7 @@ async def get_generic_sso_redirect_response( """ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse - from litellm.proxy.proxy_server import user_api_key_cache + from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache with generic_sso: # TODO: state should be a random string and added to the user session with cookie @@ -1769,13 +1783,21 @@ async def get_generic_sso_redirect_response( # If PKCE is enabled, add PKCE parameters to the redirect URL if code_verifier and "state" in redirect_params: - # Store code_verifier in cache (10 min TTL) + # Store code_verifier in cache (10 min TTL). Use Redis when available + # so callbacks landing on another pod can retrieve it (multi-pod SSO). cache_key = f"pkce_verifier:{redirect_params['state']}" - user_api_key_cache.set_cache( - key=cache_key, - value=code_verifier, - ttl=600, - ) + if redis_usage_cache is not None: + await redis_usage_cache.async_set_cache( + key=cache_key, + value=code_verifier, + ttl=600, + ) + else: + await user_api_key_cache.async_set_cache( + key=cache_key, + value=code_verifier, + ttl=600, + ) # Add PKCE parameters to the authorization URL if pkce_params: @@ -2372,7 +2394,7 @@ async def get_redirect_response_from_openid( # noqa: PLR0915 return redirect_response @staticmethod - def prepare_token_exchange_parameters( + async def prepare_token_exchange_parameters( request: Request, generic_include_client_id: bool, ) -> dict: @@ -2386,27 +2408,38 @@ def prepare_token_exchange_parameters( Returns: dict: Token exchange parameters """ - # Prepare token exchange parameters - token_params = {"include_client_id": generic_include_client_id} + # Prepare token exchange parameters (may add code_verifier: str later) + token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id} - # Retrieve PKCE code_verifier if PKCE was used in authorization + # Retrieve PKCE code_verifier if PKCE was used in authorization. + # Use same cache as store: Redis when available (multi-pod), else in-memory. query_params = dict(request.query_params) state = query_params.get("state") if state: - from litellm.proxy.proxy_server import user_api_key_cache + from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache cache_key = f"pkce_verifier:{state}" - code_verifier = user_api_key_cache.get_cache(key=cache_key) + if redis_usage_cache is not None: + code_verifier = await redis_usage_cache.async_get_cache(key=cache_key) + else: + code_verifier = await user_api_key_cache.async_get_cache(key=cache_key) if code_verifier: - # Add code_verifier to token exchange parameters - token_params["code_verifier"] = code_verifier + # Add code_verifier to token exchange parameters (Redis returns decoded string) + token_params["code_verifier"] = ( + code_verifier + if isinstance(code_verifier, str) + else str(code_verifier) + ) verbose_proxy_logger.debug( "PKCE code_verifier retrieved and will be included in token exchange" ) # Clean up the cache entry (single-use verifier) - user_api_key_cache.delete_cache(key=cache_key) + if redis_usage_cache is not None: + await redis_usage_cache.async_delete_cache(key=cache_key) + else: + await user_api_key_cache.async_delete_cache(key=cache_key) return token_params @staticmethod @@ -2549,7 +2582,9 @@ def openid_from_response( response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") openid_response = CustomOpenID( - email=normalize_email(response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail")), + email=normalize_email( + response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail") + ), display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE), provider="microsoft", id=response.get(MICROSOFT_USER_ID_ATTRIBUTE), diff --git a/litellm/proxy/openai_files_endpoints/common_utils.py b/litellm/proxy/openai_files_endpoints/common_utils.py index f67dc5e2aaa..75f64cddf59 100644 --- a/litellm/proxy/openai_files_endpoints/common_utils.py +++ b/litellm/proxy/openai_files_endpoints/common_utils.py @@ -644,6 +644,28 @@ def _extract_model_param(request: "Request", request_body: dict) -> Optional[str # ============================================================================ +async def resolve_input_file_id_to_unified(response, prisma_client) -> None: + """ + If the batch response contains a raw provider input_file_id (not already a + unified ID), look up the corresponding unified file ID from the managed file + table and replace it in-place. + """ + if ( + hasattr(response, "input_file_id") + and response.input_file_id + and not _is_base64_encoded_unified_file_id(response.input_file_id) + and prisma_client + ): + try: + managed_file = await prisma_client.db.litellm_managedfiletable.find_first( + where={"flat_model_file_ids": {"has": response.input_file_id}} + ) + if managed_file: + response.input_file_id = managed_file.unified_file_id + except Exception: + pass + + async def get_batch_from_database( batch_id: str, unified_batch_id: Union[str, Literal[False]], @@ -687,6 +709,9 @@ async def get_batch_from_database( batch_data = json.loads(db_batch_object.file_object) if isinstance(db_batch_object.file_object, str) else db_batch_object.file_object response = LiteLLMBatch(**batch_data) response.id = batch_id + + # The stored batch object has the raw provider input_file_id. Resolve to unified ID. + await resolve_input_file_id_to_unified(response, prisma_client) verbose_proxy_logger.debug( f"Retrieved batch {batch_id} from ManagedObjectTable with status={response.status}" diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 2509a80b140..e91447af895 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -37,11 +37,16 @@ class LiteLLMDatabaseConnectionPool(Enum): database_connection_pool_timeout = 60 -def append_query_params(url, params) -> str: +def append_query_params(url: Optional[str], params: dict) -> str: from litellm._logging import verbose_proxy_logger verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"params: {params}") + if not isinstance(url, str) or url == "": + # Preserve previous startup behavior when DATABASE_URL is absent. + # Returning an empty string avoids urlparse type errors in test/dev flows. + verbose_proxy_logger.warning("append_query_params received empty or non-string URL, returning empty string") + return "" parsed_url = urlparse.urlparse(url) parsed_query = urlparse.parse_qs(parsed_url.query) parsed_query.update(params) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index c2fca8705cb..441c2cdf70d 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -325,6 +325,19 @@ model LiteLLM_VerificationToken { @@index([budget_reset_at, expires]) } +// Deprecated keys during grace period - allows old key to work until revoke_at +model LiteLLM_DeprecatedVerificationToken { + id String @id @default(uuid()) + token String // Hashed old key + active_token_id String // Current token hash in LiteLLM_VerificationToken + revoke_at DateTime // When the old key stops working + created_at DateTime @default(now()) @map("created_at") + + @@unique([token]) + @@index([token, revoke_at]) + @@index([revoke_at]) +} + // Audit table for deleted keys - preserves spend and key information for historical tracking model LiteLLM_DeletedVerificationToken { id String @id @default(uuid()) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6c8db8d7d99..1c2a6d378fc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -7,7 +7,7 @@ import threading import time import traceback -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import ( @@ -76,6 +76,7 @@ from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache +from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.exceptions import RejectedRequestError from litellm.integrations.custom_guardrail import ( CustomGuardrail, @@ -2154,6 +2155,58 @@ def jsonify_object(data: dict) -> dict: return db_data +# In-memory cache for deprecated key lookups: maps old_token_hash -> (active_token_id, expires_at_ts) +# Avoids a DB query on every auth request for non-deprecated keys. +# Bounded to prevent memory leaks from accumulated rotations. +_deprecated_key_cache: LimitedSizeOrderedDict = LimitedSizeOrderedDict(max_size=1000) +_DEPRECATED_KEY_CACHE_TTL_SECONDS = 60 + + +async def _lookup_deprecated_key( + db: Any, + hashed_token: str, +) -> Optional[str]: + """ + Check if a token exists in the deprecated keys table and is still within its grace period. + + Returns the active_token_id if found and valid, otherwise None. + Uses an in-memory cache to avoid DB queries on every auth request. + """ + now = datetime.now(timezone.utc) + now_ts = now.timestamp() + + # Check cache first + cached = _deprecated_key_cache.get(hashed_token) + cached = _deprecated_key_cache.get(hashed_token) + if cached is not None: + active_token_id, cache_expires_at_ts, revoke_at_ts = cached + if now_ts < cache_expires_at_ts and now_ts < revoke_at_ts: + return active_token_id + else: + _deprecated_key_cache.pop(hashed_token, None) + + try: + deprecated_row = await db.litellm_deprecatedverificationtoken.find_first( + where={ + "token": hashed_token, + "revoke_at": {"gt": now}, + }, + select={"active_token_id": True}, + ) + if deprecated_row and deprecated_row.active_token_id: + _deprecated_key_cache[hashed_token] = ( + deprecated_row.active_token_id, + now_ts + _DEPRECATED_KEY_CACHE_TTL_SECONDS, + ) + return deprecated_row.active_token_id + # Only cache positive results; negative lookups are fast on indexed columns + # and caching them risks evicting real deprecated key entries. + except Exception as e: + verbose_proxy_logger.debug("Deprecated key lookup skipped: %s", e) + + return None + + class PrismaClient: spend_log_transactions: List = [] _spend_log_transactions_lock = asyncio.Lock() @@ -2489,6 +2542,7 @@ async def get_data( # noqa: PLR0915 parent_otel_span: Optional[Span] = None, proxy_logging_obj: Optional[ProxyLogging] = None, budget_id_list: Optional[List[str]] = None, + check_deprecated: bool = True, ): args_passed_in = locals() start_time = time.time() @@ -2786,6 +2840,30 @@ async def get_data( # noqa: PLR0915 sql_query ) + # If not found in main table, check deprecated keys (grace period) + # check_deprecated=False on the recursive call prevents unbounded chaining + if ( + response is None + and hashed_token is not None + and check_deprecated + ): + active_token_id = await _lookup_deprecated_key( + db=self.db, hashed_token=hashed_token + ) + if active_token_id: + response = await self.get_data( + token=active_token_id, + table_name="combined_view", + query_type="find_unique", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_deprecated=False, + ) + if response is not None: + verbose_proxy_logger.debug( + "Deprecated key used during grace period" + ) + if response is not None: if response["team_models"] is None: response["team_models"] = [] diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 95d8ba2ff60..2b6d2800124 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -12456,6 +12456,19 @@ "supports_tool_choice": true, "supports_web_search": true }, + "fireworks_ai/accounts/fireworks/models/kimi-k2p5": { + "input_cost_per_token": 6e-07, + "litellm_provider": "fireworks_ai", + "max_input_tokens": 262144, + "max_output_tokens": 262144, + "max_tokens": 262144, + "mode": "chat", + "output_cost_per_token": 3e-06, + "source": "https://fireworks.ai/pricing", + "supports_function_calling": true, + "supports_response_schema": true, + "supports_tool_choice": true + }, "fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct": { "input_cost_per_token": 3e-06, "litellm_provider": "fireworks_ai", @@ -23759,7 +23772,7 @@ "max_output_tokens": 131072, "max_tokens": 131072, "mode": "chat", - "output_cost_per_token": 1.5e-07, + "output_cost_per_token": 1.5e-05, "source": "https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/pricing", "supports_function_calling": true, "supports_response_schema": false @@ -23807,7 +23820,7 @@ "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", - "output_cost_per_token": 1.5e-07, + "output_cost_per_token": 1.5e-05, "source": "https://www.oracle.com/artificial-intelligence/generative-ai/generative-ai-service/pricing", "supports_function_calling": true, "supports_response_schema": false diff --git a/poetry.lock b/poetry.lock index 4d44b36aa26..d04e20eb0fe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. [[package]] name = "a2a-sdk" @@ -7,11 +7,11 @@ description = "A2A Python SDK" optional = false python-versions = ">=3.10" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.10\"" files = [ {file = "a2a_sdk-0.3.22-py3-none-any.whl", hash = "sha256:b98701135bb90b0ff85d35f31533b6b7a299bf810658c1c65f3814a6c15ea385"}, {file = "a2a_sdk-0.3.22.tar.gz", hash = "sha256:77a5694bfc4f26679c11b70c7f1062522206d430b34bc1215cfbb1eba67b7e7d"}, ] +markers = {main = "python_version >= \"3.10\" and extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] google-api-core = ">=1.26.0" @@ -385,6 +385,7 @@ files = [ {file = "azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b"}, {file = "azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] requests = ">=2.21.0" @@ -405,6 +406,7 @@ files = [ {file = "azure_identity-1.25.1-py3-none-any.whl", hash = "sha256:e9edd720af03dff020223cd269fa3a61e8f345ea75443858273bcb44844ab651"}, {file = "azure_identity-1.25.1.tar.gz", hash = "sha256:87ca8328883de6036443e1c37b40e8dc8fb74898240f61071e09d2e369361456"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] azure-core = ">=1.31.0" @@ -598,7 +600,7 @@ files = [ {file = "cachetools-6.2.2-py3-none-any.whl", hash = "sha256:6c09c98183bf58560c97b2abfcedcbaf6a896a490f534b031b661d3723b45ace"}, {file = "cachetools-6.2.2.tar.gz", hash = "sha256:8e6d266b25e539df852251cfd6f990b4bc3a141db73b939058d809ebd2590fc6"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "certifi" @@ -705,7 +707,7 @@ files = [ {file = "cffi-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9"}, {file = "cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529"}, ] -markers = {main = "platform_python_implementation != \"PyPy\" or extra == \"proxy\"", dev = "platform_python_implementation != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\""} +markers = {main = "(platform_python_implementation != \"PyPy\" or extra == \"proxy\") and (python_version >= \"3.10\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\")", dev = "platform_python_implementation != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\""} [package.dependencies] pycparser = {version = "*", markers = "implementation_name != \"PyPy\""} @@ -1055,6 +1057,7 @@ files = [ {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, ] +markers = {main = "python_version >= \"3.10\" and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\") or extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} @@ -1822,11 +1825,11 @@ description = "Google API client core library" optional = false python-versions = ">=3.7" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.14\"" files = [ {file = "google_api_core-2.25.2-py3-none-any.whl", hash = "sha256:e9a8f62d363dc8424a8497f4c2a47d6bcda6c16514c935629c257ab5d10210e7"}, {file = "google_api_core-2.25.2.tar.gz", hash = "sha256:1c63aa6af0d0d5e37966f157a77f9396d820fba59f9e43e9415bc3dc5baff300"}, ] +markers = {main = "python_version >= \"3.14\" and (extra == \"extra-proxy\" or extra == \"google\")", proxy-dev = "python_version >= \"3.14\""} [package.dependencies] google-auth = ">=2.14.1,<3.0.0" @@ -1854,7 +1857,7 @@ files = [ {file = "google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c"}, {file = "google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8"}, ] -markers = {main = "(python_version >= \"3.10\" or extra == \"google\" or extra == \"extra-proxy\") and python_version < \"3.14\"", proxy-dev = "python_version >= \"3.10\" and python_version < \"3.14\""} +markers = {main = "python_version < \"3.14\" and (extra == \"extra-proxy\" or extra == \"google\")", proxy-dev = "python_version >= \"3.10\" and python_version < \"3.14\""} [package.dependencies] google-auth = ">=2.14.1,<3.0.0" @@ -1891,7 +1894,7 @@ files = [ {file = "google_auth-2.43.0-py2.py3-none-any.whl", hash = "sha256:af628ba6fa493f75c7e9dbe9373d148ca9f4399b5ea29976519e0a3848eddd16"}, {file = "google_auth-2.43.0.tar.gz", hash = "sha256:88228eee5fc21b62a1b5fe773ca15e67778cb07dc8363adcb4a8827b52d81483"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] cachetools = ">=2.0.0,<7.0" @@ -2063,11 +2066,11 @@ files = [ ] [package.dependencies] -google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} -google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" -grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" -proto-plus = ">=1.22.3,<2.0.0dev" -protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0.dev0", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +grpc-google-iam-v1 = ">=0.12.4,<1.0.0.dev0" +proto-plus = ">=1.22.3,<2.0.0.dev0" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" [[package]] name = "google-cloud-resource-manager" @@ -2249,7 +2252,7 @@ files = [ {file = "googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038"}, {file = "googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\") or extra == \"google\" or extra == \"extra-proxy\""} [package.dependencies] grpcio = {version = ">=1.44.0,<2.0.0", optional = true, markers = "extra == \"grpc\""} @@ -2658,11 +2661,11 @@ description = "Consume Server-Sent Event (SSE) messages with HTTPX." optional = false python-versions = ">=3.9" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.10\"" files = [ {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, ] +markers = {main = "python_version >= \"3.10\" and (extra == \"proxy\" or extra == \"extra-proxy\")", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "huey" @@ -3027,7 +3030,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.03.6" +jsonschema-specifications = ">=2023.3.6" referencing = ">=0.28.4" rpds-py = ">=0.7.1" @@ -3500,6 +3503,38 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mirakuru" +version = "2.6.1" +description = "Process executor (not only) for tests." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version == \"3.9\"" +files = [ + {file = "mirakuru-2.6.1-py3-none-any.whl", hash = "sha256:4be0bfd270744454fa0c0466b8127b66bd55f4decaf05bbee9b071f2acbd9473"}, + {file = "mirakuru-2.6.1.tar.gz", hash = "sha256:95d4f5a5ad406a625e9ca418f20f8e09386a35dad1ea30fd9073e0ae93f712c7"}, +] + +[package.dependencies] +psutil = {version = ">=4.0.0", markers = "sys_platform != \"cygwin\""} + +[[package]] +name = "mirakuru" +version = "3.0.2" +description = "Process executor (not only) for tests." +optional = false +python-versions = ">=3.10" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "mirakuru-3.0.2-py3-none-any.whl", hash = "sha256:10e5dac4a8f26872c63e9cdfdc01b775aaa2beb3ced98abc497279d2dc525b8f"}, + {file = "mirakuru-3.0.2.tar.gz", hash = "sha256:21192186a8680ea7567ca68170261df3785768b12962dd19fe8cccab15ad3441"}, +] + +[package.dependencies] +psutil = {version = ">=4.0.0", markers = "sys_platform != \"cygwin\""} + [[package]] name = "ml-dtypes" version = "0.4.1" @@ -3666,6 +3701,7 @@ files = [ {file = "msal-1.34.0-py3-none-any.whl", hash = "sha256:f669b1644e4950115da7a176441b0e13ec2975c29528d8b9e81316023676d6e1"}, {file = "msal-1.34.0.tar.gz", hash = "sha256:76ba83b716ea5a6d75b0279c0ac353a0e05b820ca1f6682c0eb7f45190c43c2f"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] cryptography = ">=2.5,<49" @@ -3686,6 +3722,7 @@ files = [ {file = "msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca"}, {file = "msal_extensions-1.3.1.tar.gz", hash = "sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] msal = ">=1.29,<2" @@ -3936,6 +3973,7 @@ files = [ {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] +markers = {main = "extra == \"extra-proxy\""} [[package]] name = "numpy" @@ -4058,7 +4096,7 @@ files = [ {file = "opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950"}, {file = "opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] importlib-metadata = ">=6.0,<8.8.0" @@ -4173,7 +4211,7 @@ files = [ {file = "opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c"}, {file = "opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] opentelemetry-api = "1.39.1" @@ -4191,7 +4229,7 @@ files = [ {file = "opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb"}, {file = "opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] opentelemetry-api = "1.39.1" @@ -4626,6 +4664,32 @@ files = [ {file = "polars_runtime_32-1.35.2.tar.gz", hash = "sha256:6e6e35733ec52abe54b7d30d245e6586b027d433315d20edfb4a5d162c79fe90"}, ] +[[package]] +name = "port-for" +version = "0.7.4" +description = "Utility that helps with local TCP ports management. It can find an unused TCP localhost port and remember the association." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version == \"3.9\"" +files = [ + {file = "port_for-0.7.4-py3-none-any.whl", hash = "sha256:08404aa072651a53dcefe8d7a598ee8a1dca320d9ac44ac464da16ccf2a02c4a"}, + {file = "port_for-0.7.4.tar.gz", hash = "sha256:fc7713e7b22f89442f335ce12536653656e8f35146739eccaeff43d28436028d"}, +] + +[[package]] +name = "port-for" +version = "1.0.0" +description = "Utility that helps with local TCP ports management. It can find an unused TCP localhost port and remember the association." +optional = false +python-versions = ">=3.10" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "port_for-1.0.0-py3-none-any.whl", hash = "sha256:35a848b98cf4cc075fe80dc49ae5c3a78e3ca345a23bd39bf5252277b4eef5c2"}, + {file = "port_for-1.0.0.tar.gz", hash = "sha256:404d161b1b2c82e2f6b31d8646396b4847d02bf5ee10068c92b7263657a14582"}, +] + [[package]] name = "priority" version = "2.0.0" @@ -4649,6 +4713,7 @@ files = [ {file = "prisma-0.11.0-py3-none-any.whl", hash = "sha256:22bb869e59a2968b99f3483bb417717273ffbc569fd1e9ceed95e5614cbaf53a"}, {file = "prisma-0.11.0.tar.gz", hash = "sha256:3f2f2fd2361e1ec5ff655f2a04c7860c2f2a5bc4c91f78ca9c5c6349735bf693"}, ] +markers = {main = "extra == \"extra-proxy\""} [package.dependencies] click = ">=7.1.2" @@ -4822,7 +4887,7 @@ files = [ {file = "proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66"}, {file = "proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] protobuf = ">=3.19.0,<7.0.0" @@ -4850,7 +4915,93 @@ files = [ {file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"}, {file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\""} + +[[package]] +name = "psutil" +version = "7.2.2" +description = "Cross-platform lib for process and system monitoring." +optional = false +python-versions = ">=3.6" +groups = ["dev"] +markers = "sys_platform != \"cygwin\"" +files = [ + {file = "psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b"}, + {file = "psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea"}, + {file = "psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63"}, + {file = "psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312"}, + {file = "psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b"}, + {file = "psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9"}, + {file = "psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00"}, + {file = "psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9"}, + {file = "psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a"}, + {file = "psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf"}, + {file = "psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1"}, + {file = "psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841"}, + {file = "psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486"}, + {file = "psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979"}, + {file = "psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9"}, + {file = "psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e"}, + {file = "psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8"}, + {file = "psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc"}, + {file = "psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988"}, + {file = "psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee"}, + {file = "psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372"}, +] + +[package.extras] +dev = ["abi3audit", "black", "check-manifest", "colorama ; os_name == \"nt\"", "coverage", "packaging", "psleak", "pylint", "pyperf", "pypinfo", "pyreadline3 ; os_name == \"nt\"", "pytest", "pytest-cov", "pytest-instafail", "pytest-xdist", "pywin32 ; os_name == \"nt\" and implementation_name != \"pypy\"", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "validate-pyproject[all]", "virtualenv", "vulture", "wheel", "wheel ; os_name == \"nt\" and implementation_name != \"pypy\"", "wmi ; os_name == \"nt\" and implementation_name != \"pypy\""] +test = ["psleak", "pytest", "pytest-instafail", "pytest-xdist", "pywin32 ; os_name == \"nt\" and implementation_name != \"pypy\"", "setuptools", "wheel ; os_name == \"nt\" and implementation_name != \"pypy\"", "wmi ; os_name == \"nt\" and implementation_name != \"pypy\""] + +[[package]] +name = "psycopg" +version = "3.2.13" +description = "PostgreSQL database adapter for Python" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version == \"3.9\"" +files = [ + {file = "psycopg-3.2.13-py3-none-any.whl", hash = "sha256:a481374514f2da627157f767a9336705ebefe93ea7a0522a6cbacba165da179a"}, + {file = "psycopg-3.2.13.tar.gz", hash = "sha256:309adaeda61d44556046ec9a83a93f42bbe5310120b1995f3af49ab6d9f13c1d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.6", markers = "python_version < \"3.13\""} +tzdata = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +binary = ["psycopg-binary (==3.2.13) ; implementation_name != \"pypy\""] +c = ["psycopg-c (==3.2.13) ; implementation_name != \"pypy\""] +dev = ["ast-comments (>=1.1.2)", "black (>=24.1.0)", "codespell (>=2.2)", "dnspython (>=2.1)", "flake8 (>=4.0)", "isort-psycopg", "isort[colors] (>=6.0)", "mypy (>=1.14)", "pre-commit (>=4.0.1)", "types-setuptools (>=57.4)", "types-shapely (>=2.0)", "wheel (>=0.37)"] +docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] +pool = ["psycopg-pool"] +test = ["anyio (>=4.0)", "mypy (>=1.14)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] + +[[package]] +name = "psycopg" +version = "3.3.2" +description = "PostgreSQL database adapter for Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "psycopg-3.3.2-py3-none-any.whl", hash = "sha256:3e94bc5f4690247d734599af56e51bae8e0db8e4311ea413f801fef82b14a99b"}, + {file = "psycopg-3.3.2.tar.gz", hash = "sha256:707a67975ee214d200511177a6a80e56e654754c9afca06a7194ea6bbfde9ca7"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.6", markers = "python_version < \"3.13\""} +tzdata = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +binary = ["psycopg-binary (==3.3.2) ; implementation_name != \"pypy\""] +c = ["psycopg-c (==3.3.2) ; implementation_name != \"pypy\""] +dev = ["ast-comments (>=1.1.2)", "black (>=24.1.0)", "codespell (>=2.2)", "cython-lint (>=0.16)", "dnspython (>=2.1)", "flake8 (>=4.0)", "isort-psycopg", "isort[colors] (>=6.0)", "mypy (>=1.19.0)", "pre-commit (>=4.0.1)", "types-setuptools (>=57.4)", "types-shapely (>=2.0)", "wheel (>=0.37)"] +docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] +pool = ["psycopg-pool"] +test = ["anyio (>=4.0)", "mypy (>=1.19.0) ; implementation_name != \"pypy\"", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] [[package]] name = "pyarrow" @@ -4924,7 +5075,7 @@ files = [ {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "pyasn1-modules" @@ -4937,7 +5088,7 @@ files = [ {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"}, {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] pyasn1 = ">=0.6.1,<0.7.0" @@ -4965,7 +5116,7 @@ files = [ {file = "pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934"}, {file = "pycparser-2.23.tar.gz", hash = "sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2"}, ] -markers = {main = "implementation_name != \"PyPy\" and (platform_python_implementation != \"PyPy\" or extra == \"proxy\")", dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\""} +markers = {main = "implementation_name != \"PyPy\" and (platform_python_implementation != \"PyPy\" or extra == \"proxy\") and (python_version >= \"3.10\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\")", dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\""} [[package]] name = "pydantic" @@ -5188,6 +5339,7 @@ files = [ {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] +markers = {main = "(python_version <= \"3.13\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"extra-proxy\" or extra == \"proxy\")"} [package.dependencies] cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""} @@ -5353,6 +5505,25 @@ pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-postgresql" +version = "6.1.1" +description = "Postgresql fixtures and fixture factories for Pytest." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "pytest_postgresql-6.1.1-py3-none-any.whl", hash = "sha256:bd4c0970d25685ac3d34d42263fcbfbf134bf02d22519fce7e1ccf4122d8b99a"}, + {file = "pytest_postgresql-6.1.1.tar.gz", hash = "sha256:f996637367e6aecebba1349da52eea95340bdb434c90e4b79739e62c656056e2"}, +] + +[package.dependencies] +mirakuru = "*" +port-for = ">=0.7.3" +psycopg = ">=3.0.0" +pytest = ">=6.2" +setuptools = "*" + [[package]] name = "pytest-retry" version = "1.7.0" @@ -6079,7 +6250,7 @@ files = [ {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"}, {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] pyasn1 = ">=0.1.3" @@ -6125,10 +6296,10 @@ files = [ ] [package.dependencies] -botocore = ">=1.37.4,<2.0a.0" +botocore = ">=1.37.4,<2.0a0" [package.extras] -crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] +crt = ["botocore[crt] (>=1.37.4,<2.0a0)"] [[package]] name = "scikit-learn" @@ -6281,9 +6452,9 @@ tornado = ">=6.4.2,<7" urllib3 = ">=1.26,<3" [package.extras] -all = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)", "cohere (>=5.9.4,<6.00)", "dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\"", "google-cloud-aiplatform (>=1.45.0,<2)", "ipykernel (>=6.25.0,<7)", "llama-cpp-python (>=0.2.28,<0.2.86) ; python_version < \"3.13\"", "mistralai (>=0.0.12,<0.1.0)", "mypy (>=1.7.1,<2)", "ollama (>=0.1.7)", "pillow (>=10.2.0,<11.0.0) ; python_version < \"3.13\"", "pinecone[asyncio] (>=7.0.0,<8.0.0)", "psycopg[binary] (>=3.1.0,<4)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "qdrant-client (>=1.11.1,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "sentence-transformers (>=5.0.0) ; python_version < \"3.13\"", "tokenizers (>=0.19) ; python_version < \"3.13\"", "torch (>=2.6.0) ; python_version < \"3.13\"", "torchvision (>=0.17.0) ; python_version < \"3.13\"", "transformers (>=4.36.2) ; python_version < \"3.13\"", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] +all = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)", "cohere (>=5.9.4,<6.0)", "dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\"", "google-cloud-aiplatform (>=1.45.0,<2)", "ipykernel (>=6.25.0,<7)", "llama-cpp-python (>=0.2.28,<0.2.86) ; python_version < \"3.13\"", "mistralai (>=0.0.12,<0.1.0)", "mypy (>=1.7.1,<2)", "ollama (>=0.1.7)", "pillow (>=10.2.0,<11.0.0) ; python_version < \"3.13\"", "pinecone[asyncio] (>=7.0.0,<8.0.0)", "psycopg[binary] (>=3.1.0,<4)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "qdrant-client (>=1.11.1,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "sentence-transformers (>=5.0.0) ; python_version < \"3.13\"", "tokenizers (>=0.19) ; python_version < \"3.13\"", "torch (>=2.6.0) ; python_version < \"3.13\"", "torchvision (>=0.17.0) ; python_version < \"3.13\"", "transformers (>=4.36.2) ; python_version < \"3.13\"", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] bedrock = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)"] -cohere = ["cohere (>=5.9.4,<6.00)"] +cohere = ["cohere (>=5.9.4,<6.0)"] dev = ["dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "ipykernel (>=6.25.0,<7)", "mypy (>=1.7.1,<2)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] docs = ["pydoc-markdown (>=4.8.2) ; python_version < \"3.12\""] fastembed = ["fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\""] @@ -6296,6 +6467,27 @@ postgres = ["psycopg[binary] (>=3.1.0,<4)"] qdrant = ["qdrant-client (>=1.11.1,<2)"] vision = ["pillow (>=10.2.0,<11.0.0) ; python_version < \"3.13\"", "torch (>=2.6.0) ; python_version < \"3.13\"", "torchvision (>=0.17.0) ; python_version < \"3.13\"", "transformers (>=4.36.2) ; python_version < \"3.13\""] +[[package]] +name = "setuptools" +version = "82.0.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "setuptools-82.0.0-py3-none-any.whl", hash = "sha256:70b18734b607bd1da571d097d236cfcfacaf01de45717d59e6e04b96877532e0"}, + {file = "setuptools-82.0.0.tar.gz", hash = "sha256:22e0a2d69474c6ae4feb01951cb69d515ed23728cf96d05513d36e42b62b37cb"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.13.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.18.*)", "pytest-mypy"] + [[package]] name = "shapely" version = "2.0.7" @@ -6990,6 +7182,7 @@ files = [ {file = "tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0"}, {file = "tomlkit-0.13.3.tar.gz", hash = "sha256:430cf247ee57df2b94ee3fbe588e71d362a941ebb545dec29b53961d61add2a1"}, ] +markers = {main = "extra == \"extra-proxy\""} [[package]] name = "tornado" @@ -7202,14 +7395,14 @@ typing-extensions = ">=4.12.0" name = "tzdata" version = "2025.2" description = "Provider of IANA time zone data" -optional = true +optional = false python-versions = ">=2" -groups = ["main"] -markers = "(extra == \"proxy\" or extra == \"mlflow\") and (platform_system == \"Windows\" or extra == \"mlflow\") and python_version >= \"3.10\" or extra == \"proxy\" and platform_system == \"Windows\" and python_version == \"3.9\"" +groups = ["main", "dev"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] +markers = {main = "(extra == \"proxy\" or extra == \"mlflow\") and (platform_system == \"Windows\" or extra == \"mlflow\") and python_version >= \"3.10\" or extra == \"proxy\" and platform_system == \"Windows\" and python_version == \"3.9\"", dev = "sys_platform == \"win32\""} [[package]] name = "tzlocal" @@ -7741,4 +7934,4 @@ utils = ["numpydoc"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "f2b4f98542c48ba2316a4c90563fc3551f34d9e3771bac39044f55a390e1f1c1" +content-hash = "20ca098d83da3b9364b05930a74e9ff8512e31d626018fc9f056b6fbd50a69af" diff --git a/pyproject.toml b/pyproject.toml index 5726d33c14c..31c0246b4c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ mypy = "^1.0" pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-asyncio = "^0.21.1" +pytest-postgresql = "^6.0.0" pytest-retry = "^1.6.3" requests-mock = "^1.12.1" responses = "^0.25.7" diff --git a/schema.prisma b/schema.prisma index c2fca8705cb..441c2cdf70d 100644 --- a/schema.prisma +++ b/schema.prisma @@ -325,6 +325,19 @@ model LiteLLM_VerificationToken { @@index([budget_reset_at, expires]) } +// Deprecated keys during grace period - allows old key to work until revoke_at +model LiteLLM_DeprecatedVerificationToken { + id String @id @default(uuid()) + token String // Hashed old key + active_token_id String // Current token hash in LiteLLM_VerificationToken + revoke_at DateTime // When the old key stops working + created_at DateTime @default(now()) @map("created_at") + + @@unique([token]) + @@index([token, revoke_at]) + @@index([revoke_at]) +} + // Audit table for deleted keys - preserves spend and key information for historical tracking model LiteLLM_DeletedVerificationToken { id String @id @default(uuid()) diff --git a/tests/batches_tests/test_batch_custom_pricing.py b/tests/batches_tests/test_batch_custom_pricing.py new file mode 100644 index 00000000000..8bc1bd5a307 --- /dev/null +++ b/tests/batches_tests/test_batch_custom_pricing.py @@ -0,0 +1,131 @@ +""" +Test that batch cost calculation uses custom deployment-level pricing +when model_info is provided. + +Reproduces the bug where `input_cost_per_token_batches` / +`output_cost_per_token_batches` set on a proxy deployment's model_info +are ignored by the batch cost pipeline because they are never threaded +through to `batch_cost_calculator`. +""" + +import pytest + +from litellm.batches.batch_utils import ( + _batch_cost_calculator, + _get_batch_job_cost_from_file_content, + calculate_batch_cost_and_usage, +) +from litellm.cost_calculator import batch_cost_calculator +from litellm.types.utils import Usage + + +# --- helpers --- + +def _make_batch_output_line(prompt_tokens: int = 10, completion_tokens: int = 5): + """Return a single successful batch output line (OpenAI JSONL format).""" + return { + "id": "batch_req_1", + "custom_id": "req-1", + "response": { + "status_code": 200, + "body": { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "fake-batch-model", + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello"}, + "finish_reason": "stop", + } + ], + }, + }, + "error": None, + } + + +CUSTOM_MODEL_INFO = { + "input_cost_per_token_batches": 0.00125, + "output_cost_per_token_batches": 0.005, +} + + +# --- tests --- + + +def test_batch_cost_calculator_uses_custom_model_info(): + """batch_cost_calculator should use model_info override when provided.""" + usage = Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + prompt_cost, completion_cost = batch_cost_calculator( + usage=usage, + model="fake-batch-model", + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected_prompt = 10 * 0.00125 + expected_completion = 5 * 0.005 + assert prompt_cost == pytest.approx(expected_prompt), ( + f"Expected prompt cost {expected_prompt}, got {prompt_cost}" + ) + assert completion_cost == pytest.approx(expected_completion), ( + f"Expected completion cost {expected_completion}, got {completion_cost}" + ) + + +def test_get_batch_job_cost_from_file_content_uses_custom_model_info(): + """_get_batch_job_cost_from_file_content should thread model_info to completion_cost.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + cost = _get_batch_job_cost_from_file_content( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {cost}" + ) + + +def test_batch_cost_calculator_func_uses_custom_model_info(): + """_batch_cost_calculator should thread model_info.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + cost = _batch_cost_calculator( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {cost}" + ) + + +@pytest.mark.asyncio +async def test_calculate_batch_cost_and_usage_uses_custom_model_info(): + """calculate_batch_cost_and_usage should thread model_info.""" + file_content = [_make_batch_output_line(prompt_tokens=10, completion_tokens=5)] + + batch_cost, batch_usage, batch_models = await calculate_batch_cost_and_usage( + file_content_dictionary=file_content, + custom_llm_provider="openai", + model_info=CUSTOM_MODEL_INFO, + ) + + expected = (10 * 0.00125) + (5 * 0.005) + assert batch_cost == pytest.approx(expected), ( + f"Expected total cost {expected}, got {batch_cost}" + ) + assert batch_usage.prompt_tokens == 10 + assert batch_usage.completion_tokens == 5 diff --git a/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py b/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py new file mode 100644 index 00000000000..7040aef73e5 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_afile_retrieve_returns_unified_id.py @@ -0,0 +1,67 @@ +""" +Test that managed_files.afile_retrieve returns the unified file ID, not the +raw provider file ID, when file_object is already stored in the database. + +Bug: managed_files.py Case 2 returns stored_file_object.file_object directly +without replacing .id with the unified ID. Case 3 (fetch from provider) does +it correctly at line 1028. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy._types import LiteLLM_ManagedFileTable +from litellm.types.llms.openai import OpenAIFileObject + + +def _make_managed_files_instance(): + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=MagicMock(), + ) + return instance + + +@pytest.mark.asyncio +async def test_should_return_unified_id_when_file_object_exists_in_db(): + """ + When get_unified_file_id returns a stored file_object (Case 2), + afile_retrieve must set .id to the unified file ID before returning. + """ + unified_id = "bGl0ZWxsbV9wcm94eTp1bmlmaWVkX291dHB1dF9maWxl" + raw_provider_id = "batch_20260214-output-file-1" + + stored = LiteLLM_ManagedFileTable( + unified_file_id=unified_id, + file_object=OpenAIFileObject( + id=raw_provider_id, + bytes=489, + created_at=1700000000, + filename="batch_output.jsonl", + object="file", + purpose="batch_output", + status="processed", + ), + model_mappings={"model-abc": raw_provider_id}, + flat_model_file_ids=[raw_provider_id], + created_by="test-user", + updated_by="test-user", + ) + + managed_files = _make_managed_files_instance() + managed_files.get_unified_file_id = AsyncMock(return_value=stored) + + result = await managed_files.afile_retrieve( + file_id=unified_id, + litellm_parent_otel_span=None, + llm_router=None, + ) + + assert result.id == unified_id, ( + f"afile_retrieve should return the unified ID '{unified_id}', " + f"but got raw provider ID '{result.id}'" + ) diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py new file mode 100644 index 00000000000..6e9c3c0354b --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_input_file_id.py @@ -0,0 +1,75 @@ +""" +Test that batch retrieve endpoint resolves raw input_file_id to the +unified managed file ID before returning. + +Bug: After batch completion, batches.retrieve returns the raw provider +input_file_id instead of the LiteLLM unified ID. +""" + +import base64 +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm.proxy.openai_files_endpoints.common_utils import ( + _is_base64_encoded_unified_file_id, +) + + +DECODED_UNIFIED_INPUT_FILE_ID = "litellm_proxy:application/octet-stream;unified_id,test-uuid;target_model_names,azure-gpt-4" +B64_UNIFIED_INPUT_FILE_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_INPUT_FILE_ID.encode()).decode().rstrip("=") +RAW_INPUT_FILE_ID = "file-raw-provider-abc123" + +DECODED_UNIFIED_BATCH_ID = "litellm_proxy;model_id:model-xyz;llm_batch_id:batch-123" +B64_UNIFIED_BATCH_ID = base64.urlsafe_b64encode(DECODED_UNIFIED_BATCH_ID.encode()).decode().rstrip("=") + + +@pytest.mark.asyncio +async def test_should_resolve_raw_input_file_id_to_unified(): + """ + When a completed batch has a raw input_file_id and the managed file table + contains a record for that raw ID, the retrieve endpoint should resolve + it to the unified file ID. + """ + unified_batch_id = _is_base64_encoded_unified_file_id(B64_UNIFIED_BATCH_ID) + assert unified_batch_id, "Test setup: batch_id should decode as unified" + + from litellm.types.utils import LiteLLMBatch + + batch_data = { + "id": B64_UNIFIED_BATCH_ID, + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": RAW_INPUT_FILE_ID, + "object": "batch", + "status": "completed", + "output_file_id": "file-output-xyz", + } + + mock_db_object = MagicMock() + mock_db_object.file_object = json.dumps(batch_data) + + mock_managed_file = MagicMock() + mock_managed_file.unified_file_id = B64_UNIFIED_INPUT_FILE_ID + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedobjecttable.find_first = AsyncMock(return_value=mock_db_object) + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock(return_value=mock_managed_file) + + from litellm.proxy.openai_files_endpoints.common_utils import get_batch_from_database + + _, response = await get_batch_from_database( + batch_id=B64_UNIFIED_BATCH_ID, + unified_batch_id=unified_batch_id, + managed_files_obj=MagicMock(), + prisma_client=mock_prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None, "Batch should be found in DB" + assert response.input_file_id == B64_UNIFIED_INPUT_FILE_ID, ( + f"input_file_id should be unified '{B64_UNIFIED_INPUT_FILE_ID}', " + f"got raw '{response.input_file_id}'" + ) diff --git a/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py new file mode 100644 index 00000000000..420f5f9789c --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_batch_retrieve_returns_unified_input_file_id.py @@ -0,0 +1,124 @@ +""" +Test that get_batch_from_database resolves raw input_file_id to the +unified/managed file ID when reading a batch from the database. + +Bug: The batch retrieve path stores the raw provider input_file_id in the +DB (via async_post_call_success_hook on the retrieve endpoint). When the +batch is later read from DB, get_batch_from_database returns the raw ID +without resolving it to the unified ID. +""" + +import json +import pytest +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +from litellm.proxy.openai_files_endpoints.common_utils import get_batch_from_database + + +def _mock_prisma(batch_json: str, managed_file_record=None): + """Create a mock prisma client with canned responses.""" + prisma = MagicMock() + + batch_db_record = MagicMock() + batch_db_record.file_object = batch_json + + prisma.db.litellm_managedobjecttable.find_first = AsyncMock( + return_value=batch_db_record + ) + + prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=managed_file_record + ) + + return prisma + + +@pytest.mark.asyncio +async def test_should_resolve_raw_input_file_id_to_unified_id(): + """ + When input_file_id in the stored batch is a raw provider ID, + get_batch_from_database must look up the unified ID from the + managed files table. + """ + unified_batch_id = "bGl0ZWxsbV9wcm94eTpiYXRjaF9pZA" + unified_input_file_id = "bGl0ZWxsbV9wcm94eTp1bmlmaWVkX2lucHV0" + raw_input_file_id = "file-abc123-raw" + + batch_data = { + "id": "batch-raw-123", + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": raw_input_file_id, + "object": "batch", + "status": "completed", + "output_file_id": "file-output-raw", + } + + managed_file_record = MagicMock() + managed_file_record.unified_file_id = unified_input_file_id + + prisma = _mock_prisma( + batch_json=json.dumps(batch_data), + managed_file_record=managed_file_record, + ) + + _, response = await get_batch_from_database( + batch_id=unified_batch_id, + unified_batch_id="decoded_unified_batch_id", + managed_files_obj=MagicMock(), + prisma_client=prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None + assert response.input_file_id == unified_input_file_id, ( + f"input_file_id should be resolved to '{unified_input_file_id}', " + f"got raw: '{response.input_file_id}'" + ) + + prisma.db.litellm_managedfiletable.find_first.assert_called_once_with( + where={"flat_model_file_ids": {"has": raw_input_file_id}} + ) + + +@pytest.mark.asyncio +async def test_should_preserve_already_managed_input_file_id(): + """ + When input_file_id is already a managed/unified ID, it should + not be modified. + """ + import base64 + + unified_batch_id = "bGl0ZWxsbV9wcm94eTpiYXRjaF9pZA" + decoded_unified = "litellm_proxy:application/octet-stream;unified_id,test-123" + base64_input_file_id = base64.urlsafe_b64encode(decoded_unified.encode()).decode().rstrip("=") + + batch_data = { + "id": "batch-raw-123", + "completion_window": "24h", + "created_at": 1700000000, + "endpoint": "/v1/chat/completions", + "input_file_id": base64_input_file_id, + "object": "batch", + "status": "completed", + } + + prisma = _mock_prisma(batch_json=json.dumps(batch_data)) + + _, response = await get_batch_from_database( + batch_id=unified_batch_id, + unified_batch_id="decoded_unified_batch_id", + managed_files_obj=MagicMock(), + prisma_client=prisma, + verbose_proxy_logger=MagicMock(), + ) + + assert response is not None + assert response.input_file_id == base64_input_file_id, ( + f"input_file_id was already managed, should be preserved as '{base64_input_file_id}', " + f"got: '{response.input_file_id}'" + ) + + prisma.db.litellm_managedfiletable.find_first.assert_not_called() diff --git a/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py b/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py new file mode 100644 index 00000000000..7ad564dc8f9 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_deleted_file_returns_403_not_404.py @@ -0,0 +1,119 @@ +""" +Regression test: deleted managed files should return 404, not 403. + +When a managed file's DB record has been deleted, can_user_call_unified_file_id() +raises HTTPException(404) directly — rather than returning True (which would +weaken access control) or False (which would cause a misleading 403). +""" + +import base64 + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from fastapi import HTTPException + +from litellm.proxy._types import UserAPIKeyAuth + + +def _make_user_api_key_dict(user_id: str) -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id=user_id, + parent_otel_span=None, + ) + + +def _make_unified_file_id() -> str: + raw = "litellm_proxy:application/octet-stream;unified_id,test-deleted-file;target_model_names,azure-gpt-4" + return base64.b64encode(raw.encode()).decode() + + +def _make_managed_files_with_no_db_record(): + """Create a _PROXY_LiteLLMManagedFiles where the DB returns None (file was deleted).""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock(return_value=None) + + return _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + +@pytest.mark.asyncio +async def test_should_raise_404_for_deleted_file(): + """ + When a managed file record has been deleted from the DB, + check_managed_file_id_access should raise 404 (not 403). + """ + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_with_no_db_record() + user = _make_user_api_key_dict("any-user") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_should_allow_owner_access_when_record_exists(): + """Baseline: file owner can access their own file.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + unified_file_id = _make_unified_file_id() + + mock_db_record = MagicMock() + mock_db_record.created_by = "user-A" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + managed_files = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + user = _make_user_api_key_dict("user-A") + data = {"file_id": unified_file_id} + + result = await managed_files.check_managed_file_id_access(data, user) + assert result is True + + +@pytest.mark.asyncio +async def test_should_block_different_user_when_record_exists(): + """Baseline: different user cannot access another user's file.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + unified_file_id = _make_unified_file_id() + + mock_db_record = MagicMock() + mock_db_record.created_by = "user-A" + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + managed_files = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + + user = _make_user_api_key_dict("user-B") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 403 diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py new file mode 100644 index 00000000000..2db5a2214cb --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py @@ -0,0 +1,200 @@ +""" +Tests for managed files access control in batch polling context. + +Regression test for: batch polling job running as default_user_id gets 403 +when trying to access managed files created by a real user. + +The fix (Option C) makes check_batch_cost call litellm.afile_content directly +with deployment credentials, bypassing the managed files access-control hooks. +""" + +import base64 +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import HTTPException + +from litellm.proxy._types import UserAPIKeyAuth + + +def _make_user_api_key_dict(user_id: str) -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id=user_id, + parent_otel_span=None, + ) + + +def _make_unified_file_id() -> str: + """Create a base64-encoded unified file ID that passes _is_base64_encoded_unified_file_id.""" + raw = "litellm_proxy:application/octet-stream;unified_id,test-123;target_model_names,azure-gpt-4" + return base64.b64encode(raw.encode()).decode() + + +def _make_managed_files_instance(file_created_by: str, unified_file_id: str): + """Create a _PROXY_LiteLLMManagedFiles with a mocked DB that returns a file owned by file_created_by.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_db_record = MagicMock() + mock_db_record.created_by = file_created_by + + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedfiletable.find_first = AsyncMock( + return_value=mock_db_record + ) + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=MagicMock(), + prisma_client=mock_prisma, + ) + return instance + + +# --- Access control unit tests (document existing behavior) --- + + +@pytest.mark.asyncio +async def test_should_allow_file_owner_access(): + """File owner can access their own file — baseline sanity check.""" + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + user = _make_user_api_key_dict("user-A") + data = {"file_id": unified_file_id} + + result = await managed_files.check_managed_file_id_access(data, user) + assert result is True + + +@pytest.mark.asyncio +async def test_should_block_different_user_access(): + """A different regular user cannot access another user's file — correct behavior.""" + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + user = _make_user_api_key_dict("user-B") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, user) + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_should_block_default_user_id_access(): + """ + default_user_id is correctly blocked by the access check. + This documents the existing behavior that the Option C fix works around. + """ + unified_file_id = _make_unified_file_id() + managed_files = _make_managed_files_instance( + file_created_by="user-A", + unified_file_id=unified_file_id, + ) + system_user = _make_user_api_key_dict("default_user_id") + data = {"file_id": unified_file_id} + + with pytest.raises(HTTPException) as exc_info: + await managed_files.check_managed_file_id_access(data, system_user) + assert exc_info.value.status_code == 403 + + +# --- Option C fix test: check_batch_cost bypasses managed files hook --- + + +@pytest.mark.asyncio +async def test_check_batch_cost_should_call_afile_content_directly_with_credentials(): + """ + check_batch_cost should call litellm.afile_content directly with deployment + credentials, bypassing managed_files_obj.afile_content and its access-control + hooks. This avoids the 403 that occurs when the background job runs as + default_user_id. + """ + from litellm_enterprise.proxy.common_utils.check_batch_cost import CheckBatchCost + + # Build a unified object ID in the expected format: + # litellm_proxy;model_id:{};llm_batch_id:{};llm_output_file_id:{} + unified_raw = "litellm_proxy;model_id:model-deploy-xyz;llm_batch_id:batch-123;llm_output_file_id:file-raw-output" + unified_object_id = base64.b64encode(unified_raw.encode()).decode() + + # Mock a pending job from the DB + mock_job = MagicMock() + mock_job.unified_object_id = unified_object_id + mock_job.created_by = "user-A" + mock_job.id = "job-1" + + # Mock prisma + mock_prisma = MagicMock() + mock_prisma.db.litellm_managedobjecttable.find_many = AsyncMock( + return_value=[mock_job] + ) + mock_prisma.db.litellm_managedobjecttable.update_many = AsyncMock() + + # Mock proxy_logging_obj — should NOT be called for file content + mock_proxy_logging = MagicMock() + mock_managed_files_hook = MagicMock() + mock_managed_files_hook.afile_content = AsyncMock() + mock_proxy_logging.get_proxy_hook = MagicMock(return_value=mock_managed_files_hook) + + # Mock the batch response (completed, with output file) + from litellm.types.utils import LiteLLMBatch + batch_response = LiteLLMBatch( + id="batch-123", + completion_window="24h", + created_at=1700000000, + endpoint="/v1/chat/completions", + input_file_id="file-input", + object="batch", + status="completed", + output_file_id="file-raw-output", + ) + + # Mock router + mock_router = MagicMock() + mock_router.aretrieve_batch = AsyncMock(return_value=batch_response) + mock_router.get_deployment_credentials_with_provider = MagicMock( + return_value={ + "api_key": "test-key", + "api_base": "https://test.azure.com/", + "custom_llm_provider": "azure", + } + ) + + mock_deployment = MagicMock() + mock_deployment.litellm_params.custom_llm_provider = "azure" + mock_deployment.litellm_params.model = "azure/gpt-4" + mock_router.get_deployment = MagicMock(return_value=mock_deployment) + + checker = CheckBatchCost( + proxy_logging_obj=mock_proxy_logging, + prisma_client=mock_prisma, + llm_router=mock_router, + ) + + mock_file_content = MagicMock() + mock_file_content.content = b'{"id":"req-1","response":{"status_code":200,"body":{"id":"cmpl-1","object":"chat.completion","created":1700000000,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}}}\n' + + with patch( + "litellm.files.main.afile_content", + new_callable=AsyncMock, + return_value=mock_file_content, + ) as mock_direct_afile_content: + await checker.check_batch_cost() + + # afile_content should be called directly (not through managed_files_obj) + mock_direct_afile_content.assert_called_once() + call_kwargs = mock_direct_afile_content.call_args.kwargs + + assert call_kwargs.get("api_key") == "test-key", ( + f"afile_content should receive api_key from deployment credentials. " + f"Got: {call_kwargs}" + ) + + # managed_files_obj.afile_content should NOT have been called + mock_managed_files_hook.afile_content.assert_not_called() diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py new file mode 100644 index 00000000000..9526304aff0 --- /dev/null +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_hook.py @@ -0,0 +1,167 @@ +""" +Tests for enterprise/litellm_enterprise/proxy/hooks/managed_files.py + +Regression test for afile_retrieve called without credentials in +async_post_call_success_hook when processing completed batch responses. +""" + +import pytest +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, patch + +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.llms.openai import OpenAIFileObject +from litellm.types.utils import LiteLLMBatch + + +def _make_file_object(file_id: str = "file-output-abc") -> OpenAIFileObject: + return OpenAIFileObject( + id=file_id, + bytes=100, + created_at=1700000000, + filename="output.jsonl", + object="file", + purpose="batch_output", + status="processed", + ) + + +def _make_batch_response( + batch_id: str = "batch-123", + output_file_id: Optional[str] = "file-output-abc", + status: str = "completed", + model_id: str = "model-deploy-xyz", + model_name: str = "azure/gpt-4", +) -> LiteLLMBatch: + """Create a LiteLLMBatch response with hidden params set as the router would.""" + batch = LiteLLMBatch( + id=batch_id, + completion_window="24h", + created_at=1700000000, + endpoint="/v1/chat/completions", + input_file_id="file-input-abc", + object="batch", + status=status, + output_file_id=output_file_id, + ) + batch._hidden_params = { + "unified_file_id": "some-unified-id", + "unified_batch_id": "some-unified-batch-id", + "model_id": model_id, + "model_name": model_name, + } + return batch + + +def _make_user_api_key_dict() -> UserAPIKeyAuth: + return UserAPIKeyAuth( + api_key="sk-test", + user_id="test-user", + parent_otel_span=None, + ) + + +def _make_managed_files_instance(): + """Create a _PROXY_LiteLLMManagedFiles with storage methods mocked out.""" + from litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles, + ) + + mock_cache = MagicMock() + mock_prisma = MagicMock() + + instance = _PROXY_LiteLLMManagedFiles( + internal_usage_cache=mock_cache, + prisma_client=mock_prisma, + ) + instance.store_unified_file_id = AsyncMock() + instance.store_unified_object_id = AsyncMock() + return instance + + +@pytest.mark.asyncio +async def test_should_pass_credentials_to_afile_retrieve(): + """ + When async_post_call_success_hook processes a completed batch with an output_file_id, + it calls afile_retrieve to fetch file metadata. It must pass credentials from the + router deployment, not just custom_llm_provider and file_id. + + Regression test for: managed_files.py:919 calling afile_retrieve without api_key/api_base. + """ + managed_files = _make_managed_files_instance() + batch_response = _make_batch_response( + model_id="model-deploy-xyz", + model_name="azure/gpt-4", + output_file_id="file-output-abc", + ) + user_api_key_dict = _make_user_api_key_dict() + + mock_credentials = { + "api_key": "test-azure-key", + "api_base": "https://my-azure.openai.azure.com/", + "api_version": "2025-03-01-preview", + "custom_llm_provider": "azure", + } + + mock_router = MagicMock() + mock_router.get_deployment_credentials_with_provider = MagicMock( + return_value=mock_credentials + ) + + mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) + + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", mock_router + ): + await managed_files.async_post_call_success_hook( + data={}, + user_api_key_dict=user_api_key_dict, + response=batch_response, + ) + + mock_afile_retrieve.assert_called() + call_kwargs = mock_afile_retrieve.call_args + + assert call_kwargs.kwargs.get("api_key") == "test-azure-key", ( + f"afile_retrieve must receive api_key from router credentials. " + f"Got kwargs: {call_kwargs.kwargs}" + ) + assert call_kwargs.kwargs.get("api_base") == "https://my-azure.openai.azure.com/", ( + f"afile_retrieve must receive api_base from router credentials. " + f"Got kwargs: {call_kwargs.kwargs}" + ) + + +@pytest.mark.asyncio +async def test_should_fallback_when_no_router(): + """ + When llm_router is not available, afile_retrieve should still be called + with the fallback behavior (custom_llm_provider extracted from model_name). + """ + managed_files = _make_managed_files_instance() + batch_response = _make_batch_response( + model_id="model-deploy-xyz", + model_name="azure/gpt-4", + output_file_id="file-output-abc", + ) + user_api_key_dict = _make_user_api_key_dict() + + mock_afile_retrieve = AsyncMock(return_value=_make_file_object("file-output-abc")) + + with patch( + "litellm.afile_retrieve", mock_afile_retrieve + ), patch( + "litellm.proxy.proxy_server.llm_router", None + ): + await managed_files.async_post_call_success_hook( + data={}, + user_api_key_dict=user_api_key_dict, + response=batch_response, + ) + + mock_afile_retrieve.assert_called() + call_kwargs = mock_afile_retrieve.call_args + assert call_kwargs.kwargs.get("custom_llm_provider") == "azure" + assert call_kwargs.kwargs.get("file_id") == "file-output-abc" diff --git a/tests/test_litellm/integrations/test_langfuse.py b/tests/test_litellm/integrations/test_langfuse.py index cd3d5b9ebe3..010d9f863c2 100644 --- a/tests/test_litellm/integrations/test_langfuse.py +++ b/tests/test_litellm/integrations/test_langfuse.py @@ -268,22 +268,21 @@ def test_log_langfuse_v2_handles_null_usage_values(self): Test that _log_langfuse_v2 correctly handles None values in the usage object by converting them to 0, preventing validation errors. """ - # Create fresh mocks for this test to avoid state pollution from setUp's side_effect - # The setUp configures trace.side_effect which can interfere with return_value - mock_trace = MagicMock() - mock_generation = MagicMock() - mock_generation.trace_id = "test-trace-id" + # Reset the mock to ensure clean state + self.mock_langfuse_client.reset_mock() + self.mock_langfuse_trace.reset_mock() + self.mock_langfuse_generation.reset_mock() + + # Re-setup the trace and generation chain with clean state + self.mock_langfuse_generation.trace_id = "test-trace-id" mock_span = MagicMock() mock_span.end = MagicMock() - - mock_trace.generation.return_value = mock_generation - mock_trace.span.return_value = mock_span - - mock_client = MagicMock() - mock_client.trace.return_value = mock_trace - - # Use our fresh mock client - self.logger.Langfuse = mock_client + self.mock_langfuse_trace.span.return_value = mock_span + self.mock_langfuse_trace.generation.return_value = self.mock_langfuse_generation + + # Ensure trace returns our mock + self.mock_langfuse_client.trace.return_value = self.mock_langfuse_trace + self.logger.Langfuse = self.mock_langfuse_client with patch( "litellm.integrations.langfuse.langfuse._add_prompt_to_generation_params", @@ -337,13 +336,13 @@ def mock_get(key, default=None): ) except Exception as e: self.fail(f"_log_langfuse_v2 raised an exception: {e}") - + # Verify that trace was called first - mock_client.trace.assert_called() - + self.mock_langfuse_client.trace.assert_called() + # Check the arguments passed to the mocked langfuse generation call - mock_trace.generation.assert_called_once() - call_args, call_kwargs = mock_trace.generation.call_args + self.mock_langfuse_trace.generation.assert_called_once() + call_args, call_kwargs = self.mock_langfuse_trace.generation.call_args # Inspect the usage and usage_details dictionaries usage_arg = call_kwargs.get("usage") diff --git a/tests/test_litellm/integrations/test_s3_v2.py b/tests/test_litellm/integrations/test_s3_v2.py index 0a3523699a9..b53c05fa241 100644 --- a/tests/test_litellm/integrations/test_s3_v2.py +++ b/tests/test_litellm/integrations/test_s3_v2.py @@ -157,6 +157,186 @@ def test_s3_v2_endpoint_url(self, mock_periodic_flush, mock_create_task): assert result == {"downloaded": "data"} + @patch('asyncio.create_task') + @patch('litellm.integrations.s3_v2.CustomBatchLogger.periodic_flush') + def test_s3_v2_virtual_hosted_style(self, mock_periodic_flush, mock_create_task): + """Test s3_use_virtual_hosted_style parameter for virtual-hosted-style URLs""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.types.integrations.s3_v2 import s3BatchLoggingElement + + # Mock periodic_flush and create_task to prevent async task creation during init + mock_periodic_flush.return_value = None + mock_create_task.return_value = None + + # Mock response for all tests + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + # Create a test batch logging element + test_element = s3BatchLoggingElement( + s3_object_key="2025-09-14/test-key.json", + payload={"test": "data"}, + s3_object_download_filename="test-file.json" + ) + + # Test 1: Virtual-hosted-style with custom endpoint + s3_logger_virtual = S3Logger( + s3_bucket_name="test-bucket", + s3_endpoint_url="https://s3.custom-endpoint.com", + s3_aws_access_key_id="test-key", + s3_aws_secret_access_key="test-secret", + s3_region_name="us-east-1", + s3_use_virtual_hosted_style=True + ) + + s3_logger_virtual.async_httpx_client = AsyncMock() + s3_logger_virtual.async_httpx_client.put.return_value = mock_response + + asyncio.run(s3_logger_virtual.async_upload_data_to_s3(test_element)) + + call_args = s3_logger_virtual.async_httpx_client.put.call_args + assert call_args is not None + url = call_args[0][0] + expected_url = "https://test-bucket.s3.custom-endpoint.com/2025-09-14/test-key.json" + assert url == expected_url, f"Expected virtual-hosted-style URL {expected_url}, got {url}" + + # Test 2: Path-style (default behavior with s3_use_virtual_hosted_style=False) + s3_logger_path = S3Logger( + s3_bucket_name="test-bucket", + s3_endpoint_url="https://s3.custom-endpoint.com", + s3_aws_access_key_id="test-key", + s3_aws_secret_access_key="test-secret", + s3_region_name="us-east-1", + s3_use_virtual_hosted_style=False + ) + + s3_logger_path.async_httpx_client = AsyncMock() + s3_logger_path.async_httpx_client.put.return_value = mock_response + + asyncio.run(s3_logger_path.async_upload_data_to_s3(test_element)) + + call_args_path = s3_logger_path.async_httpx_client.put.call_args + assert call_args_path is not None + url_path = call_args_path[0][0] + expected_path_url = "https://s3.custom-endpoint.com/test-bucket/2025-09-14/test-key.json" + assert url_path == expected_path_url, f"Expected path-style URL {expected_path_url}, got {url_path}" + + # Test 3: Virtual-hosted-style with http protocol + s3_logger_http = S3Logger( + s3_bucket_name="http-bucket", + s3_endpoint_url="http://minio.local:9000", + s3_aws_access_key_id="minio-key", + s3_aws_secret_access_key="minio-secret", + s3_region_name="us-east-1", + s3_use_virtual_hosted_style=True + ) + + s3_logger_http.async_httpx_client = AsyncMock() + s3_logger_http.async_httpx_client.put.return_value = mock_response + + asyncio.run(s3_logger_http.async_upload_data_to_s3(test_element)) + + call_args_http = s3_logger_http.async_httpx_client.put.call_args + assert call_args_http is not None + url_http = call_args_http[0][0] + expected_http_url = "http://http-bucket.minio.local:9000/2025-09-14/test-key.json" + assert url_http == expected_http_url, f"Expected virtual-hosted-style URL with http {expected_http_url}, got {url_http}" + + # Test 4: Sync upload method with virtual-hosted-style + s3_logger_sync_virtual = S3Logger( + s3_bucket_name="sync-bucket", + s3_endpoint_url="https://storage.example.com", + s3_aws_access_key_id="sync-key", + s3_aws_secret_access_key="sync-secret", + s3_region_name="us-east-1", + s3_use_virtual_hosted_style=True + ) + + mock_sync_client = MagicMock() + mock_sync_client.put.return_value = mock_response + + with patch('litellm.integrations.s3_v2._get_httpx_client', return_value=mock_sync_client): + s3_logger_sync_virtual.upload_data_to_s3(test_element) + + call_args_sync = mock_sync_client.put.call_args + assert call_args_sync is not None + url_sync = call_args_sync[0][0] + expected_sync_url = "https://sync-bucket.storage.example.com/2025-09-14/test-key.json" + assert url_sync == expected_sync_url, f"Expected virtual-hosted-style sync URL {expected_sync_url}, got {url_sync}" + + # Test 5: Download method with virtual-hosted-style + s3_logger_download_virtual = S3Logger( + s3_bucket_name="download-bucket", + s3_endpoint_url="https://download.endpoint.com", + s3_aws_access_key_id="download-key", + s3_aws_secret_access_key="download-secret", + s3_region_name="us-east-1", + s3_use_virtual_hosted_style=True + ) + + mock_download_response = MagicMock() + mock_download_response.status_code = 200 + mock_download_response.json = MagicMock(return_value={"downloaded": "data"}) + s3_logger_download_virtual.async_httpx_client = AsyncMock() + s3_logger_download_virtual.async_httpx_client.get.return_value = mock_download_response + + result = asyncio.run(s3_logger_download_virtual._download_object_from_s3("2025-09-14/download-test-key.json")) + + call_args_download = s3_logger_download_virtual.async_httpx_client.get.call_args + assert call_args_download is not None + url_download = call_args_download[0][0] + expected_download_url = "https://download-bucket.download.endpoint.com/2025-09-14/download-test-key.json" + assert url_download == expected_download_url, f"Expected virtual-hosted-style download URL {expected_download_url}, got {url_download}" + + assert result == {"downloaded": "data"} + +@pytest.mark.asyncio +async def test_async_log_event_skips_when_standard_logging_object_missing(): + """ + Reproduces the bug where _async_log_event_base raises ValueError when + kwargs has no standard_logging_object (e.g. call_type=afile_delete). + + The S3 logger should skip gracefully, not raise. + """ + logger = S3Logger( + s3_bucket_name="test-bucket", + s3_region_name="us-east-1", + s3_aws_access_key_id="fake", + s3_aws_secret_access_key="fake", + ) + + kwargs_without_slo = { + "call_type": "afile_delete", + "model": None, + "litellm_call_id": "test-call-id", + } + + start_time = datetime.utcnow() + end_time = datetime.utcnow() + + # Spy on handle_callback_failure — should NOT be called if we skip gracefully. + # Without the fix, the ValueError is caught by the except block which calls + # handle_callback_failure. With the fix, we return early and never hit except. + with patch.object(logger, "handle_callback_failure") as mock_failure: + await logger._async_log_event_base( + kwargs=kwargs_without_slo, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + assert not mock_failure.called, ( + "handle_callback_failure should not be called — " + "missing standard_logging_object should be a graceful skip, not an error" + ) + + # Nothing should have been queued (catches the case where code falls + # through without returning and appends None to the queue) + assert len(logger.log_queue) == 0, "log_queue should be empty when standard_logging_object is missing" + + @pytest.mark.asyncio async def test_strip_base64_removes_file_and_nontext_entries(): logger = S3Logger(s3_strip_base64_files=True) diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index 57e0dd494e0..50e948c1a27 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -1638,6 +1638,41 @@ def test_effort_with_claude_opus_45(): assert result["model"] == "claude-opus-4-5-20251101" +def test_effort_validation_with_opus_46(): + """Test that all four effort levels are accepted for Claude Opus 4.6.""" + config = AnthropicConfig() + + messages = [{"role": "user", "content": "Test"}] + + for effort in ["high", "medium", "low", "max"]: + optional_params = {"output_config": {"effort": effort}} + result = config.transform_request( + model="claude-opus-4-6-20260205", + messages=messages, + optional_params=optional_params, + litellm_params={}, + headers={} + ) + assert result["output_config"]["effort"] == effort + + +def test_max_effort_rejected_for_opus_45(): + """Test that effort='max' is rejected when using Claude Opus 4.5.""" + config = AnthropicConfig() + + messages = [{"role": "user", "content": "Test"}] + + with pytest.raises(ValueError, match="effort='max' is only supported by Claude Opus 4.6"): + optional_params = {"output_config": {"effort": "max"}} + config.transform_request( + model="claude-opus-4-5-20251101", + messages=messages, + optional_params=optional_params, + litellm_params={}, + headers={} + ) + + def test_effort_with_other_features(): """Test effort works alongside other features (thinking, tools).""" config = AnthropicConfig() diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py index b228a51447b..f9e5c6d0252 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py @@ -1706,3 +1706,108 @@ def test_translate_openai_response_restores_tool_names(): assert len(tool_use_blocks) == 1 # Name should be restored to original assert tool_use_blocks[0]["name"] == original_name + + +def test_translate_openai_response_to_anthropic_input_tokens_excludes_cached_tokens(): + """ + Regression test: input_tokens in Anthropic format should NOT include cached tokens. + + Issue: v1/messages API was returning incorrect input_token count when using prompt caching. + The OpenAI format includes cached tokens in prompt_tokens, but Anthropic format should not. + + According to Anthropic's spec: + - input_tokens = uncached input tokens only + - cache_read_input_tokens = tokens read from cache + + In OpenAI format: + - prompt_tokens = all input tokens (including cached) + - prompt_tokens_details.cached_tokens = cached tokens + + Expected: anthropic.input_tokens = openai.prompt_tokens - openai.prompt_tokens_details.cached_tokens + """ + from litellm.types.utils import PromptTokensDetailsWrapper + + # Create OpenAI format response with cached tokens + # Scenario: 100 total prompt tokens, 30 of which are cached + usage = Usage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=PromptTokensDetailsWrapper( + cached_tokens=30 + ), + cache_read_input_tokens=30, # Anthropic format cache info + ) + + response = ModelResponse( + id="test-id", + choices=[ + Choices( + index=0, + finish_reason="stop", + message=Message( + role="assistant", + content="Test response", + ), + ) + ], + model="claude-3-sonnet-20240229", + usage=usage, + ) + + # Convert to Anthropic format + adapter = LiteLLMAnthropicMessagesAdapter() + anthropic_response = adapter.translate_openai_response_to_anthropic( + response=response, + tool_name_mapping=None, + ) + + # Validate: input_tokens should be 70 (100 - 30 cached), not 100 + assert anthropic_response["usage"]["input_tokens"] == 70, ( + f"Expected input_tokens=70 (100 total - 30 cached), " + f"but got {anthropic_response['usage']['input_tokens']}. " + f"input_tokens should NOT include cached tokens per Anthropic spec." + ) + assert anthropic_response["usage"]["output_tokens"] == 50 + assert anthropic_response["usage"]["cache_read_input_tokens"] == 30 + + +def test_translate_openai_response_to_anthropic_input_tokens_no_cache(): + """ + Regression test: input_tokens should equal prompt_tokens when there are no cached tokens. + """ + from litellm.types.utils import PromptTokensDetailsWrapper + + # Create OpenAI format response without cached tokens + usage = Usage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ) + + response = ModelResponse( + id="test-id", + choices=[ + Choices( + index=0, + finish_reason="stop", + message=Message( + role="assistant", + content="Test response", + ), + ) + ], + model="claude-3-sonnet-20240229", + usage=usage, + ) + + # Convert to Anthropic format + adapter = LiteLLMAnthropicMessagesAdapter() + anthropic_response = adapter.translate_openai_response_to_anthropic( + response=response, + tool_name_mapping=None, + ) + + # Validate: input_tokens should equal prompt_tokens when no caching + assert anthropic_response["usage"]["input_tokens"] == 100 + assert anthropic_response["usage"]["output_tokens"] == 50 diff --git a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py index ce43f22d8f8..ddbb0454cac 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py @@ -2934,3 +2934,47 @@ def test_drop_thinking_param_when_thinking_blocks_missing(): finally: # Restore original modify_params setting litellm.modify_params = original_modify_params + + +class TestBedrockMinThinkingBudgetTokens: + """Test that thinking.budget_tokens is clamped to the Bedrock minimum (1024).""" + + def _map_params( + self, thinking_value, model="anthropic.claude-3-7-sonnet-20250219-v1:0" + ): + """Helper to call map_openai_params with the given thinking value.""" + config = AmazonConverseConfig() + non_default_params = {"thinking": thinking_value} + optional_params = {"thinking": thinking_value} + return config.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=False, + ) + + def test_budget_tokens_below_minimum_is_clamped(self): + """budget_tokens < 1024 should be clamped to 1024.""" + result = self._map_params({"type": "enabled", "budget_tokens": 499}) + assert result["thinking"]["budget_tokens"] == 1024 + + def test_budget_tokens_at_minimum_is_unchanged(self): + """budget_tokens == 1024 should remain 1024.""" + result = self._map_params({"type": "enabled", "budget_tokens": 1024}) + assert result["thinking"]["budget_tokens"] == 1024 + + def test_budget_tokens_above_minimum_is_unchanged(self): + """budget_tokens > 1024 should remain unchanged.""" + result = self._map_params({"type": "enabled", "budget_tokens": 2048}) + assert result["thinking"]["budget_tokens"] == 2048 + + def test_no_thinking_param_does_not_error(self): + """When thinking is not provided, map_openai_params should not raise.""" + config = AmazonConverseConfig() + result = config.map_openai_params( + non_default_params={}, + optional_params={}, + model="anthropic.claude-3-7-sonnet-20250219-v1:0", + drop_params=False, + ) + assert "thinking" not in result or result.get("thinking") is None diff --git a/tests/test_litellm/llms/chatgpt/responses/test_chatgpt_responses_transformation.py b/tests/test_litellm/llms/chatgpt/responses/test_chatgpt_responses_transformation.py index bec748d8dc8..03cea8785bc 100644 --- a/tests/test_litellm/llms/chatgpt/responses/test_chatgpt_responses_transformation.py +++ b/tests/test_litellm/llms/chatgpt/responses/test_chatgpt_responses_transformation.py @@ -88,6 +88,45 @@ def test_chatgpt_forces_streaming_and_reasoning_include(self): "You are Codex, based on GPT-5." ) + def test_chatgpt_drops_unsupported_responses_params(self): + config = ChatGPTResponsesAPIConfig() + request = config.transform_responses_api_request( + model="chatgpt/gpt-5.2-codex", + input="hi", + response_api_optional_request_params={ + # unsupported by ChatGPT Codex + "user": "user_123", + "temperature": 0.2, + "top_p": 0.9, + "context_management": [{"type": "compaction", "compact_threshold": 200000}], + "metadata": {"foo": "bar"}, + "max_output_tokens": 123, + "stream_options": {"include_usage": True}, + # supported and should be preserved + "truncation": "auto", + "previous_response_id": "resp_123", + "reasoning": {"effort": "medium"}, + "tools": [{"type": "function", "function": {"name": "hello"}}], + "tool_choice": {"type": "function", "function": {"name": "hello"}}, + }, + litellm_params=GenericLiteLLMParams(), + headers={}, + ) + + assert "user" not in request + assert "temperature" not in request + assert "top_p" not in request + assert "context_management" not in request + assert "metadata" not in request + assert "max_output_tokens" not in request + assert "stream_options" not in request + + assert request["truncation"] == "auto" + assert request["previous_response_id"] == "resp_123" + assert request["reasoning"] == {"effort": "medium"} + assert request["tools"] == [{"type": "function", "function": {"name": "hello"}}] + assert request["tool_choice"] == {"type": "function", "function": {"name": "hello"}} + def test_chatgpt_non_stream_sse_response_parsing(self): config = ChatGPTResponsesAPIConfig() response_payload = { diff --git a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py index 002fa81b9b5..6e2e60ba0dd 100644 --- a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py +++ b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py @@ -12,10 +12,42 @@ from litellm.llms.custom_httpx.aiohttp_transport import ( AiohttpResponseStream, + AiohttpTransport, LiteLLMAiohttpTransport, ) +@pytest.mark.asyncio +async def test_aclose_does_not_close_shared_session(): + """Test that aclose() does not close a session it does not own (shared session).""" + session = aiohttp.ClientSession() + try: + transport = LiteLLMAiohttpTransport(client=session, owns_session=False) + await transport.aclose() + assert not session.closed, "Shared session should not be closed by transport" + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_aclose_closes_owned_session(): + """Test that aclose() closes a session it owns.""" + session = aiohttp.ClientSession() + transport = LiteLLMAiohttpTransport(client=session, owns_session=True) + await transport.aclose() + assert session.closed, "Owned session should be closed by transport" + + +@pytest.mark.asyncio +async def test_owns_session_defaults_to_true(): + """Test that owns_session defaults to True for backwards compatibility.""" + session = aiohttp.ClientSession() + transport = AiohttpTransport(client=session) + assert transport._owns_session is True + await transport.aclose() + assert session.closed + + class MockAiohttpResponse: """Mock aiohttp ClientResponse for testing""" diff --git a/tests/test_litellm/llms/openai_like/test_json_providers.py b/tests/test_litellm/llms/openai_like/test_json_providers.py index 5efd3c4cd6d..81c7eccd353 100644 --- a/tests/test_litellm/llms/openai_like/test_json_providers.py +++ b/tests/test_litellm/llms/openai_like/test_json_providers.py @@ -97,6 +97,47 @@ def test_supported_params(self): assert isinstance(supported, list) assert len(supported) > 0 + def test_tool_params_excluded_when_function_calling_not_supported(self): + """Test that tool-related params are excluded for models that don't support + function calling. Regression test for https://github.com/BerriAI/litellm/issues/21125""" + from litellm.llms.openai_like.dynamic_config import create_config_class + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + provider = JSONProviderRegistry.get("publicai") + config_class = create_config_class(provider) + config = config_class() + + # Mock supports_function_calling to return False + with patch("litellm.utils.supports_function_calling", return_value=False): + supported = config.get_supported_openai_params("some-model-without-fc") + + tool_params = ["tools", "tool_choice", "function_call", "functions", "parallel_tool_calls"] + for param in tool_params: + assert param not in supported, ( + f"'{param}' should not be in supported params when function calling is not supported" + ) + + # Non-tool params should still be present + assert "temperature" in supported + assert "max_tokens" in supported + assert "stop" in supported + + def test_tool_params_included_when_function_calling_supported(self): + """Test that tool-related params are included for models that support function calling.""" + from litellm.llms.openai_like.dynamic_config import create_config_class + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + provider = JSONProviderRegistry.get("publicai") + config_class = create_config_class(provider) + config = config_class() + + # Mock supports_function_calling to return True + with patch("litellm.utils.supports_function_calling", return_value=True): + supported = config.get_supported_openai_params("some-model-with-fc") + + assert "tools" in supported + assert "tool_choice" in supported + def test_provider_resolution(self): """Test that provider resolution finds JSON providers""" from litellm.litellm_core_utils.get_llm_provider_logic import ( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 07c3dfcc763..0513650e1ff 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -16,6 +16,30 @@ from litellm.types.mcp_server.mcp_server_manager import MCPServer +@pytest.fixture(autouse=True) +def cleanup_mcp_global_state(): + """Clean up MCP global state before and after each test. + + This fixture ensures test isolation when running with pytest-xdist + parallel execution. Without this, global_mcp_server_manager state + can leak between tests causing mock assertion failures. + """ + try: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + # Clear before test + global_mcp_server_manager.registry.clear() + global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear() + yield + # Clear after test + global_mcp_server_manager.registry.clear() + global_mcp_server_manager.tool_name_to_mcp_server_name_mapping.clear() + except ImportError: + # MCP not available, skip cleanup + yield + + @pytest.mark.asyncio async def test_mcp_server_tool_call_body_contains_request_data(): """Test that proxy_server_request body contains name and arguments""" @@ -756,6 +780,31 @@ async def init_task(): @pytest.mark.asyncio +async def test_streamable_http_session_manager_is_stateless(): + """ + Test that the StreamableHTTPSessionManager is initialized with stateless=True. + + Regression test for GitHub issue #20242 / PR #19809. + When stateless=False, the mcp library rejects non-initialize requests + that lack an mcp-session-id header, breaking clients like MCP Inspector, + curl, and any HTTP client without automatic session management. + """ + try: + from litellm.proxy._experimental.mcp_server.server import session_manager + except ImportError: + pytest.skip("MCP server not available") + + # The session manager must be stateless to avoid requiring mcp-session-id + # on every request. This was regressed by PR #19809 (stateless=True -> False). + assert session_manager.stateless is True, ( + "StreamableHTTPSessionManager must be initialized with stateless=True. " + "stateless=False breaks MCP clients that don't manage session IDs. " + "See: https://github.com/BerriAI/litellm/issues/20242" + ) + + +@pytest.mark.asyncio +@pytest.mark.no_parallel async def test_mcp_routing_with_conflicting_alias_and_group_name(): """ Tests (GH #14536) where an MCP server alias (e.g., "group/id") @@ -839,6 +888,7 @@ async def test_mcp_routing_with_conflicting_alias_and_group_name(): @pytest.mark.asyncio +@pytest.mark.no_parallel async def test_oauth2_headers_passed_to_mcp_client(): """Test that OAuth2 headers are properly passed through to the MCP client for OAuth2 servers like github_mcp""" try: diff --git a/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py b/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py index 6b3b4c92416..24828cdff36 100644 --- a/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py +++ b/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py @@ -4,7 +4,7 @@ import os import sys from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import pytest @@ -24,7 +24,7 @@ class TestKeyRotationManager: async def test_should_rotate_key_logic(self): """ Test the core logic for determining when a key should be rotated. - + This tests: - Keys with null key_rotation_at should rotate immediately - Keys with future key_rotation_at should not rotate @@ -33,69 +33,69 @@ async def test_should_rotate_key_logic(self): # Setup mock_prisma_client = AsyncMock() manager = KeyRotationManager(mock_prisma_client) - + now = datetime.now(timezone.utc) - + # Test Case 1: No rotation time set (key_rotation_at = None) - should rotate key_no_rotation_time = LiteLLM_VerificationToken( token="test-token-1", auto_rotate=True, rotation_interval="30s", key_rotation_at=None, - rotation_count=0 + rotation_count=0, ) - - assert manager._should_rotate_key(key_no_rotation_time, now) == True - + + assert manager._should_rotate_key(key_no_rotation_time, now) is True + # Test Case 2: Future rotation time - should NOT rotate key_future_rotation = LiteLLM_VerificationToken( token="test-token-2", auto_rotate=True, rotation_interval="30s", key_rotation_at=now + timedelta(seconds=10), - rotation_count=1 + rotation_count=1, ) - - assert manager._should_rotate_key(key_future_rotation, now) == False - + + assert manager._should_rotate_key(key_future_rotation, now) is False + # Test Case 3: Past rotation time - should rotate key_past_rotation = LiteLLM_VerificationToken( token="test-token-3", auto_rotate=True, rotation_interval="30s", key_rotation_at=now - timedelta(seconds=10), - rotation_count=2 + rotation_count=2, ) - - assert manager._should_rotate_key(key_past_rotation, now) == True - + + assert manager._should_rotate_key(key_past_rotation, now) is True + # Test Case 4: Exact rotation time - should rotate key_exact_rotation = LiteLLM_VerificationToken( token="test-token-4", auto_rotate=True, rotation_interval="30s", key_rotation_at=now, - rotation_count=1 + rotation_count=1, ) - - assert manager._should_rotate_key(key_exact_rotation, now) == True - + + assert manager._should_rotate_key(key_exact_rotation, now) is True + # Test Case 5: No rotation interval - should NOT rotate key_no_interval = LiteLLM_VerificationToken( token="test-token-5", auto_rotate=True, rotation_interval=None, key_rotation_at=None, - rotation_count=0 + rotation_count=0, ) - - assert manager._should_rotate_key(key_no_interval, now) == False + + assert manager._should_rotate_key(key_no_interval, now) is False @pytest.mark.asyncio async def test_find_keys_needing_rotation(self): """ Test finding keys that need rotation from database. - + This tests: - Only keys with auto_rotate=True are considered - Database query filters by key_rotation_at properly @@ -104,10 +104,10 @@ async def test_find_keys_needing_rotation(self): # Setup mock_prisma_client = AsyncMock() manager = KeyRotationManager(mock_prisma_client) - + # Use a fixed timestamp to avoid timing issues in tests now = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - + # Mock database response - these are the keys the database query would return mock_keys = [ LiteLLM_VerificationToken( @@ -115,42 +115,47 @@ async def test_find_keys_needing_rotation(self): auto_rotate=True, rotation_interval="30s", key_rotation_at=None, # Should rotate (null key_rotation_at) - rotation_count=0 + rotation_count=0, ), LiteLLM_VerificationToken( token="token-2", auto_rotate=True, rotation_interval="60s", - key_rotation_at=now - timedelta(seconds=10), # Should rotate (past time) - rotation_count=1 - ) + key_rotation_at=now + - timedelta(seconds=10), # Should rotate (past time) + rotation_count=1, + ), ] - - mock_prisma_client.db.litellm_verificationtoken.find_many.return_value = mock_keys - + + mock_prisma_client.db.litellm_verificationtoken.find_many.return_value = ( + mock_keys + ) + # Mock datetime.now to return our fixed timestamp from unittest.mock import patch - with patch('litellm.proxy.common_utils.key_rotation_manager.datetime') as mock_datetime: + + with patch( + "litellm.proxy.common_utils.key_rotation_manager.datetime" + ) as mock_datetime: mock_datetime.now.return_value = now - mock_datetime.side_effect = lambda *args, **kwargs: datetime(*args, **kwargs) - + mock_datetime.side_effect = lambda *args, **kwargs: datetime( + *args, **kwargs + ) + # Execute keys_needing_rotation = await manager._find_keys_needing_rotation() - + # Verify database query - should use OR condition for key_rotation_at mock_prisma_client.db.litellm_verificationtoken.find_many.assert_called_once_with( where={ "auto_rotate": True, - "OR": [ - {"key_rotation_at": None}, - {"key_rotation_at": {"lte": now}} - ] + "OR": [{"key_rotation_at": None}, {"key_rotation_at": {"lte": now}}], } ) - + # Verify all keys returned by database query are included (no additional filtering) assert len(keys_needing_rotation) == 2 - + tokens_needing_rotation = [key.token for key in keys_needing_rotation] assert "token-1" in tokens_needing_rotation # Null key_rotation_at assert "token-2" in tokens_needing_rotation # Past key_rotation_at @@ -159,7 +164,7 @@ async def test_find_keys_needing_rotation(self): async def test_rotate_key_updates_database(self): """ Test that key rotation properly updates the database with new rotation info. - + This tests: - Rotation count is incremented - last_rotation_at is set to current time @@ -169,7 +174,7 @@ async def test_rotate_key_updates_database(self): # Setup mock_prisma_client = AsyncMock() manager = KeyRotationManager(mock_prisma_client) - + # Mock key to rotate key_to_rotate = LiteLLM_VerificationToken( token="old-token", @@ -177,31 +182,35 @@ async def test_rotate_key_updates_database(self): rotation_interval="30s", last_rotation_at=None, key_rotation_at=None, - rotation_count=0 + rotation_count=0, ) - + # Mock regenerate_key_fn response mock_response = GenerateKeyResponse( - key="new-api-key", - token_id="new-token-id", - user_id="test-user" + key="new-api-key", token_id="new-token-id", user_id="test-user" ) - + # Mock the regenerate function from unittest.mock import patch - with patch('litellm.proxy.common_utils.key_rotation_manager.regenerate_key_fn', return_value=mock_response): - with patch('litellm.proxy.common_utils.key_rotation_manager.KeyManagementEventHooks.async_key_rotated_hook'): + + with patch( + "litellm.proxy.common_utils.key_rotation_manager.regenerate_key_fn", + return_value=mock_response, + ): + with patch( + "litellm.proxy.common_utils.key_rotation_manager.KeyManagementEventHooks.async_key_rotated_hook" + ): # Execute await manager._rotate_key(key_to_rotate) - + # Verify database update was called with correct data mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once() - + call_args = mock_prisma_client.db.litellm_verificationtoken.update.call_args - + # Check the WHERE clause targets the new token assert call_args[1]["where"]["token"] == "new-token-id" - + # Check the data being updated update_data = call_args[1]["data"] assert update_data["rotation_count"] == 1 # Incremented from 0 @@ -209,9 +218,75 @@ async def test_rotate_key_updates_database(self): assert isinstance(update_data["last_rotation_at"], datetime) assert "key_rotation_at" in update_data assert isinstance(update_data["key_rotation_at"], datetime) - + # Verify key_rotation_at is set to future time (30s from now) now = datetime.now(timezone.utc) next_rotation = update_data["key_rotation_at"] time_diff = (next_rotation - now).total_seconds() - assert 25 <= time_diff <= 35 # Should be around 30 seconds, allow some tolerance + assert ( + 25 <= time_diff <= 35 + ) # Should be around 30 seconds, allow some tolerance + + @pytest.mark.asyncio + async def test_cleanup_expired_deprecated_keys(self): + """ + Test that _cleanup_expired_deprecated_keys deletes expired deprecated keys. + """ + mock_prisma_client = AsyncMock() + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.return_value = ( + 3 + ) + manager = KeyRotationManager(mock_prisma_client) + + await manager._cleanup_expired_deprecated_keys() + + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.assert_called_once() + call_args = ( + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.call_args + ) + assert "revoke_at" in call_args[1]["where"] + assert call_args[1]["where"]["revoke_at"]["lt"] is not None + + @pytest.mark.asyncio + async def test_rotate_key_passes_grace_period(self): + """ + Test that _rotate_key passes grace_period in RegenerateKeyRequest. + """ + mock_prisma_client = AsyncMock() + manager = KeyRotationManager(mock_prisma_client) + + key_to_rotate = LiteLLM_VerificationToken( + token="old-token", + auto_rotate=True, + rotation_interval="30s", + key_rotation_at=None, + rotation_count=0, + ) + + mock_response = GenerateKeyResponse( + key="new-api-key", + token_id="new-token-id", + user_id="test-user", + ) + + from unittest.mock import patch + + with patch( + "litellm.proxy.common_utils.key_rotation_manager.regenerate_key_fn", + new_callable=AsyncMock, + ) as mock_regenerate: + mock_regenerate.return_value = mock_response + with patch( + "litellm.proxy.common_utils.key_rotation_manager.KeyManagementEventHooks.async_key_rotated_hook", + new_callable=AsyncMock, + ): + with patch( + "litellm.proxy.common_utils.key_rotation_manager.LITELLM_KEY_ROTATION_GRACE_PERIOD", + "48h", + ): + await manager._rotate_key(key_to_rotate) + + mock_regenerate.assert_called_once() + call_args = mock_regenerate.call_args + regenerate_request = call_args[1]["data"] + assert regenerate_request.grace_period == "48h" diff --git a/tests/test_litellm/proxy/db/test_db_spend_update_writer.py b/tests/test_litellm/proxy/db/test_db_spend_update_writer.py index 6ccecf59eed..1dd5cba2c4b 100644 --- a/tests/test_litellm/proxy/db/test_db_spend_update_writer.py +++ b/tests/test_litellm/proxy/db/test_db_spend_update_writer.py @@ -756,6 +756,45 @@ async def test_add_spend_log_transaction_to_daily_agent_transaction_injects_agen assert transaction["custom_llm_provider"] == "openai" +@pytest.mark.asyncio +async def test_add_spend_log_transaction_to_daily_agent_transaction_calls_common_helper_once(): + writer = DBSpendUpdateWriter() + mock_prisma = MagicMock() + mock_prisma.get_request_status = MagicMock(return_value="success") + + payload = { + "request_id": "req-common-helper", + "agent_id": "agent-abc", + "user": "test-user", + "startTime": "2024-01-01T12:00:00", + "api_key": "test-key", + "model": "gpt-4", + "custom_llm_provider": "openai", + "model_group": "gpt-4-group", + "prompt_tokens": 12, + "completion_tokens": 6, + "spend": 0.25, + "metadata": '{"usage_object": {}}', + } + + writer.daily_agent_spend_update_queue.add_update = AsyncMock() + original_common_helper = ( + writer._common_add_spend_log_transaction_to_daily_transaction + ) + writer._common_add_spend_log_transaction_to_daily_transaction = AsyncMock( + wraps=original_common_helper + ) + + await writer.add_spend_log_transaction_to_daily_agent_transaction( + payload=payload, + prisma_client=mock_prisma, + ) + + assert ( + writer._common_add_spend_log_transaction_to_daily_transaction.await_count == 1 + ) + + @pytest.mark.asyncio async def test_add_spend_log_transaction_to_daily_agent_transaction_skips_when_agent_id_missing(): """ @@ -960,4 +999,4 @@ async def test_update_daily_spend_re_raises_exception_after_logging(): entity_id_field="user_id", table_name="litellm_dailyuserspend", unique_constraint_name="user_id_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name_endpoint", - ) \ No newline at end of file + ) diff --git a/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py b/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py index cb6d90103f7..e8765cf78ca 100644 --- a/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py +++ b/tests/test_litellm/proxy/hooks/test_proxy_track_cost_callback.py @@ -126,3 +126,77 @@ async def test_async_post_call_failure_hook_non_llm_route(): # Assert that update_database was NOT called for non-LLM routes mock_update_database.assert_not_called() + + +@pytest.mark.asyncio +async def test_track_cost_callback_skips_when_no_standard_logging_object(): + """ + Reproduces the bug where _PROXY_track_cost_callback raises + 'Cost tracking failed for model=None' when kwargs has no + standard_logging_object (e.g. call_type=afile_delete). + + File operations have no model and no standard_logging_object. + The callback should skip gracefully instead of raising. + """ + logger = _ProxyDBLogger() + + kwargs = { + "call_type": "afile_delete", + "model": None, + "litellm_call_id": "test-call-id", + "litellm_params": {}, + "stream": False, + } + + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + ) as mock_proxy_logging: + mock_proxy_logging.failed_tracking_alert = AsyncMock() + mock_proxy_logging.db_spend_update_writer = MagicMock() + mock_proxy_logging.db_spend_update_writer.update_database = AsyncMock() + + await logger._PROXY_track_cost_callback( + kwargs=kwargs, + completion_response=None, + start_time=datetime.now(), + end_time=datetime.now(), + ) + + # update_database should NOT be called — nothing to track + mock_proxy_logging.db_spend_update_writer.update_database.assert_not_called() + + # failed_tracking_alert should NOT be called — this is not an error + mock_proxy_logging.failed_tracking_alert.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_value", [None, ""]) +async def test_track_cost_callback_skips_for_falsy_model_and_no_slo(model_value): + """ + Same bug as above but model can also be empty string (e.g. health check callbacks). + The guard should catch all falsy model values when sl_object is missing. + """ + logger = _ProxyDBLogger() + + kwargs = { + "call_type": "acompletion", + "model": model_value, + "litellm_params": {}, + "stream": False, + } + + with patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + ) as mock_proxy_logging: + mock_proxy_logging.failed_tracking_alert = AsyncMock() + mock_proxy_logging.db_spend_update_writer = MagicMock() + mock_proxy_logging.db_spend_update_writer.update_database = AsyncMock() + + await logger._PROXY_track_cost_callback( + kwargs=kwargs, + completion_response=None, + start_time=datetime.now(), + end_time=datetime.now(), + ) + + mock_proxy_logging.failed_tracking_alert.assert_not_called() diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index de2c940943b..2c526d340ef 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -5606,3 +5606,122 @@ async def test_validate_key_list_check_key_hash_not_found(): assert exc_info.value.code == "403" or exc_info.value.code == 403 assert "Key Hash not found" in exc_info.value.message + + +@pytest.mark.asyncio +@patch( + "litellm.proxy.management_endpoints.key_management_endpoints.rotate_mcp_server_credentials_master_key" +) +async def test_rotate_master_key_model_data_valid_for_prisma( + mock_rotate_mcp, +): + """ + Test that _rotate_master_key produces valid data for Prisma create_many(). + + Regression test for: master key rotation fails with Prisma validation error + because created_at/updated_at are None (non-nullable DateTime) and + litellm_params/model_info are JSON strings (create_many expects dicts). + """ + from unittest.mock import AsyncMock, MagicMock + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _rotate_master_key, + ) + + # Setup mock prisma client + mock_prisma_client = AsyncMock() + mock_prisma_client.db = MagicMock() + + # Mock model table — return one model + mock_model = MagicMock() + mock_model.model_id = "model-1" + mock_model.model_name = "test-model" + mock_model.litellm_params = '{"model": "openai/gpt-4", "api_key": "sk-encrypted-old"}' + mock_model.model_info = '{"id": "model-1"}' + mock_model.created_by = "admin" + mock_model.updated_by = "admin" + mock_prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock( + return_value=[mock_model] + ) + + # Mock transaction context manager + mock_tx = AsyncMock() + mock_tx.litellm_proxymodeltable = MagicMock() + mock_tx.litellm_proxymodeltable.delete_many = AsyncMock() + mock_tx.litellm_proxymodeltable.create_many = AsyncMock() + mock_prisma_client.db.tx = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_tx), + __aexit__=AsyncMock(return_value=False), + )) + + # Mock config table — no env vars + mock_prisma_client.db.litellm_config.find_many = AsyncMock(return_value=[]) + + # Mock credentials table — no credentials + mock_prisma_client.db.litellm_credentialstable.find_many = AsyncMock( + return_value=[] + ) + + # Mock MCP rotation + mock_rotate_mcp.return_value = None + + # Mock proxy_config + mock_proxy_config = MagicMock() + mock_proxy_config.decrypt_model_list_from_db.return_value = [ + { + "model_name": "test-model", + "litellm_params": { + "model": "openai/gpt-4", + "api_key": "sk-decrypted-key", + }, + "model_info": {"id": "model-1"}, + } + ] + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="test-user", + ) + + with patch( + "litellm.proxy.proxy_server.proxy_config", + mock_proxy_config, + ): + await _rotate_master_key( + prisma_client=mock_prisma_client, + user_api_key_dict=user_api_key_dict, + current_master_key="sk-old-master-key", + new_master_key="sk-new-master-key", + ) + + # Verify create_many was called + mock_tx.litellm_proxymodeltable.create_many.assert_called_once() + + # Get the data passed to create_many + call_args = mock_tx.litellm_proxymodeltable.create_many.call_args + created_models = call_args.kwargs.get("data") or call_args[1].get("data") + + assert len(created_models) == 1 + model_data = created_models[0] + + # Verify timestamps are NOT present (Prisma @default(now()) should apply) + assert "created_at" not in model_data, ( + "created_at should be excluded so Prisma @default(now()) applies" + ) + assert "updated_at" not in model_data, ( + "updated_at should be excluded so Prisma @default(now()) applies" + ) + + # Verify litellm_params and model_info are prisma.Json wrappers, NOT JSON strings + import prisma + + assert isinstance(model_data["litellm_params"], prisma.Json), ( + f"litellm_params should be prisma.Json for create_many(), got {type(model_data['litellm_params'])}" + ) + assert isinstance(model_data["model_info"], prisma.Json), ( + f"model_info should be prisma.Json for create_many(), got {type(model_data['model_info'])}" + ) + + # Verify delete_many was called inside the transaction (before create_many) + mock_tx.litellm_proxymodeltable.delete_many.assert_called_once() diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 74d36c0acac..09b78335054 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -2,12 +2,10 @@ import json import os import sys -from typing import Optional, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request -from fastapi.testclient import TestClient from litellm._uuid import uuid @@ -16,7 +14,7 @@ ) # Adds the parent directory to the system path import litellm -from litellm.proxy._types import LiteLLM_UserTable, NewTeamRequest, NewUserResponse +from litellm.proxy._types import LiteLLM_UserTable, NewUserResponse from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO from litellm.proxy.management_endpoints.types import CustomOpenID @@ -136,16 +134,32 @@ def test_microsoft_sso_handler_openid_from_response_with_custom_attributes(): expected_team_ids = ["team1"] # Act - with patch("litellm.constants.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field"), \ - patch("litellm.constants.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name"), \ - patch("litellm.constants.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field"), \ - patch("litellm.constants.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name"), \ - patch("litellm.constants.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name"), \ - patch("litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name"): + with patch( + "litellm.constants.MICROSOFT_USER_EMAIL_ATTRIBUTE", "custom_email_field" + ), patch( + "litellm.constants.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", "custom_display_name" + ), patch( + "litellm.constants.MICROSOFT_USER_ID_ATTRIBUTE", "custom_id_field" + ), patch( + "litellm.constants.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", "custom_first_name" + ), patch( + "litellm.constants.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", "custom_last_name" + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_EMAIL_ATTRIBUTE", + "custom_email_field", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE", + "custom_display_name", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_ID_ATTRIBUTE", + "custom_id_field", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_FIRST_NAME_ATTRIBUTE", + "custom_first_name", + ), patch( + "litellm.proxy.management_endpoints.ui_sso.MICROSOFT_USER_LAST_NAME_ATTRIBUTE", + "custom_last_name", + ): result = MicrosoftSSOHandler.openid_from_response( response=mock_response, team_ids=expected_team_ids, user_role=None ) @@ -231,7 +245,6 @@ def test_get_microsoft_callback_response_raw_sso_response(): ) # Assert - print("result from verify_and_process", result) assert isinstance(result, dict) assert result["mail"] == "microsoft_user@example.com" assert result["displayName"] == "Microsoft User" @@ -455,10 +468,6 @@ def mock_jsonify_team_object(db_data): # Assert # Verify team was created with correct parameters mock_prisma.db.litellm_teamtable.create.assert_called_once() - print( - "mock_prisma.db.litellm_teamtable.create.call_args", - mock_prisma.db.litellm_teamtable.create.call_args, - ) create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ "data" ] @@ -583,7 +592,7 @@ def test_apply_user_info_values_to_sso_user_defined_values_with_models(): def test_apply_user_info_values_sso_role_takes_precedence(): """ Test that SSO role takes precedence over DB role. - + When Microsoft SSO returns a user_role, it should be used instead of the role stored in the database. This ensures SSO is the authoritative source for user roles. """ @@ -678,16 +687,16 @@ def test_normalize_email(): """ # Test with lowercase email assert normalize_email("test@example.com") == "test@example.com" - + # Test with uppercase email assert normalize_email("TEST@EXAMPLE.COM") == "test@example.com" - + # Test with mixed case email assert normalize_email("Test.User@Example.COM") == "test.user@example.com" - + # Test with None assert normalize_email(None) is None - + # Test with empty string assert normalize_email("") == "" @@ -900,7 +909,7 @@ async def test_upsert_sso_user_no_role_in_sso_response(): def test_get_user_email_and_id_extracts_microsoft_role(): """ Test that _get_user_email_and_id_from_result extracts user_role from Microsoft SSO. - + This ensures Microsoft SSO roles (from app_roles in id_token) are properly extracted and converted from enum to string. """ @@ -966,7 +975,7 @@ async def test_get_user_info_from_db_user_exists(): with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_object" ) as mock_get_user_object: - user_info = await get_user_info_from_db(**args) + await get_user_info_from_db(**args) mock_get_user_object.assert_called_once() assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd" @@ -1008,7 +1017,7 @@ async def test_get_user_info_from_db_user_exists_alternate_user_id(): with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_object" ) as mock_get_user_object: - user_info = await get_user_info_from_db(**args) + await get_user_info_from_db(**args) mock_get_user_object.assert_called_once() assert mock_get_user_object.call_args.kwargs["user_id"] == "krrishd-email1234" @@ -1017,7 +1026,7 @@ async def test_get_user_info_from_db_user_exists_alternate_user_id(): async def test_get_user_info_from_db_user_not_exists_creates_user(): """ Test that get_user_info_from_db creates a new user when user doesn't exist in DB. - + When get_existing_user_info_from_db returns None, get_user_info_from_db should: 1. Call upsert_sso_user with user_info=None 2. upsert_sso_user should call insert_sso_user to create the user @@ -1105,7 +1114,7 @@ async def test_get_user_info_from_db_user_not_exists_creates_user(): async def test_get_user_info_from_db_user_exists_updates_user(): """ Test that get_user_info_from_db updates existing user when user exists in DB. - + When get_existing_user_info_from_db returns a user, get_user_info_from_db should: 1. Call upsert_sso_user with the existing user_info 2. upsert_sso_user should update the user in the database @@ -1197,6 +1206,7 @@ async def test_get_user_info_from_db_user_exists_updates_user(): # Should return the updated user assert user_info == updated_user + @pytest.mark.asyncio async def test_check_and_update_if_proxy_admin_id(): """ @@ -1305,10 +1315,10 @@ async def test_get_generic_sso_response_with_additional_headers(): mock_sso_class = MagicMock(return_value=mock_sso_instance) with patch.dict(os.environ, test_env_vars): - with patch("fastapi_sso.sso.base.DiscoveryDocument") as mock_discovery: + with patch("fastapi_sso.sso.base.DiscoveryDocument"): with patch( "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class - ) as mock_create_provider: + ): # Act result, received_response = await get_generic_sso_response( request=mock_request, @@ -1367,10 +1377,10 @@ async def test_get_generic_sso_response_with_empty_headers(): mock_sso_class = MagicMock(return_value=mock_sso_instance) with patch.dict(os.environ, test_env_vars): - with patch("fastapi_sso.sso.base.DiscoveryDocument") as mock_discovery: + with patch("fastapi_sso.sso.base.DiscoveryDocument"): with patch( "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class - ) as mock_create_provider: + ): # Act result, received_response = await get_generic_sso_response( request=mock_request, @@ -1755,8 +1765,6 @@ def test_enterprise_import_error_handling(self): """Test that proper error is raised when enterprise module is not available""" from unittest.mock import MagicMock, patch - from litellm.proxy.management_endpoints.ui_sso import google_login - # Mock request mock_request = MagicMock() mock_request.base_url = "https://test.example.com/" @@ -1778,7 +1786,7 @@ async def mock_google_login(): # This mimics the relevant part of google_login that would trigger the import error try: from enterprise.litellm_enterprise.proxy.auth.custom_sso_handler import ( - EnterpriseCustomSSOHandler, + EnterpriseCustomSSOHandler, # noqa: F401 ) return "success" @@ -1982,59 +1990,56 @@ async def test_cli_sso_callback_stores_session(self): # Test data session_key = "sk-session-456" - + # Mock user info mock_user_info = LiteLLM_UserTable( user_id="test-user-123", user_role="internal_user", teams=["team1", "team2"], - models=["gpt-4"] + models=["gpt-4"], ) # Mock SSO result - mock_sso_result = { - "user_email": "test@example.com", - "user_id": "test-user-123" - } + mock_sso_result = {"user_email": "test@example.com", "user_id": "test-user-123"} # Mock cache mock_cache = MagicMock() - + with patch( "litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db", - return_value=mock_user_info - ), patch( - "litellm.proxy.proxy_server.prisma_client", MagicMock() - ), patch( + return_value=mock_user_info, + ), patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), patch( "litellm.proxy.proxy_server.user_api_key_cache", mock_cache ), patch( "litellm.proxy.common_utils.html_forms.cli_sso_success.render_cli_sso_success_page", return_value="Success", ): - # Act result = await cli_sso_callback( - request=mock_request, key=session_key, existing_key=None, result=mock_sso_result + request=mock_request, + key=session_key, + existing_key=None, + result=mock_sso_result, ) # Assert - verify session was stored in cache mock_cache.set_cache.assert_called_once() call_args = mock_cache.set_cache.call_args - + # Verify cache key format assert "cli_sso_session:" in call_args.kwargs["key"] assert session_key in call_args.kwargs["key"] - + # Verify session data structure session_data = call_args.kwargs["value"] assert session_data["user_id"] == "test-user-123" assert session_data["user_role"] == "internal_user" assert session_data["teams"] == ["team1", "team2"] assert session_data["models"] == ["gpt-4"] - + # Verify TTL assert call_args.kwargs["ttl"] == 600 # 10 minutes - + assert result.status_code == 200 # Verify response contains success message (response is HTML) assert result.body is not None @@ -2050,17 +2055,14 @@ async def test_cli_poll_key_returns_teams_for_selection(self): "user_id": "test-user-456", "user_role": "internal_user", "teams": ["team-a", "team-b", "team-c"], - "models": ["gpt-4"] + "models": ["gpt-4"], } # Mock cache mock_cache = MagicMock() mock_cache.get_cache.return_value = session_data - - with patch( - "litellm.proxy.proxy_server.user_api_key_cache", mock_cache - ): + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act - First poll without team_id result = await cli_poll_key(key_id=session_key, team_id=None) @@ -2070,7 +2072,7 @@ async def test_cli_poll_key_returns_teams_for_selection(self): assert result["user_id"] == "test-user-456" assert result["teams"] == ["team-a", "team-b", "team-c"] assert "key" not in result # JWT should not be generated yet - + # Verify session was NOT deleted mock_cache.delete_cache.assert_not_called() @@ -2174,34 +2176,33 @@ async def test_cli_poll_key_generates_jwt_with_team(self): "user_role": "internal_user", "teams": ["team-a", "team-b", "team-c"], "models": ["gpt-4"], - "user_email": "test@example.com" + "user_email": "test@example.com", } - + # Mock user info mock_user_info = LiteLLM_UserTable( user_id="test-user-789", user_role="internal_user", teams=["team-a", "team-b", "team-c"], - models=["gpt-4"] + models=["gpt-4"], ) # Mock cache mock_cache = MagicMock() mock_cache.get_cache.return_value = session_data - + mock_jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test.token" - - with patch( - "litellm.proxy.proxy_server.user_api_key_cache", mock_cache - ), patch( + + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache), patch( "litellm.proxy.proxy_server.prisma_client" ) as mock_prisma, patch( "litellm.proxy.auth.auth_checks.ExperimentalUIJWTToken.get_cli_jwt_auth_token", - return_value=mock_jwt_token + return_value=mock_jwt_token, ) as mock_get_jwt: - # Mock the user lookup - mock_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=mock_user_info) + mock_prisma.db.litellm_usertable.find_unique = AsyncMock( + return_value=mock_user_info + ) # Act - Second poll with team_id result = await cli_poll_key(key_id=session_key, team_id=selected_team) @@ -2212,12 +2213,12 @@ async def test_cli_poll_key_generates_jwt_with_team(self): assert result["user_id"] == "test-user-789" assert result["team_id"] == selected_team assert result["teams"] == ["team-a", "team-b", "team-c"] - + # Verify JWT was generated with correct team mock_get_jwt.assert_called_once() jwt_call_args = mock_get_jwt.call_args assert jwt_call_args.kwargs["team_id"] == selected_team - + # Verify session was deleted after JWT generation mock_cache.delete_cache.assert_called_once() @@ -2227,7 +2228,6 @@ class TestGetAppRolesFromIdToken: def test_roles_picked_when_app_roles_not_exists(self): """Test that 'roles' is picked when 'app_roles' doesn't exist""" - import jwt # Create a token with only 'roles' claim token_payload = { @@ -2251,7 +2251,6 @@ def test_roles_picked_when_app_roles_not_exists(self): def test_app_roles_picked_when_both_exist(self): """Test that 'app_roles' takes precedence when both 'app_roles' and 'roles' exist""" - import jwt # Create a token with both 'app_roles' and 'roles' claims token_payload = { @@ -2272,7 +2271,6 @@ def test_app_roles_picked_when_both_exist(self): def test_roles_picked_when_app_roles_is_empty(self): """Test that 'roles' is picked when 'app_roles' exists but is empty""" - import jwt # Create a token with empty 'app_roles' and populated 'roles' token_payload = { @@ -2293,7 +2291,6 @@ def test_roles_picked_when_app_roles_is_empty(self): def test_empty_list_when_neither_exists(self): """Test that empty list is returned when neither 'app_roles' nor 'roles' exist""" - import jwt # Create a token without roles claims token_payload = {"sub": "user123", "email": "test@example.com"} @@ -2317,7 +2314,6 @@ def test_empty_list_when_no_token_provided(self): def test_empty_list_when_roles_not_a_list(self): """Test that empty list is returned when roles is not a list""" - import jwt # Create a token with non-list roles token_payload = { @@ -2337,7 +2333,6 @@ def test_empty_list_when_roles_not_a_list(self): def test_error_handling_on_jwt_decode_exception(self): """Test that exceptions during JWT decode are handled gracefully""" - import jwt mock_token = "invalid.jwt.token" @@ -2788,12 +2783,6 @@ def test_generic_response_convertor_with_nested_attributes(self): # to handle dotted paths like "attributes.userId" # Current behavior: returns None for nested paths - print(f"User ID result: {result.id}") - print(f"Email result: {result.email}") - print(f"First name result: {result.first_name}") - print(f"Last name result: {result.last_name}") - print(f"Display name result: {result.display_name}") - # Expected behavior with current implementation (no nested path support): assert result.id == "nested-user-456" assert ( @@ -2883,14 +2872,15 @@ def test_state_priority_cli_state_provided(self): # Arrange cli_state = "litellm-session-token:sk-test123" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": "env_state_value"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=cli_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=cli_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2905,14 +2895,15 @@ def test_state_priority_env_variable_when_no_cli_state(self): # Arrange env_state = "custom_env_state_value" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": env_state}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=None, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=None, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2929,13 +2920,14 @@ def test_state_priority_generated_uuid_fallback(self): with patch.dict(os.environ, {}, clear=False): # Remove GENERIC_CLIENT_STATE if it exists os.environ.pop("GENERIC_CLIENT_STATE", None) - + # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=None, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=None, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -2955,26 +2947,27 @@ def test_state_with_pkce_enabled(self): # Arrange test_state = "test_state_123" - + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=test_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=test_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert state assert redirect_params["state"] == test_state - + # Assert PKCE parameters assert code_verifier is not None assert len(code_verifier) == 43 # Standard PKCE verifier length assert "code_challenge" in redirect_params assert "code_challenge_method" in redirect_params assert redirect_params["code_challenge_method"] == "S256" - + # Verify code_challenge is correctly derived from code_verifier expected_challenge_bytes = hashlib.sha256( code_verifier.encode("utf-8") @@ -2994,14 +2987,15 @@ def test_state_with_pkce_disabled(self): # Arrange test_state = "test_state_456" - + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "false"}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=test_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=test_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert @@ -3019,7 +3013,7 @@ def test_state_priority_cli_state_overrides_env_with_pkce(self): # Arrange cli_state = "cli_state_priority" env_state = "env_state_should_not_be_used" - + with patch.dict( os.environ, { @@ -3028,17 +3022,18 @@ def test_state_priority_cli_state_overrides_env_with_pkce(self): }, ): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state=cli_state, - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state=cli_state, + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert assert redirect_params["state"] == cli_state # CLI state takes priority assert redirect_params["state"] != env_state - + # PKCE should still be generated assert code_verifier is not None assert "code_challenge" in redirect_params @@ -3052,14 +3047,15 @@ def test_empty_string_state_uses_env_variable(self): # Arrange env_state = "env_state_for_empty_cli" - + with patch.dict(os.environ, {"GENERIC_CLIENT_STATE": env_state}): # Act - redirect_params, code_verifier = ( - SSOAuthenticationHandler._get_generic_sso_redirect_params( - state="", # Empty string - generic_authorization_endpoint="https://auth.example.com/authorize", - ) + ( + redirect_params, + code_verifier, + ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( + state="", # Empty string + generic_authorization_endpoint="https://auth.example.com/authorize", ) # Assert - empty string is falsy, so env variable should be used @@ -3076,7 +3072,7 @@ def test_multiple_calls_generate_different_uuids(self): # Arrange - no state provided with patch.dict(os.environ, {}, clear=False): os.environ.pop("GENERIC_CLIENT_STATE", None) - + # Act params1, _ = SSOAuthenticationHandler._get_generic_sso_redirect_params( state=None, @@ -3139,15 +3135,18 @@ async def test_prepare_token_exchange_parameters_with_pkce(self): test_state = "test_oauth_state_123" mock_request.query_params = {"state": test_state} - # Mock cache + # Mock cache with async methods mock_cache = MagicMock() test_code_verifier = "test_code_verifier_abc123xyz" - mock_cache.get_cache.return_value = test_code_verifier + mock_cache.async_get_cache = AsyncMock(return_value=test_code_verifier) + mock_cache.async_delete_cache = AsyncMock() - with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): + with patch("litellm.proxy.proxy_server.redis_usage_cache", None), patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act - token_params = SSOAuthenticationHandler.prepare_token_exchange_parameters( - request=mock_request, generic_include_client_id=False + token_params = ( + await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) ) # Assert @@ -3155,10 +3154,10 @@ async def test_prepare_token_exchange_parameters_with_pkce(self): assert token_params["code_verifier"] == test_code_verifier # Verify cache was accessed and deleted - mock_cache.get_cache.assert_called_once_with( + mock_cache.async_get_cache.assert_called_once_with( key=f"pkce_verifier:{test_state}" ) - mock_cache.delete_cache.assert_called_once_with( + mock_cache.async_delete_cache.assert_called_once_with( key=f"pkce_verifier:{test_state}" ) @@ -3183,6 +3182,8 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): test_state = "test456" mock_cache = MagicMock() + mock_cache.async_set_cache = AsyncMock() + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): # Act @@ -3193,9 +3194,9 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): ) # Assert - # Verify cache was called to store code_verifier - mock_cache.set_cache.assert_called_once() - cache_call = mock_cache.set_cache.call_args + # Verify async cache was called to store code_verifier + mock_cache.async_set_cache.assert_called_once() + cache_call = mock_cache.async_set_cache.call_args assert cache_call.kwargs["key"] == f"pkce_verifier:{test_state}" assert cache_call.kwargs["ttl"] == 600 assert len(cache_call.kwargs["value"]) == 43 @@ -3207,6 +3208,178 @@ async def test_get_generic_sso_redirect_response_with_pkce(self): assert "code_challenge_method=S256" in updated_location assert f"state={test_state}" in updated_location + @pytest.mark.asyncio + async def test_pkce_redis_multi_pod_verifier_roundtrip(self): + """ + Mock Redis to verify PKCE code_verifier round-trip across "pods": + Pod A stores verifier in Redis; Pod B retrieves it (no real IdP). + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + # In-memory mock of Redis (shared between "pods") + class MockRedisCache: + def __init__(self): + self._store = {} + + async def async_set_cache(self, key, value, **kwargs): + self._store[key] = json.dumps(value) + + async def async_get_cache(self, key, **kwargs): + val = self._store.get(key) + if val is None: + return None + # Simulate RedisCache._get_cache_logic: stored as JSON string, return decoded + if isinstance(val, str): + try: + return json.loads(val) + except (ValueError, TypeError): + return val + return val + + async def async_delete_cache(self, key): + self._store.pop(key, None) + + mock_redis = MockRedisCache() + mock_in_memory = MagicMock() + + mock_sso = MagicMock() + mock_redirect_response = MagicMock() + mock_redirect_response.headers = { + "location": "https://auth.example.com/authorize?state=multi_pod_state_xyz&client_id=abc" + } + mock_sso.get_login_redirect = AsyncMock(return_value=mock_redirect_response) + mock_sso.__enter__ = MagicMock(return_value=mock_sso) + mock_sso.__exit__ = MagicMock(return_value=False) + + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): + with patch("litellm.proxy.proxy_server.redis_usage_cache", mock_redis): + with patch( + "litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory + ): + # Pod A: start login, store code_verifier in "Redis" + await SSOAuthenticationHandler.get_generic_sso_redirect_response( + generic_sso=mock_sso, + state="multi_pod_state_xyz", + generic_authorization_endpoint="https://auth.example.com/authorize", + ) + mock_in_memory.async_set_cache.assert_not_called() + # MockRedisCache is a real class; assert on state, not .assert_called_* + stored_key = "pkce_verifier:multi_pod_state_xyz" + assert stored_key in mock_redis._store + stored_value = mock_redis._store[stored_key] + assert isinstance(stored_value, str) and len(json.loads(stored_value)) == 43 + + # Pod B: callback with same state, retrieve from "Redis" + mock_request = MagicMock(spec=Request) + mock_request.query_params = {"state": "multi_pod_state_xyz"} + token_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + assert "code_verifier" in token_params + assert token_params["code_verifier"] == json.loads(stored_value) + mock_in_memory.async_get_cache.assert_not_called() + # delete_cache called; key removed (asserted below) + + # Verifier consumed (single-use); key removed from "Redis" + assert "pkce_verifier:multi_pod_state_xyz" not in mock_redis._store + + @pytest.mark.asyncio + async def test_pkce_fallback_in_memory_roundtrip_when_redis_none(self): + """ + Regression: When redis_usage_cache is None (no Redis configured), + code_verifier is stored and retrieved via user_api_key_cache. + Roundtrip works when callback hits same pod (same in-memory cache). + Single-pod or no-Redis deployments must continue to work. + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + # In-memory store (simulates user_api_key_cache on one pod) + in_memory_store = {} + + async def async_set_cache(key, value, **kwargs): + in_memory_store[key] = value + + async def async_get_cache(key, **kwargs): + return in_memory_store.get(key) + + async def async_delete_cache(key): + in_memory_store.pop(key, None) + + mock_in_memory = MagicMock() + mock_in_memory.async_set_cache = AsyncMock(side_effect=async_set_cache) + mock_in_memory.async_get_cache = AsyncMock(side_effect=async_get_cache) + mock_in_memory.async_delete_cache = AsyncMock(side_effect=async_delete_cache) + + mock_sso = MagicMock() + mock_redirect_response = MagicMock() + mock_redirect_response.headers = { + "location": "https://auth.example.com/authorize?state=fallback_state_xyz&client_id=abc" + } + mock_sso.get_login_redirect = AsyncMock(return_value=mock_redirect_response) + mock_sso.__enter__ = MagicMock(return_value=mock_sso) + mock_sso.__exit__ = MagicMock(return_value=False) + + with patch.dict(os.environ, {"GENERIC_CLIENT_USE_PKCE": "true"}): + with patch("litellm.proxy.proxy_server.redis_usage_cache", None): + with patch( + "litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory + ): + # Pod A: start login, store code_verifier in in-memory cache + await SSOAuthenticationHandler.get_generic_sso_redirect_response( + generic_sso=mock_sso, + state="fallback_state_xyz", + generic_authorization_endpoint="https://auth.example.com/authorize", + ) + mock_in_memory.async_set_cache.assert_called_once() + stored_key = mock_in_memory.async_set_cache.call_args.kwargs["key"] + stored_value = mock_in_memory.async_set_cache.call_args.kwargs[ + "value" + ] + assert stored_key == "pkce_verifier:fallback_state_xyz" + assert isinstance(stored_value, str) and len(stored_value) == 43 + + # Same pod: callback retrieves from in-memory cache + mock_request = MagicMock(spec=Request) + mock_request.query_params = {"state": "fallback_state_xyz"} + token_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + assert "code_verifier" in token_params + assert token_params["code_verifier"] == stored_value + mock_in_memory.async_get_cache.assert_called_once_with( + key=stored_key + ) + mock_in_memory.async_delete_cache.assert_called_once_with( + key=stored_key + ) + + # Verifier consumed; key removed from in-memory + assert "pkce_verifier:fallback_state_xyz" not in in_memory_store + + @pytest.mark.asyncio + async def test_pkce_prepare_token_exchange_returns_nothing_when_no_state(self): + """ + Regression: prepare_token_exchange_parameters with no state in request + does not call cache and does not add code_verifier. + """ + from litellm.proxy.management_endpoints.ui_sso import SSOAuthenticationHandler + + mock_redis = MagicMock() + mock_in_memory = MagicMock() + + with patch("litellm.proxy.proxy_server.redis_usage_cache", mock_redis): + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_in_memory): + mock_request = MagicMock(spec=Request) + mock_request.query_params = {} + token_params = ( + await SSOAuthenticationHandler.prepare_token_exchange_parameters( + request=mock_request, generic_include_client_id=False + ) + ) + assert "code_verifier" not in token_params + mock_redis.async_get_cache.assert_not_called() + mock_in_memory.async_get_cache.assert_not_called() + # Tests for SSO user team assignment bug (Issue: SSO Users Not Added to Entra-Synced Teams on First Login) class TestAddMissingTeamMember: @@ -3330,9 +3503,7 @@ async def test_sso_first_login_full_flow_adds_user_to_teams(self): team_member_calls = [] async def track_team_member_add(team_id, user_info): - team_member_calls.append( - {"team_id": team_id, "user_id": user_info.user_id} - ) + team_member_calls.append({"team_id": team_id, "user_id": user_info.user_id}) # New SSO user with Entra groups new_user = NewUserResponse( @@ -3393,7 +3564,6 @@ async def test_add_missing_team_member_handles_all_user_types( """ Parametrized test ensuring add_missing_team_member works for all user types. """ - from litellm.proxy._types import LiteLLM_UserTable from litellm.proxy.management_endpoints.ui_sso import add_missing_team_member user_info = user_info_factory("test-user-id") @@ -3483,7 +3653,7 @@ async def test_role_mappings_override_default_internal_user_params(): return_value=mock_new_user_response, ) as mock_new_user: # Act - result = await insert_sso_user( + _ = await insert_sso_user( result_openid=mock_result_openid, user_defined_values=user_defined_values, ) @@ -3505,7 +3675,7 @@ async def test_role_mappings_override_default_internal_user_params(): assert ( new_user_request.budget_duration == "30d" ), "budget_duration from default_internal_user_params should be applied" - + # Note: models are applied via _update_internal_new_user_params inside new_user, # not in insert_sso_user, so we verify user_defined_values was updated correctly # by checking that the function completed successfully and other defaults were applied @@ -3620,7 +3790,10 @@ async def test_sso_readiness_google_missing_secret(self): assert data["sso_configured"] is True assert data["provider"] == "google" assert "GOOGLE_CLIENT_SECRET" in data["missing_environment_variables"] - assert "Google SSO is configured but missing required environment variables" in data["message"] + assert ( + "Google SSO is configured but missing required environment variables" + in data["message"] + ) finally: app.dependency_overrides.clear() @@ -3669,7 +3842,7 @@ async def test_sso_readiness_microsoft_configurations( response = client.get("/sso/readiness") assert response.status_code == expected_status - + if expected_status == 200: data = response.json() assert data["sso_configured"] is True @@ -3739,7 +3912,7 @@ async def test_sso_readiness_generic_configurations( response = client.get("/sso/readiness") assert response.status_code == expected_status - + if expected_status == 200: data = response.json() assert data["sso_configured"] is True @@ -3784,8 +3957,14 @@ async def test_custom_microsoft_sso_uses_default_endpoints_when_no_env_vars(self discovery = await sso.get_discovery_document() - assert discovery["authorization_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize" - assert discovery["token_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + assert ( + discovery["authorization_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize" + ) + assert ( + discovery["token_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + ) assert discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" @pytest.mark.asyncio @@ -3849,8 +4028,13 @@ async def test_custom_microsoft_sso_uses_partial_custom_endpoints(self): # Custom auth endpoint assert discovery["authorization_endpoint"] == custom_auth_endpoint # Default token and userinfo endpoints - assert discovery["token_endpoint"] == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" - assert discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" + assert ( + discovery["token_endpoint"] + == "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token" + ) + assert ( + discovery["userinfo_endpoint"] == "https://graph.microsoft.com/v1.0/me" + ) def test_custom_microsoft_sso_uses_common_tenant_when_none(self): """ @@ -3887,11 +4071,7 @@ async def test_setup_team_mappings(): # Arrange mock_prisma = MagicMock() mock_sso_config = MagicMock() - mock_sso_config.sso_settings = { - "team_mappings": { - "team_ids_jwt_field": "groups" - } - } + mock_sso_config.sso_settings = {"team_mappings": {"team_ids_jwt_field": "groups"}} mock_prisma.db.litellm_ssoconfig.find_unique = AsyncMock( return_value=mock_sso_config ) diff --git a/tests/test_litellm/proxy/test_proxy_cli.py b/tests/test_litellm/proxy/test_proxy_cli.py index 12065ad5b4d..be91800732b 100644 --- a/tests/test_litellm/proxy/test_proxy_cli.py +++ b/tests/test_litellm/proxy/test_proxy_cli.py @@ -218,6 +218,12 @@ def test_database_url_construction_with_special_characters(self): assert "connection_limit=10" in modified_url assert "pool_timeout=60" in modified_url + def test_append_query_params_handles_missing_url(self): + from litellm.proxy.proxy_cli import append_query_params + + modified_url = append_query_params(None, {"connection_limit": 10}) + assert modified_url == "" + @patch("uvicorn.run") @patch("atexit.register") # 🔥 critical def test_skip_server_startup(self, mock_atexit_register, mock_uvicorn_run): diff --git a/tests/test_litellm/test_cost_calculation_log_level.py b/tests/test_litellm/test_cost_calculation_log_level.py index 3925ea751af..8ee9ad95cd0 100644 --- a/tests/test_litellm/test_cost_calculation_log_level.py +++ b/tests/test_litellm/test_cost_calculation_log_level.py @@ -3,25 +3,39 @@ import os import sys -import pytest - sys.path.insert(0, os.path.abspath("../../..")) import litellm from litellm import completion_cost -def test_cost_calculation_uses_debug_level(caplog): +def test_cost_calculation_uses_debug_level(): """ Test that cost calculation logs use DEBUG level instead of INFO. This ensures cost calculation details don't appear in production logs. Part of fix for issue #9815. + + Note: This test uses a custom log handler instead of caplog because + caplog doesn't work reliably with pytest-xdist parallel execution. """ - # Ensure verbose_logger is set to DEBUG level to capture the debug logs from litellm._logging import verbose_logger + + # Create a custom handler to capture log records + class LogRecordHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + + def emit(self, record): + self.records.append(record) + + # Set up custom handler + handler = LogRecordHandler() + handler.setLevel(logging.DEBUG) original_level = verbose_logger.level verbose_logger.setLevel(logging.DEBUG) - + verbose_logger.addHandler(handler) + try: # Create a mock completion response mock_response = { @@ -40,72 +54,87 @@ def test_cost_calculation_uses_debug_level(caplog): "total_tokens": 30 } } - - # Test that cost calculation logs are at DEBUG level - with caplog.at_level(logging.DEBUG, logger="LiteLLM"): - try: - cost = completion_cost( - completion_response=mock_response, - model="gpt-3.5-turbo" - ) - except Exception: - pass # Cost calculation may fail, but we're checking log levels - + + # Call completion_cost to trigger logs + try: + cost = completion_cost( + completion_response=mock_response, + model="gpt-3.5-turbo" + ) + except Exception: + pass # Cost calculation may fail, but we're checking log levels + # Find the cost calculation log records cost_calc_records = [ - record for record in caplog.records + record for record in handler.records if "selected model name for cost calculation" in record.message ] - + # Verify that cost calculation logs are at DEBUG level assert len(cost_calc_records) > 0, "No cost calculation logs found" - + for record in cost_calc_records: assert record.levelno == logging.DEBUG, \ f"Cost calculation log should be DEBUG level, but was {record.levelname}" finally: - # Restore original logger level + # Clean up: remove handler and restore original logger level + verbose_logger.removeHandler(handler) verbose_logger.setLevel(original_level) -def test_batch_cost_calculation_uses_debug_level(caplog): +def test_batch_cost_calculation_uses_debug_level(): """ Test that batch cost calculation logs also use DEBUG level. + + Note: This test uses a custom log handler instead of caplog because + caplog doesn't work reliably with pytest-xdist parallel execution. """ from litellm.cost_calculator import batch_cost_calculator from litellm.types.utils import Usage from litellm._logging import verbose_logger - - # Ensure verbose_logger is set to DEBUG level to capture the debug logs + + # Create a custom handler to capture log records + class LogRecordHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + + def emit(self, record): + self.records.append(record) + + # Set up custom handler + handler = LogRecordHandler() + handler.setLevel(logging.DEBUG) original_level = verbose_logger.level verbose_logger.setLevel(logging.DEBUG) - + verbose_logger.addHandler(handler) + try: # Create a mock usage object usage = Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300) - - # Test that batch cost calculation logs are at DEBUG level - with caplog.at_level(logging.DEBUG, logger="LiteLLM"): - try: - batch_cost_calculator( - usage=usage, - model="gpt-3.5-turbo", - custom_llm_provider="openai" - ) - except Exception: - pass # May fail, but we're checking log levels - + + # Call batch_cost_calculator to trigger logs + try: + batch_cost_calculator( + usage=usage, + model="gpt-3.5-turbo", + custom_llm_provider="openai" + ) + except Exception: + pass # May fail, but we're checking log levels + # Find batch cost calculation log records batch_cost_records = [ - record for record in caplog.records + record for record in handler.records if "Calculating batch cost per token" in record.message ] - + # Verify logs exist and are at DEBUG level if batch_cost_records: # May not always log depending on the code path for record in batch_cost_records: assert record.levelno == logging.DEBUG, \ f"Batch cost calculation log should be DEBUG level, but was {record.levelname}" finally: - # Restore original logger level - verbose_logger.setLevel(original_level) \ No newline at end of file + # Clean up: remove handler and restore original logger level + verbose_logger.removeHandler(handler) + verbose_logger.setLevel(original_level) diff --git a/tests/test_litellm/test_service_logger.py b/tests/test_litellm/test_service_logger.py new file mode 100644 index 00000000000..ed44fe9b9f2 --- /dev/null +++ b/tests/test_litellm/test_service_logger.py @@ -0,0 +1,97 @@ +""" +Tests for litellm/_service_logger.py + +Regression test for KeyError: 'call_type' when async_log_success_event +is called without call_type in kwargs (e.g. from batch polling callbacks). +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, patch + +from litellm._service_logger import ServiceLogging + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_not_raise_when_call_type_missing(): + """ + When async_log_success_event is called with kwargs that omit 'call_type', + it should not raise a KeyError. This happens in the batch polling flow + where check_batch_cost.py creates a Logging object whose model_call_details + don't include call_type. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = datetime(2026, 2, 13, 22, 35, 0) + end_time = datetime(2026, 2, 13, 22, 35, 1) + kwargs_without_call_type = {"model": "gpt-4", "stream": False} + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs=kwargs_without_call_type, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["call_type"] == "unknown" + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_pass_call_type_when_present(): + """ + When call_type IS present in kwargs, it should be forwarded correctly. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = datetime(2026, 2, 13, 22, 35, 0) + end_time = datetime(2026, 2, 13, 22, 35, 1) + kwargs_with_call_type = { + "model": "gpt-4", + "stream": False, + "call_type": "aretrieve_batch", + } + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs=kwargs_with_call_type, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["call_type"] == "aretrieve_batch" + + +@pytest.mark.asyncio +async def test_async_log_success_event_should_handle_float_duration(): + """ + When start_time and end_time produce a float duration (not timedelta), + it should still work correctly. + """ + service_logger = ServiceLogging(mock_testing=True) + + start_time = 1000.0 + end_time = 1001.5 + + with patch.object( + service_logger, "async_service_success_hook", new_callable=AsyncMock + ) as mock_hook: + await service_logger.async_log_success_event( + kwargs={"call_type": "completion"}, + response_obj=None, + start_time=start_time, + end_time=end_time, + ) + + mock_hook.assert_called_once() + call_kwargs = mock_hook.call_args + assert call_kwargs.kwargs["duration"] == 1.5 diff --git a/tests/test_litellm/test_video_generation.py b/tests/test_litellm/test_video_generation.py index d4150c349f4..121bf1a1f03 100644 --- a/tests/test_litellm/test_video_generation.py +++ b/tests/test_litellm/test_video_generation.py @@ -731,50 +731,56 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti @pytest.mark.asyncio async def test_video_generation_logging(self): - """Test that video generation creates proper logging payload with cost tracking.""" + """Test that video generation creates proper logging payload with cost tracking. + + Note: Uses AsyncMock with side_effect pattern for reliable parallel execution. + """ custom_logger = self.TestVideoLogger() litellm.logging_callback_manager._reset_all_callbacks() litellm.callbacks = [custom_logger] - + # Mock video generation response mock_response = VideoObject( id="video_test_123", - object="video", + object="video", status="queued", created_at=1712697600, model="sora-2", size="720x1280", seconds="8" ) - - with patch('litellm.videos.main.base_llm_http_handler') as mock_handler: - mock_handler.video_generation_handler.return_value = mock_response - + + # Create async mock function to return the mock_response + async def mock_async_handler(*args, **kwargs): + return mock_response + + # Patch the async_video_generation_handler method on base_llm_http_handler + with patch.object(videos_main.base_llm_http_handler, 'async_video_generation_handler', side_effect=mock_async_handler): response = await litellm.avideo_generation( prompt="A cat running in a garden", model="sora-2", seconds="8", size="720x1280" ) - + await asyncio.sleep(1) # Allow logging to complete - + # Verify logging payload was created assert custom_logger.standard_logging_payload is not None - + payload = custom_logger.standard_logging_payload - + # Verify basic logging fields assert payload["call_type"] == "avideo_generation" assert payload["status"] == "success" assert payload["model"] == "sora-2" assert payload["custom_llm_provider"] == "openai" - + # Verify response object is recognized for logging assert payload["response"] is not None assert payload["response"]["id"] == "video_test_123" assert payload["response"]["object"] == "video" - + # Verify cost tracking is present (may be 0 in test environment) assert payload["response_cost"] is not None # Note: Cost calculation may not work in test environment due to mocking diff --git a/ui/litellm-dashboard/src/components/guardrails/add_guardrail_form.tsx b/ui/litellm-dashboard/src/components/guardrails/add_guardrail_form.tsx index 71b61904dd5..0aad42feb08 100644 --- a/ui/litellm-dashboard/src/components/guardrails/add_guardrail_form.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/add_guardrail_form.tsx @@ -109,6 +109,7 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a const [selectedPatterns, setSelectedPatterns] = useState([]); const [blockedWords, setBlockedWords] = useState([]); const [selectedContentCategories, setSelectedContentCategories] = useState([]); + const [pendingCategorySelection, setPendingCategorySelection] = useState(""); const [toolPermissionConfig, setToolPermissionConfig] = useState({ rules: [], default_action: "deny", @@ -169,6 +170,12 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a setGlobalSeverityThreshold(2); setCategorySpecificThresholds({}); + // Reset Content Filter selections + setSelectedPatterns([]); + setBlockedWords([]); + setSelectedContentCategories([]); + setPendingCategorySelection(""); + setToolPermissionConfig({ rules: [], default_action: "deny", @@ -247,6 +254,39 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a setCurrentStep(currentStep - 1); }; + const handleAddAndContinue = () => { + if (!pendingCategorySelection || !guardrailSettings) return; + + const contentFilterSettings = guardrailSettings.content_filter_settings; + if (!contentFilterSettings) return; + + const category = contentFilterSettings.content_categories?.find((c) => c.name === pendingCategorySelection); + if (!category) return; + + // Check if already added + if (selectedContentCategories.some((c) => c.category === pendingCategorySelection)) { + setPendingCategorySelection(""); + setCurrentStep(currentStep + 1); + return; + } + + // Add the category + setSelectedContentCategories([ + ...selectedContentCategories, + { + id: `category-${Date.now()}`, + category: category.name, + display_name: category.display_name, + action: category.default_action as "BLOCK" | "MASK", + severity_threshold: "medium", + }, + ]); + + // Clear pending selection and advance to next step + setPendingCategorySelection(""); + setCurrentStep(currentStep + 1); + }; + const resetForm = () => { form.resetFields(); setSelectedProvider(null); @@ -258,6 +298,7 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a setSelectedPatterns([]); setBlockedWords([]); setSelectedContentCategories([]); + setPendingCategorySelection(""); setToolPermissionConfig({ rules: [], default_action: "deny", @@ -324,6 +365,15 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a // For Content Filter, add patterns, blocked words, and categories if (shouldRenderContentFilterConfigSettings(values.provider)) { + // Validate that at least one content filter setting is configured + if (selectedPatterns.length === 0 && blockedWords.length === 0 && selectedContentCategories.length === 0) { + NotificationsManager.fromBackend( + "Please configure at least one content filter setting (category, pattern, or keyword)" + ); + setLoading(false); + return; + } + if (selectedPatterns.length > 0) { guardrailData.litellm_params.patterns = selectedPatterns.map((p) => ({ pattern_type: p.type === "prebuilt" ? "prebuilt" : "regex", @@ -658,6 +708,8 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a selectedContentCategories.map((c) => (c.id === id ? { ...c, [field]: value } : c)) ); }} + pendingCategorySelection={pendingCategorySelection} + onPendingCategorySelectionChange={setPendingCategorySelection} accessToken={accessToken} showStep={step} /> @@ -720,6 +772,8 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a const renderStepButtons = () => { const totalSteps = shouldRenderContentFilterConfigSettings(selectedProvider) ? 4 : 2; const isLastStep = currentStep === totalSteps - 1; + const isCategoriesStep = shouldRenderContentFilterConfigSettings(selectedProvider) && currentStep === 1; + const hasPendingCategory = pendingCategorySelection !== ""; return (
@@ -728,11 +782,30 @@ const AddGuardrailForm: React.FC = ({ visible, onClose, a Previous )} - {!isLastStep && } - {isLastStep && ( - + {isCategoriesStep ? ( + <> + + + + ) : ( + <> + {!isLastStep && ( + + )} + {isLastStep && ( + + )} + )} + ), + } as any); + } + + if (categories.length === 0) { + return ( +
+ No categories configured. +
+ ); + } + + return ( + + ); +}; + +export default CategoryTable; diff --git a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentCategoryConfiguration.tsx b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentCategoryConfiguration.tsx index 0f8a02220b7..5ac5c70cd36 100644 --- a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentCategoryConfiguration.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentCategoryConfiguration.tsx @@ -28,6 +28,8 @@ interface ContentCategoryConfigurationProps { onCategoryRemove: (id: string) => void; onCategoryUpdate: (id: string, field: string, value: any) => void; accessToken?: string | null; + pendingSelection?: string; + onPendingSelectionChange?: (value: string) => void; } const ContentCategoryConfiguration: React.FC = ({ @@ -37,8 +39,13 @@ const ContentCategoryConfiguration: React.FC onCategoryRemove, onCategoryUpdate, accessToken, + pendingSelection, + onPendingSelectionChange, }) => { - const [selectedCategoryName, setSelectedCategoryName] = React.useState(""); + // Use controlled state if parent provides it, otherwise use local state + const [localSelectedCategoryName, setLocalSelectedCategoryName] = React.useState(""); + const selectedCategoryName = pendingSelection !== undefined ? pendingSelection : localSelectedCategoryName; + const setSelectedCategoryName = onPendingSelectionChange || setLocalSelectedCategoryName; const [categoryYaml, setCategoryYaml] = React.useState<{ [key: string]: string }>({}); const [categoryFileTypes, setCategoryFileTypes] = React.useState<{ [key: string]: string }>({}); const [loadingYaml, setLoadingYaml] = React.useState<{ [key: string]: boolean }>({}); diff --git a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterConfiguration.tsx b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterConfiguration.tsx index 882abc0b933..5715b3c136b 100644 --- a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterConfiguration.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterConfiguration.tsx @@ -69,6 +69,8 @@ interface ContentFilterConfigurationProps { onContentCategoryAdd?: (category: SelectedContentCategory) => void; onContentCategoryRemove?: (id: string) => void; onContentCategoryUpdate?: (id: string, field: string, value: any) => void; + pendingCategorySelection?: string; + onPendingCategorySelectionChange?: (value: string) => void; } const ContentFilterConfiguration: React.FC = ({ @@ -90,6 +92,8 @@ const ContentFilterConfiguration: React.FC = ({ onContentCategoryAdd, onContentCategoryRemove, onContentCategoryUpdate, + pendingCategorySelection, + onPendingCategorySelectionChange, }) => { const [patternModalVisible, setPatternModalVisible] = useState(false); const [keywordModalVisible, setKeywordModalVisible] = useState(false); @@ -278,6 +282,8 @@ const ContentFilterConfiguration: React.FC = ({ onCategoryRemove={onContentCategoryRemove} onCategoryUpdate={onContentCategoryUpdate} accessToken={accessToken} + pendingSelection={pendingCategorySelection} + onPendingSelectionChange={onPendingCategorySelectionChange} /> )} diff --git a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterDisplay.tsx b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterDisplay.tsx index 0c1e12d8860..db7345fa361 100644 --- a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterDisplay.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterDisplay.tsx @@ -2,6 +2,7 @@ import React from "react"; import { Card, Text, Badge } from "@tremor/react"; import PatternTable from "./PatternTable"; import KeywordTable from "./KeywordTable"; +import CategoryTable from "./CategoryTable"; interface Pattern { id: string; @@ -19,26 +20,42 @@ interface BlockedWord { description?: string; } +interface ContentCategory { + id: string; + category: string; + display_name: string; + action: "BLOCK" | "MASK"; + severity_threshold: "high" | "medium" | "low"; +} + interface ContentFilterDisplayProps { patterns: Pattern[]; blockedWords: BlockedWord[]; + categories?: ContentCategory[]; readOnly?: boolean; onPatternActionChange?: (id: string, action: "BLOCK" | "MASK") => void; onPatternRemove?: (id: string) => void; onBlockedWordUpdate?: (id: string, field: string, value: any) => void; onBlockedWordRemove?: (id: string) => void; + onCategoryActionChange?: (id: string, action: "BLOCK" | "MASK") => void; + onCategorySeverityChange?: (id: string, severity: "high" | "medium" | "low") => void; + onCategoryRemove?: (id: string) => void; } const ContentFilterDisplay: React.FC = ({ patterns, blockedWords, + categories = [], readOnly = true, onPatternActionChange, onPatternRemove, onBlockedWordUpdate, onBlockedWordRemove, + onCategoryActionChange, + onCategorySeverityChange, + onCategoryRemove, }) => { - if (patterns.length === 0 && blockedWords.length === 0) { + if (patterns.length === 0 && blockedWords.length === 0 && categories.length === 0) { return null; } @@ -47,6 +64,22 @@ const ContentFilterDisplay: React.FC = ({ return ( <> + {categories.length > 0 && ( + +
+ Content Categories + {categories.length} categories configured +
+ +
+ )} + {patterns.length > 0 && (
diff --git a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterManager.tsx b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterManager.tsx index aa23c0e1db0..1070453425b 100644 --- a/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterManager.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/content_filter/ContentFilterManager.tsx @@ -158,7 +158,14 @@ const ContentFilterManager: React.FC = ({ // Read-only display mode if (!isEditing) { - return ; + return ( + + ); } // Edit mode diff --git a/ui/litellm-dashboard/src/components/organisms/regenerate_key_modal.tsx b/ui/litellm-dashboard/src/components/organisms/regenerate_key_modal.tsx index a4339e11920..2fad101c20f 100644 --- a/ui/litellm-dashboard/src/components/organisms/regenerate_key_modal.tsx +++ b/ui/litellm-dashboard/src/components/organisms/regenerate_key_modal.tsx @@ -37,6 +37,7 @@ export function RegenerateKeyModal({ selectedToken, visible, onClose, onKeyUpdat tpm_limit: selectedToken.tpm_limit, rpm_limit: selectedToken.rpm_limit, duration: selectedToken.duration || "", + grace_period: "", }); // Initialize the current access token @@ -223,6 +224,23 @@ export function RegenerateKeyModal({ selectedToken, visible, onClose, onKeyUpdat Current expiry: {selectedToken?.expires ? new Date(selectedToken.expires).toLocaleString() : "Never"}
{newExpiryTime &&
New expiry: {newExpiryTime}
} + + + +
+ Recommended: 24h to 72h for production keys to allow seamless client migration. +
)}