diff --git a/core/bifrost.go b/core/bifrost.go index 99c92a8e3a..91071567e5 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5526,16 +5526,24 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // Step 1: compute effective value for each flag (provider config ← per-request override). effectiveSendBackReq := config.SendBackRawRequest - if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok { - effectiveSendBackReq = override + allowRawOverride, _ := req.Context.Value(schemas.BifrostContextKeyAllowPerRequestRawOverride).(bool) + if allowRawOverride { + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok { + effectiveSendBackReq = override + } } effectiveSendBackResp := config.SendBackRawResponse - if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok { - effectiveSendBackResp = override + if allowRawOverride { + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok { + effectiveSendBackResp = override + } } effectiveStore := config.StoreRawRequestResponse - if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok { - effectiveStore = override + allowStorageOverride, _ := req.Context.Value(schemas.BifrostContextKeyAllowPerRequestStorageOverride).(bool) + if allowStorageOverride { + if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok { + effectiveStore = override + } } // Step 2: derive per-side capture and strip flags. diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index fc66309189..e96b3061c0 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -2787,7 +2787,8 @@ func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider sch if slices.Contains(availableProviders, defaultProvider) { return defaultProvider } - return "" + // Return the first available provider + return availableProviders[0] } return defaultProvider } diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 953b05b21e..b28d9dcaef 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -248,6 +248,9 @@ const ( BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) + BifrostContextKeyAllowPerRequestStorageOverride BifrostContextKey = "bifrost-allow-per-request-storage-override" // bool (set by transport from config — gates whether x-bf-disable-content-logging and x-bf-store-raw-request-response per-request overrides are honored) + BifrostContextKeyAllowPerRequestRawOverride BifrostContextKey = "bifrost-allow-per-request-raw-override" // bool (set by transport from config — gates whether x-bf-send-back-raw-request and x-bf-send-back-raw-response per-request overrides are honored) + BifrostContextKeyDisableContentLogging BifrostContextKey = "x-bf-disable-content-logging" // bool (per-request override for content logging; only honored when BifrostContextKeyAllowPerRequestStorageOverride is true) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySCIMClaims BifrostContextKey = "scim_claims" BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) @@ -299,6 +302,7 @@ const ( RoutingEngineGovernance = "governance" RoutingEngineRoutingRule = "routing-rule" RoutingEngineLoadbalancing = "loadbalancing" + RoutingEngineModelCatalog = "model-catalog" ) // KeyAttemptRecord captures the outcome of a single request attempt within executeRequestWithRetries. diff --git a/docs/architecture/framework/model-catalog.mdx b/docs/architecture/framework/model-catalog.mdx index 26b138e070..30ef3627ec 100644 --- a/docs/architecture/framework/model-catalog.mdx +++ b/docs/architecture/framework/model-catalog.mdx @@ -7,26 +7,35 @@ icon: "book-open" The Model Catalog is a foundational component of Bifrost that provides a unified interface for managing AI models, including their pricing, capabilities, and availability. It serves as a centralized repository for all model-related information, enabling dynamic cost calculation, intelligent model routing, and efficient resource management. -**Related Documentation**: The Model Catalog powers Bifrost's intelligent routing system. See [Provider Routing](/providers/provider-routing) for detailed examples of how governance and load balancing use the catalog to make routing decisions, including cross-provider scenarios and weighted routing via proxy providers. + **Related Documentation**: The Model Catalog powers Bifrost's intelligent + routing system. See [Provider Routing](/providers/provider-routing) for + detailed examples of how governance and load balancing use the catalog to make + routing decisions, including cross-provider scenarios and weighted routing via + proxy providers. ## Core Features ### **1. Automatic Pricing Synchronization** + The Model Catalog manages pricing data through a two-phase approach: **Startup Behavior:** + - **With ConfigStore**: Downloads a pricing sheet from Maxim's datasheet, persists it to the config store, and then loads it into memory for fast lookups. - **Without ConfigStore**: Downloads the pricing sheet directly into memory on every startup. **Ongoing Synchronization:** + - When ConfigStore is available, an automatic sync occurs every 24 hours to keep pricing data current. - All pricing data is cached in memory for O(1) lookup performance during cost calculations. This ensures that cost calculations always use the latest pricing information from AI providers while maintaining optimal performance. ### **2. Multi-Modal Cost Calculation** + It supports diverse pricing models across different AI operation types: + - **Text Operations**: Token-based pricing for chat completions, text completions, responses, and embeddings. Cache-read/cache-write pricing applies to chat/text/responses when providers surface prompt cache token details. - **Audio Processing**: Character-based, token-based, and duration-based pricing for speech synthesis and transcription, with audio token detail breakdown. Speech responses populate `usage.input_chars` so speech can be billed by input characters in addition to tokens/duration. - **Image Processing**: Per-image (`input_cost_per_image`/`output_cost_per_image`), per-pixel (`input_cost_per_pixel`/`output_cost_per_pixel`), or token-based pricing with text/image token breakdown. @@ -35,17 +44,22 @@ It supports diverse pricing models across different AI operation types: - **Prompt Caching**: Separate rates for cache-read tokens (`cached_read_tokens`) and cache-creation tokens (`cached_write_tokens`), both surfaced under `prompt_tokens_details` (see [Prompt Cache Cost Calculation](#prompt-cache-cost-calculation)). ### **3. Model Information Management** + The Model Catalog maintains a pool of available models for each provider, populated from both pricing data and provider list models APIs. This enables: + - **Model Discovery**: Listing all available models for a given provider - **Provider Discovery**: Finding all providers that support a specific model with intelligent cross-provider resolution (OpenRouter, Vertex, Groq, Bedrock) - **Model Validation**: Checking if a model is allowed for a provider based on allowed models lists (supports provider-prefixed entries) ### **4. Intelligent Cache Cost Handling** + It integrates with semantic caching to provide accurate cost calculations: + - **Cache Hits**: Zero cost for direct cache hits, and embedding cost only for semantic matches. - **Cache Misses**: Combined cost of the base model usage plus the embedding generation cost for cache storage. ### **5. Tiered Pricing Support** + The system automatically applies different pricing rates for high-token contexts, reflecting real provider pricing models. Two tiers are supported: above 128k tokens and above 200k tokens, with the higher tier taking precedence when both are configured. ## Configuration @@ -74,6 +88,7 @@ modelCatalog, err := modelcatalog.Init(context.Background(), config, configStore ## Architecture ### ModelCatalog + The `ModelCatalog` is the central component that handles all model and pricing operations: ```go @@ -100,6 +115,7 @@ type ModelCatalog struct { ``` ### Pricing Data Structure + Each model's pricing information includes comprehensive cost metrics, supporting various modalities and tiered pricing: ```go @@ -166,10 +182,14 @@ type PricingEntry struct { The Model Catalog is designed to be shared across all Bifrost plugins, providing consistent model information and validation logic for governance, load balancing, and other routing mechanisms. -**Governance & Load Balancing**: Both plugins delegate model validation to the Model Catalog's `IsModelAllowedForProvider` method, ensuring consistent handling of cross-provider scenarios and provider-prefixed allowed models. See [Provider Routing](/providers/provider-routing) for configuration examples. + **Governance & Load Balancing**: Both plugins delegate model validation to the + Model Catalog's `IsModelAllowedForProvider` method, ensuring consistent + handling of cross-provider scenarios and provider-prefixed allowed models. See + [Provider Routing](/providers/provider-routing) for configuration examples. ### Initialization + In Bifrost's gateway, the `ModelCatalog` is initialized once at the start and shared across all plugins: ```go @@ -183,6 +203,7 @@ if err != nil { ``` ### Basic Cost Calculation + Calculate costs from a Bifrost response: ```go @@ -196,6 +217,7 @@ logger.Info("Request cost: $%.6f", cost) ``` ### Unified Cost Calculation + `CalculateCost` is the single entry point for all cost calculations. It handles all request types, semantic cache billing, and tiered pricing automatically: ```go @@ -208,10 +230,13 @@ cost := modelCatalog.CalculateCost(result, nil) // *schemas.BifrostResponse, *Pr ``` ### Model Discovery + The `ModelCatalog` provides several methods to query for model and provider information. #### Get Models for a Provider + Retrieve a list of all models supported by a specific provider. + ```go openaiModels := modelCatalog.GetModelsForProvider(schemas.OpenAI) for _, model := range openaiModels { @@ -222,6 +247,7 @@ for _, model := range openaiModels { **Thread-safe**: Uses read lock for concurrent access. #### Get Providers for a Model + Find all providers that offer a specific model, including cross-provider resolution. ```go @@ -244,6 +270,7 @@ This method implements intelligent cross-provider routing logic to discover all 5. **Bedrock Claude Models**: For Claude models, flexible matching against Bedrock's full ARN format **Example**: + ```go providers := modelCatalog.GetProvidersForModel("claude-3-5-sonnet") // Returns: [anthropic, vertex, bedrock, openrouter] @@ -251,10 +278,14 @@ providers := modelCatalog.GetProvidersForModel("claude-3-5-sonnet") ``` -This cross-provider logic powers Bifrost's intelligent routing capabilities. See [Provider Routing](/providers/provider-routing#the-model-catalog) for detailed examples of how this enables features like weighted routing via proxy providers. + This cross-provider logic powers Bifrost's intelligent routing capabilities. + See [Provider Routing](/providers/provider-routing#the-model-catalog) for + detailed examples of how this enables features like weighted routing via proxy + providers. #### Check Model Allowance for Provider + Validate if a model is allowed for a specific provider based on an allowed models list. This method is used internally by governance and load balancing plugins. ```go @@ -284,39 +315,55 @@ isAllowed := modelCatalog.IsModelAllowedForProvider( ``` **Behavior**: + - **`["*"]` wildcard**: Delegates to `GetProvidersForModel` (includes cross-provider logic) — this is the "allow all via catalog" mode - **Non-empty explicit list**: Checks for both direct matches and provider-prefixed entries - **Empty slice (`[]string{}` / empty `schemas.WhiteList`)**: Returns `false` (deny-all) — mirrors the config deny-by-default semantics -In `config.json` and the governance API, `allowed_models: []` (empty array) means **deny all models** (deny-by-default, v1.5.0+). The Go helper `IsModelAllowedForProvider` behaves the same way: an empty `allowedModels` slice also returns `false`. Use `["*"]` to allow all models validated through the catalog. + In `config.json` and the governance API, `allowed_models: []` (empty array) + means **deny all models** (deny-by-default, v1.5.0+). The Go helper + `IsModelAllowedForProvider` behaves the same way: an empty `allowedModels` + slice also returns `false`. Use `["*"]` to allow all models validated through + the catalog. - - Direct: `"gpt-4o"` matches `"gpt-4o"` - - Prefixed: `"openai/gpt-4o"` matches request for `"gpt-4o"` (prefix stripped) +- Direct: `"gpt-4o"` matches `"gpt-4o"` +- Prefixed: `"openai/gpt-4o"` matches request for `"gpt-4o"` (prefix stripped) **Use Cases**: + - **Governance Routing**: Validate if a model request is allowed for a provider configuration - **Load Balancing**: Filter providers based on allowed models before performance scoring - **Virtual Key Validation**: Check if a model can be used with a specific virtual key's provider configs -This method is the central validation point for both governance and load balancing plugins, ensuring consistent model allowance logic across all routing mechanisms. It handles all edge cases including proxy providers (OpenRouter, Vertex) and provider-prefixed model entries. + This method is the central validation point for both governance and load + balancing plugins, ensuring consistent model allowance logic across all + routing mechanisms. It handles all edge cases including proxy providers + (OpenRouter, Vertex) and provider-prefixed model entries. #### Dynamically Add Models + You can dynamically add models to the catalog's pool from a `v1/models` compatible response structure. This is useful for providers that expose a model list endpoint. + ```go // response is *schemas.BifrostListModelsResponse modelCatalog.AddModelDataToPool(response) ``` + This is automatically done in Bifrost gateway initialization for all providers that are supported by Bifrost. **When to use**: + - After fetching models from a provider's `/v1/models` endpoint - When a new provider is dynamically added at runtime - For testing with custom model lists + ### Reloading Configuration + You can reload the pricing configuration at runtime if you need to change the pricing URL or sync interval. + ```go newConfig := &modelcatalog.Config{ PricingSyncInterval: 12 * time.Hour, @@ -383,7 +430,6 @@ func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.R // This ensures operations continue smoothly without billing failures. ``` - ## Cleanup and Lifecycle Management Properly clean up resources when shutting down: diff --git a/docs/features/semantic-caching.mdx b/docs/features/semantic-caching.mdx index d63211cfab..f25747c720 100644 --- a/docs/features/semantic-caching.mdx +++ b/docs/features/semantic-caching.mdx @@ -118,7 +118,6 @@ import ( cacheConfig := &semanticcache.Config{ // Embedding model configuration (Required) Provider: schemas.OpenAI, - Keys: []schemas.Key{{Value: "sk-..."}}, EmbeddingModel: "text-embedding-3-small", Dimension: 1536, @@ -155,22 +154,32 @@ bifrostConfig := schemas.BifrostConfig{ ![Semantic Cache Plugin Configuration](../media/ui-semantic-cache-config.png) -**Note**: Make sure you have a vector store setup (using `config.json`) before configuring the semantic cache plugin. +**Prerequisites**: A vector store must be configured and enabled in `config.json`, and at least one provider must be configured, before the toggle becomes available. -1. **Navigate to Settings** - - Open Bifrost UI at `http://localhost:8080` - - Go to Settings. +1. **Navigate to the Config page** in the Bifrost UI and find the **Plugins** section. -2. **Configure Semantic Cache Plugin** +2. **Toggle** the **Enable Semantic Caching** switch to enable it. The configuration form expands below. -- Toggle the plugin switch to enable it, and fill in the required fields. +3. **Fill in the fields** across the four sections: -**Required Fields:** -- **Provider**: The provider to use for caching. -- **Embedding Model**: The embedding model to use for caching. -- **Dimension**: The embedding dimension for the configured embedding model. +**Provider and Model Settings** (required for semantic mode): +- **Configured Providers**: Dropdown of providers already set up in Bifrost. The selected provider's API keys are inherited automatically. +- **Embedding Model**: The embedding model to use (e.g. `text-embedding-3-small`). -**Note**: Changes will need a restart of the Bifrost server to take effect, because the plugin is loaded on startup only. +**Cache Settings**: +- **TTL (seconds)**: How long cached responses are kept (default: 300 s). +- **Similarity Threshold**: Cosine similarity cutoff for a cache hit (0–1, default: 0.8). +- **Dimension**: Vector dimension matching your embedding model (e.g. 1536 for `text-embedding-3-small`). + +**Conversation Settings**: +- **Conversation History Threshold**: Skip caching when the conversation has more than this many messages (default: 3). +- **Exclude System Prompt** (toggle): Exclude system messages from cache-key generation. + +**Cache Behavior**: +- **Cache by Model** (toggle): Include the model name in the cache key (default: on). +- **Cache by Provider** (toggle): Include the provider name in the cache key (default: on). + +4. Click **Save**. Changes are persisted and applied immediately for enabled plugins via the API reload path; other plugin changes (e.g. via `config.json`) may still require a restart. @@ -202,7 +211,7 @@ bifrostConfig := schemas.BifrostConfig{ } ``` -> **Note**: In `config.json` setups, provider keys are taken from the provider config on initialization, so you do not need to duplicate `keys` inside the plugin config. Any updates to the provider keys will not be reflected until next restart. +> **Note**: Provider API keys are inherited automatically from the global provider configuration. You do not need to (and cannot) specify keys inside the plugin config. **TTL Format Options:** - Duration strings: `"30s"`, `"5m"`, `"1h"`, `"24h"` @@ -228,7 +237,7 @@ Exact-match direct entries are stored and retrieved using a deterministic cache ### Setup -To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `keys` fields from the plugin config. The plugin will automatically fall back to direct search only. +To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `embedding_model` fields from the plugin config. The plugin will automatically fall back to direct search only. > **Important**: If you specify `dimension: 1` and also provide a `provider`, Bifrost treats the config as provider-backed semantic mode, not direct-only mode. To use direct-only mode, omit the `provider` field entirely. @@ -246,7 +255,7 @@ import ( ) cacheConfig := &semanticcache.Config{ - // No Provider, Keys, or EmbeddingModel -- direct hash mode only + // No Provider or EmbeddingModel -- direct hash mode only Dimension: 1, // Placeholder; entries are stored as metadata-only (no embedding vectors). Change dimension before switching to dual-layer mode to avoid mixed-dimension issues. TTL: 5 * time.Minute, diff --git a/docs/features/telemetry.mdx b/docs/features/telemetry.mdx index f816e30b4f..f8e2395902 100644 --- a/docs/features/telemetry.mdx +++ b/docs/features/telemetry.mdx @@ -9,6 +9,7 @@ icon: "gauge" Bifrost provides built-in telemetry and monitoring capabilities through Prometheus metrics collection. The telemetry system tracks both HTTP-level performance metrics and upstream provider interactions, giving you complete visibility into your AI gateway's performance and usage patterns. **Key Features:** + - **Prometheus Integration** - Native metrics collection at `/metrics` endpoint - **Comprehensive Tracking** - Success/error rates, token usage, costs, and cache performance - **Custom Labels** - Configurable dimensions for detailed analysis @@ -28,14 +29,15 @@ The telemetry plugin operates asynchronously to ensure metrics collection doesn' These metrics track all incoming HTTP requests to Bifrost: -| Metric | Type | Description | -|--------|------|-------------| -| `http_requests_total` | Counter | Total number of HTTP requests | -| `http_request_duration_seconds` | Histogram | Duration of HTTP requests | -| `http_request_size_bytes` | Histogram | Size of incoming HTTP requests | -| `http_response_size_bytes` | Histogram | Size of outgoing HTTP responses | +| Metric | Type | Description | +| ------------------------------- | --------- | ------------------------------- | +| `http_requests_total` | Counter | Total number of HTTP requests | +| `http_request_duration_seconds` | Histogram | Duration of HTTP requests | +| `http_request_size_bytes` | Histogram | Size of incoming HTTP requests | +| `http_response_size_bytes` | Histogram | Size of outgoing HTTP responses | Labels: + - `path`: HTTP endpoint path - `method`: HTTP verb (e.g., `GET`, `POST`, `PUT`, `DELETE`) - `status`: HTTP status code @@ -45,24 +47,25 @@ Labels: These metrics track requests forwarded to AI providers: -| Metric | Type | Description | Labels | -|--------|------|-------------|---------| -| `bifrost_upstream_requests_total` | Counter | Total requests forwarded to upstream providers | Base Labels, custom labels | -| `bifrost_success_requests_total` | Counter | Total successful requests to upstream providers | Base Labels, custom labels | -| `bifrost_error_requests_total` | Counter | Total failed requests to upstream providers | Base Labels, `status_code`, custom labels | -| `bifrost_upstream_latency_seconds` | Histogram | Latency of upstream provider requests | Base Labels, `is_success`, custom labels | -| `bifrost_input_tokens_total` | Counter | Total input tokens sent to upstream providers | Base Labels, custom labels | -| `bifrost_output_tokens_total` | Counter | Total output tokens received from upstream providers | Base Labels, custom labels | -| `bifrost_cache_hits_total` | Counter | Total cache hits by type (direct/semantic) | Base Labels, `cache_type`, custom labels | -| `bifrost_cost_total` | Counter | Total cost in USD for upstream provider requests | Base Labels, custom labels | +| Metric | Type | Description | Labels | +| ---------------------------------- | --------- | ---------------------------------------------------- | ----------------------------------------- | +| `bifrost_upstream_requests_total` | Counter | Total requests forwarded to upstream providers | Base Labels, custom labels | +| `bifrost_success_requests_total` | Counter | Total successful requests to upstream providers | Base Labels, custom labels | +| `bifrost_error_requests_total` | Counter | Total failed requests to upstream providers | Base Labels, `status_code`, custom labels | +| `bifrost_upstream_latency_seconds` | Histogram | Latency of upstream provider requests | Base Labels, `is_success`, custom labels | +| `bifrost_input_tokens_total` | Counter | Total input tokens sent to upstream providers | Base Labels, custom labels | +| `bifrost_output_tokens_total` | Counter | Total output tokens received from upstream providers | Base Labels, custom labels | +| `bifrost_cache_hits_total` | Counter | Total cache hits by type (direct/semantic) | Base Labels, `cache_type`, custom labels | +| `bifrost_cost_total` | Counter | Total cost in USD for upstream provider requests | Base Labels, custom labels | Base Labels: + - `provider`: AI provider name (e.g., `openai`, `anthropic`, `azure`) - `model`: Model name (e.g., `gpt-4o-mini`, `claude-3-sonnet`) - `method`: Request type (`chat`, `text`, `embedding`, `speech`, `transcription`) - `virtual_key_id`: Virtual key ID - `virtual_key_name`: Virtual key name -- `routing_engines_used`: Comma-separated routing engines used ("routing-rule", "governance", "loadbalancing") +- `routing_engines_used`: Comma-separated routing engines used ("routing-rule", "governance", "loadbalancing", "model-catalog") - `routing_rule_id`: Routing rule ID that matched the request - `routing_rule_name`: Routing rule name that matched the request - `selected_key_id`: ID of the key that successfully served the request (`null` on final errors) @@ -75,40 +78,43 @@ Base Labels: These metrics capture latency characteristics specific to streaming responses: -| Metric | Type | Description | Labels | -|--------|------|-------------|---------| +| Metric | Type | Description | Labels | +| -------------------------------------------- | --------- | ----------------------------------------------- | ----------- | | `bifrost_stream_first_token_latency_seconds` | Histogram | Time from request start to first streamed token | Base Labels | -| `bifrost_stream_inter_token_latency_seconds` | Histogram | Latency between subsequent streamed tokens | Base Labels | +| `bifrost_stream_inter_token_latency_seconds` | Histogram | Latency between subsequent streamed tokens | Base Labels | --- ## Monitoring Examples ### Success Rate Monitoring + Track the success rate of requests to different providers: ```promql # Success rate by provider -rate(bifrost_success_requests_total[5m]) / +rate(bifrost_success_requests_total[5m]) / rate(bifrost_upstream_requests_total[5m]) * 100 ``` ### Token Usage Analysis + Monitor token consumption across different models: ```promql # Input tokens per minute by model increase(bifrost_input_tokens_total[1m]) -# Output tokens per minute by model +# Output tokens per minute by model increase(bifrost_output_tokens_total[1m]) # Token efficiency (output/input ratio) -rate(bifrost_output_tokens_total[5m]) / +rate(bifrost_output_tokens_total[5m]) / rate(bifrost_input_tokens_total[5m]) ``` ### Cost Tracking + Monitor spending across providers and models: ```promql @@ -119,16 +125,17 @@ sum by (provider) (rate(bifrost_cost_total[1m])) sum by (provider) (increase(bifrost_cost_total[1d])) # Cost per request by provider and model -sum by (provider, model) (rate(bifrost_cost_total[5m])) / +sum by (provider, model) (rate(bifrost_cost_total[5m])) / sum by (provider, model) (rate(bifrost_upstream_requests_total[5m])) ``` ### Cache Performance + Track cache effectiveness: ```promql # Cache hit rate by type -rate(bifrost_cache_hits_total[5m]) / +rate(bifrost_cache_hits_total[5m]) / rate(bifrost_upstream_requests_total[5m]) * 100 # Direct vs semantic cache hits @@ -136,11 +143,12 @@ sum by (cache_type) (rate(bifrost_cache_hits_total[5m])) ``` ### Error Rate Analysis + Monitor error patterns: ```promql # Error rate by provider -rate(bifrost_error_requests_total[5m]) / +rate(bifrost_error_requests_total[5m]) / rate(bifrost_upstream_requests_total[5m]) * 100 # Errors by model @@ -216,6 +224,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ``` **Header Format:** + - Prefix: `x-bf-prom-` - Label name: Any string after the prefix - Value: String value for the label @@ -242,7 +251,9 @@ docker-compose up -d ``` -**Development Only**: The provided Docker Compose setup is for testing purposes only. Do not use in production without proper security, scaling, and persistence configuration. + **Development Only**: The provided Docker Compose setup is for testing + purposes only. Do not use in production without proper security, scaling, and + persistence configuration. You can use the Prometheus scraping endpoint to create your own Grafana dashboards. Given below are few examples created using the Docker Compose setup. @@ -259,6 +270,7 @@ For production environments: 4. **Configure alerts** based on your SLA requirements **Prometheus Scrape Configuration:** + ```yaml scrape_configs: - job_name: "bifrost-gateway" @@ -273,7 +285,10 @@ scrape_configs: ``` - If you have Bifrost authentication enabled (`auth_config`), you must include `basic_auth` in the scrape config with your `admin_username` and `admin_password`. See the [Prometheus docs](/features/observability/prometheus#pull-based-scraping) for details. + If you have Bifrost authentication enabled (`auth_config`), you must include + `basic_auth` in the scrape config with your `admin_username` and + `admin_password`. See the [Prometheus + docs](/features/observability/prometheus#pull-based-scraping) for details. ### Production Alerting Examples @@ -281,6 +296,7 @@ scrape_configs: Configure alerts for critical scenarios using the new metrics: **High Error Rate Alert:** + ```yaml - alert: BifrostHighErrorRate expr: sum by (provider) (rate(bifrost_error_requests_total[5m])) / sum by (provider) (rate(bifrost_upstream_requests_total[5m])) > 0.05 @@ -292,17 +308,19 @@ Configure alerts for critical scenarios using the new metrics: ``` **High Cost Alert:** + ```yaml - alert: BifrostHighCosts - expr: sum by (provider) (increase(bifrost_cost_total[1d])) > 100 # $100/day threshold + expr: sum by (provider) (increase(bifrost_cost_total[1d])) > 100 # $100/day threshold for: 10m labels: severity: warning annotations: - summary: "Daily cost for provider {{ $labels.provider }} exceeds $100 ({{ $value | printf \"%.2f\" }})" + summary: 'Daily cost for provider {{ $labels.provider }} exceeds $100 ({{ $value | printf "%.2f" }})' ``` **Cache Performance Alert:** + ```yaml - alert: BifrostLowCacheHitRate expr: sum by (provider) (rate(bifrost_cache_hits_total[15m])) / sum by (provider) (rate(bifrost_upstream_requests_total[15m])) < 0.1 diff --git a/docs/openapi/schemas/management/config.yaml b/docs/openapi/schemas/management/config.yaml index eaafb3821f..40237e52ab 100644 --- a/docs/openapi/schemas/management/config.yaml +++ b/docs/openapi/schemas/management/config.yaml @@ -31,6 +31,14 @@ ClientConfig: disable_content_logging: type: boolean description: Whether content logging is disabled + allow_per_request_content_storage_override: + type: boolean + default: false + description: Allow individual requests to override content storage via the x-bf-disable-content-logging header or context key. When false (default), per-request overrides are ignored. + allow_per_request_raw_override: + type: boolean + default: false + description: Allow individual requests to override raw request/response visibility via the x-bf-send-back-raw-request and x-bf-send-back-raw-response headers. When false (default), provider-level settings are authoritative and per-request overrides are ignored. enforce_auth_on_inference: type: boolean description: Whether to enforce virtual key authentication on inference requests diff --git a/docs/providers/provider-routing.mdx b/docs/providers/provider-routing.mdx index 7a341a8e40..141be0d248 100644 --- a/docs/providers/provider-routing.mdx +++ b/docs/providers/provider-routing.mdx @@ -14,9 +14,9 @@ Bifrost offers two powerful methods for routing requests across AI providers, ea When both methods are available, **governance takes precedence** because users have explicitly defined their routing preferences through provider configurations on Virtual Keys. -**When to use which method:** -- Use **Governance** when you need explicit control, compliance requirements, or specific cost optimization strategies -- Use **Adaptive Load Balancing** for automatic performance optimization and minimal configuration overhead + **When to use which method:** + - Use **Governance** when you need explicit control, compliance requirements, or specific cost optimization strategies + - Use **Adaptive Load Balancing** for automatic performance optimization and minimal configuration overhead --- @@ -26,7 +26,10 @@ When both methods are available, **governance takes precedence** because users h The Model Catalog is Bifrost's central registry that tracks which models are available from which providers. It powers both governance-based routing and adaptive load balancing by maintaining an up-to-date mapping of models to providers. -**Architecture Documentation**: For detailed technical documentation on the Model Catalog implementation, including API reference, thread safety, and advanced usage patterns, see [Model Catalog Architecture](/architecture/framework/model-catalog). + **Architecture Documentation**: For detailed technical documentation on the + Model Catalog implementation, including API reference, thread safety, and + advanced usage patterns, see [Model Catalog + Architecture](/architecture/framework/model-catalog). ### Data Sources @@ -48,7 +51,9 @@ The Model Catalog combines two data sources to maintain a comprehensive and up-t - **Stored as**: In-memory map `modelPool[provider][]models` -**Why two sources?** Pricing data provides comprehensive model coverage with cost information, while the List Models API ensures you can use newly released models immediately without waiting for pricing data updates. + **Why two sources?** Pricing data provides comprehensive model coverage with + cost information, while the List Models API ensures you can use newly released + models immediately without waiting for pricing data updates. ### How Model Availability is Determined @@ -73,6 +78,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail - Routing Methods to validate `allowed_models` - Dashboard model selector dropdowns - API responses for `/v1/models?provider=openai` + @@ -134,6 +140,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail - Load balancing to find candidate providers - Fallback generation - Model validation in requests + @@ -163,6 +170,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail - Cost calculation for billing - Model validation during routing - Budget enforcement + @@ -219,6 +227,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail **Result**: Bifrost is ready with a comprehensive model catalog combining both sources. + @@ -261,6 +270,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail - Pricing URL fails but database has data → Use cached database records - Pricing URL fails and no database data → Error logged, existing memory cache retained - List models API fails → Log warning, retain existing model pool entries + @@ -296,6 +306,7 @@ Bifrost uses a sophisticated multi-step process to determine if a model is avail ``` This design ensures **requests never fail due to sync issues** as long as one data source is available. + @@ -307,6 +318,7 @@ The `allowed_models` field in provider configs controls which models can be used **Configuration**: + ```json { "provider_configs": [ @@ -320,11 +332,13 @@ The `allowed_models` field in provider configs controls which models can be used ``` **Behavior**: + - Bifrost calls `GetModelsForProvider("openai")` - Returns all models in `modelPool["openai"]` - Request validated against catalog **Examples**: + ```bash # ✅ Allowed (in catalog) curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' @@ -337,12 +351,14 @@ curl -H "x-bf-vk: vk-123" -d '{"model": "claude-3-5-sonnet"}' ``` **Use Cases**: + - Default behavior for most deployments - Automatically stays up-to-date with provider's model offerings - No manual model list maintenance required -Using `"allowed_models": []` (empty array) means **deny all models** — no requests will be served. Use `["*"]` to allow all models via the catalog. + Using `"allowed_models": []` (empty array) means **deny all models** — no + requests will be served. Use `["*"]` to allow all models via the catalog. @@ -350,17 +366,18 @@ Using `"allowed_models": []` (empty array) means **deny all models** — no requ **Configuration**: + ```json { "provider_configs": [ { "provider": "openai", - "allowed_models": ["gpt-4o", "gpt-4o-mini"], // Only these two + "allowed_models": ["gpt-4o", "gpt-4o-mini"], // Only these two "weight": 1.0 }, { "provider": "anthropic", - "allowed_models": ["claude-3-5-sonnet-20241022"], // Specific version + "allowed_models": ["claude-3-5-sonnet-20241022"], // Specific version "weight": 1.0 } ] @@ -368,12 +385,14 @@ Using `"allowed_models": []` (empty array) means **deny all models** — no requ ``` **Behavior**: + - Bifrost validates request model against explicit list - Catalog is **ignored** for this provider - Supports both direct matches and provider-prefixed entries - Case-sensitive matching **Examples**: + ```bash # ✅ Allowed (in explicit list) curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' @@ -406,6 +425,7 @@ You can also use provider-prefixed model names in `allowed_models`. Bifrost will ``` **How it works**: + ```bash # Request without prefix curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' @@ -419,6 +439,7 @@ curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' This is particularly useful for proxy providers (OpenRouter, Vertex) where you want to explicitly control which upstream models are accessible. **Use Cases**: + - Compliance requirements (only approved models) - Cost control (restrict to cheaper models) - Version pinning (prevent automatic updates) @@ -432,6 +453,7 @@ This is particularly useful for proxy providers (OpenRouter, Vertex) where you w **Key Concept**: Aliases are **key-level** mappings that allow user-friendly model names to map to provider-specific identifiers. **How Aliases Work**: + - Defined at the **Key level**, not Virtual Key level - Structure: `aliases: {"user-facing-name": "provider-specific-id"}` - **Alias key** (left side): User-facing model name used in requests @@ -440,6 +462,7 @@ This is particularly useful for proxy providers (OpenRouter, Vertex) where you w **Azure OpenAI Example**: Provider configuration with alias mapping: + ```json { "providers": { @@ -463,6 +486,7 @@ Provider configuration with alias mapping: ``` **What Happens**: + 1. **Allowed models derived from aliases**: `["gpt-4o", "gpt-4o-mini"]` 2. **User requests with alias**: `{"model": "gpt-4o"}` 3. **Bifrost validates**: `gpt-4o` is in derived allowed models ✅ @@ -496,6 +520,7 @@ Provider configuration with alias mapping: ``` **What Happens**: + 1. **Allowed models**: `["claude-sonnet", "claude-opus"]` (from alias keys) 2. **User requests**: `{"model": "claude-sonnet"}` 3. **Bifrost validates**: `claude-sonnet` in allowed models ✅ @@ -505,6 +530,7 @@ Provider configuration with alias mapping: **Priority of Model Restrictions**: When determining allowed models for a key: + ``` 1. If key.models is NOT empty → Use key.models 2. Else if aliases exist → Use alias keys @@ -512,14 +538,15 @@ When determining allowed models for a key: ``` **Example with Both**: + ```json { "keys": [ { - "models": ["gpt-4o", "gpt-3.5-turbo"], // Explicit restriction + "models": ["gpt-4o", "gpt-3.5-turbo"], // Explicit restriction "aliases": { "gpt-4o": "my-deployment", - "gpt-4-turbo": "another-deployment" // NOT accessible! + "gpt-4-turbo": "another-deployment" // NOT accessible! }, "azure_key_config": { "endpoint": "https://your-resource.openai.azure.com" @@ -528,9 +555,11 @@ When determining allowed models for a key: ] } ``` + Result: Only `["gpt-4o", "gpt-3.5-turbo"]` allowed (models field takes priority) **Vertex Example** (similar pattern): + ```json { "keys": [ @@ -549,12 +578,14 @@ Result: Only `["gpt-4o", "gpt-3.5-turbo"]` allowed (models field takes priority) ``` **Use Cases for Aliases**: + - **Azure**: Map generic model names to specific deployment names in your Azure resource - **Bedrock**: Use short aliases for long inference profile ARNs - **Vertex**: Map to specific model versions or regional endpoints - **Multi-environment**: Different aliases per key (dev/staging/prod) **Key Insight**: + ``` User Request: {"model": "gpt-4o"} ↓ @@ -574,6 +605,7 @@ This allows user-friendly model names in requests while supporting provider-spec **Configuration**: + ```json { "provider_configs": [ @@ -592,12 +624,14 @@ This allows user-friendly model names in requests while supporting provider-spec ``` **Request**: + ```bash curl -H "x-bf-vk: vk-123" \ -d '{"model": "gpt-4o"}' ``` **Routing Behavior**: + 1. **Model validation**: Both providers have `gpt-4o` in allowed_models ✅ 2. **Weighted selection**: 50% chance each 3. **Provider selected**: Let's say Azure @@ -623,6 +657,7 @@ curl -H "x-bf-vk: vk-123" \ - Bifrost checks: `GetModelsForProvider("openrouter")` - Finds: `anthropic/claude-3-5-sonnet` in OpenRouter catalog - ✅ Allowed, routes to OpenRouter + @@ -653,6 +688,7 @@ curl -H "x-bf-vk: vk-123" \ - **Fallbacks**: `["openai/gpt-4o"]` (1% provider as fallback) **Why this works**: Bifrost now supports provider-prefixed entries in `allowed_models`, so `"openai/gpt-4o"` matches requests for `"gpt-4o"`. + @@ -672,6 +708,7 @@ curl -H "x-bf-vk: vk-123" \ - Finds: `["anthropic", "vertex", "bedrock"]` - Validation: `claude-3-5-sonnet` in allowed_models ✅ - Sends to Vertex as: `anthropic/claude-3-5-sonnet` + @@ -690,6 +727,7 @@ curl -H "x-bf-vk: vk-123" \ - Special handling: Checks Groq catalog for `openai/gpt-3.5-turbo` - ✅ Found, validation passes - Sends to Groq as: `openai/gpt-3.5-turbo` + @@ -704,6 +742,7 @@ curl -H "x-bf-vk: vk-123" \ When a Virtual Key has `provider_configs`, governance uses the model catalog for validation: **Wildcard allowed_models Example**: + ```json { "provider_configs": [ @@ -717,6 +756,7 @@ When a Virtual Key has `provider_configs`, governance uses the model catalog for ``` **Request Flow**: + ```bash curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' @@ -727,6 +767,7 @@ curl -H "x-bf-vk: vk-123" -d '{"model": "gpt-4o"}' ``` **Rejection Example**: + ```bash curl -H "x-bf-vk: vk-123" -d '{"model": "claude-3-5-sonnet"}' @@ -743,6 +784,7 @@ curl -H "x-bf-vk: vk-123" -d '{"model": "claude-3-5-sonnet"}' When load balancing selects providers, it queries the catalog to find candidates: **Request Flow**: + ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -d '{"model": "gpt-4o", "messages": [...]}' @@ -757,6 +799,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ``` **Cross-Provider Discovery**: + ```bash curl -d '{"model": "claude-3-5-sonnet"}' @@ -777,7 +820,57 @@ This is how Bifrost achieves **intelligent cross-provider routing** without manu -**Model Catalog is essential for cross-provider routing**. Without it, Bifrost wouldn't know that `gpt-4o` is available from OpenAI, Azure, and Groq, or that `claude-3-5-sonnet` can be routed through Anthropic, Vertex, Bedrock, and OpenRouter. This knowledge powers both governance validation and load balancing provider discovery. + **Model Catalog is essential for cross-provider routing**. Without it, Bifrost + wouldn't know that `gpt-4o` is available from OpenAI, Azure, and Groq, or that + `claude-3-5-sonnet` can be routed through Anthropic, Vertex, Bedrock, and + OpenRouter. This knowledge powers both governance validation and load + balancing provider discovery. + + +--- + +## Default Provider Resolution + + + Default provider resolution via model catalog is available in **Bifrost + v1.5.0-prerelease7 and above**. + + +When a request includes a bare model name without a `provider/` prefix (e.g., `"model": "gpt-4o"` instead of `"model": "openai/gpt-4o"`), Bifrost automatically resolves the provider using the Model Catalog. Note that this default behavior is applied **after all other routing engines** have run. + +### How It Works + +1. **Request arrives** without a provider prefix (e.g., `"model": "gpt-4o"`) +2. **Catalog lookup**: Bifrost calls `GetProvidersForModel("gpt-4o")` to find all providers that support the model +3. **Provider selected**: A provider from the catalog's available list is used (e.g., `openai`) +4. **Request continues**: The resolved `provider/model` string is used for load balancing and fallback handling + +This is logged as the **`model-catalog`** routing engine in telemetry and routing logs, with a message like: + +``` +No provider specified for model gpt-4o, found 3 options in model catalog: +[openai, azure, groq], selecting first: openai +``` + +### Example + +```bash +# These two requests are equivalent when the model catalog +# maps gpt-4o → openai as the first provider: +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "gpt-4o", "messages": [{"role": "user", "content": "Hello!"}]}' + +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "openai/gpt-4o", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + + + If the model catalog is not available or the model is not found in any + provider, the request returns an error asking for the `provider/model` format. + For deterministic provider selection, always use the explicit `provider/model` + prefix. --- @@ -858,13 +951,13 @@ When a Virtual Key has `provider_configs` defined: ### Key Features -| Feature | Description | -|---------|-------------| -| **Explicit Control** | Define exactly which providers and models are accessible | -| **Budget Enforcement** | Automatically exclude providers exceeding budget limits | -| **Rate Limit Protection** | Skip providers that have hit rate limits | -| **Weighted Distribution** | Control traffic distribution with custom weights | -| **Automatic Fallbacks** | Failed providers automatically retry with next highest weight | +| Feature | Description | +| ------------------------- | ------------------------------------------------------------- | +| **Explicit Control** | Define exactly which providers and models are accessible | +| **Budget Enforcement** | Automatically exclude providers exceeding budget limits | +| **Rate Limit Protection** | Skip providers that have hit rate limits | +| **Weighted Distribution** | Control traffic distribution with custom weights | +| **Automatic Fallbacks** | Failed providers automatically retry with next highest weight | ### Best Practices @@ -921,6 +1014,7 @@ When a Virtual Key has `provider_configs` defined: **`allowed_models: []` (empty array)**: Denies **all** models — no requests will be served for this provider config. This is deny-by-default behavior introduced in v1.5.0. **Empty `provider_configs`**: When `provider_configs` is empty (no providers configured), **all providers are blocked** (deny-by-default). You must explicitly add provider configurations to allow traffic through a Virtual Key. + --- @@ -928,7 +1022,9 @@ When a Virtual Key has `provider_configs` defined: ## Adaptive Load Balancing -**Enterprise Feature**: Adaptive Load Balancing is available in Bifrost Enterprise. [Contact us](https://www.getmaxim.ai/bifrost/enterprise) to enable it. + **Enterprise Feature**: Adaptive Load Balancing is available in Bifrost + Enterprise. [Contact us](https://www.getmaxim.ai/bifrost/enterprise) to enable + it. Adaptive Load Balancing automatically optimizes routing based on real-time performance metrics. It operates at **two levels** to provide both macro-level provider selection and micro-level key optimization. @@ -936,10 +1032,10 @@ Adaptive Load Balancing automatically optimizes routing based on real-time perfo ### Two-Level Architecture -Separating provider selection (direction) from key selection (route) enables: -- **Provider-level optimization**: Choose the best provider for a model based on aggregate performance -- **Key-level optimization**: Within that provider, choose the best API key based on individual key performance -- **Resilience**: Even when provider is specified (by governance or user), key-level load balancing still optimizes which API key to use + Separating provider selection (direction) from key selection (route) enables: + - **Provider-level optimization**: Choose the best provider for a model based on aggregate performance + - **Key-level optimization**: Within that provider, choose the best API key based on individual key performance + - **Resilience**: Even when provider is specified (by governance or user), key-level load balancing still optimizes which API key to use ```mermaid @@ -1007,7 +1103,9 @@ Score = (P_{error} \times 0.5) + (P_{latency} \times 0.2) + (P_{util} \times 0.0 $$ -Lower penalties = Higher weights = More traffic. The system self-heals by quickly penalizing failing routes but enabling fast recovery once issues are resolved. + Lower penalties = Higher weights = More traffic. The system self-heals by + quickly penalizing failing routes but enabling fast recovery once issues are + resolved. ### Request Flow @@ -1043,25 +1141,28 @@ Lower penalties = Higher weights = More traffic. The system self-heals by quickl ### Key Features -| Feature | Description | -|---------|-------------| -| **Automatic Optimization** | No manual weight tuning required | -| **Real-time Adaptation** | Weights recomputed every 5 seconds based on live metrics | -| **Circuit Breakers** | Failing routes automatically removed from rotation | -| **Fast Recovery** | 90% penalty reduction in 30 seconds after issues resolve | -| **Health States** | Routes transition between Healthy, Degraded, Failed, and Recovering | -| **Smart Exploration** | 25% chance to probe potentially recovered routes | - +| Feature | Description | +| -------------------------- | ------------------------------------------------------------------- | +| **Automatic Optimization** | No manual weight tuning required | +| **Real-time Adaptation** | Weights recomputed every 5 seconds based on live metrics | +| **Circuit Breakers** | Failing routes automatically removed from rotation | +| **Fast Recovery** | 90% penalty reduction in 30 seconds after issues resolve | +| **Health States** | Routes transition between Healthy, Degraded, Failed, and Recovering | +| **Smart Exploration** | 25% chance to probe potentially recovered routes | ### Dashboard Visibility Monitor load balancing performance in real-time: - Adaptive Load Balancing Dashboard + Adaptive Load Balancing Dashboard The dashboard shows: + - Weight distribution across provider-model-key routes - Performance metrics (error rates, latency, success rates) - State transitions (Healthy → Degraded → Failed → Recovering) @@ -1079,6 +1180,7 @@ When both methods are available in your Bifrost deployment, they work together i - **Level 2 (Route/Key)**: **Always runs**, even when provider is specified This means key-level optimization works regardless of how the provider was chosen! + ### Execution Flow @@ -1138,7 +1240,9 @@ flowchart TD - **Result**: Optimal key selected within the provider -**Important**: Even when governance specifies `azure/gpt-4o`, load balancing **still optimizes which Azure key to use** based on performance metrics. This is the power of the two-level architecture! + **Important**: Even when governance specifies `azure/gpt-4o`, load balancing + **still optimizes which Azure key to use** based on performance metrics. This + is the power of the two-level architecture! ### Example Scenarios @@ -1147,10 +1251,12 @@ flowchart TD **Setup:** + - Virtual Key has `provider_configs` defined - No adaptive load balancing enabled **Request:** + ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -H "x-bf-vk: vk-prod-main" \ @@ -1158,6 +1264,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ``` **Behavior:** + 1. **Governance** applies weighted provider routing → selects Azure (70% weight) 2. Model becomes `azure/gpt-4o` 3. **Standard key selection** (non-adaptive) chooses an Azure key based on static weights @@ -1168,17 +1275,20 @@ curl -X POST http://localhost:8080/v1/chat/completions \ **Setup:** + - **No Virtual Key** (do not send `x-bf-vk`) → this is the **Load Balancing–only** setup - **Virtual Key with empty / missing `provider_configs`** → **blocks all providers** (deny-by-default) and therefore is **NOT** an LB-only setup - Adaptive load balancing enabled **Request:** + ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -d '{"model": "gpt-4o", "messages": [...]}' ``` **Behavior:** + 1. **Load Balancing Level 1** applies performance-based provider routing → selects OpenAI (best performing) 2. Model becomes `openai/gpt-4o` 3. **Load Balancing Level 2** selects best OpenAI key based on performance metrics (error rate, latency, TPM status) @@ -1189,11 +1299,13 @@ curl -X POST http://localhost:8080/v1/chat/completions \ **Setup:** + - Virtual Key has `provider_configs` defined - Adaptive load balancing enabled - Azure has 3 keys: `azure-key-1`, `azure-key-2`, `azure-key-3` **Request:** + ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -H "x-bf-vk: vk-prod-main" \ @@ -1201,6 +1313,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ``` **Behavior:** + 1. **Governance** applies first (respects explicit user config) → selects Azure provider 2. Model becomes `azure/gpt-4o` 3. **Load Balancing Level 1** sees "/" and **skips provider selection** (already decided) @@ -1218,16 +1331,19 @@ curl -X POST http://localhost:8080/v1/chat/completions \ **Setup:** + - Both governance and load balancing enabled - OpenAI has 2 keys available **Request:** + ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -d '{"model": "openai/gpt-4o", "messages": [...]}' ``` **Behavior:** + 1. **Governance** sees "/" and skips 2. **Load Balancing Level 1** sees "/" and **skips provider selection** 3. **Load Balancing Level 2** still runs! Selects best OpenAI key based on current metrics @@ -1240,13 +1356,13 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ### Provider vs Key Selection Rules -| Scenario | Provider Selection | Key Selection | -|----------|-------------------|---------------| -| VK with provider_configs | **Governance** (weighted random) | **Standard** or **Adaptive** (if enabled) | -| VK without provider_configs + LB | **Blocked** (empty = no providers allowed) | N/A | -| No VK + LB | **Load Balancing Level 1** (performance) | **Load Balancing Level 2** (performance) | -| Model with provider prefix + LB | **Skip** (already specified) | **Load Balancing Level 2** (performance) ✅ | -| No Load Balancing enabled | **Governance** or **User** or **Model Catalog** | **Standard** (static weights) | +| Scenario | Provider Selection | Key Selection | +| -------------------------------- | ----------------------------------------------- | ------------------------------------------- | +| VK with provider_configs | **Governance** (weighted random) | **Standard** or **Adaptive** (if enabled) | +| VK without provider_configs + LB | **Blocked** (empty = no providers allowed) | N/A | +| No VK + LB | **Load Balancing Level 1** (performance) | **Load Balancing Level 2** (performance) | +| Model with provider prefix + LB | **Skip** (already specified) | **Load Balancing Level 2** (performance) ✅ | +| No Load Balancing enabled | **Governance** or **User** or **Model Catalog** | **Standard** (static weights) | **Critical Insight**: @@ -1254,6 +1370,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ - **Key selection** runs independently and benefits from load balancing **even when provider is predetermined** This separation is what makes the two-level architecture so powerful! + --- @@ -1261,7 +1378,11 @@ This separation is what makes the two-level architecture so powerful! ## Routing Rules (Dynamic Expression-Based Routing) -**Position in routing pipeline**: Routing Rules execute **before governance provider selection** and can override it. They are evaluated before adaptive load balancing, enabling dynamic provider/model overrides based on runtime conditions like headers, parameters, capacity metrics, and organizational hierarchy. + **Position in routing pipeline**: Routing Rules execute **before governance + provider selection** and can override it. They are evaluated before adaptive + load balancing, enabling dynamic provider/model overrides based on runtime + conditions like headers, parameters, capacity metrics, and organizational + hierarchy. ### Overview @@ -1331,21 +1452,25 @@ request // Request rate limit usage % ### Examples #### Route based on user tier + ```cel headers["x-tier"] == "premium" // → openai/gpt-4o ``` #### Route to fallback when budget high + ```cel budget_used > 85 // → groq/llama-2 (cheaper) ``` #### Route by team + ```cel team_name == "ml-research" // → anthropic/claude-3-opus ``` #### Complex multi-condition routing + ```cel headers["x-environment"] == "production" && tokens_used < 75 && @@ -1367,20 +1492,21 @@ Within each scope, rules are sorted by **priority** (ascending: 0 before 10). ### Key Features -| Feature | Description | -|---------|-------------| -| **CEL Expressions** | Powerful, composable condition language with multiple operators | -| **Scope Hierarchy** | Rules at VirtualKey/Team/Customer/Global levels with proper precedence | -| **Dynamic Override** | Override provider and/or model based on runtime conditions | -| **Fallback Chains** | Define multiple fallback providers for automatic failover | -| **Priority Ordering** | Lower priority evaluated first within same scope | -| **Capacity Awareness** | Access real-time budget and rate limit usage percentages | +| Feature | Description | +| ---------------------- | ---------------------------------------------------------------------- | +| **CEL Expressions** | Powerful, composable condition language with multiple operators | +| **Scope Hierarchy** | Rules at VirtualKey/Team/Customer/Global levels with proper precedence | +| **Dynamic Override** | Override provider and/or model based on runtime conditions | +| **Fallback Chains** | Define multiple fallback providers for automatic failover | +| **Priority Ordering** | Lower priority evaluated first within same scope | +| **Capacity Awareness** | Access real-time budget and rate limit usage percentages | ### Integration with Governance Routing Rules execute **before** governance provider selection and can override it: **If a routing rule matches**: + ``` Routing Rules evaluate ↓ @@ -1394,6 +1520,7 @@ Load Balancing selects best key ``` **If no routing rule matches**: + ``` Routing Rules evaluate ↓ @@ -1479,16 +1606,31 @@ For complete documentation, see [Routing Rules Documentation](/providers/routing ## Additional Resources - - Configuration instructions for setting up governance routing via Virtual Keys (Web UI, API, config.json) + + Configuration instructions for setting up governance routing via Virtual + Keys (Web UI, API, config.json) - Dynamic, expression-based routing using CEL expressions for runtime conditions + Dynamic, expression-based routing using CEL expressions for runtime + conditions - - Technical implementation details: scoring algorithms, weight calculations, and performance characteristics + + Technical implementation details: scoring algorithms, weight calculations, + and performance characteristics - + Learn how to create and configure Virtual Keys diff --git a/docs/providers/request-options.mdx b/docs/providers/request-options.mdx index 55d91481a7..b4e87b0ab3 100644 --- a/docs/providers/request-options.mdx +++ b/docs/providers/request-options.mdx @@ -8,36 +8,36 @@ Bifrost provides request options that control behavior, enable features, and pas ## Complete Reference -| Context Key | Header | Type | Description | -|-------------|--------|------|-------------| -| `BifrostContextKeyVirtualKey` | `x-bf-vk` | `string` | Virtual key identifier for governance | -| `BifrostContextKeyAPIKeyName` | `x-bf-api-key` | `string` | Explicit API key name selection | -| `BifrostContextKeyAPIKeyID` | `x-bf-api-key-id` | `string` | Explicit API key ID selection (takes priority over name) | -| `BifrostContextKeySessionID` | `x-bf-session-id` | `string` | Session ID for key stickiness (requires KV store) | -| `BifrostContextKeySessionTTL` | `x-bf-session-ttl` | `time.Duration` | Session-to-key cache TTL (duration string or seconds) | -| `BifrostContextKeyRequestID` | `x-request-id` | `string` | Custom request ID for tracking | -| `BifrostContextKeySendBackRawRequest` | `x-bf-send-back-raw-request` | `bool` | Include raw provider request in the response | -| `BifrostContextKeySendBackRawResponse` | `x-bf-send-back-raw-response` | `bool` | Include raw provider response in the response | -| `BifrostContextKeyStoreRawRequestResponse` | `x-bf-store-raw-request-response` | `bool` | Persist raw request/response in log records | -| `BifrostContextKeyPassthroughExtraParams` | `x-bf-passthrough-extra-params` | `bool` | Enable passthrough for extra parameters | -| `BifrostContextKeyExtraHeaders` | `x-bf-eh-*` | `map[string][]string` | Custom headers forwarded to provider | -| `BifrostContextKeyDirectKey` | `-` | `schemas.Key` | Direct key credentials (Go SDK only) | -| `BifrostContextKeySkipKeySelection` | `-` | `bool` | Skip key selection process (Go SDK only) | -| `BifrostContextKeyURLPath` | `-` | `string` | Custom URL path appended to provider base URL (Go SDK only) | -| `BifrostContextKeyUseRawRequestBody` | `-` | `bool` | Use raw request body (Go SDK only, requires RawRequestBody field) | -| `semanticcache.CacheKey` | `x-bf-cache-key` | `string` | Custom cache key | -| `semanticcache.CacheTTLKey` | `x-bf-cache-ttl` | `time.Duration` | Cache TTL (duration string or seconds) | -| `semanticcache.CacheThresholdKey` | `x-bf-cache-threshold` | `float64` | Similarity threshold (0.0-1.0) | -| `semanticcache.CacheTypeKey` | `x-bf-cache-type` | `string` | Cache type | -| `semanticcache.CacheNoStoreKey` | `x-bf-cache-no-store` | `bool` | Prevent caching | -| `mcp-include-clients` | `x-bf-mcp-include-clients` | `[]string` | Filter MCP clients (comma-separated). | -| `mcp-include-tools` | `x-bf-mcp-include-tools` | `[]string` | Filter MCP tools (`clientName-toolName` format, comma-separated) | -| `BifrostContextKeyMCPExtraHeaders` | *(any header in a client's `allowed_extra_headers`)* | `map[string][]string` | Headers forwarded to MCP servers at tool execution time, filtered per-client against `allowed_extra_headers` | -| `maxim.TraceIDKey` | `x-bf-maxim-trace-id` | `string` | Maxim trace ID | -| `maxim.GenerationIDKey` | `x-bf-maxim-generation-id` | `string` | Maxim generation ID | -| `maxim.TagsKey` | `x-bf-maxim-*` | `map[string]string` | Maxim tags (custom tag names) | -| `BifrostContextKey(labelName)` | `x-bf-prom-*` | `string` | Prometheus metric labels | - +| Context Key | Header | Type | Description | +| ------------------------------------------ | ---------------------------------------------------- | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | +| `BifrostContextKeyVirtualKey` | `x-bf-vk` | `string` | Virtual key identifier for governance | +| `BifrostContextKeyAPIKeyName` | `x-bf-api-key` | `string` | Explicit API key name selection | +| `BifrostContextKeyAPIKeyID` | `x-bf-api-key-id` | `string` | Explicit API key ID selection (takes priority over name) | +| `BifrostContextKeySessionID` | `x-bf-session-id` | `string` | Session ID for key stickiness (requires KV store) | +| `BifrostContextKeySessionTTL` | `x-bf-session-ttl` | `time.Duration` | Session-to-key cache TTL (duration string or seconds) | +| `BifrostContextKeyRequestID` | `x-request-id` | `string` | Custom request ID for tracking | +| `BifrostContextKeySendBackRawRequest` | `x-bf-send-back-raw-request` | `bool` | Include raw provider request in the response | +| `BifrostContextKeySendBackRawResponse` | `x-bf-send-back-raw-response` | `bool` | Include raw provider response in the response | +| `BifrostContextKeyStoreRawRequestResponse` | `x-bf-store-raw-request-response` | `bool` | Persist raw request/response in log records | +| `BifrostContextKeyDisableContentLogging` | `x-bf-disable-content-logging` | `bool` | Per-request override for content logging; only honored when `allow_per_request_content_storage_override` is enabled in logging config | +| `BifrostContextKeyPassthroughExtraParams` | `x-bf-passthrough-extra-params` | `bool` | Enable passthrough for extra parameters | +| `BifrostContextKeyExtraHeaders` | `x-bf-eh-*` | `map[string][]string` | Custom headers forwarded to provider | +| `BifrostContextKeyDirectKey` | `-` | `schemas.Key` | Direct key credentials (Go SDK only) | +| `BifrostContextKeySkipKeySelection` | `-` | `bool` | Skip key selection process (Go SDK only) | +| `BifrostContextKeyURLPath` | `-` | `string` | Custom URL path appended to provider base URL (Go SDK only) | +| `BifrostContextKeyUseRawRequestBody` | `-` | `bool` | Use raw request body (Go SDK only, requires RawRequestBody field) | +| `semanticcache.CacheKey` | `x-bf-cache-key` | `string` | Custom cache key | +| `semanticcache.CacheTTLKey` | `x-bf-cache-ttl` | `time.Duration` | Cache TTL (duration string or seconds) | +| `semanticcache.CacheThresholdKey` | `x-bf-cache-threshold` | `float64` | Similarity threshold (0.0-1.0) | +| `semanticcache.CacheTypeKey` | `x-bf-cache-type` | `string` | Cache type | +| `semanticcache.CacheNoStoreKey` | `x-bf-cache-no-store` | `bool` | Prevent caching | +| `mcp-include-clients` | `x-bf-mcp-include-clients` | `[]string` | Filter MCP clients (comma-separated). | +| `mcp-include-tools` | `x-bf-mcp-include-tools` | `[]string` | Filter MCP tools (`clientName-toolName` format, comma-separated) | +| `BifrostContextKeyMCPExtraHeaders` | _(any header in a client's `allowed_extra_headers`)_ | `map[string][]string` | Headers forwarded to MCP servers at tool execution time, filtered per-client against `allowed_extra_headers` | +| `maxim.TraceIDKey` | `x-bf-maxim-trace-id` | `string` | Maxim trace ID | +| `maxim.GenerationIDKey` | `x-bf-maxim-generation-id` | `string` | Maxim generation ID | +| `maxim.TagsKey` | `x-bf-maxim-*` | `map[string]string` | Maxim tags (custom tag names) | +| `BifrostContextKey(labelName)` | `x-bf-prom-*` | `string` | Prometheus metric labels | ## Request Configuration Options @@ -70,11 +70,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyVirtualKey, "sk-bf-*") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -105,7 +106,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -113,11 +115,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyAPIKeyID, "key-uuid-1234") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -140,7 +143,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -148,19 +152,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyAPIKeyName, "premium-key") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Session Stickiness (Session ID) -**Context Key:** `BifrostContextKeySessionID` -**Header:** `x-bf-session-id` -**Type:** `string` +**Context Key:** `BifrostContextKeySessionID` +**Header:** `x-bf-session-id` +**Type:** `string` **Required:** No Bind a session to a specific API key so that requests with the same session ID consistently use the same key. Useful for predictable rate-limit buckets, cost attribution per user, and consistent model routing per session. @@ -181,7 +186,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -189,19 +195,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeySessionID, "user-123-session-abc") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Session TTL -**Context Key:** `BifrostContextKeySessionTTL` -**Header:** `x-bf-session-ttl` -**Type:** `time.Duration` (header value: duration string like `"30m"` or `"1h"`, or seconds as integer) +**Context Key:** `BifrostContextKeySessionTTL` +**Header:** `x-bf-session-ttl` +**Type:** `time.Duration` (header value: duration string like `"30m"` or `"1h"`, or seconds as integer) **Required:** No Optional. Controls how long the session-to-key binding is cached. If not set, Bifrost uses 1 hour. The TTL is refreshed on each request so active sessions do not expire. @@ -219,7 +226,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -228,19 +236,20 @@ ctx = context.WithValue(ctx, schemas.BifrostContextKeySessionID, "user-123-sessi ctx = context.WithValue(ctx, schemas.BifrostContextKeySessionTTL, 30*time.Minute) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Request ID -**Context Key:** `BifrostContextKeyRequestID` -**Header:** `x-request-id` -**Type:** `string` +**Context Key:** `BifrostContextKeyRequestID` +**Header:** `x-request-id` +**Type:** `string` **Required:** No Set a custom request ID for tracking and correlation. If not provided, Bifrost generates a UUID. @@ -255,7 +264,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -263,23 +273,28 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestID, "req-12345-abc") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Send Back Raw Request -**Context Key:** `BifrostContextKeySendBackRawRequest` -**Header:** `x-bf-send-back-raw-request` -**Type:** `bool` (header values: `"true"` or `"false"`) +**Context Key:** `BifrostContextKeySendBackRawRequest` +**Header:** `x-bf-send-back-raw-request` +**Type:** `bool` (header values: `"true"` or `"false"`) **Required:** No Include the exact JSON body sent to the provider alongside Bifrost's standardized response. Accepts `"true"` or `"false"` — either value fully overrides the provider-level `send_back_raw_request` config for this request. + +Per-request overrides are **disabled by default**. You must first enable `allow_per_request_raw_override` in your logging configuration (or in the UI under **Logs Settings**) before this header or context key has any effect. This flag controls only what is **sent back to the caller** — it does not affect log storage. To persist raw bytes in logs, use `x-bf-store-raw-request-response` (gated by `allow_per_request_content_storage_override`). + + ```bash @@ -290,7 +305,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -298,16 +314,17 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeySendBackRawRequest, true) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) // Access raw request if response.ChatResponse != nil { - rawReq := response.ChatResponse.ExtraFields.RawRequest +rawReq := response.ChatResponse.ExtraFields.RawRequest } -``` + +```` @@ -324,7 +341,7 @@ The raw request appears in `extra_fields.raw_request`: } } } -``` +```` ### Send Back Raw Response @@ -335,6 +352,10 @@ The raw request appears in `extra_fields.raw_request`: Include the original provider response alongside Bifrost's standardized response format. Accepts `"true"` or `"false"` — either value fully overrides the provider-level `send_back_raw_response` config for this request. + +Per-request overrides are **disabled by default**. You must first enable `allow_per_request_raw_override` in your logging configuration (or in the UI under **Logs Settings**) before this header or context key has any effect. This flag controls only what is **sent back to the caller** — it does not affect log storage. To persist raw bytes in logs, use `x-bf-store-raw-request-response` (gated by `allow_per_request_content_storage_override`). + + ```bash @@ -353,16 +374,17 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeySendBackRawResponse, true) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) // Access raw response if response.ChatResponse != nil { - rawResp := response.ChatResponse.ExtraFields.RawResponse +rawResp := response.ChatResponse.ExtraFields.RawResponse } -``` + +```` @@ -379,7 +401,7 @@ The raw response appears in `extra_fields.raw_response`: } } } -``` +```` ### Store Raw Request/Response @@ -392,6 +414,10 @@ Persist the raw provider request and response in the log record. Accepts `"true" This is orthogonal to the send-back flags: enabling this does not affect whether raw data appears in the API response, and enabling send-back does not automatically store raw data in logs. Use this when you want observability into provider payloads without necessarily exposing them to the caller, or combine it with `x-bf-send-back-raw-*` to do both. + +Per-request overrides are **disabled by default**. You must first enable `allow_per_request_content_storage_override` in your logging configuration (or in the UI under **Logs Settings**) before this header or context key has any effect. Note that this is gated by the **content storage** override, not the raw override — `allow_per_request_raw_override` only gates `x-bf-send-back-raw-request` and `x-bf-send-back-raw-response` (sending raw bytes back to the caller). + + ```bash @@ -410,13 +436,14 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyStoreRawRequestResponse, true) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) // Raw data is persisted in the log record. // ExtraFields.RawRequest/RawResponse are nil unless send-back flags are also enabled. -``` + +```` @@ -426,11 +453,66 @@ response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, sch `x-bf-store-raw-request-response` and `x-bf-send-back-raw-*` are orthogonal — you can enable any combination. Enabling store does not send data back to the caller; enabling send-back does not persist data in logs. Enable both to do both. +### Disable Content Logging (Per-Request) + +**Context Key:** `BifrostContextKeyDisableContentLogging` +**Header:** `x-bf-disable-content-logging` +**Type:** `bool` (header values: `"true"` or `"false"`) +**Required:** No + +Override the logging plugin's global `disable_content_logging` config for a single request. When set to `true`, messages, parameters, tool arguments, and tool results are omitted from the log record for that request. When set to `false`, content is recorded even if the global toggle is off. + +This is useful when you need to suppress sensitive data (e.g. PII, credentials) for specific requests while keeping content logging enabled globally. + + +Per-request overrides are **disabled by default**. You must first enable `allow_per_request_content_storage_override` in your logging configuration (or in the UI under **Logs Settings**) before this header or context key has any effect. When the toggle is off, the global `disable_content_logging` setting is authoritative and this value is ignored. + + + + +```bash +# Suppress content for this request only +curl --location 'http://localhost:8080/v1/chat/completions' \ +--header 'x-bf-disable-content-logging: true' \ +--header 'Content-Type: application/json' \ +--data '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Sensitive data here"}] +}' +```` + + + +```go +ctx := context.Background() +bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + +// Suppress content logging for this request +bfCtx.SetValue(schemas.BifrostContextKeyDisableContentLogging, true) + +response, err := client.ChatCompletionRequest(bfCtx, &schemas.BifrostChatRequest{ +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, +}) + +```` + + + +**Prerequisite:** `allow_per_request_content_storage_override` must be `true` in the logging plugin config (set in `config.json` or via the UI). + +**Precedence (when override is enabled):** The per-request value takes precedence over the global `disable_content_logging` setting. A value of `true` suppresses content; `false` forces content on. + + +This flag affects only what is written to the log record (messages, params, tool results, raw request/response). Token counts, latency, cost, status, and routing metadata are always logged regardless of this setting. + + ### Passthrough Extra Parameters -**Context Key:** `BifrostContextKeyPassthroughExtraParams` -**Header:** `x-bf-passthrough-extra-params` -**Type:** `bool` (header value: `"true"`) +**Context Key:** `BifrostContextKeyPassthroughExtraParams` +**Header:** `x-bf-passthrough-extra-params` +**Type:** `bool` (header value: `"true"`) **Required:** No Enable passthrough mode for extra parameters. When enabled, any parameters in `extra_params` (or provider-specific extra parameter fields) will be merged directly into the request sent to the provider. @@ -450,7 +532,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "b": 123 } }' -``` +```` + ```go @@ -458,20 +541,21 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyPassthroughExtraParams, true) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, - Params: &schemas.ChatParameters{ - ExtraParams: map[string]interface{}{ - "custom_param": "value", - "nested_param": map[string]interface{}{ - "a": "value", - "b": 123, - }, - }, - }, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, +Params: &schemas.ChatParameters{ +ExtraParams: map[string]interface{}{ +"custom_param": "value", +"nested_param": map[string]interface{}{ +"a": "value", +"b": 123, +}, +}, +}, }) -``` + +```` @@ -483,9 +567,9 @@ response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, sch ### Direct Key (Go SDK Only) -**Context Key:** `BifrostContextKeyDirectKey` -**Header:** `-` (not available via HTTP) -**Type:** `schemas.Key` +**Context Key:** `BifrostContextKeyDirectKey` +**Header:** `-` (not available via HTTP) +**Type:** `schemas.Key` **Required:** No Bypass key selection and provide credentials directly. Useful for dynamic key scenarios. @@ -504,7 +588,7 @@ response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, sch Model: "gpt-4o", Input: messages, }) -``` +```` ### Skip Key Selection (Go SDK Only) @@ -574,12 +658,14 @@ response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, sch ``` -When using raw request body, Bifrost bypasses its request conversion and sends your payload directly to the provider. You're responsible for ensuring the payload matches the provider's expected format. + When using raw request body, Bifrost bypasses its request conversion and sends + your payload directly to the provider. You're responsible for ensuring the + payload matches the provider's expected format. ## Custom Headers -### Extra Headers (x-bf-eh-*) +### Extra Headers (x-bf-eh-\*) **Context Key:** `BifrostContextKeyExtraHeaders` **Header Pattern:** `x-bf-eh-{header-name}` @@ -611,11 +697,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -633,9 +720,9 @@ These options control semantic caching behavior. ### Cache Key -**Context Key:** `semanticcache.CacheKey` -**Header:** `x-bf-cache-key` -**Type:** `string` +**Context Key:** `semanticcache.CacheKey` +**Header:** `x-bf-cache-key` +**Type:** `string` **Required:** No Specify a custom cache key for semantic cache lookups. @@ -650,7 +737,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -658,19 +746,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, semanticcache.CacheKey, "custom-key-123") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Cache TTL -**Context Key:** `semanticcache.CacheTTLKey` -**Header:** `x-bf-cache-ttl` -**Type:** `time.Duration` (header value: duration string like `"30s"` or `"5m"`, or seconds as integer) +**Context Key:** `semanticcache.CacheTTLKey` +**Header:** `x-bf-cache-ttl` +**Type:** `time.Duration` (header value: duration string like `"30s"` or `"5m"`, or seconds as integer) **Required:** No Set a custom time-to-live for cached responses. @@ -685,7 +774,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -693,11 +783,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, semanticcache.CacheTTLKey, 5*time.Minute) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -705,9 +796,9 @@ Accepts duration strings (`"30s"`, `"5m"`, `"1h"`) or plain numbers (treated as ### Cache Threshold -**Context Key:** `semanticcache.CacheThresholdKey` -**Header:** `x-bf-cache-threshold` -**Type:** `float64` (range: 0.0 to 1.0) +**Context Key:** `semanticcache.CacheThresholdKey` +**Header:** `x-bf-cache-threshold` +**Type:** `float64` (range: 0.0 to 1.0) **Required:** No Set the similarity threshold for semantic cache matching. @@ -722,7 +813,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -730,19 +822,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, semanticcache.CacheThresholdKey, 0.85) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Cache Type -**Context Key:** `semanticcache.CacheTypeKey` -**Header:** `x-bf-cache-type` -**Type:** `semanticcache.CacheType` (string) +**Context Key:** `semanticcache.CacheTypeKey` +**Header:** `x-bf-cache-type` +**Type:** `semanticcache.CacheType` (string) **Required:** No Specify the cache type for this request. @@ -757,7 +850,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -765,19 +859,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, semanticcache.CacheTypeKey, semanticcache.CacheTypeSemantic) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Cache No Store -**Context Key:** `semanticcache.CacheNoStoreKey` -**Header:** `x-bf-cache-no-store` -**Type:** `bool` (header value: `"true"`) +**Context Key:** `semanticcache.CacheNoStoreKey` +**Header:** `x-bf-cache-no-store` +**Type:** `bool` (header value: `"true"`) **Required:** No Prevent caching of this request/response. @@ -792,7 +887,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -800,11 +896,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, semanticcache.CacheNoStoreKey, true) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -814,9 +911,9 @@ These options control MCP client and tool filtering. ### Include Clients -**Context Key:** `mcp-include-clients` -**Header:** `x-bf-mcp-include-clients` -**Type:** `[]string` (comma-separated values) +**Context Key:** `mcp-include-clients` +**Header:** `x-bf-mcp-include-clients` +**Type:** `[]string` (comma-separated values) **Required:** No Filter MCP clients to include only the specified ones. @@ -831,7 +928,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -839,11 +937,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-clients"), []string{"client1", "client2"}) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -866,7 +965,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -874,11 +974,12 @@ ctx := context.Background() ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), []string{"gmail-send_email", "filesystem-read_file"}) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` @@ -888,9 +989,9 @@ These options enable Maxim observability integration and tag propagation. ### Maxim Trace ID -**Context Key:** `maxim.TraceIDKey` -**Header:** `x-bf-maxim-trace-id` -**Type:** `string` +**Context Key:** `maxim.TraceIDKey` +**Header:** `x-bf-maxim-trace-id` +**Type:** `string` **Required:** No Set the Maxim trace ID for distributed tracing. @@ -905,7 +1006,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -913,19 +1015,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, maxim.TraceIDKey, "trace-12345") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Maxim Generation ID -**Context Key:** `maxim.GenerationIDKey` -**Header:** `x-bf-maxim-generation-id` -**Type:** `string` +**Context Key:** `maxim.GenerationIDKey` +**Header:** `x-bf-maxim-generation-id` +**Type:** `string` **Required:** No Set the Maxim generation ID for request correlation. @@ -940,7 +1043,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -948,19 +1052,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, maxim.GenerationIDKey, "gen-12345") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ### Maxim Tags (x-bf-maxim-*) -**Context Key:** `maxim.TagsKey` -**Header Pattern:** `x-bf-maxim-{tag-name}` -**Type:** `map[string]string` +**Context Key:** `maxim.TagsKey` +**Header Pattern:** `x-bf-maxim-{tag-name}` +**Type:** `map[string]string` **Required:** No Add custom tags to Maxim traces. Any header starting with `x-bf-maxim-` that isn't a reserved header becomes a tag. @@ -976,7 +1081,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -988,19 +1094,20 @@ ctx := context.Background() ctx = context.WithValue(ctx, maxim.TagsKey, tags) response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) -``` + +```` ## Prometheus Options -**Context Key:** `BifrostContextKey(labelName)` -**Header Pattern:** `x-bf-prom-{label-name}` -**Type:** `string` +**Context Key:** `BifrostContextKey(labelName)` +**Header Pattern:** `x-bf-prom-{label-name}` +**Type:** `string` **Required:** No Add custom labels to Prometheus metrics. The `x-bf-prom-` prefix is stripped and the remainder becomes the label name. @@ -1016,7 +1123,8 @@ curl --location 'http://localhost:8080/v1/chat/completions' \ "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello!"}] }' -``` +```` + ```go @@ -1025,10 +1133,11 @@ ctx = context.WithValue(ctx, schemas.BifrostContextKey("environment"), "producti ctx = context.WithValue(ctx, schemas.BifrostContextKey("team"), "engineering") response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: messages, +Provider: schemas.OpenAI, +Model: "gpt-4o-mini", +Input: messages, }) + ``` @@ -1048,7 +1157,7 @@ Bifrost maintains a security denylist of headers that are **never** forwarded to - `x-bf-api-key` (when used via `x-bf-eh-*`) - `x-bf-vk` (when used via `x-bf-eh-*`) -## Internal Context Keys +## Internal Context Keys These context keys are read-only and set when request is completed. **Do not set these values.** @@ -1068,4 +1177,5 @@ The following context keys are set by Bifrost internally. - **[Gateway Provider Configuration](../quickstart/gateway/provider-configuration)** - Configure providers and headers - **[Go SDK Context Keys](../quickstart/go-sdk/context-keys)** - Programmatic context key usage - **[Virtual Keys](../features/governance/virtual-keys)** - Virtual key usage and governance -- **[Semantic Cache](../features/semantic-caching)** - Caching configuration \ No newline at end of file +- **[Semantic Cache](../features/semantic-caching)** - Caching configuration +``` diff --git a/docs/quickstart/gateway/setting-up.mdx b/docs/quickstart/gateway/setting-up.mdx index 7dd58228a9..dac2b552f0 100644 --- a/docs/quickstart/gateway/setting-up.mdx +++ b/docs/quickstart/gateway/setting-up.mdx @@ -17,7 +17,10 @@ Both options work perfectly - choose what fits your workflow: #### NPX Binary @@ -119,7 +122,7 @@ curl -X POST http://localhost:8080/v1/chat/completions \ 1. **Zero Configuration Start**: Bifrost launched without any config files - everything can be configured through the Web UI or API 2. **OpenAI-Compatible API**: All Bifrost APIs follow OpenAI request/response format for seamless integration 3. **Unified API Endpoint**: `/v1/chat/completions` works with any provider (OpenAI, Anthropic, Bedrock, etc.) -4. **Provider Resolution**: `openai/gpt-4o-mini` tells Bifrost to use OpenAI's GPT-4o Mini model +4. **Provider Resolution**: `openai/gpt-4o-mini` tells Bifrost to use OpenAI's GPT-4o Mini model. You can also use bare model names like `gpt-4o-mini`, Bifrost will automatically resolve the provider via the [Model Catalog](/providers/provider-routing#default-provider-resolution) 5. **Automatic Routing**: Bifrost handles authentication, rate limiting, and request routing automatically --- @@ -139,7 +142,9 @@ Bifrost supports **two configuration approaches** - you cannot use both simultan ### Mode 2: File-based Configuration -You can view entire config schema [here](https://www.getbifrost.ai/schema) + + You can view entire config schema [here](https://www.getbifrost.ai/schema) + **When to use:** Advanced setups, GitOps workflows, or when UI is not needed @@ -199,13 +204,12 @@ If you want database persistence but prefer not to use the UI, note that modifyi ## PostgreSQL UTF8 Requirement - - The minimum PostgreSQL version required is 16 or above. - +The minimum PostgreSQL version required is 16 or above. - For the log store, Bifrost creates materialized views to improve analytics performance. Ensure that the PostgreSQL user - has the necessary permissions to perform these operations on the target schema. + For the log store, Bifrost creates materialized views to improve analytics + performance. Ensure that the PostgreSQL user has the necessary permissions to + perform these operations on the target schema. If you use PostgreSQL for `config_store` or `logs_store`, the target database must use `UTF8` encoding. diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index fa70fba8c5..ddeb6ab9a7 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -46,34 +46,36 @@ type CompatConfig struct { // ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. // It includes settings for excess request handling, Prometheus metrics, and initial pool size. type ClientConfig struct { - DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full - InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client - PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics - EnableLogging *bool `json:"enable_logging"` // Enable logging of requests and responses - DisableContentLogging bool `json:"disable_content_logging"` // Disable logging of content - DisableDBPingsInHealth bool `json:"disable_db_pings_in_health"` - LogRetentionDays int `json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) - EnforceAuthOnInference bool `json:"enforce_auth_on_inference"` // Require auth (VK, API key, or user token) on inference endpoints - EnforceGovernanceHeader bool `json:"enforce_governance_header,omitempty"` // Deprecated: use EnforceAuthOnInference - EnforceSCIMAuth bool `json:"enforce_scim_auth,omitempty"` // Deprecated: use EnforceAuthOnInference - AllowDirectKeys bool `json:"allow_direct_keys"` // Allow direct keys to be used for requests - AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) - AllowedHeaders []string `json:"allowed_headers,omitempty"` // Additional allowed headers for CORS and WebSocket - MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB - Compat CompatConfig `json:"compat"` // Compat plugin configuration - MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution - MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds - MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" - MCPToolSyncInterval int `json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) - MCPDisableAutoToolInject bool `json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default - HeaderFilterConfig *tables.GlobalHeaderFilterConfig `json:"header_filter_config,omitempty"` // Global header filtering configuration for x-bf-eh-* headers - AsyncJobResultTTL int `json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour) - RequiredHeaders []string `json:"required_headers,omitempty"` // Headers that must be present on every request (case-insensitive) - LoggingHeaders []string `json:"logging_headers,omitempty"` // Headers to capture in log metadata - WhitelistedRoutes []string `json:"whitelisted_routes,omitempty"` // Routes that bypass auth middleware - HideDeletedVirtualKeysInFilters bool `json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys from logs/MCP filter data - RoutingChainMaxDepth int `json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full + InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client + PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics + EnableLogging *bool `json:"enable_logging"` // Enable logging of requests and responses + DisableContentLogging bool `json:"disable_content_logging"` // Disable logging of content + AllowPerRequestContentStorageOverride bool `json:"allow_per_request_content_storage_override"` // Allow per-request override of content storage via x-bf-disable-content-logging header/context + AllowPerRequestRawOverride bool `json:"allow_per_request_raw_override"` // Allow per-request override of raw request/response visibility via x-bf-send-back-raw-request and x-bf-send-back-raw-response headers + DisableDBPingsInHealth bool `json:"disable_db_pings_in_health"` + LogRetentionDays int `json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) + EnforceAuthOnInference bool `json:"enforce_auth_on_inference"` // Require auth (VK, API key, or user token) on inference endpoints + EnforceGovernanceHeader bool `json:"enforce_governance_header,omitempty"` // Deprecated: use EnforceAuthOnInference + EnforceSCIMAuth bool `json:"enforce_scim_auth,omitempty"` // Deprecated: use EnforceAuthOnInference + AllowDirectKeys bool `json:"allow_direct_keys"` // Allow direct keys to be used for requests + AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) + AllowedHeaders []string `json:"allowed_headers,omitempty"` // Additional allowed headers for CORS and WebSocket + MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB + Compat CompatConfig `json:"compat"` // Compat plugin configuration + MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution + MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds + MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" + MCPToolSyncInterval int `json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) + MCPDisableAutoToolInject bool `json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default + HeaderFilterConfig *tables.GlobalHeaderFilterConfig `json:"header_filter_config,omitempty"` // Global header filtering configuration for x-bf-eh-* headers + AsyncJobResultTTL int `json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour) + RequiredHeaders []string `json:"required_headers,omitempty"` // Headers that must be present on every request (case-insensitive) + LoggingHeaders []string `json:"logging_headers,omitempty"` // Headers to capture in log metadata + WhitelistedRoutes []string `json:"whitelisted_routes,omitempty"` // Routes that bypass auth middleware + HideDeletedVirtualKeysInFilters bool `json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys from logs/MCP filter data + RoutingChainMaxDepth int `json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } // GenerateClientConfigHash generates a SHA256 hash of the client configuration. @@ -174,6 +176,15 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("mcpDisableAutoToolInject:true")) } + // Only hash non-default value to avoid legacy config hash churn on upgrade. + if c.AllowPerRequestContentStorageOverride { + hash.Write([]byte("allowPerRequestContentStorageOverride:true")) + } + + if c.AllowPerRequestRawOverride { + hash.Write([]byte("allowPerRequestRawOverride:true")) + } + if c.AsyncJobResultTTL > 0 { hash.Write([]byte("asyncJobResultTTL:" + strconv.Itoa(c.AsyncJobResultTTL))) } else { diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index e69de29bb2..140698b1c3 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -0,0 +1 @@ +- fix: routing rule chain no longer halts when a chain_rule resolves to the same provider/model (self-loop); subsequent rules now continue to evaluate correctly diff --git a/plugins/governance/routing.go b/plugins/governance/routing.go index 4638f70cab..4c8b33eec4 100644 --- a/plugins/governance/routing.go +++ b/plugins/governance/routing.go @@ -79,7 +79,7 @@ func NewRoutingEngine(store GovernanceStore, logger schemas.Logger, chainMaxDept // and the full scope chain is re-evaluated with the updated context. This repeats until: // 1. No rule matches the current context // 2. A terminal rule matches (chain_rule=false, the default) -// 3. A cycle is detected (a provider/model state was already visited) +// 3. Every chain-rule that could match has already fired once (all candidates exhausted) // 4. The chain exceeds the configured max depth (chainMaxDepth, default 10) func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routingCtx *RoutingContext) (*RoutingDecision, error) { if routingCtx == nil { @@ -92,10 +92,10 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi currentProvider := routingCtx.Provider currentModel := routingCtx.Model - // Track visited provider/model states to detect cycles (e.g. A→B→A). - visited := map[string]struct{}{ - fmt.Sprintf("%s|%s", currentProvider, currentModel): {}, - } + // Track which rule IDs have already fired to prevent a rule from matching more than once per chain. + // This allows a self-looping rule (target == current state) to fire once and then let subsequent + // rules in the chain run, rather than halting with a cycle error. + visitedRuleIDs := map[string]struct{}{} var finalDecision *RoutingDecision @@ -154,6 +154,11 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, schemas.LogLevelInfo, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", "))) for _, rule := range rules { + if _, fired := visitedRuleIDs[rule.ID]; fired { + re.logger.Debug("[RoutingEngine] Skipping rule %s (already fired this chain)", rule.Name) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, schemas.LogLevelInfo, fmt.Sprintf("Rule '%s' skipped: already fired in this chain", rule.Name)) + continue + } re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression) program, err := re.store.GetRoutingProgram(ctx, rule) @@ -174,7 +179,7 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi if !matched { ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, schemas.LogLevelInfo, - fmt.Sprintf("Rule '%s' [%s] → no match (%s)", rule.Name, rule.CelExpression, buildNoMatchContext(rule.CelExpression, variables))) + fmt.Sprintf("Rule '%s' [%s] → no match (%s)", rule.Name, rule.CelExpression, buildNoMatchContext(rule.CelExpression, variables))) continue } @@ -236,14 +241,8 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi break } - // TERMINATION 3: Cycle detection — if the next state was already visited, continuing would loop forever. - nextState := fmt.Sprintf("%s|%s", stepDecision.Provider, stepDecision.Model) - if _, seen := visited[nextState]; seen { - re.logger.Debug("[RoutingEngine] Chain cycle detected at step=%d (state=%s already visited), stopping", chainStep, nextState) - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, schemas.LogLevelInfo, fmt.Sprintf("Chain cycle detected at step %d (provider=%s, model=%s already visited), stopping. Final resolved: provider=%s, model=%s", chainStep, stepDecision.Provider, stepDecision.Model, stepDecision.Provider, stepDecision.Model)) - break - } - visited[nextState] = struct{}{} + // Mark this chain-rule as fired; it will be skipped in all subsequent chain steps. + visitedRuleIDs[matchedRule.ID] = struct{}{} // Advance context for next chain iteration. currentProvider = schemas.ModelProvider(stepDecision.Provider) diff --git a/plugins/governance/routing_test.go b/plugins/governance/routing_test.go index b1816d8dd4..f8125a975c 100644 --- a/plugins/governance/routing_test.go +++ b/plugins/governance/routing_test.go @@ -813,9 +813,9 @@ func TestEvaluateRoutingRules_TerminalRuleStopsChain(t *testing.T) { assert.Equal(t, "terminal-a", decision.MatchedRuleID) } -// TestEvaluateRoutingRules_ConvergenceStopsChain tests that the cycle-detection mechanism stops -// the chain when a chain_rule=true rule resolves to a provider/model already visited (no-op loop). -func TestEvaluateRoutingRules_ConvergenceStopsChain(t *testing.T) { +// TestEvaluateRoutingRules_SelfLoopContinuesToNextRule tests that a chain_rule=true rule which +// resolves to the same provider/model (self-loop) fires once and then allows the next rule to run. +func TestEvaluateRoutingRules_SelfLoopContinuesToNextRule(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) @@ -823,10 +823,67 @@ func TestEvaluateRoutingRules_ConvergenceStopsChain(t *testing.T) { engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) - // Rule A: chain_rule=true but resolves back to the initial provider/model — creates a cycle. + // Rule A: matches gpt-4o, chain_rule=true but resolves back to openai/gpt-4o (self-loop). + // Should fire once and then be skipped so Rule B can run. ruleA := &configstoreTables.TableRoutingRule{ - ID: "converge-a", - Name: "Convergence Rule A", + ID: "self-loop-a", + Name: "Self-Loop Rule A", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4o"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 0, + ChainRule: true, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleA)) + + // Rule B: also matches gpt-4o, terminal — should be reached after Rule A fires once. + ruleB := &configstoreTables.TableRoutingRule{ + ID: "self-loop-b", + Name: "Self-Loop Rule B", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("anthropic"), Model: bifrost.Ptr("claude-3"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 1, + ChainRule: false, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(context.Background(), ruleB)) + + ctx := &RoutingContext{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Headers: map[string]string{}, + QueryParams: map[string]string{}, + } + + decision, err := engine.EvaluateRoutingRules(bgCtx, ctx) + require.NoError(t, err) + require.NotNil(t, decision) + + // Rule A fired once (self-loop), then was skipped. Rule B matched on the second step. + assert.Equal(t, "anthropic", decision.Provider) + assert.Equal(t, "claude-3", decision.Model) + assert.Equal(t, "self-loop-b", decision.MatchedRuleID) +} + +// TestEvaluateRoutingRules_SelfLoopAloneTerminates tests that a self-looping chain rule with no +// other rules terminates cleanly after firing once (TERMINATION 1: no remaining rule matches). +func TestEvaluateRoutingRules_SelfLoopAloneTerminates(t *testing.T) { + store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) + + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) + require.NoError(t, err) + + ruleA := &configstoreTables.TableRoutingRule{ + ID: "solo-self-loop", + Name: "Solo Self-Loop", CelExpression: "model == 'gpt-4o'", Targets: []configstoreTables.TableRoutingTarget{ {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4o"), Weight: 1.0}, @@ -849,10 +906,10 @@ func TestEvaluateRoutingRules_ConvergenceStopsChain(t *testing.T) { require.NoError(t, err) require.NotNil(t, decision) - // Cycle detected after the first match; the last matched decision (openai/gpt-4o) is returned. + // Rule A fired once, then was skipped. No other rules → terminates with Rule A's decision. assert.Equal(t, "openai", decision.Provider) assert.Equal(t, "gpt-4o", decision.Model) - assert.Equal(t, "converge-a", decision.MatchedRuleID) + assert.Equal(t, "solo-self-loop", decision.MatchedRuleID) } // TestEvaluateRoutingRules_MaxDepthCutoff tests that the chain stops once chainMaxDepth is reached, diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb2..baa3f9df70 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1 @@ +- feat: added per-request content logging toggle that overrides the global setting diff --git a/plugins/logging/main.go b/plugins/logging/main.go index fa1a484b90..d5607878dd 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -106,6 +106,37 @@ func applyLargePayloadPreviewsToEntry(ctx *schemas.BifrostContext, entry *logsto } } +// sanitizeErrorForLogging returns a shallow copy of err with ExtraFields.RawRequest and +// RawResponse cleared when raw-byte persistence is disabled, preventing raw bytes from +// leaking into entry.ErrorDetails via JSON serialization. +func sanitizeErrorForLogging(err *schemas.BifrostError, contentLoggingEnabled, shouldStoreRaw bool) *schemas.BifrostError { + if err == nil { + return nil + } + if contentLoggingEnabled && shouldStoreRaw { + return err + } + cloned := *err + cloned.ExtraFields.RawRequest = nil + cloned.ExtraFields.RawResponse = nil + return &cloned +} + +// contentLoggingEnabled returns true if content (messages, params, tool results) should be +// recorded for this request. The BifrostContextKeyDisableContentLogging per-request override is +// only honored when BifrostContextKeyAllowPerRequestStorageOverride is true in context (set by +// ConvertToBifrostContext from allow_per_request_content_storage_override config). +func (p *LoggerPlugin) contentLoggingEnabled(ctx *schemas.BifrostContext) bool { + if ctx != nil { + if perRequestAllowed, _ := ctx.Value(schemas.BifrostContextKeyAllowPerRequestStorageOverride).(bool); perRequestAllowed { + if override, ok := ctx.Value(schemas.BifrostContextKeyDisableContentLogging).(bool); ok { + return !override + } + } + } + return p.disableContentLogging == nil || !*p.disableContentLogging +} + // scheduleDeferredUsageUpdate schedules a deferred usage update for the request. func (p *LoggerPlugin) scheduleDeferredUsageUpdate(ctx *schemas.BifrostContext, requestID string, usageAlreadyPresent bool) { if usageAlreadyPresent || ctx == nil { @@ -475,7 +506,7 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr initialData.Object = "realtime.turn" } - if p.disableContentLogging == nil || !*p.disableContentLogging { + if p.contentLoggingEnabled(ctx) { inputHistory, responsesInputHistory := p.extractInputHistory(req) initialData.InputHistory = inputHistory initialData.ResponsesInputHistory = responsesInputHistory @@ -713,7 +744,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. requestType, _, originalModelRequested, resolvedModelUsed := bifrost.GetResponseFields(result, bifrostErr) shouldStoreRaw, _ := ctx.Value(schemas.BifrostContextKeyShouldStoreRawInLogs).(bool) - contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging + contentLoggingEnabled := p.contentLoggingEnabled(ctx) isFinalChunk := bifrost.IsFinalChunk(ctx) @@ -746,7 +777,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. CreatedAt: time.Now().UTC(), } applyModelAlias(entry, originalModelRequested, resolvedModelUsed) - if data, err := sonic.Marshal(bifrostErr); err == nil { + if data, err := sonic.Marshal(sanitizeErrorForLogging(bifrostErr, contentLoggingEnabled, shouldStoreRaw)); err == nil { entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr @@ -827,7 +858,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. traceID != "" { if accResult := tracer.ProcessStreamingChunk(traceID, true, result, bifrostErr); accResult != nil { if streamResponse := convertToProcessedStreamResponse(accResult, requestType); streamResponse != nil { - p.applyStreamingOutputToEntry(entry, streamResponse, shouldStoreRaw) + p.applyStreamingOutputToEntry(entry, streamResponse, shouldStoreRaw, contentLoggingEnabled) } } tracer.CleanupStreamAccumulator(traceID) @@ -836,7 +867,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Serialize error details immediately since bifrostErr may be released // back to the pool before the async batch writer processes this entry. // Also set ErrorDetailsParsed for UI callback (JSON serialization uses this field). - if data, err := sonic.Marshal(bifrostErr); err == nil { + if data, err := sonic.Marshal(sanitizeErrorForLogging(bifrostErr, contentLoggingEnabled, shouldStoreRaw)); err == nil { entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr @@ -875,7 +906,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry.Status = "error" entry.Stream = true applyModelAlias(entry, originalModelRequested, resolvedModelUsed) - if data, err := sonic.Marshal(bifrostErr); err == nil { + if data, err := sonic.Marshal(sanitizeErrorForLogging(bifrostErr, contentLoggingEnabled, shouldStoreRaw)); err == nil { entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr @@ -887,7 +918,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } else if isFinalChunk { // Apply streaming output fields to the entry entry.Stream = true - p.applyStreamingOutputToEntry(entry, streamResponse, shouldStoreRaw) + p.applyStreamingOutputToEntry(entry, streamResponse, shouldStoreRaw, contentLoggingEnabled) } if entry.ErrorDetails != "" || entry.ErrorDetailsParsed != nil { entry.Status = "error" @@ -926,7 +957,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Serialize error details immediately since bifrostErr may be released // back to the pool before the async batch writer processes this entry. // Also set ErrorDetailsParsed for UI callback (JSON serialization uses this field). - if data, err := sonic.Marshal(bifrostErr); err == nil { + if data, err := sonic.Marshal(sanitizeErrorForLogging(bifrostErr, contentLoggingEnabled, shouldStoreRaw)); err == nil { entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr @@ -940,9 +971,9 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. extraFields := result.GetExtraFields() applyModelAlias(entry, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed) if requestType == schemas.RealtimeRequest { - p.applyRealtimeOutputToEntry(entry, result, shouldStoreRaw) + p.applyRealtimeOutputToEntry(entry, result, shouldStoreRaw, contentLoggingEnabled) } else { - p.applyNonStreamingOutputToEntry(entry, result, shouldStoreRaw) + p.applyNonStreamingOutputToEntry(entry, result, shouldStoreRaw, contentLoggingEnabled) } // Flip status for passthrough error responses (4xx/5xx from provider) if isPassthroughErrorResponse(result) { @@ -1161,7 +1192,7 @@ func (p *LoggerPlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas.Bifr } // Set arguments if content logging is enabled - if p.disableContentLogging == nil || !*p.disableContentLogging { + if p.contentLoggingEnabled(ctx) { entry.ArgumentsParsed = arguments } @@ -1262,7 +1293,7 @@ func (p *LoggerPlugin) PostMCPHook(ctx *schemas.BifrostContext, resp *schemas.Bi } else if resp != nil { updates["status"] = "success" // Store result if content logging is enabled - if p.disableContentLogging == nil || !*p.disableContentLogging { + if p.contentLoggingEnabled(ctx) { var result interface{} if resp.ChatMessage != nil { // For ChatMessage, try to parse the content as JSON if it's a string diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index 42850c5953..8443a32c35 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -147,6 +147,7 @@ func (p *LoggerPlugin) updateLogEntry( cacheDebug *schemas.BifrostCacheDebug, routingEngineLogs string, data *UpdateLogData, + contentLoggingEnabled bool, ) error { updates := make(map[string]interface{}) if selectedKeyID != "" { @@ -177,7 +178,6 @@ func (p *LoggerPlugin) updateLogEntry( if routingEngineLogs != "" { updates["routing_engine_logs"] = routingEngineLogs } - contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging tempEntry := &logstore.Log{} needsSerialization := false @@ -294,12 +294,12 @@ func (p *LoggerPlugin) updateLogEntry( if data.IsLargePayloadResponse { updates["is_large_payload_response"] = true // Large payload preview is already a string — skip sonic.Marshal. - if p.disableContentLogging == nil || !*p.disableContentLogging { + if contentLoggingEnabled { if str, ok := data.RawResponse.(string); ok { updates["raw_response"] = str } } - } else if (p.disableContentLogging == nil || !*p.disableContentLogging) && data.RawResponse != nil { + } else if contentLoggingEnabled && data.RawResponse != nil { rawResponseBytes, err := sonic.Marshal(data.RawResponse) if err != nil { p.logger.Error("failed to marshal raw response: %v", err) @@ -332,7 +332,7 @@ func (p *LoggerPlugin) makePostWriteCallback(enrichFn func(*logstore.Log)) func( // applyStreamingOutputToEntry applies accumulated streaming data to a log entry. // shouldStoreRaw gates whether raw request/response bytes are written to the entry. -func (p *LoggerPlugin) applyStreamingOutputToEntry(entry *logstore.Log, streamResponse *streaming.ProcessedStreamResponse, shouldStoreRaw bool) { +func (p *LoggerPlugin) applyStreamingOutputToEntry(entry *logstore.Log, streamResponse *streaming.ProcessedStreamResponse, shouldStoreRaw bool, contentLoggingEnabled bool) { if streamResponse.Data == nil { return } @@ -378,7 +378,17 @@ func (p *LoggerPlugin) applyStreamingOutputToEntry(entry *logstore.Log, streamRe entry.StopReason = streamResponse.Data.FinishReason } - if p.disableContentLogging == nil || !*p.disableContentLogging { + // Cache + if streamResponse.Data.CacheDebug != nil { + entry.CacheDebugParsed = streamResponse.Data.CacheDebug + } + + // Finish/stop reason - always persist regardless of content logging settings + if streamResponse.Data.FinishReason != nil { + entry.StopReason = streamResponse.Data.FinishReason + } + + if contentLoggingEnabled { // Transcription output if streamResponse.Data.TranscriptionOutput != nil { entry.TranscriptionOutputParsed = streamResponse.Data.TranscriptionOutput @@ -430,7 +440,7 @@ func isPassthroughErrorResponse(result *schemas.BifrostResponse) bool { // applyNonStreamingOutputToEntry applies non-streaming response data to a log entry. // shouldStoreRaw gates whether raw request/response bytes are written to the entry. -func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse, shouldStoreRaw bool) { +func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse, shouldStoreRaw bool, contentLoggingEnabled bool) { if result == nil { return } @@ -493,7 +503,7 @@ func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, resul entry.StopReason = result.ResponsesResponse.StopReason } - if p.disableContentLogging == nil || !*p.disableContentLogging { + if contentLoggingEnabled { if shouldStoreRaw { if extraFields.RawRequest != nil { rawRequestBytes, err := sonic.Marshal(extraFields.RawRequest) @@ -565,7 +575,7 @@ func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, resul } } -func (p *LoggerPlugin) applyRealtimeOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse, shouldStoreRaw bool) { +func (p *LoggerPlugin) applyRealtimeOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse, shouldStoreRaw bool, contentLoggingEnabled bool) { if result == nil || result.ResponsesResponse == nil { return } @@ -583,8 +593,6 @@ func (p *LoggerPlugin) applyRealtimeOutputToEntry(entry *logstore.Log, result *s entry.TotalTokens = bifrostUsage.TotalTokens } - contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging - if contentLoggingEnabled { if outputMessage := extractRealtimeOutputMessage(result.ResponsesResponse.Output); outputMessage != nil { entry.OutputMessageParsed = outputMessage diff --git a/plugins/logging/operations_test.go b/plugins/logging/operations_test.go index 45a0b747e6..df03af9354 100644 --- a/plugins/logging/operations_test.go +++ b/plugins/logging/operations_test.go @@ -75,7 +75,7 @@ func TestUpdateLogEntryPreservesResponsesInputContentSummary(t *testing.T) { }}, } - if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update); err != nil { + if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update, true); err != nil { t.Fatalf("updateLogEntry() error = %v", err) } @@ -121,7 +121,7 @@ func TestUpdateLogEntryUpdatesContentSummaryForChatOutput(t *testing.T) { }, } - if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update); err != nil { + if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update, true); err != nil { t.Fatalf("updateLogEntry() error = %v", err) } @@ -136,11 +136,9 @@ func TestUpdateLogEntryUpdatesContentSummaryForChatOutput(t *testing.T) { func TestUpdateLogEntrySuppressesChatOutputWhenContentLoggingDisabled(t *testing.T) { store := newTestStore(t) - disableContentLogging := true plugin := &LoggerPlugin{ - store: store, - logger: testLogger{}, - disableContentLogging: &disableContentLogging, + store: store, + logger: testLogger{}, } requestID := "req-chat-disabled" @@ -166,7 +164,7 @@ func TestUpdateLogEntrySuppressesChatOutputWhenContentLoggingDisabled(t *testing }, } - if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update); err != nil { + if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update, false); err != nil { t.Fatalf("updateLogEntry() error = %v", err) } @@ -292,7 +290,7 @@ func TestApplyRealtimeOutputToEntryBackfillsUserTranscriptFromRawRequest(t *test }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -341,7 +339,7 @@ func TestApplyRealtimeOutputToEntryBackfillsMissingTranscriptPlaceholder(t *test }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -381,7 +379,7 @@ func TestApplyRealtimeOutputToEntryBackfillsDoneMissingTranscriptPlaceholder(t * }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -421,7 +419,7 @@ func TestApplyRealtimeOutputToEntryBackfillsRetrievedUserAndToolHistory(t *testi }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -462,7 +460,7 @@ func TestApplyRealtimeOutputToEntryBackfillsCreatedUserAndToolHistory(t *testing }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if len(entry.InputHistoryParsed) != 2 { t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) @@ -513,7 +511,7 @@ func TestApplyRealtimeOutputToEntryBackfillsAddedUserAndToolHistory(t *testing.T }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -576,7 +574,7 @@ func TestApplyRealtimeOutputToEntryMergesRawTranscriptIntoStructuredRealtimeHist }, } - plugin.applyRealtimeOutputToEntry(entry, result, true) + plugin.applyRealtimeOutputToEntry(entry, result, true, true) if err := entry.SerializeFields(); err != nil { t.Fatalf("SerializeFields() error = %v", err) } @@ -628,7 +626,7 @@ func TestApplyRealtimeOutputToEntryDoesNotPersistRawWhenShouldStoreRawFalse(t *t }, } - plugin.applyRealtimeOutputToEntry(entry, result, false) + plugin.applyRealtimeOutputToEntry(entry, result, false, true) if entry.RawRequest != "" { t.Fatalf("expected RawRequest to remain empty when shouldStoreRaw=false, got %q", entry.RawRequest) @@ -643,3 +641,202 @@ func TestApplyRealtimeOutputToEntryDoesNotPersistRawWhenShouldStoreRawFalse(t *t t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) } } + +// TestContentLoggingEnabledHelper verifies precedence: ctx override > global config > default-enabled. +func TestContentLoggingEnabledHelper(t *testing.T) { + boolPtr := func(b bool) *bool { return &b } + + tests := []struct { + name string + globalDisable *bool + ctxOverride *bool // nil = don't set the key + want bool + }{ + {"no config no override → enabled", nil, nil, true}, + {"global disable=false no override → enabled", boolPtr(false), nil, true}, + {"global disable=true no override → disabled", boolPtr(true), nil, false}, + {"ctx override=false global disable=true → enabled", boolPtr(true), boolPtr(false), true}, + {"ctx override=true global disable=false → disabled", boolPtr(false), boolPtr(true), false}, + {"ctx override=true nil global → disabled", nil, boolPtr(true), false}, + {"ctx override=false nil global → enabled", nil, boolPtr(false), true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := &LoggerPlugin{disableContentLogging: tc.globalDisable} + + var ctx *schemas.BifrostContext + if tc.ctxOverride != nil { + ctx = schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyAllowPerRequestStorageOverride, true) + ctx.SetValue(schemas.BifrostContextKeyDisableContentLogging, *tc.ctxOverride) + } + + got := p.contentLoggingEnabled(ctx) + if got != tc.want { + t.Errorf("contentLoggingEnabled() = %v, want %v", got, tc.want) + } + }) + } +} + +// TestContentLoggingEnabledHelperNilCtx verifies nil context falls back to global config. +func TestContentLoggingEnabledHelperNilCtx(t *testing.T) { + disabled := true + p := &LoggerPlugin{disableContentLogging: &disabled} + if p.contentLoggingEnabled(nil) { + t.Error("expected false with nil ctx and global disable=true") + } +} + +// TestUpdateLogEntryPerRequestOverrideEnablesContent verifies that passing contentLoggingEnabled=true +// to updateLogEntry stores output even when the plugin's global toggle is disabled. +func TestUpdateLogEntryPerRequestOverrideEnablesContent(t *testing.T) { + store := newTestStore(t) + disabled := true + plugin := &LoggerPlugin{ + store: store, + logger: testLogger{}, + disableContentLogging: &disabled, // global: off + } + + requestID := "req-per-request-enable" + now := time.Now().UTC() + if err := plugin.insertInitialLogEntry(context.Background(), requestID, "", now, 0, nil, &InitialLogData{ + Object: "chat_completion", + Provider: "openai", + Model: "gpt-4o-mini", + }); err != nil { + t.Fatalf("insertInitialLogEntry() error = %v", err) + } + + chatText := "should be stored via per-request override" + update := &UpdateLogData{ + Status: "success", + ChatOutput: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &chatText}, + }, + } + + // Explicitly pass true — simulates the per-request ctx override enabling content logging + if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update, true); err != nil { + t.Fatalf("updateLogEntry() error = %v", err) + } + + logEntry, err := store.FindByID(context.Background(), requestID) + if err != nil { + t.Fatalf("FindByID() error = %v", err) + } + if logEntry.OutputMessage == "" { + t.Error("expected output_message to be stored when contentLoggingEnabled=true override is used") + } +} + +// TestUpdateLogEntryPerRequestOverrideDisablesContent verifies that passing contentLoggingEnabled=false +// suppresses output even when the plugin's global toggle is enabled. +func TestUpdateLogEntryPerRequestOverrideDisablesContent(t *testing.T) { + store := newTestStore(t) + plugin := &LoggerPlugin{ + store: store, + logger: testLogger{}, + // global: nil → content logging on by default + } + + requestID := "req-per-request-disable" + now := time.Now().UTC() + if err := plugin.insertInitialLogEntry(context.Background(), requestID, "", now, 0, nil, &InitialLogData{ + Object: "chat_completion", + Provider: "openai", + Model: "gpt-4o-mini", + }); err != nil { + t.Fatalf("insertInitialLogEntry() error = %v", err) + } + + chatText := "should NOT be stored" + update := &UpdateLogData{ + Status: "success", + ChatOutput: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &chatText}, + }, + } + + // Explicitly pass false — simulates x-bf-disable-content-logging: true on this request + if err := plugin.updateLogEntry(context.Background(), requestID, "", "", 10, "", "", "", "", 0, nil, "", update, false); err != nil { + t.Fatalf("updateLogEntry() error = %v", err) + } + + logEntry, err := store.FindByID(context.Background(), requestID) + if err != nil { + t.Fatalf("FindByID() error = %v", err) + } + if logEntry.OutputMessage != "" { + t.Errorf("expected output_message to be suppressed, got %q", logEntry.OutputMessage) + } +} + +// TestApplyNonStreamingOutputToEntryContentLoggingDisabled verifies that output fields are +// suppressed when contentLoggingEnabled=false. +func TestApplyNonStreamingOutputToEntryContentLoggingDisabled(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + chatText := "should not appear" + result := &schemas.BifrostResponse{ + ChatResponse: &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &chatText}, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionRequest, + }, + }, + } + + plugin.applyNonStreamingOutputToEntry(entry, result, false, false) + + if entry.OutputMessageParsed != nil { + t.Error("expected OutputMessageParsed to be nil when contentLoggingEnabled=false") + } +} + +// TestApplyNonStreamingOutputToEntryContentLoggingEnabled verifies that output fields are +// stored when contentLoggingEnabled=true regardless of the global plugin config. +func TestApplyNonStreamingOutputToEntryContentLoggingEnabled(t *testing.T) { + disabled := true + plugin := &LoggerPlugin{disableContentLogging: &disabled} // global off, but explicit true passed + entry := &logstore.Log{} + + chatText := "should appear" + result := &schemas.BifrostResponse{ + ChatResponse: &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &chatText}, + }, + }, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionRequest, + }, + }, + } + + plugin.applyNonStreamingOutputToEntry(entry, result, false, true) + + if entry.OutputMessageParsed == nil { + t.Error("expected OutputMessageParsed to be set when contentLoggingEnabled=true") + } +} diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 148bbfef1a..c065ceff35 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -15,7 +15,6 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/vectorstore" ) @@ -25,7 +24,6 @@ import ( type Config struct { // Embedding Model settings - REQUIRED for semantic caching Provider schemas.ModelProvider `json:"provider"` - Keys []schemas.Key `json:"keys"` EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional) // Plugin behavior settings @@ -48,19 +46,18 @@ type Config struct { func (c *Config) UnmarshalJSON(data []byte) error { // Define a temporary struct to avoid infinite recursion type TempConfig struct { - Provider string `json:"provider"` - Keys []schemas.Key `json:"keys"` - EmbeddingModel string `json:"embedding_model,omitempty"` - CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` - Dimension int `json:"dimension"` - TTL interface{} `json:"ttl,omitempty"` - Threshold float64 `json:"threshold,omitempty"` - VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` - DefaultCacheKey string `json:"default_cache_key,omitempty"` - ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` - CacheByModel *bool `json:"cache_by_model,omitempty"` - CacheByProvider *bool `json:"cache_by_provider,omitempty"` - ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` + Provider string `json:"provider"` + EmbeddingModel string `json:"embedding_model,omitempty"` + CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` + Dimension int `json:"dimension"` + TTL interface{} `json:"ttl,omitempty"` + Threshold float64 `json:"threshold,omitempty"` + VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` + DefaultCacheKey string `json:"default_cache_key,omitempty"` + ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` + CacheByModel *bool `json:"cache_by_model,omitempty"` + CacheByProvider *bool `json:"cache_by_provider,omitempty"` + ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` } var temp TempConfig @@ -70,7 +67,6 @@ func (c *Config) UnmarshalJSON(data []byte) error { // Set simple fields c.Provider = schemas.ModelProvider(temp.Provider) - c.Keys = temp.Keys c.EmbeddingModel = temp.EmbeddingModel c.CleanUpOnShutdown = temp.CleanUpOnShutdown c.Dimension = temp.Dimension @@ -129,6 +125,10 @@ type StreamAccumulator struct { mu sync.Mutex // Protects chunk operations } +// EmbeddingRequestExecutor is a function that executes a request and returns a response and an error. +// It maps to .EmbeddingRequest() of the bifrost client. +type EmbeddingRequestExecutor func(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) + // Plugin implements the schemas.LLMPlugin interface for semantic caching. // It caches responses using a two-tier approach: direct hash matching for exact requests // and semantic similarity search for related content. The plugin supports configurable caching behavior @@ -139,12 +139,12 @@ type StreamAccumulator struct { // - config: Plugin configuration including semantic cache and caching settings // - logger: Logger instance for plugin operations type Plugin struct { - store vectorstore.VectorStore - config *Config - logger schemas.Logger - client *bifrost.Bifrost - streamAccumulators sync.Map // Track stream accumulators by request ID - waitGroup sync.WaitGroup + store vectorstore.VectorStore + config *Config + logger schemas.Logger + embeddingRequestExecutor EmbeddingRequestExecutor + streamAccumulators sync.Map // Track stream accumulators by request ID + waitGroup sync.WaitGroup } // Plugin constants @@ -201,45 +201,6 @@ var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ }, } -type PluginAccount struct { - provider schemas.ModelProvider - keys []schemas.Key -} - -func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{pa.provider}, nil -} - -func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { - return pa.keys, nil -} - -func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - return &schemas.ProviderConfig{ - NetworkConfig: schemas.DefaultNetworkConfig, - ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, - }, nil -} - -// Dependencies is a list of dependencies that the plugin requires. -var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore} - -// ProvidersWithEmbeddingSupport lists all providers that support embedding operations. -// Providers not in this list will return UnsupportedOperationError for embedding requests. -var ProvidersWithEmbeddingSupport = map[schemas.ModelProvider]bool{ - schemas.OpenAI: true, - schemas.Azure: true, - schemas.Bedrock: true, - schemas.Cohere: true, - schemas.Gemini: true, - schemas.Vertex: true, - schemas.Mistral: true, - schemas.Ollama: true, - schemas.Nebius: true, - schemas.HuggingFace: true, - schemas.SGL: true, -} - const ( CacheKey schemas.BifrostContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests CacheTTLKey schemas.BifrostContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request @@ -323,26 +284,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store vect if config.Provider == "" && config.Dimension == 1 { logger.Info(PluginLoggerPrefix + " Starting in direct-only mode (dimension=1, no embedding provider)") - } else if config.Provider == "" || len(config.Keys) == 0 { - logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider or keys, falling back to direct search only") - } else { - // Validate that the provider supports embeddings - if bifrost.IsStandardProvider(config.Provider) && !ProvidersWithEmbeddingSupport[config.Provider] { - return nil, fmt.Errorf("provider '%s' does not support embedding operations required for semantic cache. Supported providers: openai, azure, bedrock, cohere, gemini, vertex, mistral, ollama, nebius, huggingface, sgl. Note: custom providers based on embedding-capable providers are also supported", config.Provider) - } - - bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{ - Logger: logger, - Account: &PluginAccount{ - provider: config.Provider, - keys: config.Keys, - }, - }) - if err != nil { - return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err) - } - - plugin.client = bifrost + } else if config.Provider == "" { + logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider, falling back to direct search only") } createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout) @@ -378,19 +321,6 @@ func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, return chunk, nil } -func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { - ctx.ClearValue(requestIDKey) - ctx.ClearValue(requestStorageIDKey) - ctx.ClearValue(requestHashKey) - ctx.ClearValue(requestParamsHashKey) - ctx.ClearValue(requestModelKey) - ctx.ClearValue(requestProviderKey) - ctx.ClearValue(requestEmbeddingKey) - ctx.ClearValue(requestEmbeddingTokensKey) - ctx.ClearValue(isCacheHitKey) - ctx.ClearValue(cacheHitTypeKey) -} - // PreLLMHook is called before a request is processed by Bifrost. // It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. // Uses UUID-based keys for entries stored in the VectorStore. @@ -465,7 +395,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro } } - if performSemanticSearch && plugin.client != nil { + if performSemanticSearch && plugin.embeddingRequestExecutor != nil { if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") // For vector stores that require vectors, set a zero vector placeholder @@ -488,7 +418,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro if shortCircuit != nil { return req, shortCircuit, nil } - } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.client != nil { + } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.embeddingRequestExecutor != nil { // Vector store requires vectors but we're in direct-only mode // Generate embeddings for storage purposes (not for searching) if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { @@ -759,11 +689,6 @@ func (plugin *Plugin) Cleanup() error { // Clean up old stream accumulators first plugin.cleanupOldStreamAccumulators() - // Shutdown the internal Bifrost client used for embeddings - if plugin.client != nil { - plugin.client.Shutdown() - } - // Only clean up cache entries if configured to do so if !plugin.config.CleanUpOnShutdown { plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup") @@ -804,6 +729,15 @@ func (plugin *Plugin) Cleanup() error { return nil } +// SetEmbeddingRequestExecutor sets the embedding request executor for the plugin. +// Needs to be set before the plugin is used. +// +// Parameters: +// - executor: The embedding request executor to set +func (plugin *Plugin) SetEmbeddingRequestExecutor(executor EmbeddingRequestExecutor) { + plugin.embeddingRequestExecutor = executor +} + // Public Methods for External Use // ClearCacheForKey deletes cache entries for a specific cache key. @@ -869,3 +803,16 @@ func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { return nil } + +func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { + ctx.ClearValue(requestIDKey) + ctx.ClearValue(requestStorageIDKey) + ctx.ClearValue(requestHashKey) + ctx.ClearValue(requestParamsHashKey) + ctx.ClearValue(requestModelKey) + ctx.ClearValue(requestProviderKey) + ctx.ClearValue(requestEmbeddingKey) + ctx.ClearValue(requestEmbeddingTokensKey) + ctx.ClearValue(isCacheHitKey) + ctx.ClearValue(cacheHitTypeKey) +} diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 822fc1f645..5bed26528d 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -2,7 +2,6 @@ package semanticcache import ( "context" - "strings" "testing" "time" @@ -389,9 +388,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.95, // Very high threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "strict_matching", }, @@ -402,9 +398,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.1, // Very low threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "loose_matching", }, @@ -416,9 +409,6 @@ func TestCacheConfiguration(t *testing.T) { Dimension: 1536, Threshold: 0.8, TTL: 1 * time.Hour, // Custom TTL - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "custom_ttl", }, @@ -463,7 +453,7 @@ func (m *MockUnsupportedStore) Ping(ctx context.Context) error { } func (m *MockUnsupportedStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]vectorstore.VectorStoreProperties) error { - return vectorstore.ErrNotSupported + return nil } func (m *MockUnsupportedStore) DeleteNamespace(ctx context.Context, namespace string) error { @@ -547,23 +537,13 @@ func TestInvalidProviderRejection(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.TEST_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } + // Provider validation was moved to request time (global client handles it). + // Init itself should succeed regardless of the provider set in config. _, err := Init(ctx, config, logger, mockStore) - if err == nil { - t.Errorf("Expected error for provider '%s' but got none", provider) - } - - expectedErrSubstring := "does not support embedding operations" - if err != nil && !strings.Contains(err.Error(), expectedErrSubstring) { - t.Errorf("Expected error message to contain '%s', but got: %v", expectedErrSubstring, err) + if err != nil { + t.Errorf("Init should succeed for provider '%s' (validation happens at request time), but got: %v", provider, err) } }) } @@ -584,18 +564,11 @@ func TestValidProviderAccepted(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } - // Should fail due to namespace creation, not provider validation + // Init should succeed; provider validation happens at request time via the global client. _, err := Init(ctx, config, logger, mockStore) - if err != nil && strings.Contains(err.Error(), "does not support embedding operations") { - t.Errorf("Valid provider OpenAI should not be rejected for embedding support, but got: %v", err) + if err != nil { + t.Errorf("Valid provider OpenAI should be accepted at Init, but got: %v", err) } } diff --git a/plugins/semanticcache/plugin_image_generation_test.go b/plugins/semanticcache/plugin_image_generation_test.go index f50f3c5c9b..a65c06e81b 100644 --- a/plugins/semanticcache/plugin_image_generation_test.go +++ b/plugins/semanticcache/plugin_image_generation_test.go @@ -128,9 +128,6 @@ func TestImageGenerationSemanticSearch(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.5, - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{"*"}, Weight: 1.0}, - }, } setup := NewTestSetupWithConfig(t, config) defer setup.Cleanup() diff --git a/plugins/semanticcache/plugin_vectorstore_test.go b/plugins/semanticcache/plugin_vectorstore_test.go index 5e390bbe80..f4ac8130f2 100644 --- a/plugins/semanticcache/plugin_vectorstore_test.go +++ b/plugins/semanticcache/plugin_vectorstore_test.go @@ -55,13 +55,6 @@ func getDefaultTestConfig() *Config { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } } diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index d267254473..e9b847c6dc 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -371,13 +371,6 @@ func NewTestSetup(t *testing.T) *TestSetup { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, }) } @@ -427,6 +420,9 @@ func NewTestSetupWithVectorStore(t *testing.T, config *Config, storeType vectors // Get a mocked Bifrost client client := getMockedBifrostClient(t, ctx, logger, plugin) + // Wire the global client as the embedding executor so semantic search works. + pluginImpl.SetEmbeddingRequestExecutor(client.EmbeddingRequest) + return &TestSetup{ Logger: logger, Store: store, @@ -648,13 +644,6 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test CleanUpOnShutdown: true, Threshold: 0.8, ConversationHistoryThreshold: threshold, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -669,13 +658,6 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T CleanUpOnShutdown: true, Threshold: 0.8, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -691,13 +673,6 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e Threshold: 0.8, ConversationHistoryThreshold: threshold, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 712030051a..957115ee24 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -67,8 +67,16 @@ func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string }, } - // Generate embedding using bifrost client - response, err := plugin.client.EmbeddingRequest(ctx, embeddingReq) + // Create a new context from incoming context. Parent ctx will be used for cancellation. + embeddingCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + defer embeddingCtx.ReleasePluginScope() + + embeddingCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + + if plugin.embeddingRequestExecutor == nil { + return nil, 0, fmt.Errorf("embedding request executor is not configured") + } + response, err := plugin.embeddingRequestExecutor(embeddingCtx, embeddingReq) if err != nil { return nil, 0, fmt.Errorf("failed to generate embedding: %v", err) } diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index 574bea702d..1a4f07e741 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -100,7 +100,7 @@ func (h *AsyncHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.B // asyncTextCompletion handles POST /v1/async/completions func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { - req, bifrostTextReq, err := prepareTextCompletionRequest(ctx) + req, bifrostTextReq, err := prepareTextCompletionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -111,7 +111,7 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -137,7 +137,7 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { // asyncChatCompletion handles POST /v1/async/chat/completions func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { - req, bifrostChatReq, err := prepareChatCompletionRequest(ctx) + req, bifrostChatReq, err := prepareChatCompletionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -148,7 +148,7 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -174,7 +174,7 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { // asyncResponses handles POST /v1/async/responses func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { - req, bifrostResponsesReq, err := prepareResponsesRequest(ctx) + req, bifrostResponsesReq, err := prepareResponsesRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -185,7 +185,7 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -212,13 +212,13 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { // asyncEmbeddings handles POST /v1/async/embeddings func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { - _, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx) + _, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -244,7 +244,7 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { // asyncSpeech handles POST /v1/async/audio/speech func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { - req, bifrostSpeechReq, err := prepareSpeechRequest(ctx) + req, bifrostSpeechReq, err := prepareSpeechRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -255,7 +255,7 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -281,7 +281,7 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { // asyncTranscription handles POST /v1/async/audio/transcriptions func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { - bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx) + bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -292,7 +292,7 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -318,7 +318,7 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { // asyncImageGeneration handles POST /v1/async/images/generations func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { - req, bifrostReq, err := prepareImageGenerationRequest(ctx) + req, bifrostReq, err := prepareImageGenerationRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -329,7 +329,7 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -355,7 +355,7 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { // asyncImageEdit handles POST /v1/async/images/edits func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { - req, bifrostReq, err := prepareImageEditRequest(ctx) + req, bifrostReq, err := prepareImageEditRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -366,7 +366,7 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -392,13 +392,13 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { // asyncImageVariation handles POST /v1/async/images/variations func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { - bifrostReq, err := prepareImageVariationRequest(ctx) + bifrostReq, err := prepareImageVariationRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -424,13 +424,13 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { // asyncRerank handles POST /v1/async/rerank func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { - _, bifrostReq, err := prepareRerankRequest(ctx) + _, bifrostReq, err := prepareRerankRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return @@ -456,13 +456,13 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { // asyncOCR handles POST /v1/async/ocr func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) { - _, bifrostReq, err := prepareOCRRequest(ctx) + _, bifrostReq, err := prepareOCRRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return @@ -498,7 +498,7 @@ func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.Reques } // Get the requesting user's VK for auth check - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index eb3c4bccb2..6ce3f3f04f 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -43,20 +43,76 @@ func forwardProviderHeadersFromContext(ctx *fasthttp.RequestCtx, bifrostCtx *sch // CompletionHandler manages HTTP requests for completion operations type CompletionHandler struct { - client *bifrost.Bifrost - handlerStore lib.HandlerStore - config *lib.Config + client *bifrost.Bifrost + config *lib.Config } // NewInferenceHandler creates a new completion handler instance func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *CompletionHandler { return &CompletionHandler{ - client: client, - handlerStore: config, - config: config, + client: client, + config: config, } } +// resolveModelAndProvider parses the model string, validates it, and resolves +// the provider via model catalog when no provider prefix is present. Stores +// resolution metadata on the fasthttp context for ConvertToBifrostContext to +// emit the routing engine log. +func resolveModelAndProvider(ctx *fasthttp.RequestCtx, config *lib.Config, model string) (schemas.ModelProvider, string, error) { + provider, modelName := schemas.ParseModelString(model, "") + if modelName == "" { + return "", "", fmt.Errorf("model is required") + } + if provider == "" { + providers := config.GetProvidersForModel(modelName) + if len(providers) == 0 { + return "", "", fmt.Errorf("provider is required in model field (format: provider/model) — no providers found for model %q in model catalog", modelName) + } + ctx.SetUserValue(lib.FastHTTPUserValueModelCatalogResolution, &lib.ModelCatalogResolution{ + Model: modelName, + ResolvedProvider: providers[0], + AllProviders: providers, + }) + provider = providers[0] + } + return provider, modelName, nil +} + +// prepareRequest is the generic entry point for all JSON-body prepare functions. +// It unmarshals the request body into T, resolves model+provider, parses +// fallbacks, and extracts extra params. Type-specific validation is left to +// the caller. +func prepareRequest[T baseRequest](ctx *fasthttp.RequestCtx, config *lib.Config, knownFields map[string]bool) (*T, *requestBase, error) { + req := new(T) + if err := sonic.Unmarshal(ctx.PostBody(), req); err != nil { + return nil, nil, fmt.Errorf("invalid request format: %v", err) + } + provider, modelName, err := resolveModelAndProvider(ctx, config, (*req).getModel()) + if err != nil { + return nil, nil, err + } + fallbacks, err := parseFallbacks((*req).getFallbacks()) + if err != nil { + return nil, nil, err + } + var extraParams map[string]any + if knownFields != nil { + ep, epErr := extractExtraParams(ctx.PostBody(), knownFields) + if epErr != nil { + logger.Warn("Failed to extract extra params: %v", epErr) + } else { + extraParams = ep + } + } + return req, &requestBase{ + Provider: provider, + ModelName: modelName, + Fallbacks: fallbacks, + ExtraParams: extraParams, + }, nil +} + // Known fields for CompletionRequest var textParamsKnownFields = map[string]bool{ "prompt": true, @@ -280,15 +336,6 @@ var transcriptionParamsKnownFields = map[string]bool{ "file_format": true, } -var countTokensParamsKnownFields = map[string]bool{ - "model": true, - "messages": true, - "fallbacks": true, - "tools": true, - "instructions": true, - "text": true, -} - var batchCreateParamsKnownFields = map[string]bool{ "model": true, "input_file_id": true, @@ -314,6 +361,24 @@ type BifrostParams struct { StreamFormat *string `json:"stream_format,omitempty"` // For speech } +func (b BifrostParams) getModel() string { return b.Model } +func (b BifrostParams) getFallbacks() []string { return b.Fallbacks } + +// baseRequest is satisfied by any type that embeds BifrostParams. +type baseRequest interface { + getModel() string + getFallbacks() []string +} + +// requestBase holds the fields common to every JSON-body prepare function +// so that each type-specific prepareXRequest only handles validation. +type requestBase struct { + Provider schemas.ModelProvider + ModelName string + Fallbacks []schemas.Fallback + ExtraParams map[string]any +} + type TextRequest struct { Prompt *schemas.TextCompletionInput `json:"prompt"` BifrostParams @@ -721,7 +786,7 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { provider := string(ctx.QueryArgs().Peek("provider")) // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -812,16 +877,8 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { } // prepareTextCompletionRequest prepares a BifrostTextCompletionRequest from the HTTP request body -func prepareTextCompletionRequest(ctx *fasthttp.RequestCtx) (*TextRequest, *schemas.BifrostTextCompletionRequest, error) { - var req TextRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareTextCompletionRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*TextRequest, *schemas.BifrostTextCompletionRequest, error) { + req, base, err := prepareRequest[TextRequest](ctx, config, textParamsKnownFields) if err != nil { return nil, nil, err } @@ -831,30 +888,24 @@ func prepareTextCompletionRequest(ctx *fasthttp.RequestCtx) (*TextRequest, *sche if req.TextCompletionParameters == nil { req.TextCompletionParameters = &schemas.TextCompletionParameters{} } - extraParams, err := extractExtraParams(ctx.PostBody(), textParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.TextCompletionParameters.ExtraParams = extraParams - } - bifrostTextReq := &schemas.BifrostTextCompletionRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.TextCompletionParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostTextCompletionRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: req.Prompt, Params: req.TextCompletionParameters, - Fallbacks: fallbacks, - } - return &req, bifrostTextReq, nil + Fallbacks: base.Fallbacks, + }, nil } // textCompletion handles POST /v1/completions - Process text completion requests func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { - req, bifrostTextReq, err := prepareTextCompletionRequest(ctx) + req, bifrostTextReq, err := prepareTextCompletionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -888,84 +939,52 @@ func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { } // prepareChatCompletionRequest prepares a BifrostChatRequest from a ChatRequest -func prepareChatCompletionRequest(ctx *fasthttp.RequestCtx) (*ChatRequest, *schemas.BifrostChatRequest, error) { - req := ChatRequest{ - ChatParameters: &schemas.ChatParameters{}, - } - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - - // Create BifrostChatRequest directly using segregated structure - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - - // Parse fallbacks using helper function - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareChatCompletionRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*ChatRequest, *schemas.BifrostChatRequest, error) { + req, base, err := prepareRequest[ChatRequest](ctx, config, chatParamsKnownFields) if err != nil { - return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err) + return nil, nil, err } - if len(req.Messages) == 0 { return nil, nil, fmt.Errorf("messages is required for chat completion") } - - // Extract extra params if req.ChatParameters == nil { req.ChatParameters = &schemas.ChatParameters{} } - - extraParams, err := extractExtraParams(ctx.PostBody(), chatParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - // Handle max_tokens -> max_completion_tokens mapping after extracting extra params - // If max_completion_tokens is nil and max_tokens is present in extra params, map it - // This is to support the legacy max_tokens field, which is still used by some implementations. - if req.ChatParameters.MaxCompletionTokens == nil { - if maxTokensVal, exists := extraParams["max_tokens"]; exists { - // Type check and convert to int - // JSON numbers are unmarshaled as float64, so we need to handle that - var maxTokens int + // Handle max_tokens -> max_completion_tokens mapping. + // This supports the legacy max_tokens field still used by some implementations. + if base.ExtraParams != nil { + if maxTokensVal, exists := base.ExtraParams["max_tokens"]; exists { + delete(base.ExtraParams, "max_tokens") + if req.ChatParameters.MaxCompletionTokens == nil { if maxTokensFloat, ok := maxTokensVal.(float64); ok { - maxTokens = int(maxTokensFloat) + maxTokens := int(maxTokensFloat) req.ChatParameters.MaxCompletionTokens = &maxTokens - // Remove max_tokens from extra params since we've mapped it - delete(extraParams, "max_tokens") } else if maxTokensInt, ok := maxTokensVal.(int); ok { req.ChatParameters.MaxCompletionTokens = &maxTokensInt - // Remove max_tokens from extra params since we've mapped it - delete(extraParams, "max_tokens") } } } - req.ChatParameters.ExtraParams = extraParams } - - // Create segregated BifrostChatRequest - bifrostChatReq := &schemas.BifrostChatRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.ChatParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostChatRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: req.Messages, Params: req.ChatParameters, - Fallbacks: fallbacks, - } - - return &req, bifrostChatReq, nil + Fallbacks: base.Fallbacks, + }, nil } // chatCompletion handles POST /v1/chat/completions - Process chat completion requests func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { - req, bifrostChatReq, err := prepareChatCompletionRequest(ctx) + req, bifrostChatReq, err := prepareChatCompletionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -994,39 +1013,18 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { } // prepareResponsesRequest prepares a BifrostResponsesRequest from a ResponsesRequest -func prepareResponsesRequest(ctx *fasthttp.RequestCtx) (*ResponsesRequest, *schemas.BifrostResponsesRequest, error) { - var req ResponsesRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - - // Create BifrostResponsesRequest directly using segregated structure - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - - // Parse fallbacks using helper function - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareResponsesRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*ResponsesRequest, *schemas.BifrostResponsesRequest, error) { + req, base, err := prepareRequest[ResponsesRequest](ctx, config, responsesParamsKnownFields) if err != nil { - return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err) + return nil, nil, err } - if len(req.Input.ResponsesRequestInputArray) == 0 && req.Input.ResponsesRequestInputStr == nil { return nil, nil, fmt.Errorf("input is required for responses") } - - // Extract extra params if req.ResponsesParameters == nil { req.ResponsesParameters = &schemas.ResponsesParameters{} } - - extraParams, err := extractExtraParams(ctx.PostBody(), responsesParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.ResponsesParameters.ExtraParams = extraParams - } + req.ResponsesParameters.ExtraParams = base.ExtraParams input := req.Input.ResponsesRequestInputArray if input == nil { @@ -1037,29 +1035,25 @@ func prepareResponsesRequest(ctx *fasthttp.RequestCtx) (*ResponsesRequest, *sche }, } } - - // Create segregated BifrostResponsesRequest - bifrostResponsesReq := &schemas.BifrostResponsesRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + return req, &schemas.BifrostResponsesRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: input, Params: req.ResponsesParameters, - Fallbacks: fallbacks, - } - - return &req, bifrostResponsesReq, nil + Fallbacks: base.Fallbacks, + }, nil } // responses handles POST /v1/responses - Process responses requests func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { - req, bifrostResponsesReq, err := prepareResponsesRequest(ctx) + req, bifrostResponsesReq, err := prepareResponsesRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1090,16 +1084,8 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { } // prepareEmbeddingRequest prepares a BifrostEmbeddingRequest from the HTTP request body -func prepareEmbeddingRequest(ctx *fasthttp.RequestCtx) (*EmbeddingRequest, *schemas.BifrostEmbeddingRequest, error) { - var req EmbeddingRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareEmbeddingRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*EmbeddingRequest, *schemas.BifrostEmbeddingRequest, error) { + req, base, err := prepareRequest[EmbeddingRequest](ctx, config, embeddingParamsKnownFields) if err != nil { return nil, nil, err } @@ -1109,31 +1095,25 @@ func prepareEmbeddingRequest(ctx *fasthttp.RequestCtx) (*EmbeddingRequest, *sche if req.EmbeddingParameters == nil { req.EmbeddingParameters = &schemas.EmbeddingParameters{} } - extraParams, err := extractExtraParams(ctx.PostBody(), embeddingParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.EmbeddingParameters.ExtraParams = extraParams - } - bifrostEmbeddingReq := &schemas.BifrostEmbeddingRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.EmbeddingParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostEmbeddingRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: req.Input, Params: req.EmbeddingParameters, - Fallbacks: fallbacks, - } - return &req, bifrostEmbeddingReq, nil + Fallbacks: base.Fallbacks, + }, nil } // embeddings handles POST /v1/embeddings - Process embeddings requests func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { - _, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx) + _, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1158,28 +1138,14 @@ func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { } // prepareRerankRequest prepares a BifrostRerankRequest from the HTTP request body -func prepareRerankRequest(ctx *fasthttp.RequestCtx) (*RerankRequest, *schemas.BifrostRerankRequest, error) { - var req RerankRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - - // Parse model - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - - // Parse fallbacks - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareRerankRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*RerankRequest, *schemas.BifrostRerankRequest, error) { + req, base, err := prepareRequest[RerankRequest](ctx, config, rerankParamsKnownFields) if err != nil { - return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err) + return nil, nil, err } - if strings.TrimSpace(req.Query) == "" { return nil, nil, fmt.Errorf("query is required for rerank") } - if len(req.Documents) == 0 { return nil, nil, fmt.Errorf("documents are required for rerank") } @@ -1188,45 +1154,33 @@ func prepareRerankRequest(ctx *fasthttp.RequestCtx) (*RerankRequest, *schemas.Bi return nil, nil, fmt.Errorf("document text is required for rerank at index %d", i) } } - - // Extract extra params if req.RerankParameters == nil { req.RerankParameters = &schemas.RerankParameters{} } if req.RerankParameters.TopN != nil && *req.RerankParameters.TopN < 1 { return nil, nil, fmt.Errorf("top_n must be at least 1") } - - extraParams, err := extractExtraParams(ctx.PostBody(), rerankParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.RerankParameters.ExtraParams = extraParams - } - - // Create BifrostRerankRequest - bifrostRerankReq := &schemas.BifrostRerankRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.RerankParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostRerankRequest{ + Provider: base.Provider, + Model: base.ModelName, Query: req.Query, Documents: req.Documents, Params: req.RerankParameters, - Fallbacks: fallbacks, - } - - return &req, bifrostRerankReq, nil + Fallbacks: base.Fallbacks, + }, nil } // rerank handles POST /v1/rerank - Process rerank requests func (h *CompletionHandler) rerank(ctx *fasthttp.RequestCtx) { - _, bifrostRerankReq, err := prepareRerankRequest(ctx) + _, bifrostRerankReq, err := prepareRerankRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1252,71 +1206,44 @@ func (h *CompletionHandler) rerank(ctx *fasthttp.RequestCtx) { } // prepareOCRRequest prepares a BifrostOCRRequest from the HTTP request body -func prepareOCRRequest(ctx *fasthttp.RequestCtx) (*OCRHandlerRequest, *schemas.BifrostOCRRequest, error) { - var req OCRHandlerRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - - // Parse model - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - - // Parse fallbacks - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareOCRRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*OCRHandlerRequest, *schemas.BifrostOCRRequest, error) { + req, base, err := prepareRequest[OCRHandlerRequest](ctx, config, ocrParamsKnownFields) if err != nil { - return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err) + return nil, nil, err } - if req.Document.Type == "" { return nil, nil, fmt.Errorf("document type is required for ocr") } - if req.Document.Type == schemas.OCRDocumentTypeDocumentURL && (req.Document.DocumentURL == nil || *req.Document.DocumentURL == "") { return nil, nil, fmt.Errorf("document_url is required when document type is document_url") } - if req.Document.Type == schemas.OCRDocumentTypeImageURL && (req.Document.ImageURL == nil || *req.Document.ImageURL == "") { return nil, nil, fmt.Errorf("image_url is required when document type is image_url") } - - // Extract extra params if req.OCRParameters == nil { req.OCRParameters = &schemas.OCRParameters{} } - - extraParams, err := extractExtraParams(ctx.PostBody(), ocrParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.OCRParameters.ExtraParams = extraParams - } - - // Create BifrostOCRRequest - bifrostOCRReq := &schemas.BifrostOCRRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.OCRParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostOCRRequest{ + Provider: base.Provider, + Model: base.ModelName, ID: req.ID, Document: req.Document, Params: req.OCRParameters, - Fallbacks: fallbacks, - } - - return &req, bifrostOCRReq, nil + Fallbacks: base.Fallbacks, + }, nil } // ocr handles POST /v1/ocr - Process OCR requests func (h *CompletionHandler) ocr(ctx *fasthttp.RequestCtx) { - _, bifrostOCRReq, err := prepareOCRRequest(ctx) + _, bifrostOCRReq, err := prepareOCRRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1342,16 +1269,8 @@ func (h *CompletionHandler) ocr(ctx *fasthttp.RequestCtx) { } // prepareSpeechRequest prepares a BifrostSpeechRequest from the HTTP request body -func prepareSpeechRequest(ctx *fasthttp.RequestCtx) (*SpeechRequest, *schemas.BifrostSpeechRequest, error) { - var req SpeechRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") - } - fallbacks, err := parseFallbacks(req.Fallbacks) +func prepareSpeechRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*SpeechRequest, *schemas.BifrostSpeechRequest, error) { + req, base, err := prepareRequest[SpeechRequest](ctx, config, speechParamsKnownFields) if err != nil { return nil, nil, err } @@ -1361,34 +1280,25 @@ func prepareSpeechRequest(ctx *fasthttp.RequestCtx) (*SpeechRequest, *schemas.Bi if req.SpeechParameters == nil || req.VoiceConfig == nil || (req.VoiceConfig.Voice == nil && len(req.VoiceConfig.MultiVoiceConfig) == 0) { return nil, nil, fmt.Errorf("voice is required for speech completion") } - if req.SpeechParameters == nil { - req.SpeechParameters = &schemas.SpeechParameters{} - } - extraParams, err := extractExtraParams(ctx.PostBody(), speechParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.SpeechParameters.ExtraParams = extraParams - } - bifrostSpeechReq := &schemas.BifrostSpeechRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.SpeechParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostSpeechRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: req.SpeechInput, Params: req.SpeechParameters, - Fallbacks: fallbacks, - } - return &req, bifrostSpeechReq, nil + Fallbacks: base.Fallbacks, + }, nil } // speech handles POST /v1/audio/speech - Process speech completion requests func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { - req, bifrostSpeechReq, err := prepareSpeechRequest(ctx) + req, bifrostSpeechReq, err := prepareSpeechRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1445,7 +1355,7 @@ func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { // prepareTranscriptionRequest prepares a BifrostTranscriptionRequest from a multipart form. // Returns the request, whether streaming was requested, and any error. -func prepareTranscriptionRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostTranscriptionRequest, bool, error) { +func prepareTranscriptionRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*schemas.BifrostTranscriptionRequest, bool, error) { form, err := ctx.MultipartForm() if err != nil { return nil, false, fmt.Errorf("failed to parse multipart form: %v", err) @@ -1454,9 +1364,9 @@ func prepareTranscriptionRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostTran if len(modelValues) == 0 || modelValues[0] == "" { return nil, false, fmt.Errorf("model is required") } - provider, modelName := schemas.ParseModelString(modelValues[0], "") - if provider == "" || modelName == "" { - return nil, false, fmt.Errorf("model should be in provider/model format") + provider, modelName, err := resolveModelAndProvider(ctx, config, modelValues[0]) + if err != nil { + return nil, false, err } fileHeaders := form.File["file"] if len(fileHeaders) == 0 { @@ -1509,13 +1419,13 @@ func prepareTranscriptionRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostTran // transcription handles POST /v1/audio/transcriptions - Process transcription requests func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { - bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx) + bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1549,13 +1459,13 @@ func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { // countTokens handles POST /v1/responses/input_tokens - Process count tokens requests func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) { - _, bifrostResponsesReq, err := prepareResponsesRequest(ctx) + _, bifrostResponsesReq, err := prepareResponsesRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1922,14 +1832,10 @@ func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) } // prepareImageGenerationRequest prepares a BifrostImageGenerationRequest from the HTTP request body -func prepareImageGenerationRequest(ctx *fasthttp.RequestCtx) (*ImageGenerationHTTPRequest, *schemas.BifrostImageGenerationRequest, error) { - var req ImageGenerationHTTPRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - return nil, nil, fmt.Errorf("invalid request format: %v", err) - } - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") +func prepareImageGenerationRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*ImageGenerationHTTPRequest, *schemas.BifrostImageGenerationRequest, error) { + req, base, err := prepareRequest[ImageGenerationHTTPRequest](ctx, config, imageGenerationParamsKnownFields) + if err != nil { + return nil, nil, err } if req.ImageGenerationInput == nil || req.Prompt == "" { return nil, nil, fmt.Errorf("prompt cannot be empty") @@ -1937,35 +1843,25 @@ func prepareImageGenerationRequest(ctx *fasthttp.RequestCtx) (*ImageGenerationHT if req.ImageGenerationParameters == nil { req.ImageGenerationParameters = &schemas.ImageGenerationParameters{} } - extraParams, err := extractExtraParams(ctx.PostBody(), imageGenerationParamsKnownFields) - if err != nil { - logger.Warn("Failed to extract extra params: %v", err) - } else { - req.ImageGenerationParameters.ExtraParams = extraParams - } - fallbacks, err := parseFallbacks(req.Fallbacks) - if err != nil { - return nil, nil, err - } - bifrostReq := &schemas.BifrostImageGenerationRequest{ - Provider: schemas.ModelProvider(provider), - Model: modelName, + req.ImageGenerationParameters.ExtraParams = base.ExtraParams + return req, &schemas.BifrostImageGenerationRequest{ + Provider: base.Provider, + Model: base.ModelName, Input: req.ImageGenerationInput, Params: req.ImageGenerationParameters, - Fallbacks: fallbacks, - } - return &req, bifrostReq, nil + Fallbacks: base.Fallbacks, + }, nil } // imageGeneration handles POST /v1/images/generations - Processes image generation requests func (h *CompletionHandler) imageGeneration(ctx *fasthttp.RequestCtx) { - req, bifrostReq, err := prepareImageGenerationRequest(ctx) + req, bifrostReq, err := prepareImageGenerationRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2010,7 +1906,7 @@ func (h *CompletionHandler) handleStreamingImageGeneration(ctx *fasthttp.Request } // prepareImageEditRequest prepares a BifrostImageEditRequest from a multipart form -func prepareImageEditRequest(ctx *fasthttp.RequestCtx) (*ImageEditHTTPRequest, *schemas.BifrostImageEditRequest, error) { +func prepareImageEditRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*ImageEditHTTPRequest, *schemas.BifrostImageEditRequest, error) { var req ImageEditHTTPRequest form, err := ctx.MultipartForm() if err != nil { @@ -2021,9 +1917,9 @@ func prepareImageEditRequest(ctx *fasthttp.RequestCtx) (*ImageEditHTTPRequest, * return nil, nil, fmt.Errorf("model is required") } req.Model = modelValues[0] - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - return nil, nil, fmt.Errorf("model should be in provider/model format") + provider, modelName, err := resolveModelAndProvider(ctx, config, req.Model) + if err != nil { + return nil, nil, err } var editType string if typeValues := form.Value["type"]; len(typeValues) > 0 && typeValues[0] != "" { @@ -2167,13 +2063,13 @@ func prepareImageEditRequest(ctx *fasthttp.RequestCtx) (*ImageEditHTTPRequest, * // imageEdit handles POST /v1/images/edits - Processes image edit requests func (h *CompletionHandler) imageEdit(ctx *fasthttp.RequestCtx) { - req, bifrostReq, err := prepareImageEditRequest(ctx) + req, bifrostReq, err := prepareImageEditRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2216,7 +2112,7 @@ func (h *CompletionHandler) handleStreamingImageEditRequest(ctx *fasthttp.Reques } // prepareImageVariationRequest prepares a BifrostImageVariationRequest from a multipart form -func prepareImageVariationRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostImageVariationRequest, error) { +func prepareImageVariationRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*schemas.BifrostImageVariationRequest, error) { rawBody := ctx.Request.Body() form, err := ctx.MultipartForm() if err != nil { @@ -2226,9 +2122,9 @@ func prepareImageVariationRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostIma if len(modelValues) == 0 || modelValues[0] == "" { return nil, fmt.Errorf("model is required") } - provider, modelName := schemas.ParseModelString(modelValues[0], "") - if provider == "" || modelName == "" { - return nil, fmt.Errorf("model should be in provider/model format") + provider, modelName, err := resolveModelAndProvider(ctx, config, modelValues[0]) + if err != nil { + return nil, err } var imageFiles []*multipart.FileHeader if imageFilesArray := form.File["image[]"]; len(imageFilesArray) > 0 { @@ -2310,13 +2206,13 @@ func prepareImageVariationRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostIma // imageVariation handles POST /v1/images/variations - Processes image variation requests func (h *CompletionHandler) imageVariation(ctx *fasthttp.RequestCtx) { - bifrostReq, err := prepareImageVariationRequest(ctx) + bifrostReq, err := prepareImageVariationRequest(ctx, h.config) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2348,10 +2244,9 @@ func (h *CompletionHandler) videoGeneration(ctx *fasthttp.RequestCtx) { return } - // Create BifrostVideoGenerationRequest directly using segregated structure - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" || modelName == "" { - SendError(ctx, fasthttp.StatusBadRequest, "model should be in provider/model format") + provider, modelName, err := resolveModelAndProvider(ctx, h.config, req.Model) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } @@ -2385,7 +2280,7 @@ func (h *CompletionHandler) videoGeneration(ctx *fasthttp.RequestCtx) { Fallbacks: fallbacks, } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2439,7 +2334,7 @@ func (h *CompletionHandler) videoRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2497,7 +2392,7 @@ func (h *CompletionHandler) videoDownload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2559,7 +2454,7 @@ func (h *CompletionHandler) videoList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2610,7 +2505,7 @@ func (h *CompletionHandler) videoDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2687,7 +2582,7 @@ func (h *CompletionHandler) videoRemix(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2718,10 +2613,9 @@ func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) { return } - // Parse provider from model string - provider, modelName := schemas.ParseModelString(req.Model, "") - if provider == "" { - SendError(ctx, fasthttp.StatusBadRequest, "model should be in provider/model format or provider must be specified") + provider, modelName, err := resolveModelAndProvider(ctx, h.config, req.Model) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } @@ -2755,7 +2649,7 @@ func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2815,7 +2709,7 @@ func (h *CompletionHandler) batchList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2861,7 +2755,7 @@ func (h *CompletionHandler) batchRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2907,7 +2801,7 @@ func (h *CompletionHandler) batchCancel(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2953,7 +2847,7 @@ func (h *CompletionHandler) batchResults(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3042,7 +2936,7 @@ func (h *CompletionHandler) fileUpload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3108,7 +3002,7 @@ func (h *CompletionHandler) fileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3154,7 +3048,7 @@ func (h *CompletionHandler) fileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3200,7 +3094,7 @@ func (h *CompletionHandler) fileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3246,7 +3140,7 @@ func (h *CompletionHandler) fileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3307,7 +3201,7 @@ func (h *CompletionHandler) containerCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3366,7 +3260,7 @@ func (h *CompletionHandler) containerList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3413,7 +3307,7 @@ func (h *CompletionHandler) containerRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3460,7 +3354,7 @@ func (h *CompletionHandler) containerDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3557,7 +3451,7 @@ func (h *CompletionHandler) containerFileCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3617,7 +3511,7 @@ func (h *CompletionHandler) containerFileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3672,7 +3566,7 @@ func (h *CompletionHandler) containerFileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3727,7 +3621,7 @@ func (h *CompletionHandler) containerFileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3782,7 +3676,7 @@ func (h *CompletionHandler) containerFileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcpinference.go b/transports/bifrost-http/handlers/mcpinference.go index 4e80e18d5d..a3f8d65940 100644 --- a/transports/bifrost-http/handlers/mcpinference.go +++ b/transports/bifrost-http/handlers/mcpinference.go @@ -60,7 +60,7 @@ func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -93,7 +93,7 @@ func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index b03488ea6f..ca735dee3f 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -112,7 +112,7 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) defer cancel() injectMCPSessionIdentity(bifrostCtx, session) @@ -153,7 +153,7 @@ func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("Connection", "keep-alive") // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) injectMCPSessionIdentity(bifrostCtx, session) diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index 11b362dddf..3dc4f8353c 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -35,25 +35,23 @@ func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.Conf } } - - // CreatePluginRequest is the request body for creating a plugin type CreatePluginRequest struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - Config map[string]any `json:"config"` - Path *string `json:"path"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Config map[string]any `json:"config"` + Path *string `json:"path"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // UpdatePluginRequest is the request body for updating a plugin type UpdatePluginRequest struct { - Enabled bool `json:"enabled"` - Path *string `json:"path"` - Config map[string]any `json:"config"` + Enabled bool `json:"enabled"` + Path *string `json:"path"` + Config map[string]any `json:"config"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // RegisterRoutes registers the routes for the PluginsHandler @@ -66,15 +64,15 @@ func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas } type PluginResponse struct { - Name string `json:"name"` - ActualName string `json:"actualName"` - Enabled bool `json:"enabled"` - Config any `json:"config"` - IsCustom bool `json:"isCustom"` - Path *string `json:"path"` - Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` - Status schemas.PluginStatus `json:"status"` + Name string `json:"name"` + ActualName string `json:"actualName"` + Enabled bool `json:"enabled"` + Config any `json:"config"` + IsCustom bool `json:"isCustom"` + Path *string `json:"path"` + Placement *schemas.PluginPlacement `json:"placement,omitempty"` + Order *int `json:"order,omitempty"` + Status schemas.PluginStatus `json:"status"` } // buildPluginResponse constructs a PluginResponse with status for a given TablePlugin. diff --git a/transports/bifrost-http/handlers/realtime_client_secrets.go b/transports/bifrost-http/handlers/realtime_client_secrets.go index 620a1db0e1..9fe07dd61e 100644 --- a/transports/bifrost-http/handlers/realtime_client_secrets.go +++ b/transports/bifrost-http/handlers/realtime_client_secrets.go @@ -86,12 +86,7 @@ func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext( - ctx, - h.handlerStore.ShouldAllowDirectKeys(), - h.config.GetHeaderMatcher(), - h.config.GetMCPHeaderCombinedAllowlist(), - ) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) defer cancel() bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) if route.DefaultProvider == schemas.OpenAI { diff --git a/transports/bifrost-http/handlers/webrtc_realtime.go b/transports/bifrost-http/handlers/webrtc_realtime.go index 9bf547952a..342328ec00 100644 --- a/transports/bifrost-http/handlers/webrtc_realtime.go +++ b/transports/bifrost-http/handlers/webrtc_realtime.go @@ -252,12 +252,7 @@ func (h *WebRTCRealtimeHandler) runWebRTCRelay( sdpOffer string, exchangeSDP func(ctx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError), ) { - bifrostCtx, cancel := lib.ConvertToBifrostContext( - ctx, - h.handlerStore.ShouldAllowDirectKeys(), - h.config.GetHeaderMatcher(), - h.config.GetMCPHeaderCombinedAllowlist(), - ) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore) defer cancel() bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) if strings.HasPrefix(string(ctx.Path()), "/openai") { diff --git a/transports/bifrost-http/handlers/webrtc_realtime_test.go b/transports/bifrost-http/handlers/webrtc_realtime_test.go index a0c0d72c1a..f0eef3fc78 100644 --- a/transports/bifrost-http/handlers/webrtc_realtime_test.go +++ b/transports/bifrost-http/handlers/webrtc_realtime_test.go @@ -18,9 +18,9 @@ type testHandlerStore struct { kv *kvstore.Store } -func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } -func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } -func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil } +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } +func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } +func (s testHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return nil } func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { return nil } @@ -28,6 +28,8 @@ func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { re func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv } func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil } +func (s testHandlerStore) ShouldAllowPerRequestStorageOverride() bool { return false } +func (s testHandlerStore) ShouldAllowPerRequestRawOverride() bool { return false } func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) { _, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`)) diff --git a/transports/bifrost-http/handlers/wsresponses_test.go b/transports/bifrost-http/handlers/wsresponses_test.go index d334d04d33..9b8fd78e76 100644 --- a/transports/bifrost-http/handlers/wsresponses_test.go +++ b/transports/bifrost-http/handlers/wsresponses_test.go @@ -26,7 +26,7 @@ func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } -func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider { +func (s testWSHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return nil } @@ -50,6 +50,9 @@ func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil } +func (s testWSHandlerStore) ShouldAllowPerRequestStorageOverride() bool { return false } +func (s testWSHandlerStore) ShouldAllowPerRequestRawOverride() bool { return false } + type timeoutNetError struct{} func (timeoutNetError) Error() string { return "i/o timeout" } diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index e21e54d2f5..eb5dfdc450 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -23,6 +23,18 @@ type AnthropicRouter struct { *GenericRouter } +// anthropicModelGetter extracts the model field from any Anthropic integration request type. +// It is called after body parsing, so req is fully populated. +func anthropicModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *anthropic.AnthropicTextRequest: + return r.Model, nil + case *anthropic.AnthropicMessageRequest: + return r.Model, nil + } + return "", nil +} + // createAnthropicCompleteRouteConfig creates a route configuration for the `/v1/complete` endpoint. func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { return RouteConfig{ @@ -35,6 +47,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicTextRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicTextRequest); ok { return &schemas.BifrostRequest{ @@ -75,6 +88,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicMessageRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx) @@ -394,6 +408,7 @@ func CreateAnthropicCountTokensRouteConfigs(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicMessageRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx) diff --git a/transports/bifrost-http/integrations/bedrock.go b/transports/bifrost-http/integrations/bedrock.go index 14bc4d9911..14c2bd826d 100644 --- a/transports/bifrost-http/integrations/bedrock.go +++ b/transports/bifrost-http/integrations/bedrock.go @@ -21,6 +21,23 @@ type BedrockRouter struct { *GenericRouter } +// bedrockModelGetter extracts the model ID from any Bedrock integration request type. +// It is called after PreCallback, so req.ModelID is populated from the URL path param. +func bedrockModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *bedrock.BedrockConverseRequest: + return r.ModelID, nil + case *bedrock.BedrockInvokeRequest: + return r.ModelID, nil + case *bedrock.BedrockCountTokensRequest: + if r.Input.Converse != nil { + return r.Input.Converse.ModelID, nil + } + return "", nil + } + return "", nil +} + // S3 context keys for storing request parameters const ( @@ -42,6 +59,7 @@ func createBedrockConverseRouteConfig(pathPrefix string, handlerStore lib.Handle GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { return schemas.ResponsesRequest }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { bifrostReq, err := bedrockReq.ToBifrostResponsesRequest(ctx) @@ -77,6 +95,7 @@ func createBedrockConverseStreamRouteConfig(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockConverseRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { // Mark as streaming request @@ -127,6 +146,7 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if invokeReq, ok := req.(*bedrock.BedrockInvokeRequest); ok { requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType) @@ -201,6 +221,7 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { invokeReq, ok := req.(*bedrock.BedrockInvokeRequest) if !ok { @@ -317,6 +338,7 @@ func createBedrockCountTokensRouteConfig(pathPrefix string, handlerStore lib.Han GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { return schemas.CountTokensRequest }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if countTokensReq, ok := req.(*bedrock.BedrockCountTokensRequest); ok { if countTokensReq.Input.Converse == nil { diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index 16ab7ad4d5..8c01edaa83 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -30,7 +30,7 @@ func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return m.headerMatcher } -func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider { +func (m *mockHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return m.availableProviders } @@ -54,6 +54,14 @@ func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return m.mcpHeaderCombinedAllowlist } +func (m *mockHandlerStore) ShouldAllowPerRequestStorageOverride() bool { + return false +} + +func (m *mockHandlerStore) ShouldAllowPerRequestRawOverride() bool { + return false +} + // Ensure mockHandlerStore implements lib.HandlerStore var _ lib.HandlerStore = (*mockHandlerStore)(nil) diff --git a/transports/bifrost-http/integrations/cohere.go b/transports/bifrost-http/integrations/cohere.go index 37aad1c1a8..cf6b7ceaca 100644 --- a/transports/bifrost-http/integrations/cohere.go +++ b/transports/bifrost-http/integrations/cohere.go @@ -69,6 +69,22 @@ func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, log } } +// cohereModelGetter extracts the model field from any Cohere integration request type. +// It is called after body parsing, so req is fully populated. +func cohereModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *cohere.CohereChatRequest: + return r.Model, nil + case *cohere.CohereEmbeddingRequest: + return r.Model, nil + case *cohere.CohereRerankRequest: + return r.Model, nil + case *cohere.CohereCountTokensRequest: + return r.Model, nil + } + return "", nil +} + // CreateCohereRouteConfigs creates route configurations for Cohere API endpoints. func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { var routes []RouteConfig @@ -85,6 +101,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereChatRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereChatRequest); ok { return &schemas.BifrostRequest{ @@ -131,6 +148,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereEmbeddingRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok { return &schemas.BifrostRequest{ @@ -164,6 +182,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereRerankRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok { return &schemas.BifrostRequest{ @@ -197,6 +216,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereCountTokensRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok { return &schemas.BifrostRequest{ diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 012a5e078a..88229393fe 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -35,6 +35,33 @@ type GenAIRouter struct { *GenericRouter } +// genAIModelGetter extracts the model name for GenAI routes. +// For request types populated by extractAndSetModelAndRequestType (the PreCallback), +// the model is already clean on the struct. For BifrostVideoRetrieveRequest (which has +// no model field), the provider-scoped model is extracted from the operation_id suffix +// (format: "op123:openai/gpt-4o") since the route pins the provider via operation_id. +func genAIModelGetter(ctx *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *gemini.GeminiGenerationRequest: + return r.Model, nil + case *gemini.GeminiEmbeddingRequest: + return r.Model, nil + case *gemini.GeminiVideoGenerationRequest: + return r.Model, nil + case *gemini.GeminiBatchCreateRequest: + return r.Model, nil + case *schemas.BifrostVideoRetrieveRequest: + // operation_id encodes the full model string: "op123:gpt-4o" or "op123:openai/gpt-4o". + operationID, _ := ctx.UserValue("operation_id").(string) + parts := strings.Split(operationID, ":") + if len(parts) >= 2 && parts[len(parts)-1] != "" { + return parts[len(parts)-1], nil + } + return "", nil + } + return "", nil +} + // CreateGenAIRouteConfigs creates a route configurations for GenAI endpoints. func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { var routes []RouteConfig @@ -51,6 +78,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &schemas.BifrostVideoRetrieveRequest{} }, + GetRequestModel: genAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if videoRetrieveReq, ok := req.(*schemas.BifrostVideoRetrieveRequest); ok { return &schemas.BifrostRequest{ @@ -89,6 +117,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { } return &gemini.GeminiGenerationRequest{} }, + GetRequestModel: genAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if geminiReq, ok := req.(*gemini.GeminiGenerationRequest); ok { if geminiReq.IsCountTokens { @@ -790,6 +819,12 @@ func createGenAIRerankRouteConfig(pathPrefix string) RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &vertex.VertexRankRequest{} }, + GetRequestModel: func(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + if r, ok := req.(*vertex.VertexRankRequest); ok && r.Model != nil { + return *r.Model, nil + } + return "", nil + }, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if vertexReq, ok := req.(*vertex.VertexRankRequest); ok { return &schemas.BifrostRequest{ @@ -1282,31 +1317,38 @@ func extractGeminiVideoOperationFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *s return errors.New("operation_id must be a non-empty string") } - // check provider from operation id suffix, id:provider, could be any provider + // operation_id encodes the raw model string as a suffix: "id:rawModel" + // rawModel is either "gpt-4o" (provider name or bare model) or "openai/gpt-4o" (provider/model). parts := strings.Split(operationIDStr, ":") if len(parts) < 2 || parts[len(parts)-1] == "" { - return errors.New("provider is required in operation_id format 'id:provider'") + return errors.New("raw model is required in operation_id format 'id:rawModel' or 'id:provider/model'") + } + rawModel := parts[len(parts)-1] + + // Parse provider from rawModel: "openai/gpt-4o" → provider="openai"; "gemini" → provider="gemini". + var provider schemas.ModelProvider + rawModelParts := strings.SplitN(rawModel, "/", 2) + if len(rawModelParts) == 2 { + provider = schemas.ModelProvider(rawModelParts[0]) + } else { + provider = schemas.ModelProvider(rawModel) } - provider := parts[len(parts)-1] modelStr, ok := model.(string) if !ok || modelStr == "" { - modelStr = provider + modelStr = rawModel } - // if its gemini, set r.ID in format models/model/operations/operation_id:provider - // else set r.ID in format operation_id:provider - switch r := req.(type) { case *schemas.BifrostVideoRetrieveRequest: - r.Provider = schemas.ModelProvider(provider) + r.Provider = provider if r.Provider == schemas.OpenAI || r.Provider == schemas.Azure { // set a context flag to have video download request after video retrieve request when incoming request is coming from genai integration bifrostCtx.SetValue(schemas.BifrostContextKeyVideoOutputRequested, true) } // Gemini provider expects an operation resource path (without /v1beta prefix). - if provider == string(schemas.Gemini) { + if provider == schemas.Gemini { r.ID = "models/" + modelStr + "/operations/" + operationIDStr } else { r.ID = operationIDStr diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 24c987f658..262119f925 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -135,6 +135,10 @@ func hydrateOpenAIRequestFromLargePayloadMetadata(ctx *fasthttp.RequestCtx, bifr if r.Model == "" { r.Model = metadata.Model } + case *openai.OpenAIVideoGenerationRequest: + if r.Model == "" { + r.Model = metadata.Model + } } } @@ -300,14 +304,43 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ } } +// openAIModelGetter extracts the model field from any OpenAI integration request type. +// It is called after body parsing and PreCallback, so req is fully populated. +func openAIModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *openai.OpenAIChatRequest: + return r.Model, nil + case *openai.OpenAITextCompletionRequest: + return r.Model, nil + case *openai.OpenAIEmbeddingRequest: + return r.Model, nil + case *openai.OpenAIResponsesRequest: + return r.Model, nil + case *openai.OpenAISpeechRequest: + return r.Model, nil + case *openai.OpenAITranscriptionRequest: + return r.Model, nil + case *openai.OpenAIImageGenerationRequest: + return r.Model, nil + case *openai.OpenAIImageEditRequest: + return r.Model, nil + case *openai.OpenAIImageVariationRequest: + return r.Model, nil + case *openai.OpenAIVideoGenerationRequest: + return r.Model, nil + } + return "", nil +} + // CreateOpenAIRouteConfigs creates route configurations for OpenAI endpoints. func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []RouteConfig { var routes []RouteConfig routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, - Path: pathPrefix + "/openai/deployments/{deploymentPath:*}", - Method: "POST", + Type: RouteConfigTypeOpenAI, + Path: pathPrefix + "/openai/deployments/{deploymentPath:*}", + Method: "POST", + GetRequestModel: openAIModelGetter, GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { deploymentPathVal, ok := ctx.UserValue("deploymentPath").(string) if !ok { @@ -543,6 +576,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIChatRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIChatRequest); ok { br := &schemas.BifrostRequest{ @@ -639,6 +673,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAITextCompletionRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAITextCompletionRequest); ok { return &schemas.BifrostRequest{ @@ -690,6 +725,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIResponsesRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ @@ -772,6 +808,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIResponsesRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ @@ -810,6 +847,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIEmbeddingRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if embeddingReq, ok := req.(*openai.OpenAIEmbeddingRequest); ok { return &schemas.BifrostRequest{ @@ -848,6 +886,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAISpeechRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if speechReq, ok := req.(*openai.OpenAISpeechRequest); ok { return &schemas.BifrostRequest{ @@ -891,7 +930,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAITranscriptionRequest{} }, - RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing + GetRequestModel: openAIModelGetter, + RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if transcriptionReq, ok := req.(*openai.OpenAITranscriptionRequest); ok { return &schemas.BifrostRequest{ @@ -946,6 +986,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageGenerationRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageGenReq, ok := req.(*openai.OpenAIImageGenerationRequest); ok { return &schemas.BifrostRequest{ @@ -996,7 +1037,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageEditRequest{} }, - RequestParser: parseOpenAIImageEditMultipartRequest, // Handle multipart form parsing + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIImageEditMultipartRequest, // Handle multipart form parsing RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageEditReq, ok := req.(*openai.OpenAIImageEditRequest); ok { return &schemas.BifrostRequest{ @@ -1046,7 +1088,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageVariationRequest{} }, - RequestParser: parseOpenAIImageVariationMultipartRequest, + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIImageVariationMultipartRequest, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageVariationReq, ok := req.(*openai.OpenAIImageVariationRequest); ok { return &schemas.BifrostRequest{ @@ -1098,7 +1141,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIVideoGenerationRequest{} }, - RequestParser: parseOpenAIVideoGenerationMultipartRequest, + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIVideoGenerationMultipartRequest, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if videoGenerationReq, ok := req.(*openai.OpenAIVideoGenerationRequest); ok { return &schemas.BifrostRequest{ @@ -1114,6 +1158,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return err }, PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { + hydrateOpenAIRequestFromLargePayloadMetadata(ctx, bifrostCtx, req) if isAzureSDKRequest(ctx) { bifrostCtx.SetValue(schemas.BifrostContextKeyIsAzureUserAgent, true) } diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 24d07026e9..5eb9276b2f 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -50,11 +50,12 @@ package integrations import ( "bytes" "context" + "errors" "fmt" "io" - "errors" "mime" "mime/multipart" + "slices" "strconv" "strings" @@ -354,6 +355,10 @@ type PostRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}, resp in // returns a schemas.RequestType indicating the HTTP request type derived from the context. type HTTPRequestTypeGetter func(ctx *fasthttp.RequestCtx) schemas.RequestType +// RequestModelGetter is a function type that accepts only a *fasthttp.RequestCtx and +// returns a string indicating the model derived from the context. +type RequestModelGetter func(ctx *fasthttp.RequestCtx, req interface{}) (string, error) + // ShortCircuit is a function that determines if the request should be short-circuited. type ShortCircuit func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) (bool, error) @@ -397,6 +402,14 @@ const ( RouteConfigTypeCohere RouteConfigType = "cohere" ) +var RouteConfigTypeToProvider = map[RouteConfigType]schemas.ModelProvider{ + RouteConfigTypeOpenAI: schemas.OpenAI, + RouteConfigTypeAnthropic: schemas.Anthropic, + RouteConfigTypeGenAI: schemas.Gemini, + RouteConfigTypeBedrock: schemas.Bedrock, + RouteConfigTypeCohere: schemas.Cohere, +} + // RouteConfig defines the configuration for a single route in an integration. // It specifies the path, method, and handlers for request/response conversion. type RouteConfig struct { @@ -404,6 +417,7 @@ type RouteConfig struct { Path string // HTTP path pattern (e.g., "/openai/v1/chat/completions") Method string // HTTP method (POST, GET, PUT, DELETE) GetHTTPRequestType HTTPRequestTypeGetter // Function to get the HTTP request type from the context (SHOULD NOT BE NIL) + GetRequestModel RequestModelGetter // Function to get the model from the context (SHOULD NOT BE NIL) GetRequestTypeInstance func(ctx context.Context) interface{} // Factory function to create request instance (SHOULD NOT BE NIL) RequestParser RequestParser // Optional: custom request parsing (e.g., multipart/form-data) RequestConverter RequestConverter // Function to convert request to BifrostRequest (for inference requests) @@ -616,15 +630,11 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle var rawBody []byte // Execute the request through Bifrost - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore) // Set integration type to context bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, string(config.Type)) - // Set available providers to context - availableProviders := g.handlerStore.GetAvailableProviders() - bifrostCtx.SetValue(schemas.BifrostContextKeyAvailableProviders, availableProviders) - // Async retrieve: check x-bf-async-id header early (before body parsing) if asyncID := string(ctx.Request.Header.Peek(schemas.AsyncHeaderGetID)); asyncID != "" { defer cancel() @@ -725,6 +735,44 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } } + // Set available providers to context + if config.GetRequestModel != nil { + model, err := config.GetRequestModel(ctx, req) + if err != nil { + cancel() + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to get model from context")) + return + } + extractedProvider, extractedModel := schemas.ParseModelString(model, "") + if extractedProvider == "" { + availableProviders := g.handlerStore.GetAvailableProviders(extractedModel) + availableProvidersStrs := make([]string, len(availableProviders)) + for i, p := range availableProviders { + availableProvidersStrs[i] = string(p) + } + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "No provider specified for model %s, found %d options in model catalog: [%s]", + extractedModel, len(availableProviders), strings.Join(availableProvidersStrs, ", "), + )) + if len(availableProviders) > 0 { + if slices.Contains(availableProviders, RouteConfigTypeToProvider[config.Type]) { + availableProviders = []schemas.ModelProvider{RouteConfigTypeToProvider[config.Type]} + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "Integration route default provider %s is found in the available providers list, selecting it", + RouteConfigTypeToProvider[config.Type], + )) + } else { + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "Integration route default provider %s is not found in the available providers list, selecting first: %s", + RouteConfigTypeToProvider[config.Type], availableProviders[0], + )) + } + bifrostCtx.SetValue(schemas.BifrostContextKeyAvailableProviders, availableProviders) + } + schemas.AppendToContextList(bifrostCtx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineModelCatalog) + } + } + // Handle batch requests if BatchRequestConverter is set // GenAI has two cases: (1) Dedicated batch routes (list/retrieve) have only BatchRequestConverter — always use batch path. // (2) The models path has both BatchRequestConverter and RequestConverter — use batch path only for batch create. @@ -795,10 +843,12 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Convert the integration-specific request to Bifrost format (inference requests) bifrostReq, err := config.RequestConverter(bifrostCtx, req) if err != nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to convert request to Bifrost format")) return } if bifrostReq == nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "invalid request")) return } @@ -808,6 +858,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Extract and parse fallbacks from the request if present if err := g.extractAndParseFallbacks(req, bifrostReq); err != nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to parse fallbacks: "+err.Error())) return } @@ -842,9 +893,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // streaming requests (where we actively detect write errors), but still provides a mechanism // for providers to respect cancellation. var response interface{} - var err error - var providerResponseHeaders map[string]string switch { @@ -1550,7 +1599,6 @@ func (g *GenericRouter) handleAsyncRetrieve( } g.handleAsyncJobResponse(ctx, bifrostCtx, config, job) - return } func (g *GenericRouter) handleAsyncJobResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, job *logstore.AsyncJob) { @@ -2683,7 +2731,7 @@ func (g *GenericRouter) handlePassthrough(ctx *fasthttp.RequestCtx) { return true }) - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore) if directKey := ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)); directKey != nil { if key, ok := directKey.(schemas.Key); ok { bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 229d528efe..77ed94437f 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -64,8 +64,8 @@ type HandlerStore interface { ShouldAllowDirectKeys() bool // GetHeaderMatcher returns the precompiled header matcher for header filtering GetHeaderMatcher() *HeaderMatcher - // GetAvailableProviders returns the list of available providers - GetAvailableProviders() []schemas.ModelProvider + // GetAvailableProviders returns the list of available providers for the given model + GetAvailableProviders(model string) []schemas.ModelProvider // GetStreamChunkInterceptor returns the interceptor for streaming chunks. // Returns nil if no plugins are loaded or streaming interception is not needed. GetStreamChunkInterceptor() StreamChunkInterceptor @@ -79,6 +79,10 @@ type HandlerStore interface { GetKVStore() *kvstore.Store // GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers GetMCPHeaderCombinedAllowlist() schemas.WhiteList + // ShouldAllowPerRequestStorageOverride returns whether per-request overrides for content storage are permitted + ShouldAllowPerRequestStorageOverride() bool + // ShouldAllowPerRequestRawOverride returns whether per-request overrides for raw request/response visibility are permitted + ShouldAllowPerRequestRawOverride() bool } // Retry backoff constants for validation @@ -2619,8 +2623,8 @@ func loadPlugins(ctx context.Context, config *Config, configData *ConfigData) { Order: plugin.Order, } if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) + if err := config.ValidateSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to validate semantic cache config: %v", err) } } config.PluginConfigs[i] = pluginConfig @@ -2689,16 +2693,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { } } - // Process semantic cache plugin - for i, plugin := range config.PluginConfigs { - if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) - } - config.PluginConfigs[i] = plugin - } - } - // Update store if config.ConfigStore != nil { logger.Debug("updating plugins in store") @@ -2720,11 +2714,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { Placement: plugin.Placement, Order: plugin.Order, } - if plugin.Name == semanticcache.PluginName { - if err := config.RemoveProviderKeysFromSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to remove provider keys from semantic cache config: %v", err) - } - } if err := config.ConfigStore.UpsertPlugin(ctx, pluginConfig); err != nil { logger.Warn("failed to update plugin: %v", err) } @@ -3255,6 +3244,16 @@ func (c *Config) ShouldAllowDirectKeys() bool { return c.ClientConfig.AllowDirectKeys } +// ShouldAllowPerRequestStorageOverride returns whether per-request content storage overrides are permitted. +func (c *Config) ShouldAllowPerRequestStorageOverride() bool { + return c.ClientConfig.AllowPerRequestContentStorageOverride +} + +// ShouldAllowPerRequestRawOverride returns whether per-request raw request/response overrides are permitted. +func (c *Config) ShouldAllowPerRequestRawOverride() bool { + return c.ClientConfig.AllowPerRequestRawOverride +} + // GetHeaderMatcher returns the precompiled header matcher for header filtering. // Lock-free via atomic pointer; safe for concurrent reads from hot paths. func (c *Config) GetHeaderMatcher() *HeaderMatcher { @@ -3375,6 +3374,19 @@ func (c *Config) GetPerUserOAuthMCPClientsForVirtualKey(ctx context.Context, vir return result } +// GetProvidersForModel returns the list of providers for a given model, sorted +// deterministically so callers picking providers[0] always get the same result. +func (c *Config) GetProvidersForModel(model string) []schemas.ModelProvider { + if c.ModelCatalog == nil { + return []schemas.ModelProvider{} + } + providers := c.ModelCatalog.GetProvidersForModel(model) + slices.SortFunc(providers, func(a, b schemas.ModelProvider) int { + return strings.Compare(string(a), string(b)) + }) + return providers +} + // GetPluginOrder returns the names of all base plugins in their sorted placement order. // This method is lock-free and safe for concurrent access from hot paths. // Do not modify the returned slice; it is a shared snapshot and must be treated read-only. @@ -4760,7 +4772,7 @@ func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.Provider return nil } -func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error { +func (c *Config) ValidateSemanticCacheConfig(config *schemas.PluginConfig) error { if config.Name != semanticcache.PluginName { return nil } @@ -4829,13 +4841,11 @@ func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConf } configMap["embedding_model"] = embeddingModel - keys, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)) - if err != nil { + // Validate that the provider is configured in the global client (keys are inherited automatically). + if _, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)); err != nil { return fmt.Errorf("failed to get provider config for %s: %w", provider, err) } - configMap["keys"] = keys.Keys - return nil } @@ -4882,44 +4892,25 @@ func semanticCacheConfigDimension(configMap map[string]interface{}) (int, bool, return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be numeric, got %T", dimensionVal) } } - -func (c *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstoreTables.TablePlugin) error { - if config.Name != semanticcache.PluginName { - return nil - } - - // Check if config.Config exists - if config.Config == nil { - return fmt.Errorf("semantic_cache plugin config is nil") - } - - // Type assert config.Config to map[string]interface{} - configMap, ok := config.Config.(map[string]interface{}) - if !ok { - return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config) - } - - configMap["keys"] = []schemas.Key{} - - config.Config = configMap - - return nil -} - -func (c *Config) GetAvailableProviders() []schemas.ModelProvider { +func (c *Config) GetAvailableProviders(model string) []schemas.ModelProvider { c.Mu.RLock() defer c.Mu.RUnlock() availableProviders := []schemas.ModelProvider{} - for provider, config := range c.Providers { - // Check if the provider has at least one key with a non-empty value. If so, add the provider to the list. - // If the provider allows empty keys, add the provider to the list. - for _, key := range config.Keys { - if key.Value.GetValue() != "" || bifrost.CanProviderKeyValueBeEmpty(provider) { - if key.Enabled != nil && !*key.Enabled { - continue + if c.ModelCatalog != nil { + availableProviders = c.ModelCatalog.GetProvidersForModel(model) + } else { + // Return all providers that have at least one key with a non-empty value. + for provider, config := range c.Providers { + // Check if the provider has at least one key with a non-empty value. If so, add the provider to the list. + // If the provider allows empty keys, add the provider to the list. + for _, key := range config.Keys { + if key.Value.GetValue() != "" || bifrost.CanProviderKeyValueBeEmpty(provider) { + if key.Enabled != nil && !*key.Enabled { + continue + } + availableProviders = append(availableProviders, provider) + break } - availableProviders = append(availableProviders, provider) - break } } } diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index e1fccb1c23..231baa167b 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -31,8 +31,20 @@ const ( // FastHTTPUserValueLargeResponseMode marks requests that streamed a large response body. // It is used by transport middleware to avoid re-buffering response bodies for post-hooks. FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode" + // FastHTTPUserValueModelCatalogResolution stores model catalog resolution metadata + // set by prepare*Request functions when a provider was auto-resolved. Picked up + // centrally in ConvertToBifrostContext to add the routing engine log. + FastHTTPUserValueModelCatalogResolution = "__bifrost_model_catalog_resolution" ) +// ModelCatalogResolution carries the result of an automatic provider lookup so +// that ConvertToBifrostContext can emit the routing engine log in one place. +type ModelCatalogResolution struct { + Model string + ResolvedProvider schemas.ModelProvider + AllProviders []schemas.ModelProvider +} + // ParseSessionIDFromBaggage extracts the session-id baggage member value. // It supports simple W3C baggage parsing sufficient for log grouping. func ParseSessionIDFromBaggage(header string) string { @@ -120,22 +132,34 @@ func ParseSessionIDFromBaggage(header string) string { // Parameters: // - ctx: The FastHTTP request context containing the original headers -// - allowDirectKeys: Whether to allow direct API key usage from headers +// - store: HandlerStore providing per-request policy flags and header matchers // // Returns: -// - *context.Context: A new cancellable context.Context containing the propagated values +// - *schemas.BifrostContext: A new cancellable context containing the propagated values // - context.CancelFunc: Function to cancel the context (should be called when request completes) // // Example Usage: // // fastCtx := &fasthttp.RequestCtx{...} -// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true, nil) +// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, handlerStore) // defer cancel() // Ensure cleanup // // bifrostCtx now contains propagated header values including Prometheus metrics, // // Maxim tracing data, MCP filters, governance keys, API keys, cache settings, // // session stickiness, and extra headers -func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) { +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, store HandlerStore) (*schemas.BifrostContext, context.CancelFunc) { + allowDirectKeys := false + var matcher *HeaderMatcher + mcpHeaderCombinedAllowlist := schemas.WhiteList{} + allowPerRequestStorageOverride := false + allowPerRequestRawOverride := false + if store != nil { + allowDirectKeys = store.ShouldAllowDirectKeys() + matcher = store.GetHeaderMatcher() + mcpHeaderCombinedAllowlist = store.GetMCPHeaderCombinedAllowlist() + allowPerRequestStorageOverride = store.ShouldAllowPerRequestStorageOverride() + allowPerRequestRawOverride = store.ShouldAllowPerRequestRawOverride() + } // Reuse a shared request-scoped context when available. var bifrostCtx *schemas.BifrostContext var cancel context.CancelFunc @@ -180,6 +204,22 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat ctx.VisitUserValuesAll(func(key, value any) { bifrostCtx.SetValue(key, value) }) + + // When a prepare*Request function resolved a provider via the model catalog, + // it stores the resolution info on the fasthttp context. Emit the routing + // engine log and mark the engine as used centrally here. + if res, ok := ctx.UserValue(FastHTTPUserValueModelCatalogResolution).(*ModelCatalogResolution); ok && res != nil { + providerStrs := make([]string, len(res.AllProviders)) + for i, p := range res.AllProviders { + providerStrs[i] = string(p) + } + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "No provider specified for model %s, found %d options in model catalog: [%s], selecting first: %s", + res.Model, len(res.AllProviders), strings.Join(providerStrs, ", "), res.ResolvedProvider, + )) + schemas.AppendToContextList(bifrostCtx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineModelCatalog) + } + // Initialize tags map for collecting maxim tags maximTags := make(map[string]string) // Initialize dimensions map for x-bf-dim-* headers @@ -404,7 +444,7 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat return true } // Apply configurable header filter - if !matcher.ShouldAllow(labelName) { + if matcher != nil && !matcher.ShouldAllow(labelName) { return true } // Append header value (allow multiple values for the same header) @@ -415,7 +455,7 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat // in the allowlist can be forwarded directly without the x-bf-eh- prefix. // This enables forwarding arbitrary headers like "anthropic-beta" directly. // Only applies when allowlist is non-empty (backward compatible). - if matcher.HasAllowlist() { + if matcher != nil && matcher.HasAllowlist() { if matcher.MatchesAllow(keyStr) { // Skip reserved x-bf-* headers (handled separately) if strings.HasPrefix(keyStr, "x-bf-") { @@ -462,6 +502,12 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat } return true } + if keyStr == "x-bf-disable-content-logging" { + if b, err := strconv.ParseBool(string(value)); err == nil { + bifrostCtx.SetValue(schemas.BifrostContextKeyDisableContentLogging, b) + } + return true + } // Parent request ID header (for linking MCP tool calls to parent LLM requests) if keyStr == "x-bf-parent-request-id" { if valueStr := strings.TrimSpace(string(value)); valueStr != "" { @@ -476,6 +522,12 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat } return true } + if keyStr == "x-bf-disable-content-logging" { + if b, err := strconv.ParseBool(string(value)); err == nil { + bifrostCtx.SetValue(schemas.BifrostContextKeyDisableContentLogging, b) + } + return true + } // Compat header: per-request override of compat plugin settings. // Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all). @@ -608,6 +660,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } + bifrostCtx.SetValue(schemas.BifrostContextKeyAllowPerRequestStorageOverride, allowPerRequestStorageOverride) + bifrostCtx.SetValue(schemas.BifrostContextKeyAllowPerRequestRawOverride, allowPerRequestRawOverride) return bifrostCtx, cancel } diff --git a/transports/bifrost-http/lib/ctx_test.go b/transports/bifrost-http/lib/ctx_test.go index 3f7ba72c44..924196249c 100644 --- a/transports/bifrost-http/lib/ctx_test.go +++ b/transports/bifrost-http/lib/ctx_test.go @@ -5,11 +5,32 @@ import ( "testing" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) +// testHandlerStore is a minimal HandlerStore for ctx tests. +type testHandlerStore struct { + allowDirectKeys bool + matcher *HeaderMatcher +} + +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return s.allowDirectKeys } +func (s testHandlerStore) GetHeaderMatcher() *HeaderMatcher { return s.matcher } +func (s testHandlerStore) GetAvailableProviders(_ string) []schemas.ModelProvider { return nil } +func (s testHandlerStore) GetStreamChunkInterceptor() StreamChunkInterceptor { return nil } +func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil } +func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } +func (s testHandlerStore) GetKVStore() *kvstore.Store { return nil } +func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return schemas.WhiteList{} +} +func (s testHandlerStore) ShouldAllowPerRequestStorageOverride() bool { return false } +func (s testHandlerStore) ShouldAllowPerRequestRawOverride() bool { return false } + func TestParseSessionIDFromBaggage(t *testing.T) { tests := []struct { name string @@ -39,7 +60,7 @@ func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared") ctx.SetUserValue(FastHTTPUserValueBifrostContext, base) - converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + converted, cancel := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancel() if converted == nil { @@ -59,13 +80,13 @@ func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} - first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + first, cancelFirst := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancelFirst() if first == nil { t.Fatal("expected first context to be non-nil") } - second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + second, cancelSecond := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancelSecond() if second == nil { t.Fatal("expected second context to be non-nil") @@ -92,7 +113,7 @@ func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing. ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked") ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{matcher: matcher}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -126,7 +147,7 @@ func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t // Security headers sent directly — should be blocked ctx.Request.Header.Set("proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{matcher: matcher}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -163,7 +184,7 @@ func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) { // Header not matching the pattern ctx.Request.Header.Set("openai-version", "should-not-forward") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{matcher: matcher}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -191,7 +212,7 @@ func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01") ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{matcher: matcher}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -219,7 +240,7 @@ func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value") ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{matcher: matcher}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -240,7 +261,7 @@ func TestConvertToBifrostContext_NilMatcher(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -254,7 +275,7 @@ func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancel() if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" { @@ -266,7 +287,7 @@ func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("baggage", "session-id= ") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, testHandlerStore{}) defer cancel() if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil { diff --git a/transports/bifrost-http/lib/semantic_cache_config_test.go b/transports/bifrost-http/lib/semantic_cache_config_test.go index 61c5da22c7..2d79bd9526 100644 --- a/transports/bifrost-http/lib/semantic_cache_config_test.go +++ b/transports/bifrost-http/lib/semantic_cache_config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyMode(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -19,7 +19,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) @@ -28,7 +28,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { require.False(t, hasKeys, "direct-only mode should not inject provider keys") } -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -39,18 +39,16 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProvider }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - _, hasKeys := configMap["keys"] - require.False(t, hasKeys, "direct-only mode should remove stale provider keys") _, hasEmbeddingModel := configMap["embedding_model"] require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model") } -func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeValidationPasses(t *testing.T) { config := &Config{ Providers: map[schemas.ModelProvider]configstore.ProviderConfig{ schemas.OpenAI: { @@ -73,19 +71,17 @@ func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - keys, ok := configMap["keys"].([]schemas.Key) - require.True(t, ok, "provider-backed mode should inject provider keys") - require.Len(t, keys, 1) - require.Equal(t, "openai-key", keys[0].Name) + _, hasKeys := configMap["keys"] + require.False(t, hasKeys, "keys are inherited from global client; they must not be injected into the plugin config") require.Equal(t, "openai", configMap["provider"]) } -func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { +func TestValidateSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -94,12 +90,12 @@ func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *tes }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'provider' for semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -109,12 +105,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -125,12 +121,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t * }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' > 1") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -140,12 +136,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbedding }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'embedding_model'") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -154,12 +150,12 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -168,7 +164,7 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testin }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index ca5a3da674..2bbdeadad5 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -969,6 +969,10 @@ func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path if err != nil { return s.updatePluginErrorStatus(name, "loading", err) } + // Wire the embedding executor on the new instance before syncing. + if semanticCachePlugin, ok := plugin.(*semanticcache.Plugin); ok { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } return s.SyncLoadedPlugin(ctx, name, plugin, placement, order) } @@ -1372,7 +1376,6 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } wg.Wait() } - logger.Info("models added to catalog") s.Config.SetBifrostClient(s.Client) // Initialize routes @@ -1394,6 +1397,11 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { apiMiddlewares = append(apiMiddlewares, s.AuthMiddleware.APIMiddleware()) } } + // Add semantic cache plugin embedding request executor if it exists + semanticCachePlugin, err := lib.FindPluginAs[*semanticcache.Plugin](s.Config, semanticcache.PluginName) + if err == nil && semanticCachePlugin != nil { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } // Register routes err = s.RegisterAPIRoutes(s.Ctx, s, apiMiddlewares...) if err != nil { diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb2..53550c479e 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1,4 @@ +- fix: routing rule chain no longer halts when a chain_rule resolves to the same provider/model (self-loop); subsequent rules now continue to evaluate correctly +- fix: response extra fields request type corruption for streaming requests on high concurrency +- feat: added support for per-request content logging toggle via `x-bf-disable-content-logging` header +- feat: auto-resolve provider when model string has no provider prefix diff --git a/transports/config.schema.json b/transports/config.schema.json index c453811892..fd69a78878 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -68,6 +68,16 @@ "type": "boolean", "description": "Disable logging of sensitive content (inputs, outputs, embeddings, etc.)" }, + "allow_per_request_content_storage_override": { + "type": "boolean", + "description": "Allow individual requests to override content storage via the x-bf-disable-content-logging header or context key, and to opt in to raw-byte persistence in logs via x-bf-store-raw-request-response. When false (default), the global disable_content_logging setting is authoritative and per-request storage overrides are ignored. Does not control sending raw bytes back to callers — see allow_per_request_raw_override.", + "default": false + }, + "allow_per_request_raw_override": { + "type": "boolean", + "description": "Allow individual requests to send raw provider request/response bytes back to the caller via the x-bf-send-back-raw-request and x-bf-send-back-raw-response headers. When false (default), the provider-level send_back_raw_request/response settings are authoritative and per-request overrides are ignored. Does not affect raw-byte persistence in logs — see allow_per_request_content_storage_override.", + "default": false + }, "disable_db_pings_in_health": { "type": "boolean", "description": "Disable DB pings in health check", @@ -1259,6 +1269,11 @@ "type": "boolean", "description": "Disable logging of request and response content" }, + "allow_per_request_content_storage_override": { + "type": "boolean", + "description": "Allow individual requests to override content storage via the x-bf-disable-content-logging header or context key, and to opt in to raw-byte persistence in logs via x-bf-store-raw-request-response. When false (default), per-request storage overrides are ignored. Does not control sending raw bytes back to callers — see allow_per_request_raw_override.", + "default": false + }, "logging_headers": { "type": "array", "items": { @@ -1378,13 +1393,6 @@ "huggingface" ] }, - "keys": { - "type": "array", - "description": "API keys for the embedding provider. These are injected at runtime for config-driven setups and are not needed for direct caching with dimension: 1.", - "items": { - "type": "string" - } - }, "embedding_model": { "type": "string", "description": "Model to use for generating embeddings in provider-backed semantic caching. Required when provider is set and not allowed in direct-only mode." diff --git a/ui/app/workspace/config/views/loggingView.tsx b/ui/app/workspace/config/views/loggingView.tsx index bb78edc733..ef96fa6a02 100644 --- a/ui/app/workspace/config/views/loggingView.tsx +++ b/ui/app/workspace/config/views/loggingView.tsx @@ -31,6 +31,8 @@ export default function LoggingView() { return ( localConfig.enable_logging !== config.enable_logging || localConfig.disable_content_logging !== config.disable_content_logging || + localConfig.allow_per_request_content_storage_override !== config.allow_per_request_content_storage_override || + localConfig.allow_per_request_raw_override !== config.allow_per_request_raw_override || localConfig.log_retention_days !== config.log_retention_days || localConfig.hide_deleted_virtual_keys_in_filters !== config.hide_deleted_virtual_keys_in_filters || JSON.stringify(localConfig.logging_headers || []) !== JSON.stringify(config.logging_headers || []) @@ -130,6 +132,50 @@ export default function LoggingView() { )} + {/* Allow Per-Request Content Storage Override - Only show when logging is enabled */} + {localConfig.enable_logging && bifrostConfig?.is_logs_connected && ( +
+
+ +

+ When enabled, individual requests can override the global content logging setting using the{" "} + x-bf-disable-content-logging header or context key, and can opt-in to persisting raw provider bytes in logs using the{" "} + x-bf-store-raw-request-response header. Raw data is not stored by default — each request must explicitly opt in. Does not control sending raw bytes back to callers — see Allow Per-Request Raw Override. +

+
+ handleConfigChange("allow_per_request_content_storage_override", checked)} + /> +
+ )} + + {/* Allow Per-Request Raw Override */} +
+
+ +

+ When enabled, individual requests can send raw provider request/response bytes back to the caller using the{" "} + x-bf-send-back-raw-request and{" "} + x-bf-send-back-raw-response headers. Does not affect log storage — raw-byte persistence in logs is controlled by Allow Per-Request Content Storage Override. +

+
+ handleConfigChange("allow_per_request_raw_override", checked)} + /> +
+ {/* Log Retention Days */} {localConfig.enable_logging && bifrostConfig?.is_logs_connected && (
diff --git a/ui/app/workspace/config/views/pluginsForm.tsx b/ui/app/workspace/config/views/pluginsForm.tsx index 33477ec8a2..dcd459de4c 100644 --- a/ui/app/workspace/config/views/pluginsForm.tsx +++ b/ui/app/workspace/config/views/pluginsForm.tsx @@ -48,9 +48,6 @@ const normalizeCacheConfigForSave = (config: EditorCacheConfig) => { if (config.updated_at !== undefined) { normalized.updated_at = config.updated_at; } - if (config.keys !== undefined) { - normalized.keys = config.keys; - } const provider = config.provider?.trim(); const embeddingModel = config.embedding_model?.trim(); @@ -375,7 +372,7 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps)

API keys for the embedding provider will be inherited from the main provider configuration. The semantic cache will use - the configured provider's keys automatically. Updates in keys will be reflected on Bifrost restart. + the configured provider's keys automatically.

diff --git a/ui/app/workspace/logs/sheets/logDetailView.tsx b/ui/app/workspace/logs/sheets/logDetailView.tsx index 6a9fe82e57..8341e3edb5 100644 --- a/ui/app/workspace/logs/sheets/logDetailView.tsx +++ b/ui/app/workspace/logs/sheets/logDetailView.tsx @@ -1456,7 +1456,8 @@ export function LogDetailView({ ) : null} - {!isPassthrough && ( + + {!isPassthrough && !log.list_models_output && ( Tools {log.params?.tools?.length ? ( @@ -1979,6 +1980,28 @@ export function LogDetailView({ )} + {log.list_models_output && ( + JSON.stringify(log.list_models_output, null, 2)} + > + + + )} + {(log.error_details?.error.message || log.error_details?.error.error != null) && (
@@ -2246,32 +2269,10 @@ export function LogDetailView({ )} - {log.list_models_output && ( - JSON.stringify(log.list_models_output, null, 2)} - > - - - )} {!rawRequest && !rawResponse && !passthroughRequestBody && - !passthroughResponseBody && - !log.list_models_output && ( + !passthroughResponseBody && (
No raw JSON available.
diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 9b89847af6..25d645fda4 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -472,6 +472,8 @@ export interface CoreConfig { prometheus_labels: string[]; enable_logging: boolean; disable_content_logging: boolean; + allow_per_request_content_storage_override: boolean; + allow_per_request_raw_override: boolean; disable_db_pings_in_health: boolean; log_retention_days: number; enforce_auth_on_inference: boolean; @@ -500,6 +502,8 @@ export const DefaultCoreConfig: CoreConfig = { prometheus_labels: [], enable_logging: true, disable_content_logging: false, + allow_per_request_content_storage_override: false, + allow_per_request_raw_override: false, disable_db_pings_in_health: false, log_retention_days: 365, enforce_auth_on_inference: false, @@ -536,13 +540,11 @@ interface BaseCacheConfig { export interface DirectCacheConfig extends BaseCacheConfig { dimension: 1; provider?: undefined; - keys?: ModelProviderKey[]; embedding_model?: undefined; } export interface ProviderBackedCacheConfig extends BaseCacheConfig { provider: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model: string; dimension: number; } @@ -551,7 +553,6 @@ export type CacheConfig = DirectCacheConfig | ProviderBackedCacheConfig; export interface EditorCacheConfig extends BaseCacheConfig { provider?: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model?: string; dimension?: number; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 60ef6c2572..7a1fbc0834 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -4,1020 +4,1115 @@ import { z } from "zod"; // Global error map - turns Zod's default messages into readable, human-friendly ones. // Individual schemas can still override by passing their own message. z.config({ - customError: (issue) => { - if (issue.code === "invalid_type") { - // Field is missing / undefined - if (issue.input === undefined || issue.input === null) { - return "This field is required"; - } - const expected = issue.expected; - const received = typeof issue.input; - if (expected === "number") return "Must be a valid number"; - if (expected === "string") return "Must be a valid text value"; - if (expected === "boolean") return "Must be true or false"; - return `Expected ${expected}, received ${received}`; - } - if (issue.code === "too_small") { - if (issue.origin === "string" && issue.minimum === 1) { - return "This field is required"; - } - if (issue.origin === "number") { - return `Must be at least ${issue.minimum}`; - } - if (issue.origin === "array" && issue.minimum === 1) { - return "At least one item is required"; - } - } - if (issue.code === "too_big") { - if (issue.origin === "number") { - return `Must be at most ${issue.maximum}`; - } - if (issue.origin === "string") { - return `Must be at most ${issue.maximum} characters`; - } - } - if (issue.code === "invalid_format") { - if (issue.format === "url") return "Must be a valid URL"; - if (issue.format === "email") return "Must be a valid email"; - } - return undefined; // fall back to Zod default - }, + customError: (issue) => { + if (issue.code === "invalid_type") { + // Field is missing / undefined + if (issue.input === undefined || issue.input === null) { + return "This field is required"; + } + const expected = issue.expected; + const received = typeof issue.input; + if (expected === "number") return "Must be a valid number"; + if (expected === "string") return "Must be a valid text value"; + if (expected === "boolean") return "Must be true or false"; + return `Expected ${expected}, received ${received}`; + } + if (issue.code === "too_small") { + if (issue.origin === "string" && issue.minimum === 1) { + return "This field is required"; + } + if (issue.origin === "number") { + return `Must be at least ${issue.minimum}`; + } + if (issue.origin === "array" && issue.minimum === 1) { + return "At least one item is required"; + } + } + if (issue.code === "too_big") { + if (issue.origin === "number") { + return `Must be at most ${issue.maximum}`; + } + if (issue.origin === "string") { + return `Must be at most ${issue.maximum} characters`; + } + } + if (issue.code === "invalid_format") { + if (issue.format === "url") return "Must be a valid URL"; + if (issue.format === "email") return "Must be a valid email"; + } + return undefined; // fall back to Zod default + }, }); // Base Zod schemas matching the TypeScript types // Known provider schema -export const knownProviderSchema = z.enum(KnownProvidersNames as unknown as [string, ...string[]]); +export const knownProviderSchema = z.enum( + KnownProvidersNames as unknown as [string, ...string[]], +); // Custom provider name schema (branded type simulation) -export const customProviderNameSchema = z.string().min(1, "Custom provider name is required"); +export const customProviderNameSchema = z + .string() + .min(1, "Custom provider name is required"); // Model provider name schema (union of known and custom providers) -export const modelProviderNameSchema = z.union([knownProviderSchema, customProviderNameSchema]); +export const modelProviderNameSchema = z.union([ + knownProviderSchema, + customProviderNameSchema, +]); // EnvVar schema - matches the Go EnvVar type from schemas/env.go export const _envVarBase = z.object({ - value: z.string().optional(), - env_var: z.string().optional(), - from_env: z.boolean().optional(), + value: z.string().optional(), + env_var: z.string().optional(), + from_env: z.boolean().optional(), }); // Extending the base schema export const envVarSchema = Object.assign(_envVarBase, { - required: (message: string) => _envVarBase.refine((v) => !!v?.value?.trim() || !!v?.env_var?.trim(), message), + required: (message: string) => + _envVarBase.refine( + (v) => !!v?.value?.trim() || !!v?.env_var?.trim(), + message, + ), }); // Helper to check if an envVar field has a value or env reference -function isEnvVarSet(v: { value?: string; env_var?: string } | undefined): boolean { - if (!v) return false; - return !!v.value?.trim() || !!v.env_var?.trim(); +function isEnvVarSet( + v: { value?: string; env_var?: string } | undefined, +): boolean { + if (!v) return false; + return !!v.value?.trim() || !!v.env_var?.trim(); } // Azure key config schema export const azureKeyConfigSchema = z - .object({ - _auth_type: z.enum(["api_key", "entra_id", "default_credential"]).optional(), - endpoint: envVarSchema.optional(), - api_version: envVarSchema.optional(), - client_id: envVarSchema.optional(), - client_secret: envVarSchema.optional(), - tenant_id: envVarSchema.optional(), - scopes: z.array(z.string()).optional(), - }) - .refine((data) => isEnvVarSet(data.endpoint), { - message: "Endpoint is required", - path: ["endpoint"], - }) - .refine( - (data) => { - // When using Entra ID, all three fields are required - if (data._auth_type === "entra_id") { - return isEnvVarSet(data.client_id) && isEnvVarSet(data.client_secret) && isEnvVarSet(data.tenant_id); - } - // Otherwise, if any Entra ID field is set, all three must be set - const hasClientId = isEnvVarSet(data.client_id); - const hasClientSecret = isEnvVarSet(data.client_secret); - const hasTenantId = isEnvVarSet(data.tenant_id); - const anyEntraField = hasClientId || hasClientSecret || hasTenantId; - if (!anyEntraField) return true; - return hasClientId && hasClientSecret && hasTenantId; - }, - { - message: "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", - path: ["client_id"], - }, - ); + .object({ + _auth_type: z + .enum(["api_key", "entra_id", "default_credential"]) + .optional(), + endpoint: envVarSchema.optional(), + api_version: envVarSchema.optional(), + client_id: envVarSchema.optional(), + client_secret: envVarSchema.optional(), + tenant_id: envVarSchema.optional(), + scopes: z.array(z.string()).optional(), + }) + .refine((data) => isEnvVarSet(data.endpoint), { + message: "Endpoint is required", + path: ["endpoint"], + }) + .refine( + (data) => { + // When using Entra ID, all three fields are required + if (data._auth_type === "entra_id") { + return ( + isEnvVarSet(data.client_id) && + isEnvVarSet(data.client_secret) && + isEnvVarSet(data.tenant_id) + ); + } + // Otherwise, if any Entra ID field is set, all three must be set + const hasClientId = isEnvVarSet(data.client_id); + const hasClientSecret = isEnvVarSet(data.client_secret); + const hasTenantId = isEnvVarSet(data.tenant_id); + const anyEntraField = hasClientId || hasClientSecret || hasTenantId; + if (!anyEntraField) return true; + return hasClientId && hasClientSecret && hasTenantId; + }, + { + message: + "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", + path: ["client_id"], + }, + ); // Vertex key config schema export const vertexKeyConfigSchema = z - .object({ - _auth_type: z.enum(["service_account", "service_account_json", "api_key"]).optional(), - project_id: envVarSchema.optional(), - project_number: envVarSchema.optional(), - region: envVarSchema.optional(), - auth_credentials: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.project_id), { - message: "Project ID is required", - path: ["project_id"], - }) - .refine((data) => isEnvVarSet(data.region), { - message: "Region is required", - path: ["region"], - }) - .refine( - (data) => { - // When using service_account_json auth, auth_credentials is required - if (data._auth_type === "service_account_json") { - return isEnvVarSet(data.auth_credentials); - } - return true; - }, - { - message: "Auth Credentials is required for service account JSON authentication", - path: ["auth_credentials"], - }, - ); + .object({ + _auth_type: z + .enum(["service_account", "service_account_json", "api_key"]) + .optional(), + project_id: envVarSchema.optional(), + project_number: envVarSchema.optional(), + region: envVarSchema.optional(), + auth_credentials: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.project_id), { + message: "Project ID is required", + path: ["project_id"], + }) + .refine((data) => isEnvVarSet(data.region), { + message: "Region is required", + path: ["region"], + }) + .refine( + (data) => { + // When using service_account_json auth, auth_credentials is required + if (data._auth_type === "service_account_json") { + return isEnvVarSet(data.auth_credentials); + } + return true; + }, + { + message: + "Auth Credentials is required for service account JSON authentication", + path: ["auth_credentials"], + }, + ); // S3 bucket configuration for Bedrock batch operations export const s3BucketConfigSchema = z.object({ - bucket_name: z.string().min(1, "Bucket name is required"), - prefix: z.string().optional(), - is_default: z.boolean().optional(), + bucket_name: z.string().min(1, "Bucket name is required"), + prefix: z.string().optional(), + is_default: z.boolean().optional(), }); export const batchS3ConfigSchema = z.object({ - buckets: z.array(s3BucketConfigSchema).optional(), + buckets: z.array(s3BucketConfigSchema).optional(), }); // Bedrock key config schema export const bedrockKeyConfigSchema = z - .object({ - _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), - access_key: envVarSchema.optional(), - secret_key: envVarSchema.optional(), - session_token: envVarSchema.optional(), - region: envVarSchema.optional(), - role_arn: envVarSchema.optional(), - external_id: envVarSchema.optional(), - session_name: envVarSchema.optional(), - arn: envVarSchema.optional(), - batch_s3_config: batchS3ConfigSchema.optional(), - }) - .refine( - (data) => { - // Region is required for Bedrock - return isEnvVarSet(data.region); - }, - { - message: "Region is required", - path: ["region"], - }, - ) - .refine( - (data) => { - // When using explicit credentials, both access_key and secret_key are required - if (data._auth_type === "explicit") { - return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); - } - // Otherwise, if either is set both must be set - const hasAccessKey = isEnvVarSet(data.access_key); - const hasSecretKey = isEnvVarSet(data.secret_key); - if (!hasAccessKey && !hasSecretKey) return true; - return hasAccessKey && hasSecretKey; - }, - { - message: "Both Access Key and Secret Key are required for explicit credentials", - path: ["access_key"], - }, - ); + .object({ + _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), + access_key: envVarSchema.optional(), + secret_key: envVarSchema.optional(), + session_token: envVarSchema.optional(), + region: envVarSchema.optional(), + role_arn: envVarSchema.optional(), + external_id: envVarSchema.optional(), + session_name: envVarSchema.optional(), + arn: envVarSchema.optional(), + batch_s3_config: batchS3ConfigSchema.optional(), + }) + .refine( + (data) => { + // Region is required for Bedrock + return isEnvVarSet(data.region); + }, + { + message: "Region is required", + path: ["region"], + }, + ) + .refine( + (data) => { + // When using explicit credentials, both access_key and secret_key are required + if (data._auth_type === "explicit") { + return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); + } + // Otherwise, if either is set both must be set + const hasAccessKey = isEnvVarSet(data.access_key); + const hasSecretKey = isEnvVarSet(data.secret_key); + if (!hasAccessKey && !hasSecretKey) return true; + return hasAccessKey && hasSecretKey; + }, + { + message: + "Both Access Key and Secret Key are required for explicit credentials", + path: ["access_key"], + }, + ); // VLLM key config schema export const vllmKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - model_name: z.string().trim().min(1, "Model name is required"), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + model_name: z.string().trim().min(1, "Model name is required"), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); export const replicateKeyConfigSchema = z.object({ - use_deployments_endpoint: z.boolean(), + use_deployments_endpoint: z.boolean(), }); // Ollama key config schema export const ollamaKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // SGL key config schema export const sglKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // Model provider key schema export const modelProviderKeySchema = z - .object({ - id: z.string().min(1, "Id is required"), - name: z.string().min(1, "Name is required"), - value: envVarSchema.optional(), - models: z.array(z.string()).optional().default(["*"]), - blacklisted_models: z.array(z.string()).default([]).optional(), - weight: z - .union([z.number(), z.string()]) - .transform((val, ctx) => { - if (typeof val === "number") return val; - if (val.trim() === "") return 1.0; - // Use Number() rather than parseFloat() so that strings like "0.5abc" - // are rejected outright instead of silently parsing to 0.5. - const num = Number(val); - if (!Number.isFinite(num)) { - ctx.addIssue({ - code: "custom", - message: "Weight must be a valid number between 0 and 1", - }); - return z.NEVER; - } - return num; - }) - .pipe(z.number().min(0, "Weight must be equal to or greater than 0").max(1, "Weight must be equal to or less than 1")), - aliases: z.record(z.string(), z.string()).optional(), - azure_key_config: azureKeyConfigSchema.optional(), - vertex_key_config: vertexKeyConfigSchema.optional(), - bedrock_key_config: bedrockKeyConfigSchema.optional(), - vllm_key_config: vllmKeyConfigSchema.optional(), - replicate_key_config: replicateKeyConfigSchema.optional(), - ollama_key_config: ollamaKeyConfigSchema.optional(), - sgl_key_config: sglKeyConfigSchema.optional(), - use_for_batch_api: z.boolean().optional(), - enabled: z.boolean().optional(), - }) - .refine( - (data) => { - // Providers with dedicated config that never need a top-level API key - if (data.vllm_key_config || data.replicate_key_config || data.ollama_key_config || data.sgl_key_config) { - return true; - } - // Azure requires API key only when using api_key auth - if (data.azure_key_config) { - if (data.azure_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Bedrock only requires API key when using api_key auth - if (data.bedrock_key_config) { - if (data.bedrock_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Vertex requires API key only when using api_key auth - if (data.vertex_key_config) { - if (data.vertex_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Otherwise, value is required - return isEnvVarSet(data.value); - }, - { - message: "API Key is required", - path: ["value"], - }, - ); + .object({ + id: z.string().min(1, "Id is required"), + name: z.string().min(1, "Name is required"), + value: envVarSchema.optional(), + models: z.array(z.string()).optional().default(["*"]), + blacklisted_models: z.array(z.string()).default([]).optional(), + weight: z + .union([z.number(), z.string()]) + .transform((val, ctx) => { + if (typeof val === "number") return val; + if (val.trim() === "") return 1.0; + // Use Number() rather than parseFloat() so that strings like "0.5abc" + // are rejected outright instead of silently parsing to 0.5. + const num = Number(val); + if (!Number.isFinite(num)) { + ctx.addIssue({ + code: "custom", + message: "Weight must be a valid number between 0 and 1", + }); + return z.NEVER; + } + return num; + }) + .pipe( + z + .number() + .min(0, "Weight must be equal to or greater than 0") + .max(1, "Weight must be equal to or less than 1"), + ), + aliases: z.record(z.string(), z.string()).optional(), + azure_key_config: azureKeyConfigSchema.optional(), + vertex_key_config: vertexKeyConfigSchema.optional(), + bedrock_key_config: bedrockKeyConfigSchema.optional(), + vllm_key_config: vllmKeyConfigSchema.optional(), + replicate_key_config: replicateKeyConfigSchema.optional(), + ollama_key_config: ollamaKeyConfigSchema.optional(), + sgl_key_config: sglKeyConfigSchema.optional(), + use_for_batch_api: z.boolean().optional(), + enabled: z.boolean().optional(), + }) + .refine( + (data) => { + // Providers with dedicated config that never need a top-level API key + if ( + data.vllm_key_config || + data.replicate_key_config || + data.ollama_key_config || + data.sgl_key_config + ) { + return true; + } + // Azure requires API key only when using api_key auth + if (data.azure_key_config) { + if (data.azure_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Bedrock only requires API key when using api_key auth + if (data.bedrock_key_config) { + if (data.bedrock_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Vertex requires API key only when using api_key auth + if (data.vertex_key_config) { + if (data.vertex_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Otherwise, value is required + return isEnvVarSet(data.value); + }, + { + message: "API Key is required", + path: ["value"], + }, + ); // Network config schema export const networkConfigSchema = z - .object({ - base_url: z.union([z.string().url("Must be a valid URL"), z.string().length(0)]).optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z - .number() - .min(1, "Timeout must be greater than 0 seconds") - .max(3600, "Timeout must be less than 3600 seconds"), - max_retries: z.number().min(0, "Max retries must be greater than 0").max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.number().min(100), - retry_backoff_max: z.number().min(100), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z - .number() - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z - .number() - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "retry_backoff_initial must be <= retry_backoff_max", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([z.string().url("Must be a valid URL"), z.string().length(0)]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z + .number() + .min(1, "Timeout must be greater than 0 seconds") + .max(3600, "Timeout must be less than 3600 seconds"), + max_retries: z + .number() + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.number().min(100), + retry_backoff_max: z.number().min(100), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z + .number() + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z + .number() + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "retry_backoff_initial must be <= retry_backoff_max", + path: ["retry_backoff_initial"], + }); // Network form schema - more lenient for form inputs export const networkFormConfigSchema = z - .object({ - base_url: z - .union([ - z - .string() - .url("Must be a valid URL") - .refine((url) => url.startsWith("https://") || url.startsWith("http://"), { - message: "Must be a valid HTTP or HTTPS URL", - }), - z.string().length(0), - ]) - .optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z.coerce - .number("Timeout must be a number") - .min(1, "Timeout must be greater than 0 seconds") - .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), - max_retries: z.coerce - .number("Max retries must be a number") - .min(0, "Max retries must be greater than 0") - .max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.coerce - .number("Retry backoff initial must be a number") - .min(100, "Retry backoff initial must be at least 100ms") - .max(1000000, "Retry backoff initial must be at most 1000000ms"), - retry_backoff_max: z.coerce - .number("Retry backoff max must be a number") - .min(100, "Retry backoff max must be at least 100ms") - .max(1000000, "Retry backoff max must be at most 1000000ms"), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z.coerce - .number("Stream idle timeout must be a number") - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z.coerce - .number("Max connections must be a number") - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "Initial backoff must be less than or equal to max backoff", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([ + z + .string() + .url("Must be a valid URL") + .refine( + (url) => url.startsWith("https://") || url.startsWith("http://"), + { + message: "Must be a valid HTTP or HTTPS URL", + }, + ), + z.string().length(0), + ]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z.coerce + .number("Timeout must be a number") + .min(1, "Timeout must be greater than 0 seconds") + .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), + max_retries: z.coerce + .number("Max retries must be a number") + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.coerce + .number("Retry backoff initial must be a number") + .min(100, "Retry backoff initial must be at least 100ms") + .max(1000000, "Retry backoff initial must be at most 1000000ms"), + retry_backoff_max: z.coerce + .number("Retry backoff max must be a number") + .min(100, "Retry backoff max must be at least 100ms") + .max(1000000, "Retry backoff max must be at most 1000000ms"), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z.coerce + .number("Stream idle timeout must be a number") + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z.coerce + .number("Max connections must be a number") + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "Initial backoff must be less than or equal to max backoff", + path: ["retry_backoff_initial"], + }); // Concurrency and buffer size schema export const concurrencyAndBufferSizeSchema = z.object({ - concurrency: z.number().min(1, "Concurrency must be greater than 0").max(100, "Concurrency must be less than or equal to 100"), - buffer_size: z.number().min(1, "Buffer size must be greater than 0").max(1000, "Buffer size must be less than or equal to 1000"), + concurrency: z + .number() + .min(1, "Concurrency must be greater than 0") + .max(100, "Concurrency must be less than or equal to 100"), + buffer_size: z + .number() + .min(1, "Buffer size must be greater than 0") + .max(1000, "Buffer size must be less than or equal to 1000"), }); // Proxy type schema -export const proxyTypeSchema = z.enum(["none", "http", "socks5", "environment"]); +export const proxyTypeSchema = z.enum([ + "none", + "http", + "socks5", + "environment", +]); // Proxy config schema export const proxyConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => - !(data.type === "http" || data.type === "socks5") || - data.url?.from_env === true || - (data.url?.value && data.url.value.trim().length > 0), - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - if ((data.type === "http" || data.type === "socks5") && data.url?.value?.trim()) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", path: ["url"] }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => + !(data.type === "http" || data.type === "socks5") || + data.url?.from_env === true || + (data.url?.value && data.url.value.trim().length > 0), + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value?.trim() + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // Proxy form schema - more lenient for form inputs with conditional validation export const proxyFormConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => { - if (data.type === "none") { - return true; - } - // URL is required when proxy type is http or socks5 - if (data.type === "http" || data.type === "socks5") { - // Env-backed URLs may have empty resolved value before env resolution. - if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) return true; - // Literal URLs must be non-empty. - if (!data.url?.value || data.url.value.trim().length === 0) return false; - } - return true; - }, - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - // URL must be valid format when provided and proxy type requires it - if ((data.type === "http" || data.type === "socks5") && data.url?.value && data.url.value.trim().length > 0) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { - message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", - path: ["url"], - }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => { + if (data.type === "none") { + return true; + } + // URL is required when proxy type is http or socks5 + if (data.type === "http" || data.type === "socks5") { + // Env-backed URLs may have empty resolved value before env resolution. + if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) + return true; + // Literal URLs must be non-empty. + if (!data.url?.value || data.url.value.trim().length === 0) + return false; + } + return true; + }, + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + // URL must be valid format when provided and proxy type requires it + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value && + data.url.value.trim().length > 0 + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // OpenAI Config tab export const openaiConfigFormSchema = z.object({ - disable_store: z.boolean(), + disable_store: z.boolean(), }); export type OpenAIConfigFormSchema = z.infer; // Allowed requests schema export const allowedRequestsSchema = z.object({ - text_completion: z.boolean(), - text_completion_stream: z.boolean(), - chat_completion: z.boolean(), - chat_completion_stream: z.boolean(), - responses: z.boolean(), - responses_stream: z.boolean(), - embedding: z.boolean(), - speech: z.boolean(), - speech_stream: z.boolean(), - transcription: z.boolean(), - transcription_stream: z.boolean(), - image_generation: z.boolean(), - image_generation_stream: z.boolean(), - image_edit: z.boolean(), - image_edit_stream: z.boolean(), - image_variation: z.boolean(), - ocr: z.boolean().optional(), - ocr_stream: z.boolean().optional(), - rerank: z.boolean(), - video_generation: z.boolean(), - video_retrieve: z.boolean(), - video_download: z.boolean(), - video_delete: z.boolean(), - video_list: z.boolean(), - video_remix: z.boolean(), - count_tokens: z.boolean(), - list_models: z.boolean(), - websocket_responses: z.boolean(), - realtime: z.boolean(), + text_completion: z.boolean(), + text_completion_stream: z.boolean(), + chat_completion: z.boolean(), + chat_completion_stream: z.boolean(), + responses: z.boolean(), + responses_stream: z.boolean(), + embedding: z.boolean(), + speech: z.boolean(), + speech_stream: z.boolean(), + transcription: z.boolean(), + transcription_stream: z.boolean(), + image_generation: z.boolean(), + image_generation_stream: z.boolean(), + image_edit: z.boolean(), + image_edit_stream: z.boolean(), + image_variation: z.boolean(), + ocr: z.boolean().optional(), + ocr_stream: z.boolean().optional(), + rerank: z.boolean(), + video_generation: z.boolean(), + video_retrieve: z.boolean(), + video_download: z.boolean(), + video_delete: z.boolean(), + video_list: z.boolean(), + video_remix: z.boolean(), + count_tokens: z.boolean(), + list_models: z.boolean(), + websocket_responses: z.boolean(), + realtime: z.boolean(), }); // Custom provider config schema export const customProviderConfigSchema = z - .object({ - base_provider_type: knownProviderSchema, - is_key_less: z.boolean().optional(), - allowed_requests: allowedRequestsSchema.optional(), - request_path_overrides: z.record(z.string(), z.string().optional()).optional(), - }) - .refine( - (data) => { - if (data.base_provider_type === "bedrock") { - return !data.is_key_less; - } - return true; - }, - { - message: "Is keyless is not allowed for Bedrock", - path: ["is_key_less"], - }, - ); + .object({ + base_provider_type: knownProviderSchema, + is_key_less: z.boolean().optional(), + allowed_requests: allowedRequestsSchema.optional(), + request_path_overrides: z + .record(z.string(), z.string().optional()) + .optional(), + }) + .refine( + (data) => { + if (data.base_provider_type === "bedrock") { + return !data.is_key_less; + } + return true; + }, + { + message: "Is keyless is not allowed for Bedrock", + path: ["is_key_less"], + }, + ); // Form-specific custom provider config schema export const formCustomProviderConfigSchema = z - .object({ - base_provider_type: z.string().min(1, "Base provider type is required"), - is_key_less: z.boolean().optional(), - allowed_requests: allowedRequestsSchema.optional(), - request_path_overrides: z.record(z.string(), z.string().optional()).optional(), - }) - .refine( - (data) => { - if (data.base_provider_type === "bedrock") { - return !data.is_key_less; - } - return true; - }, - { - message: "Is keyless is not allowed for Bedrock", - path: ["is_key_less"], - }, - ); + .object({ + base_provider_type: z.string().min(1, "Base provider type is required"), + is_key_less: z.boolean().optional(), + allowed_requests: allowedRequestsSchema.optional(), + request_path_overrides: z + .record(z.string(), z.string().optional()) + .optional(), + }) + .refine( + (data) => { + if (data.base_provider_type === "bedrock") { + return !data.is_key_less; + } + return true; + }, + { + message: "Is keyless is not allowed for Bedrock", + path: ["is_key_less"], + }, + ); // Full model provider config schema export const modelProviderConfigSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), }); // Model provider schema export const modelProviderSchema = modelProviderConfigSchema.extend({ - name: modelProviderNameSchema, + name: modelProviderNameSchema, }); // Form-specific model provider config schema export const formModelProviderConfigSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: formCustomProviderConfigSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: formCustomProviderConfigSchema.optional(), }); // Flexible model provider schema for form data - allows any string for name export const formModelProviderSchema = formModelProviderConfigSchema.extend({ - name: z.string().min(1, "Provider name is required"), + name: z.string().min(1, "Provider name is required"), }); // Add provider request schema export const addProviderRequestSchema = z.object({ - provider: modelProviderNameSchema, - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), - openai_config: openaiConfigFormSchema.optional(), + provider: modelProviderNameSchema, + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), + openai_config: openaiConfigFormSchema.optional(), }); // Update provider request schema export const updateProviderRequestSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema, - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema, - proxy_config: proxyConfigSchema, - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), - openai_config: openaiConfigFormSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema, + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema, + proxy_config: proxyConfigSchema, + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), + openai_config: openaiConfigFormSchema.optional(), }); // Cache config schema const baseCacheConfigSchema = z.object({ - ttl_seconds: z.number().int().min(1).default(3600), - threshold: z.number().min(0).max(1).default(0.8), - conversation_history_threshold: z.number().int().min(0).optional(), - exclude_system_prompt: z.boolean().optional(), - cache_by_model: z.boolean().default(false), - cache_by_provider: z.boolean().default(false), - created_at: z.string().optional(), - updated_at: z.string().optional(), + ttl_seconds: z.number().int().min(1).default(3600), + threshold: z.number().min(0).max(1).default(0.8), + conversation_history_threshold: z.number().int().min(0).optional(), + exclude_system_prompt: z.boolean().optional(), + cache_by_model: z.boolean().default(false), + cache_by_provider: z.boolean().default(false), + created_at: z.string().optional(), + updated_at: z.string().optional(), }); const directCacheConfigSchema = baseCacheConfigSchema - .extend({ - dimension: z.literal(1), - keys: z.array(modelProviderKeySchema).optional(), - }) - .strict(); + .extend({ + dimension: z.literal(1), + keys: z.array(modelProviderKeySchema).optional(), + }) + .strict(); const providerBackedCacheConfigSchema = baseCacheConfigSchema - .extend({ - provider: modelProviderNameSchema, - keys: z.array(modelProviderKeySchema).optional(), - embedding_model: z.string().min(1, "Embedding model is required"), - dimension: z.number().int().min(2, "Dimension must be greater than 1 for provider-backed semantic cache"), - }) - .strict(); - -export const cacheConfigSchema = z.union([directCacheConfigSchema, providerBackedCacheConfigSchema]); + .extend({ + provider: modelProviderNameSchema, + keys: z.array(modelProviderKeySchema).optional(), + embedding_model: z.string().min(1, "Embedding model is required"), + dimension: z + .number() + .int() + .min( + 2, + "Dimension must be greater than 1 for provider-backed semantic cache", + ), + }) + .strict(); + +export const cacheConfigSchema = z.union([ + directCacheConfigSchema, + providerBackedCacheConfigSchema, +]); // Core config schema export const coreConfigSchema = z.object({ - drop_excess_requests: z.boolean().default(false), - initial_pool_size: z.number().min(1).default(10), - prometheus_labels: z.array(z.string()).default([]), - enable_logging: z.boolean().default(true), - disable_content_logging: z.boolean().default(false), - enforce_auth_on_inference: z.boolean().default(false), - allow_direct_keys: z.boolean().default(false), - hide_deleted_virtual_keys_in_filters: z.boolean().default(false), - allowed_origins: z.array(z.string()).default(["*"]), - max_request_body_size_mb: z.number().min(1).default(100), - mcp_agent_depth: z.number().min(1).default(10), - mcp_tool_execution_timeout: z.number().min(1).default(30), - mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), - mcp_disable_auto_tool_inject: z.boolean().default(false), + drop_excess_requests: z.boolean().default(false), + initial_pool_size: z.number().min(1).default(10), + prometheus_labels: z.array(z.string()).default([]), + enable_logging: z.boolean().default(true), + disable_content_logging: z.boolean().default(false), + enforce_auth_on_inference: z.boolean().default(false), + allow_direct_keys: z.boolean().default(false), + hide_deleted_virtual_keys_in_filters: z.boolean().default(false), + allowed_origins: z.array(z.string()).default(["*"]), + max_request_body_size_mb: z.number().min(1).default(100), + mcp_agent_depth: z.number().min(1).default(10), + mcp_tool_execution_timeout: z.number().min(1).default(30), + mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), + mcp_disable_auto_tool_inject: z.boolean().default(false), }); // Bifrost config schema export const bifrostConfigSchema = z.object({ - client_config: coreConfigSchema, - is_db_connected: z.boolean(), - is_cache_connected: z.boolean(), - is_logs_connected: z.boolean(), + client_config: coreConfigSchema, + is_db_connected: z.boolean(), + is_cache_connected: z.boolean(), + is_logs_connected: z.boolean(), }); // Network and proxy form schema - combined for the NetworkFormFragment export const networkAndProxyFormSchema = z.object({ - network_config: networkFormConfigSchema.optional(), - proxy_config: proxyFormConfigSchema.optional(), + network_config: networkFormConfigSchema.optional(), + proxy_config: proxyFormConfigSchema.optional(), }); // Proxy-only form schema for the ProxyFormFragment export const proxyOnlyFormSchema = z.object({ - proxy_config: proxyFormConfigSchema.optional(), + proxy_config: proxyFormConfigSchema.optional(), }); // Network-only form schema for the NetworkFormFragment export const networkOnlyFormSchema = z.object({ - network_config: networkFormConfigSchema.optional(), + network_config: networkFormConfigSchema.optional(), }); // Performance form schema for the PerformanceFormFragment (concurrency/buffer only; raw request/response are in Debugging tab) export const performanceFormSchema = z.object({ - concurrency_and_buffer_size: z - .object({ - concurrency: z - .number({ error: "Concurrency must be a number" }) - .min(1, "Concurrency must be greater than 0") - .max(100000, "Concurrency must be less than 100000"), - buffer_size: z - .number({ error: "Buffer size must be a number" }) - .min(1, "Buffer size must be greater than 0") - .max(100000, "Buffer size must be less than 100000"), - }) - .refine((data) => data.concurrency <= data.buffer_size, { - message: "Concurrency must be less than or equal to buffer size", - path: ["concurrency"], - }), + concurrency_and_buffer_size: z + .object({ + concurrency: z + .number({ error: "Concurrency must be a number" }) + .min(1, "Concurrency must be greater than 0") + .max(100000, "Concurrency must be less than 100000"), + buffer_size: z + .number({ error: "Buffer size must be a number" }) + .min(1, "Buffer size must be greater than 0") + .max(100000, "Buffer size must be less than 100000"), + }) + .refine((data) => data.concurrency <= data.buffer_size, { + message: "Concurrency must be less than or equal to buffer size", + path: ["concurrency"], + }), }); // Debugging tab (raw request/response toggles) export const debuggingFormSchema = z.object({ - send_back_raw_request: z.boolean(), - send_back_raw_response: z.boolean(), - store_raw_request_response: z.boolean(), + send_back_raw_request: z.boolean(), + send_back_raw_response: z.boolean(), + store_raw_request_response: z.boolean(), }); export type DebuggingFormSchema = z.infer; // Beta Headers tab export const betaHeadersFormSchema = z.object({ - beta_header_overrides: z.record(z.string(), z.boolean()).optional(), + beta_header_overrides: z.record(z.string(), z.boolean()).optional(), }); export type BetaHeadersFormSchema = z.infer; // OTEL Configuration Schema export const otelConfigSchema = z - .object({ - service_name: z.string().optional(), - collector_url: z.string().default(""), - trace_type: z - .enum(["genai_extension", "vercel", "open_inference"], { - message: "Please select a trace type", - }) - .default("genai_extension"), - headers: z.record(z.string(), z.string()).optional(), - protocol: z - .enum(["http", "grpc"], { - message: "Please select a protocol", - }) - .default("http"), - // TLS configuration - tls_ca_cert: z.string().optional(), - insecure: z.boolean().default(true), - // Metrics push configuration - metrics_enabled: z.boolean().default(false), - metrics_endpoint: z.string().optional(), - metrics_push_interval: z.number().int().min(1).max(300).default(15), - }) - .superRefine((data, ctx) => { - const protocol = data.protocol; - const hostPortRegex = /^(?!https?:\/\/)([a-zA-Z0-9.-]+|\[[0-9a-fA-F:]+\]|\d{1,3}(?:\.\d{1,3}){3}):(\d{1,5})$/; - - // Helper to validate URL format - const validateHttpUrl = (url: string, path: string[]) => { - try { - const u = new URL(url); - if (!(u.protocol === "http:" || u.protocol === "https:")) { - ctx.addIssue({ - code: "custom", - path, - message: "Must be a valid HTTP or HTTPS URL", - }); - return false; - } - return true; - } catch { - ctx.addIssue({ - code: "custom", - path, - message: "Must be a valid HTTP or HTTPS URL", - }); - return false; - } - }; - - // Helper to validate host:port format - const validateHostPort = (value: string, path: string[], example: string) => { - const match = value.match(hostPortRegex); - if (!match) { - ctx.addIssue({ - code: "custom", - path, - message: `Must be in the format : for gRPC (e.g. ${example})`, - }); - return false; - } - const port = Number(match[2]); - if (!(port >= 1 && port <= 65535)) { - ctx.addIssue({ - code: "custom", - path, - message: "Port must be between 1 and 65535", - }); - return false; - } - return true; - }; - - // Validate collector_url format (emptiness check is at form level, gated by enabled) - const collectorUrl = (data.collector_url || "").trim(); - if (collectorUrl && protocol === "http") { - validateHttpUrl(collectorUrl, ["collector_url"]); - } else if (collectorUrl && protocol === "grpc") { - validateHostPort(collectorUrl, ["collector_url"], "otel-collector:4317"); - } - - // Validate metrics_endpoint when metrics_enabled is true - if (data.metrics_enabled) { - const metricsEndpoint = (data.metrics_endpoint || "").trim(); - if (!metricsEndpoint) { - ctx.addIssue({ - code: "custom", - path: ["metrics_endpoint"], - message: "Metrics endpoint is required when metrics push is enabled", - }); - } else if (protocol === "http") { - validateHttpUrl(metricsEndpoint, ["metrics_endpoint"]); - } else if (protocol === "grpc") { - validateHostPort(metricsEndpoint, ["metrics_endpoint"], "otel-collector:4317"); - } - } - }); + .object({ + service_name: z.string().optional(), + collector_url: z.string().default(""), + trace_type: z + .enum(["genai_extension", "vercel", "open_inference"], { + message: "Please select a trace type", + }) + .default("genai_extension"), + headers: z.record(z.string(), z.string()).optional(), + protocol: z + .enum(["http", "grpc"], { + message: "Please select a protocol", + }) + .default("http"), + // TLS configuration + tls_ca_cert: z.string().optional(), + insecure: z.boolean().default(true), + // Metrics push configuration + metrics_enabled: z.boolean().default(false), + metrics_endpoint: z.string().optional(), + metrics_push_interval: z.number().int().min(1).max(300).default(15), + }) + .superRefine((data, ctx) => { + const protocol = data.protocol; + const hostPortRegex = + /^(?!https?:\/\/)([a-zA-Z0-9.-]+|\[[0-9a-fA-F:]+\]|\d{1,3}(?:\.\d{1,3}){3}):(\d{1,5})$/; + + // Helper to validate URL format + const validateHttpUrl = (url: string, path: string[]) => { + try { + const u = new URL(url); + if (!(u.protocol === "http:" || u.protocol === "https:")) { + ctx.addIssue({ + code: "custom", + path, + message: "Must be a valid HTTP or HTTPS URL", + }); + return false; + } + return true; + } catch { + ctx.addIssue({ + code: "custom", + path, + message: "Must be a valid HTTP or HTTPS URL", + }); + return false; + } + }; + + // Helper to validate host:port format + const validateHostPort = ( + value: string, + path: string[], + example: string, + ) => { + const match = value.match(hostPortRegex); + if (!match) { + ctx.addIssue({ + code: "custom", + path, + message: `Must be in the format : for gRPC (e.g. ${example})`, + }); + return false; + } + const port = Number(match[2]); + if (!(port >= 1 && port <= 65535)) { + ctx.addIssue({ + code: "custom", + path, + message: "Port must be between 1 and 65535", + }); + return false; + } + return true; + }; + + // Validate collector_url format (emptiness check is at form level, gated by enabled) + const collectorUrl = (data.collector_url || "").trim(); + if (collectorUrl && protocol === "http") { + validateHttpUrl(collectorUrl, ["collector_url"]); + } else if (collectorUrl && protocol === "grpc") { + validateHostPort(collectorUrl, ["collector_url"], "otel-collector:4317"); + } + + // Validate metrics_endpoint when metrics_enabled is true + if (data.metrics_enabled) { + const metricsEndpoint = (data.metrics_endpoint || "").trim(); + if (!metricsEndpoint) { + ctx.addIssue({ + code: "custom", + path: ["metrics_endpoint"], + message: "Metrics endpoint is required when metrics push is enabled", + }); + } else if (protocol === "http") { + validateHttpUrl(metricsEndpoint, ["metrics_endpoint"]); + } else if (protocol === "grpc") { + validateHostPort( + metricsEndpoint, + ["metrics_endpoint"], + "otel-collector:4317", + ); + } + } + }); // OTEL form schema for the OtelFormFragment export const otelFormSchema = z - .object({ - enabled: z.boolean().default(true), - otel_config: otelConfigSchema, - }) - .superRefine((data, ctx) => { - if (data.enabled) { - const collectorUrl = (data.otel_config.collector_url || "").trim(); - if (!collectorUrl) { - ctx.addIssue({ - code: "custom", - path: ["otel_config", "collector_url"], - message: "Collector address is required", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + otel_config: otelConfigSchema, + }) + .superRefine((data, ctx) => { + if (data.enabled) { + const collectorUrl = (data.otel_config.collector_url || "").trim(); + if (!collectorUrl) { + ctx.addIssue({ + code: "custom", + path: ["otel_config", "collector_url"], + message: "Collector address is required", + }); + } + } + }); // Maxim Configuration Schema export const maximConfigSchema = z.object({ - api_key: z.string().default(""), - log_repo_id: z.string().optional(), + api_key: z.string().default(""), + log_repo_id: z.string().optional(), }); // Maxim form schema for the MaximFormFragment export const maximFormSchema = z - .object({ - enabled: z.boolean().default(true), - maxim_config: maximConfigSchema, - }) - .superRefine((data, ctx) => { - if (data.enabled) { - const apiKey = (data.maxim_config.api_key || "").trim(); - if (!apiKey) { - ctx.addIssue({ - code: "custom", - path: ["maxim_config", "api_key"], - message: "API key is required", - }); - } else if (!apiKey.startsWith("sk_mx_")) { - ctx.addIssue({ - code: "custom", - path: ["maxim_config", "api_key"], - message: "API key must start with 'sk_mx_'", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + maxim_config: maximConfigSchema, + }) + .superRefine((data, ctx) => { + if (data.enabled) { + const apiKey = (data.maxim_config.api_key || "").trim(); + if (!apiKey) { + ctx.addIssue({ + code: "custom", + path: ["maxim_config", "api_key"], + message: "API key is required", + }); + } else if (!apiKey.startsWith("sk_mx_")) { + ctx.addIssue({ + code: "custom", + path: ["maxim_config", "api_key"], + message: "API key must start with 'sk_mx_'", + }); + } + } + }); // Prometheus Push Gateway Configuration Schema export const prometheusConfigSchema = z - .object({ - push_gateway_url: z.string().optional(), - job_name: z.string().default("bifrost"), - instance_id: z.string().optional(), - push_interval: z.number().min(1).max(300).default(15), - basic_auth_username: z.string().optional(), - basic_auth_password: z.string().optional(), - }) - .superRefine((data, ctx) => { - // Validate push_gateway_url format - const url = (data.push_gateway_url || "").trim(); - if (url) { - try { - const u = new URL(url); - if (!(u.protocol === "http:" || u.protocol === "https:")) { - ctx.addIssue({ - code: "custom", - path: ["push_gateway_url"], - message: "Must be a valid HTTP or HTTPS URL", - }); - } - } catch { - ctx.addIssue({ - code: "custom", - path: ["push_gateway_url"], - message: "Must be a valid URL (e.g., http://pushgateway:9091)", - }); - } - } - - // Validate basic auth: if one credential is provided, both must be provided - const hasUsername = !!data.basic_auth_username?.trim(); - const hasPassword = !!data.basic_auth_password?.trim(); - if (hasUsername && !hasPassword) { - ctx.addIssue({ - code: "custom", - path: ["basic_auth_password"], - message: "Password is required when username is provided", - }); - } - if (hasPassword && !hasUsername) { - ctx.addIssue({ - code: "custom", - path: ["basic_auth_username"], - message: "Username is required when password is provided", - }); - } - }); + .object({ + push_gateway_url: z.string().optional(), + job_name: z.string().default("bifrost"), + instance_id: z.string().optional(), + push_interval: z.number().min(1).max(300).default(15), + basic_auth_username: z.string().optional(), + basic_auth_password: z.string().optional(), + }) + .superRefine((data, ctx) => { + // Validate push_gateway_url format + const url = (data.push_gateway_url || "").trim(); + if (url) { + try { + const u = new URL(url); + if (!(u.protocol === "http:" || u.protocol === "https:")) { + ctx.addIssue({ + code: "custom", + path: ["push_gateway_url"], + message: "Must be a valid HTTP or HTTPS URL", + }); + } + } catch { + ctx.addIssue({ + code: "custom", + path: ["push_gateway_url"], + message: "Must be a valid URL (e.g., http://pushgateway:9091)", + }); + } + } + + // Validate basic auth: if one credential is provided, both must be provided + const hasUsername = !!data.basic_auth_username?.trim(); + const hasPassword = !!data.basic_auth_password?.trim(); + if (hasUsername && !hasPassword) { + ctx.addIssue({ + code: "custom", + path: ["basic_auth_password"], + message: "Password is required when username is provided", + }); + } + if (hasPassword && !hasUsername) { + ctx.addIssue({ + code: "custom", + path: ["basic_auth_username"], + message: "Username is required when password is provided", + }); + } + }); // Prometheus form schema for the PrometheusFormFragment export const prometheusFormSchema = z - .object({ - enabled: z.boolean().default(true), - prometheus_config: prometheusConfigSchema, - }) - .superRefine((data, ctx) => { - // When enabled, push_gateway_url is required - if (data.enabled) { - const url = (data.prometheus_config.push_gateway_url || "").trim(); - if (!url) { - ctx.addIssue({ - code: "custom", - path: ["prometheus_config", "push_gateway_url"], - message: "Push Gateway URL is required when enabled", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + prometheus_config: prometheusConfigSchema, + }) + .superRefine((data, ctx) => { + // When enabled, push_gateway_url is required + if (data.enabled) { + const url = (data.prometheus_config.push_gateway_url || "").trim(); + if (!url) { + ctx.addIssue({ + code: "custom", + path: ["prometheus_config", "push_gateway_url"], + message: "Push Gateway URL is required when enabled", + }); + } + } + }); // MCP Client update schema export const mcpClientUpdateSchema = z.object({ - is_code_mode_client: z.boolean().optional(), - is_ping_available: z.boolean().optional(), - allow_on_all_virtual_keys: z.boolean().optional(), - name: z - .string() - .min(1, "Name is required") - .refine((val) => !val.includes("-"), { - message: "Client name cannot contain hyphens", - }) - .refine((val) => !val.includes(" "), { - message: "Client name cannot contain spaces", - }) - .refine((val) => !/^[0-9]/.test(val), { - message: "Client name cannot start with a number", - }), - headers: z.record(z.string(), envVarSchema).optional().nullable(), - tools_to_execute: z - .array(z.string()) - .optional() - .refine( - (tools) => { - if (!tools || tools.length === 0) return true; - const hasWildcard = tools.includes("*"); - return !hasWildcard || tools.length === 1; - }, - { message: "Wildcard '*' cannot be combined with other tool names" }, - ) - .refine( - (tools) => { - if (!tools) return true; - return tools.length === new Set(tools).size; - }, - { message: "Duplicate tool names are not allowed" }, - ), - tools_to_auto_execute: z - .array(z.string()) - .optional() - .refine( - (tools) => { - if (!tools || tools.length === 0) return true; - const hasWildcard = tools.includes("*"); - return !hasWildcard || tools.length === 1; - }, - { message: "Wildcard '*' cannot be combined with other tool names" }, - ) - .refine( - (tools) => { - if (!tools) return true; - return tools.length === new Set(tools).size; - }, - { message: "Duplicate tool names are not allowed" }, - ), - tool_pricing: z.record(z.string(), z.number().min(0, "Cost must be non-negative")).optional(), - tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes - allowed_extra_headers: z - .array(z.string()) - .optional() - .refine( - (headers) => { - if (!headers || headers.length === 0) return true; - const hasWildcard = headers.includes("*"); - return !hasWildcard || headers.length === 1; - }, - { message: "Wildcard '*' cannot be combined with specific header names" }, - ), + is_code_mode_client: z.boolean().optional(), + is_ping_available: z.boolean().optional(), + allow_on_all_virtual_keys: z.boolean().optional(), + name: z + .string() + .min(1, "Name is required") + .refine((val) => !val.includes("-"), { + message: "Client name cannot contain hyphens", + }) + .refine((val) => !val.includes(" "), { + message: "Client name cannot contain spaces", + }) + .refine((val) => !/^[0-9]/.test(val), { + message: "Client name cannot start with a number", + }), + headers: z.record(z.string(), envVarSchema).optional().nullable(), + tools_to_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), + tools_to_auto_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), + tool_pricing: z + .record(z.string(), z.number().min(0, "Cost must be non-negative")) + .optional(), + tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes + allowed_extra_headers: z + .array(z.string()) + .optional() + .refine( + (headers) => { + if (!headers || headers.length === 0) return true; + const hasWildcard = headers.includes("*"); + return !hasWildcard || headers.length === 1; + }, + { message: "Wildcard '*' cannot be combined with specific header names" }, + ), }); // Global proxy type schema @@ -1025,88 +1120,102 @@ export const globalProxyTypeSchema = z.enum(["http", "socks5", "tcp"]); // Global proxy configuration schema export const globalProxyConfigSchema = z - .object({ - enabled: z.boolean(), - type: globalProxyTypeSchema, - url: z.string(), - username: z.string().optional(), - password: z.string().optional(), - ca_cert_pem: z.string().optional(), - no_proxy: z.string().optional(), - timeout: z.number().min(0).optional(), - skip_tls_verify: z.boolean().optional(), - enable_for_scim: z.boolean(), - enable_for_inference: z.boolean(), - enable_for_api: z.boolean(), - }) - .refine( - (data) => { - // URL is required when proxy is enabled - if (data.enabled && (!data.url || data.url.trim().length === 0)) { - return false; - } - return true; - }, - { - message: "Proxy URL is required when proxy is enabled", - path: ["url"], - }, - ) - .refine( - (data) => { - // Validate URL format when provided and enabled - if (data.enabled && data.url && data.url.trim().length > 0) { - try { - new URL(data.url); - return true; - } catch { - return false; - } - } - return true; - }, - { - message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", - path: ["url"], - }, - ); + .object({ + enabled: z.boolean(), + type: globalProxyTypeSchema, + url: z.string(), + username: z.string().optional(), + password: z.string().optional(), + ca_cert_pem: z.string().optional(), + no_proxy: z.string().optional(), + timeout: z.number().min(0).optional(), + skip_tls_verify: z.boolean().optional(), + enable_for_scim: z.boolean(), + enable_for_inference: z.boolean(), + enable_for_api: z.boolean(), + }) + .refine( + (data) => { + // URL is required when proxy is enabled + if (data.enabled && (!data.url || data.url.trim().length === 0)) { + return false; + } + return true; + }, + { + message: "Proxy URL is required when proxy is enabled", + path: ["url"], + }, + ) + .refine( + (data) => { + // Validate URL format when provided and enabled + if (data.enabled && data.url && data.url.trim().length > 0) { + try { + new URL(data.url); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // Global proxy form schema for the ProxyView export const globalProxyFormSchema = z.object({ - proxy_config: globalProxyConfigSchema, + proxy_config: globalProxyConfigSchema, }); // Global header filter configuration schema // Controls which headers with the x-bf-eh-* prefix are forwarded to LLM providers export const globalHeaderFilterConfigSchema = z.object({ - allowlist: z.array(z.string()).optional(), // If non-empty, only these headers are allowed - denylist: z.array(z.string()).optional(), // Headers to always block + allowlist: z.array(z.string()).optional(), // If non-empty, only these headers are allowed + denylist: z.array(z.string()).optional(), // Headers to always block }); // Global header filter form schema for the HeaderFilterView export const globalHeaderFilterFormSchema = z.object({ - header_filter_config: globalHeaderFilterConfigSchema, + header_filter_config: globalHeaderFilterConfigSchema, }); // Routing rule creation schema export const routingRuleSchema = z - .object({ - name: z.string().min(1, "Rule name is required").max(255, "Rule name must be less than 255 characters"), - description: z.string().max(1000, "Description must be less than 1000 characters").optional(), - cel_expression: z.string().optional(), - provider: z.string().min(1, "Provider is required"), - model: z.string().optional(), - fallbacks: z.array(z.string()).optional().default([]), - scope: z.enum(["global", "team", "customer", "virtual_key"]), - scope_id: z.string().optional(), - priority: z.number().min(0, "Priority must be 0 or greater").max(1000, "Priority must be 1000 or less"), - enabled: z.boolean().default(true), - chain_rule: z.boolean().default(false), - }) - .refine((data) => data.scope === "global" || (data.scope_id != null && data.scope_id.trim() !== ""), { - message: "Scope ID is required when scope is not global", - path: ["scope_id"], - }); + .object({ + name: z + .string() + .min(1, "Rule name is required") + .max(255, "Rule name must be less than 255 characters"), + description: z + .string() + .max(1000, "Description must be less than 1000 characters") + .optional(), + cel_expression: z.string().optional(), + provider: z.string().min(1, "Provider is required"), + model: z.string().optional(), + fallbacks: z.array(z.string()).optional().default([]), + scope: z.enum(["global", "team", "customer", "virtual_key"]), + scope_id: z.string().optional(), + priority: z + .number() + .min(0, "Priority must be 0 or greater") + .max(1000, "Priority must be 1000 or less"), + enabled: z.boolean().default(true), + chain_rule: z.boolean().default(false), + }) + .refine( + (data) => + data.scope === "global" || + (data.scope_id != null && data.scope_id.trim() !== ""), + { + message: "Scope ID is required when scope is not global", + path: ["scope_id"], + }, + ); // Export type inference helpers export type EnvVar = z.infer; @@ -1115,7 +1224,9 @@ export type ModelProviderKeySchema = z.infer; export type NetworkConfigSchema = z.infer; export type NetworkFormConfigSchema = z.infer; export type ProxyFormConfigSchema = z.infer; -export type NetworkAndProxyFormSchema = z.infer; +export type NetworkAndProxyFormSchema = z.infer< + typeof networkAndProxyFormSchema +>; export type ProxyOnlyFormSchema = z.infer; export type OtelConfigSchema = z.infer; export type OtelFormSchema = z.infer; @@ -1125,9 +1236,15 @@ export type PrometheusConfigSchema = z.infer; export type PrometheusFormSchema = z.infer; export type NetworkOnlyFormSchema = z.infer; export type PerformanceFormSchema = z.infer; -export type CustomProviderConfigSchema = z.infer; +export type CustomProviderConfigSchema = z.infer< + typeof customProviderConfigSchema +>; export type GlobalProxyConfigSchema = z.infer; export type GlobalProxyFormSchema = z.infer; -export type GlobalHeaderFilterConfigSchema = z.infer; -export type GlobalHeaderFilterFormSchema = z.infer; -export type RoutingRuleSchema = z.infer; \ No newline at end of file +export type GlobalHeaderFilterConfigSchema = z.infer< + typeof globalHeaderFilterConfigSchema +>; +export type GlobalHeaderFilterFormSchema = z.infer< + typeof globalHeaderFilterFormSchema +>; +export type RoutingRuleSchema = z.infer;