diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index a0122adfa2..42758ab9dd 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index 85f901cfd5..e33ee40d9a 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -542,8 +542,8 @@ VALUES ('migration-test-lock', 'holder-migration-test-001', $future, $now) ON CONFLICT DO NOTHING; -- config_client (global client configuration) -INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, enable_litellm_fallbacks, config_hash, created_at, updated_at) -VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, 'client-config-hash-001', $now, $now) +INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, compat_convert_text_to_chat, compat_convert_chat_to_responses, compat_should_drop_params, compat_should_convert_params, config_hash, created_at, updated_at) +VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, false, false, 'client-config-hash-001', $now, $now) ON CONFLICT DO NOTHING; -- governance_config (key-value config table) @@ -3509,4 +3509,4 @@ main() { exit $exit_code } -main "$@" +main "$@" \ No newline at end of file diff --git a/.github/workflows/scripts/test-docker-image.sh b/.github/workflows/scripts/test-docker-image.sh index 5d770fbd64..ac115394bf 100755 --- a/.github/workflows/scripts/test-docker-image.sh +++ b/.github/workflows/scripts/test-docker-image.sh @@ -212,8 +212,7 @@ cat > "$CONFIG_FILE" << 'CONFIGEOF' "enable_logging": true, "enforce_governance_header": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "encryption_key": "" } diff --git a/.github/workflows/scripts/validate-helm-config-fields.sh b/.github/workflows/scripts/validate-helm-config-fields.sh index 3b08dfffe9..4b4c63a259 100755 --- a/.github/workflows/scripts/validate-helm-config-fields.sh +++ b/.github/workflows/scripts/validate-helm-config-fields.sh @@ -164,7 +164,11 @@ bifrost: enforceGovernanceHeader: true allowDirectKeys: true maxRequestBodySizeMb: 50 - enableLitellmFallbacks: true + compat: + convertTextToChat: true + convertChatToResponses: true + shouldDropParams: true + shouldConvertParams: true prometheusLabels: - "team" - "env" @@ -200,7 +204,9 @@ assert_field_value 'client.log_retention_days' '.client.log_retention_days' '30' assert_field_value 'client.enforce_governance_header' '.client.enforce_governance_header' 'true' assert_field_value 'client.allow_direct_keys' '.client.allow_direct_keys' 'true' assert_field_value 'client.max_request_body_size_mb' '.client.max_request_body_size_mb' '50' -assert_field_value 'client.enable_litellm_fallbacks' '.client.enable_litellm_fallbacks' 'true' +assert_field_value 'client.compat.convert_text_to_chat' '.client.compat.convert_text_to_chat' 'true' +assert_field_value 'client.compat.convert_chat_to_responses' '.client.compat.convert_chat_to_responses' 'true' +assert_field_value 'client.compat.should_drop_params' '.client.compat.should_drop_params' 'true' assert_field 'client.prometheus_labels' '.client.prometheus_labels' assert_field 'client.header_filter_config.allowlist' '.client.header_filter_config.allowlist' assert_field 'client.header_filter_config.denylist' '.client.header_filter_config.denylist' @@ -1194,4 +1200,4 @@ if [ "$TESTS_FAILED" -gt 0 ]; then else echo -e "${GREEN}✅ All config.json field validations passed!${NC}" exit 0 -fi +fi \ No newline at end of file diff --git a/docs/features/litellm-compat.mdx b/docs/features/litellm-compat.mdx index 51cd26dcd9..490a37efa4 100644 --- a/docs/features/litellm-compat.mdx +++ b/docs/features/litellm-compat.mdx @@ -9,8 +9,10 @@ icon: "train" The LiteLLM compatibility plugin provides two transformations: 1. **Text-to-Chat Conversion** - Automatically converts text completion requests to chat completion format for models that only support chat APIs +2. **Chat-to-Responses Conversion** - Automatically converts chat completion requests to responses format for models that only support responses APIs +3. **Drop Unsupported Params** - Automatically drops unsupported parameters if the model doesn't support them -When either transformation is applied, responses include `extra_fields.litellm_compat: true`. +When either transformation is applied, responses include `extra_fields.converted_request_type: `. If request parameters are dropped, the keys are added in `extra_fields.dropped_compat_plugin_params`. --- @@ -55,6 +57,36 @@ F --> G - `object: "chat.completion"` → `object: "text_completion"` - Usage statistics and metadata are preserved +## 2. Chat-to-Responses Conversion + +Some AI models (like OpenAI o1-pro) only support the responses API and don't support native chat completion endpoints. LiteLLM compatibility mode automatically handles this by: + +1. Checking if the model supports chat completion natively (using the model catalog) +2. If not supported, converting your chat message to responses API format +3. Calling the responses endpoint internally +4. Transforming the response back to chat completion format + + +**Smart Conversion**: The conversion only happens when the model doesn't support chat completions natively. If a model has native chat completion support (like OpenAI's gpt-4 models), Bifrost uses the chat completion endpoint directly without any conversion. + + +This allows you to use a unified chat completion interface across all providers, even those that only support responses API. + +## How It Works + +When LiteLLM compatibility is enabled and you make a chat completion request, Bifrost first checks if the model supports chat completion: + +```mermaid +flowchart LR +A[Chat Completion Request] --> B{Model Supports Chat Completion?} +B -->|Yes| C[Call Chat Completion API] +B -->|No| D[Convert to Responses Message] +D --> E[Call Responses API] +E --> F[Transform Response] +C --> G[Chat Completion Response] +F --> G +``` + ## Enabling LiteLLM Compatibility @@ -63,7 +95,10 @@ F --> G 1. Open the Bifrost dashboard 2. Navigate to **Settings** → **Client Configuration** -3. Enable **LiteLLM Fallbacks** +3. Expand **LiteLLM Compat** and enable the features you need: + - **Convert Text to Chat** — converts text completion requests to chat for models that only support chat + - **Convert Chat to Responses** — converts chat completion requests to responses for models that only support responses + - **Drop Unsupported Params** — drops unsupported parameters based on model catalog allowlist 4. Save your configuration @@ -73,7 +108,11 @@ F --> G ```json { "client_config": { - "enable_litellm_fallbacks": true + "compat": { + "convert_text_to_chat": true, + "convert_chat_to_responses": true, + "should_drop_params": true + } } } ``` @@ -84,9 +123,9 @@ F --> G ## Supported Providers -LiteLLM compatibility mode works with any provider that supports chat completions but lacks native text completion support: +Text completion to chat completion conversion works with any provider that supports chat completions but lacks native text completion support: -| Provider | Native Text Completion | LiteLLM Fallback | +| Provider | Native Text Completion | With Fallback | |----------|----------------------|------------------| | OpenAI (GPT-4, GPT-3.5-turbo) | No | Yes | | Anthropic (Claude) | No | Yes | @@ -95,6 +134,12 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Mistral | No | Yes | | Bedrock | Varies by model | Yes | +Chat completion to responses conversion works with any provider that supports responses but lacks native chat completion support: + +| Provider | Native Chat Completion | With Fallback | +|----------|----------------------|------------------| +| OpenAI (o1-pro) | No | Yes | + ## Behavior Details **Model Capability Detection:** @@ -117,13 +162,19 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Response | `choices[0].message.content` | `choices[0].text` | | Response | `object: "chat.completion"` | `object: "text_completion"` | +### Transformation 2: Chat-to-Responses Conversion + +**Applies to:** Chat completion requests on responses-only models + +| Phase | Original | Transformed | +|-------|----------|-------------| +| Request | Chat message with `role: "user"` | Responses input with `role: "user"` | +| Request | `chat_completion` request type | `responses` request type | ### Metadata Set on Transformed Responses When either transformation is applied: -- `extra_fields.litellm_compat`: Set to `true` -- `extra_fields.provider`: The provider that handled the request - `extra_fields.request_type`: Reflects the original request type - `extra_fields.original_model_requested`: The originally requested model - `extra_fields.resolved_model_used`: The actual provider API identifier used (equals original_model_requested when no alias mapping exists) @@ -131,8 +182,11 @@ When either transformation is applied: ### Error Handling When errors occur on transformed requests: -- `extra_fields.litellm_compat` is set to `true` - Original request type and model are preserved in error metadata +- `extra_fields.converted_request_type`: Set to type of request that was converted to (i.e., `chat_completion` or `responses`) +- `extra_fields.provider`: The provider that handled the request +- `extra_fields.original_model_requested`: The originally requested model +- `extra_fields.dropped_compat_plugin_params`: If any unsupported parameters were dropped, the keys are added here ## What's Preserved @@ -145,7 +199,7 @@ When errors occur on transformed requests: **Good Use Cases:** - Migrating from LiteLLM to Bifrost without code changes -- Maintaining backward compatibility with text completion interfaces +- Maintaining backward compatibility with text completion interfaces or chat completion interfaces - Using a unified API across providers with different capabilities **Consider Alternatives When:** @@ -157,4 +211,4 @@ When errors occur on transformed requests: - [Fallbacks](/features/fallbacks) - Automatic provider failover - [Drop-in Replacement](/features/drop-in-replacement) - Use existing SDKs with Bifrost -- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost +- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost \ No newline at end of file diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 1043039a1f..b278472d01 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -133221,9 +133221,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": true } + } }, "log_retention_days": { "type": "integer", @@ -133537,9 +133543,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": true } + } }, "log_retention_days": { "type": "integer", @@ -205784,9 +205796,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": true } + } }, "log_retention_days": { "type": "integer", @@ -205999,9 +206017,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": true } + } }, "log_retention_days": { "type": "integer", @@ -224498,4 +224522,4 @@ } } } -} \ No newline at end of file +} diff --git a/docs/openapi/schemas/management/config.yaml b/docs/openapi/schemas/management/config.yaml index 2c54b3979d..b4e6c2008e 100644 --- a/docs/openapi/schemas/management/config.yaml +++ b/docs/openapi/schemas/management/config.yaml @@ -44,9 +44,19 @@ ClientConfig: max_request_body_size_mb: type: integer description: Maximum request body size in MB - enable_litellm_fallbacks: - type: boolean - description: Whether LiteLLM fallbacks are enabled + compat: + type: object + description: Compat plugin configuration + properties: + convert_text_to_chat: + type: boolean + description: Convert text completion requests to chat + convert_chat_to_responses: + type: boolean + description: Convert chat completion requests to responses + should_drop_params: + type: boolean + description: Drop unsupported parameters based on model catalog log_retention_days: type: integer description: Number of days to retain logs diff --git a/docs/providers/supported-providers/overview.mdx b/docs/providers/supported-providers/overview.mdx index b3ae42f62f..98d13ffa73 100644 --- a/docs/providers/supported-providers/overview.mdx +++ b/docs/providers/supported-providers/overview.mdx @@ -48,7 +48,7 @@ The following table summarizes which operations are supported by each provider v Some operations are not supported by the downstream provider, and their internal implementation in Bifrost is optional. 🟡 -Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting the `enable_litellm_fallbacks` flag to `true` in the client configuration. +Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting `compat.convert_text_to_chat` to `true` in the client configuration. We do not promote using such fallbacks, since text completions and chat completions are fundamentally different. However, this option is available to help users migrating from LiteLLM (which does support these fallbacks). diff --git a/examples/configs/withpostgresmcpclientsinconfig/config.json b/examples/configs/withpostgresmcpclientsinconfig/config.json index 8e03969988..068bc88012 100644 --- a/examples/configs/withpostgresmcpclientsinconfig/config.json +++ b/examples/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/configs/withprompushgateway/config.json b/examples/configs/withprompushgateway/config.json index f697041388..110557d797 100644 --- a/examples/configs/withprompushgateway/config.json +++ b/examples/configs/withprompushgateway/config.json @@ -183,8 +183,7 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "config_store": { "enabled": true, diff --git a/examples/configs/withvirtualkeys/config.json b/examples/configs/withvirtualkeys/config.json index a968bad65c..9d9ae2c87a 100644 --- a/examples/configs/withvirtualkeys/config.json +++ b/examples/configs/withvirtualkeys/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/dockers/data/config.json b/examples/dockers/data/config.json index 46cbfd8e68..598fa59785 100644 --- a/examples/dockers/data/config.json +++ b/examples/dockers/data/config.json @@ -27,7 +27,9 @@ "*" ], "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "compat": { + "should_convert_params": true + } }, "framework": { "pricing": { @@ -35,4 +37,4 @@ "pricing_sync_interval": 86400 } } -} \ No newline at end of file +} diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 698437c328..90a27db25e 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -34,6 +34,14 @@ type EnvKeyInfo struct { KeyID string // The key ID this env var belongs to (empty for non-key configs like bedrock_config, connection_string) } +// CompatConfig holds the compat plugin feature flags. +type CompatConfig struct { + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + // ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. // It includes settings for excess request handling, Prometheus metrics, and initial pool size. type ClientConfig struct { @@ -51,7 +59,7 @@ type ClientConfig struct { AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) AllowedHeaders []string `json:"allowed_headers,omitempty"` // Additional allowed headers for CORS and WebSocket MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB - EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + Compat CompatConfig `json:"compat"` // Compat plugin configuration MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" @@ -110,10 +118,17 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("allowDirectKeys:false")) } - if c.EnableLiteLLMFallbacks { - hash.Write([]byte("enableLiteLLMFallbacks:true")) - } else { - hash.Write([]byte("enableLiteLLMFallbacks:false")) + if c.Compat.ConvertTextToChat { + hash.Write([]byte("compatConvertTextToChat:true")) + } + if c.Compat.ConvertChatToResponses { + hash.Write([]byte("compatConvertChatToResponses:true")) + } + if c.Compat.ShouldDropParams { + hash.Write([]byte("compatShouldDropParams:true")) + } + if c.Compat.ShouldConvertParams { + hash.Write([]byte("compatShouldConvertParams:true")) } // Only hash non-default value to avoid legacy config hash churn. diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index e64351eeaa..d25303b136 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -376,6 +376,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddWhitelistedRoutesJSONColumn(ctx, db); err != nil { return err } + if err := migrationReplaceEnableLiteLLMWithCompatColumns(ctx, db); err != nil { + return err + } return nil } @@ -785,9 +788,10 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) ID: "add_enable_litellm_fallbacks_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { - if err := migrator.AddColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + // Use raw SQL since the struct field was removed in a later migration. + // This column is subsequently dropped by migrationReplaceEnableLiteLLMWithCompatColumns. + if !tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT FALSE").Error; err != nil { return err } } @@ -795,9 +799,7 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - - if err := migrator.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + if err := tx.Exec("ALTER TABLE config_client DROP COLUMN IF EXISTS enable_litellm_fallbacks").Error; err != nil { return err } return nil @@ -2162,7 +2164,6 @@ func migrationAddAdditionalConfigHashColumns(ctx context.Context, db *gorm.DB) e AllowDirectKeys: cc.AllowDirectKeys, AllowedOrigins: cc.AllowedOrigins, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, } hash, err := clientConfig.GenerateClientConfigHash() if err != nil { @@ -5611,7 +5612,6 @@ func migrationAddRoutingChainMaxDepthColumn(ctx context.Context, db *gorm.DB) er AllowedOrigins: cc.AllowedOrigins, AllowedHeaders: cc.AllowedHeaders, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, HideDeletedVirtualKeysInFilters: cc.HideDeletedVirtualKeysInFilters, MCPAgentDepth: cc.MCPAgentDepth, MCPToolExecutionTimeout: cc.MCPToolExecutionTimeout, @@ -5907,7 +5907,6 @@ func migrationAddMultiBudgetTables(ctx context.Context, db *gorm.DB) error { if mg.HasColumn(&tables.TableBudget{}, "provider_config_id") { if err := mg.DropColumn(&tables.TableBudget{}, "provider_config_id"); err != nil { return err - } } return nil @@ -6063,21 +6062,104 @@ func migrationAddWhitelistedRoutesJSONColumn(ctx context.Context, db *gorm.DB) e return fmt.Errorf("failed to add whitelisted_routes_json column: %w", err) } } + return nil }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableClientConfig{}, "whitelisted_routes_json") { if err := migrator.DropColumn(&tables.TableClientConfig{}, "whitelisted_routes_json"); err != nil { return fmt.Errorf("failed to drop whitelisted_routes_json column: %w", err) } } + + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running whitelisted_routes_json migration: %s", err.Error()) + } + return nil +} + +// migrationReplaceEnableLiteLLMWithCompatColumns replaces the single enable_litellm_fallbacks +// boolean with compat feature columns. If enable_litellm_fallbacks was true, +// only convert_text_to_chat is set to true (preserving the original behavior). +func migrationReplaceEnableLiteLLMWithCompatColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "replace_enable_litellm_with_compat_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + // Add new columns + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_drop_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_drop_params"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_convert_params"); err != nil { + return err + } + } + + if err := tx.Exec("UPDATE config_client SET compat_should_convert_params = TRUE").Error; err != nil { + return err + } + + // Migrate data: if enable_litellm_fallbacks was true, set convert_text_to_chat = true + if mig.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("UPDATE config_client SET compat_convert_text_to_chat = enable_litellm_fallbacks").Error; err != nil { + return err + } + if err := mig.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + return err + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client DROP COLUMN enable_litellm_fallbacks").Error; err != nil { + return err + } + if err := tx.Exec("UPDATE config_client SET enable_litellm_fallbacks = compat_convert_text_to_chat").Error; err != nil { + return err + } + } + for _, col := range []string{ + "compat_convert_text_to_chat", + "compat_convert_chat_to_responses", + "compat_should_drop_params", + "compat_should_convert_params", + } { + if mig.HasColumn(&tables.TableClientConfig{}, col) { + if err := mig.DropColumn(&tables.TableClientConfig{}, col); err != nil { + return err + } + } + } return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running add_whitelisted_routes_json_column migration: %s", err.Error()) + return fmt.Errorf("error while running replace_enable_litellm_with_compat_columns migration: %s", err.Error()) } return nil } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 4002278501..c718f6e6c4 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -137,7 +137,10 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, AllowedHeaders: config.AllowedHeaders, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + CompatConvertTextToChat: config.Compat.ConvertTextToChat, + CompatConvertChatToResponses: config.Compat.ConvertChatToResponses, + CompatShouldDropParams: config.Compat.ShouldDropParams, + CompatShouldConvertParams: config.Compat.ShouldConvertParams, MCPAgentDepth: config.MCPAgentDepth, MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, @@ -289,21 +292,26 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er return nil, err } return &ClientConfig{ - DropExcessRequests: dbConfig.DropExcessRequests, - InitialPoolSize: dbConfig.InitialPoolSize, - PrometheusLabels: dbConfig.PrometheusLabels, - EnableLogging: dbConfig.EnableLogging, - DisableContentLogging: dbConfig.DisableContentLogging, - DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, - LogRetentionDays: dbConfig.LogRetentionDays, - EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, - EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, - EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, - AllowDirectKeys: dbConfig.AllowDirectKeys, - AllowedOrigins: dbConfig.AllowedOrigins, - AllowedHeaders: dbConfig.AllowedHeaders, - MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + DropExcessRequests: dbConfig.DropExcessRequests, + InitialPoolSize: dbConfig.InitialPoolSize, + PrometheusLabels: dbConfig.PrometheusLabels, + EnableLogging: dbConfig.EnableLogging, + DisableContentLogging: dbConfig.DisableContentLogging, + DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, + LogRetentionDays: dbConfig.LogRetentionDays, + EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, + EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, + EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, + AllowDirectKeys: dbConfig.AllowDirectKeys, + AllowedOrigins: dbConfig.AllowedOrigins, + AllowedHeaders: dbConfig.AllowedHeaders, + MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, + Compat: CompatConfig{ + ConvertTextToChat: dbConfig.CompatConvertTextToChat, + ConvertChatToResponses: dbConfig.CompatConvertChatToResponses, + ShouldDropParams: dbConfig.CompatShouldDropParams, + ShouldConvertParams: dbConfig.CompatShouldConvertParams, + }, MCPAgentDepth: dbConfig.MCPAgentDepth, MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, @@ -4465,4 +4473,4 @@ func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.C } s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) return nil -} +} \ No newline at end of file diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index a9ff7fc7f6..ccb845c809 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -37,8 +37,11 @@ type TableClientConfig struct { RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - // LiteLLM fallback flag - EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` + // Compat plugin feature flags + CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"` + CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"` + CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"` + CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:true" json:"-"` // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 2122f339d5..f014e2965f 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "slices" - "strings" "sync" "time" @@ -323,643 +322,6 @@ func (mc *ModelCatalog) getPricingURL() string { return mc.pricingURL } -// getPricingSyncInterval returns a copy of the pricing sync interval under mutex protection -func (mc *ModelCatalog) getPricingSyncInterval() time.Duration { - mc.pricingMu.RLock() - defer mc.pricingMu.RUnlock() - return mc.pricingSyncInterval -} - -// GetPricingEntryForModel returns the pricing data -func (mc *ModelCatalog) GetPricingEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { - mc.mu.RLock() - defer mc.mu.RUnlock() - // Check all modes - for _, mode := range []schemas.RequestType{ - schemas.TextCompletionRequest, - schemas.ChatCompletionRequest, - schemas.ResponsesRequest, - schemas.EmbeddingRequest, - schemas.RerankRequest, - schemas.SpeechRequest, - schemas.TranscriptionRequest, - schemas.ImageGenerationRequest, - schemas.ImageEditRequest, - schemas.ImageVariationRequest, - schemas.VideoGenerationRequest, - } { - key := makeKey(model, string(provider), normalizeRequestType(mode)) - pricing, ok := mc.pricingData[key] - if ok { - return convertTableModelPricingToPricingData(&pricing) - } - } - return nil -} - -// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair. -// It prefers chat, then responses, then text-completion entries; if none exist, -// it falls back to the lexicographically first available mode for deterministic behavior. -func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { - mc.mu.RLock() - defer mc.mu.RUnlock() - - if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil { - return entry - } - - baseModel := mc.getBaseModelNameUnsafe(model) - if baseModel != model { - if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil { - return entry - } - } - - if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil { - return entry - } - - return nil -} - -func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry { - preferredModes := []schemas.RequestType{ - schemas.ChatCompletionRequest, - schemas.ResponsesRequest, - schemas.TextCompletionRequest, - } - - for _, mode := range preferredModes { - key := makeKey(model, string(provider), normalizeRequestType(mode)) - pricing, ok := mc.pricingData[key] - if ok { - return convertTableModelPricingToPricingData(&pricing) - } - } - - prefix := model + "|" + string(provider) + "|" - matchingKeys := make([]string, 0) - for key := range mc.pricingData { - if strings.HasPrefix(key, prefix) { - matchingKeys = append(matchingKeys, key) - } - } - return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) -} - -func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry { - if baseModel == "" { - return nil - } - - matchingKeys := make([]string, 0) - for key, pricing := range mc.pricingData { - if normalizeProvider(pricing.Provider) != string(provider) { - continue - } - if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel { - continue - } - matchingKeys = append(matchingKeys, key) - } - return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) -} - -func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry { - if len(matchingKeys) == 0 { - return nil - } - - preferredModes := []string{ - normalizeRequestType(schemas.ChatCompletionRequest), - normalizeRequestType(schemas.ResponsesRequest), - normalizeRequestType(schemas.TextCompletionRequest), - } - - for _, mode := range preferredModes { - modeMatches := make([]string, 0) - for _, key := range matchingKeys { - parts := strings.SplitN(key, "|", 3) - if len(parts) != 3 || parts[2] != mode { - continue - } - modeMatches = append(modeMatches, key) - } - if len(modeMatches) == 0 { - continue - } - slices.Sort(modeMatches) - pricing := mc.pricingData[modeMatches[0]] - return convertTableModelPricingToPricingData(&pricing) - } - - slices.Sort(matchingKeys) - pricing := mc.pricingData[matchingKeys[0]] - return convertTableModelPricingToPricingData(&pricing) -} - -// GetModelsForProvider returns all available models for a given provider (thread-safe) -func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - models, exists := mc.modelPool[provider] - if !exists { - return []string{} - } - - // Return a copy to prevent external modification - result := make([]string, len(models)) - copy(result, models) - return result -} - -// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe) -func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - models, exists := mc.unfilteredModelPool[provider] - if !exists { - return []string{} - } - - // Return a copy to prevent external modification - result := make([]string, len(models)) - copy(result, models) - return result -} - -// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe). -// This is used for governance model selection when no specific provider is chosen. -func (mc *ModelCatalog) GetDistinctBaseModelNames() []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - seen := make(map[string]bool) - for _, baseName := range mc.baseModelIndex { - seen[baseName] = true - } - - result := make([]string, 0, len(seen)) - for name := range seen { - result = append(result, name) - } - return result -} - -// GetProvidersForModel returns all providers for a given model (thread-safe) -func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider { - mc.mu.RLock() - defer mc.mu.RUnlock() - - providers := make([]schemas.ModelProvider, 0) - for provider, models := range mc.modelPool { - isModelMatch := false - for _, m := range models { - if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) { - isModelMatch = true - break - } - } - if isModelMatch { - providers = append(providers, provider) - } - } - - // Handler special provider cases - // 1. Handler openrouter models - if !slices.Contains(providers, schemas.OpenRouter) { - for _, provider := range providers { - if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok { - if slices.Contains(openRouterModels, string(provider)+"/"+model) { - providers = append(providers, schemas.OpenRouter) - } - } - } - } - - // 2. Handle vertex models - if !slices.Contains(providers, schemas.Vertex) { - for _, provider := range providers { - if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok { - if slices.Contains(vertexModels, string(provider)+"/"+model) { - providers = append(providers, schemas.Vertex) - } - } - } - } - - // 3. Handle openai models for groq - if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") { - if groqModels, ok := mc.modelPool[schemas.Groq]; ok { - if slices.Contains(groqModels, "openai/"+model) { - providers = append(providers, schemas.Groq) - } - } - } - - // 4. Handle anthropic models for bedrock - if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") { - if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok { - for _, bedrockModel := range bedrockModels { - if strings.Contains(bedrockModel, model) { - providers = append(providers, schemas.Bedrock) - break - } - } - } - } - - return providers -} - -// IsModelAllowedForProvider checks if a model is allowed for a specific provider -// based on the allowed models list and catalog data. It handles all cross-provider -// logic including provider-prefixed models and special routing rules. -// -// Parameters: -// - provider: The provider to check against -// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet") -// - allowedModels: List of allowed model names (can be empty, can include provider prefixes) -// -// Behavior: -// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model -// (delegates to GetProvidersForModel which handles all cross-provider logic) -// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair -// - If allowedModels is not empty: Checks if model matches any entry in the list -// Provider-specific validation: -// - Direct matches: "gpt-4o" in allowedModels for any provider -// - Prefixed matches: Only if the prefixed model exists in provider's catalog -// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog -// contains "openai/gpt-4o" AND the model part matches the request) -// -// Returns: -// - bool: true if the model is allowed for the provider, false otherwise -// -// Examples: -// -// // Wildcard allowedModels - uses catalog to check provider support -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"}) -// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet") -// -// // Empty allowedModels - deny all (deny-by-default) -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{}) -// // Returns: false (no models are permitted) -// -// // Explicit allowedModels with prefix - validates against catalog -// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"}) -// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o") -// -// // Explicit allowedModels with prefix - wrong model -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"}) -// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet") -// -// // Explicit allowedModels without prefix -// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) -// // Returns: true (direct match) -func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels schemas.WhiteList) bool { - // Case 1: ["*"] = allow all models; use catalog to determine support - // Empty allowedModels = deny all (fail-safe deny-by-default) - if allowedModels.IsUnrestricted() { - supportedProviders := mc.GetProvidersForModel(model) - return slices.Contains(supportedProviders, provider) - } - if allowedModels.IsEmpty() { - return false - } - - // Case 2: Explicit allowedModels = check if model matches any entry - // Get provider's catalog models for validation of prefixed entries - providerCatalogModels := mc.GetModelsForProvider(provider) - - for _, allowedModel := range allowedModels { - // Direct match: "gpt-4o" == "gpt-4o" - if allowedModel == model { - return true - } - - // Provider-prefixed match: verify it exists in provider's catalog first - // This ensures we only allow provider-specific model combinations that are actually supported - if strings.Contains(allowedModel, "/") { - // Check if this exact prefixed model exists in the provider's catalog - // e.g., for openrouter, check if "openai/gpt-4o" is in its catalog - if slices.Contains(providerCatalogModels, allowedModel) { - // Extract the model part and compare with request - _, modelPart := schemas.ParseModelString(allowedModel, "") - if modelPart == model { - return true - } - } - } - } - - return false -} - -// GetBaseModelName returns the canonical base model name for a given model string. -// It uses the pre-computed base_model from the pricing catalog when available, -// falling back to algorithmic date/version stripping for models not in the catalog. -// -// Examples: -// -// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o" -// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o" -// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback) -func (mc *ModelCatalog) GetBaseModelName(model string) string { - mc.mu.RLock() - defer mc.mu.RUnlock() - return mc.getBaseModelNameUnsafe(model) -} - -// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking. -// This is used to avoid locking overhead when getting the base model name for many models. -// Make sure the caller function is holding the read lock before calling this function. -// It is not safe to use this function when the model pool is being updated. -func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string { - // Step 1: Direct lookup in base model index - if base, ok := mc.baseModelIndex[model]; ok { - return base - } - - // Step 2: Strip provider prefix and try again - _, baseName := schemas.ParseModelString(model, "") - if baseName != model { - if base, ok := mc.baseModelIndex[baseName]; ok { - return base - } - } - - // Step 3: Fallback to algorithmic date/version stripping - // (for models not in the catalog, e.g., user-configured custom models) - return schemas.BaseModelName(baseName) -} - -// IsSameModel checks if two model strings refer to the same underlying model. -// It compares the canonical base model names derived from the pricing catalog -// (or algorithmic fallback for models not in the catalog). -// -// Examples: -// -// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match) -// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model) -// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models) -// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false -func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool { - if model1 == model2 { - return true - } - return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2) -} - -// DeleteModelDataForProvider deletes all model data from the pool for a given provider -func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) { - mc.mu.Lock() - defer mc.mu.Unlock() - - delete(mc.modelPool, provider) - delete(mc.unfilteredModelPool, provider) -} - -// UpsertModelDataForProvider upserts model data for a given provider -func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { - if modelData == nil { - return - } - mc.mu.Lock() - defer mc.mu.Unlock() - - // Populating models from pricing data for the given provider - // Provider models map - providerModels := []string{} - // Iterate through all pricing data to collect models per provider - for _, pricing := range mc.pricingData { - // Normalize provider before adding to model pool - normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) - // We will only add models for the given provider - if normalizedProvider != provider { - continue - } - // Add model to the provider's model set (using map for deduplication) - if slices.Contains(providerModels, pricing.Model) { - continue - } - providerModels = append(providerModels, pricing.Model) - // Build base model index from pre-computed base_model field - if pricing.BaseModel != "" { - mc.baseModelIndex[pricing.Model] = pricing.BaseModel - } - } - // If modelData is empty, then we allow all models - if len(modelData.Data) == 0 && len(allowedModels) == 0 { - mc.modelPool[provider] = providerModels - return - } - // Here we make sure that we still keep the backup for model catalog intact - // So we start with a existing model pool and add the new models from incoming data - finalModelList := make([]string, 0) - seenModels := make(map[string]bool) - // Case where list models failed but we have allowed models from keys - if len(modelData.Data) == 0 && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - finalModelList = append(finalModelList, parsedModel) - } - } - } - for _, model := range modelData.Data { - parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - finalModelList = append(finalModelList, parsedModel) - } - } - - if len(allowedModels) == 0 { - for _, model := range providerModels { - if !seenModels[model] { - seenModels[model] = true - finalModelList = append(finalModelList, model) - } - } - } - mc.modelPool[provider] = finalModelList -} - -// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider -func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) { - if modelData == nil { - return - } - mc.mu.Lock() - defer mc.mu.Unlock() - - // Populating models from pricing data for the given provider - providerModels := []string{} - seenModels := make(map[string]bool) - for _, pricing := range mc.pricingData { - normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) - if normalizedProvider != provider { - continue - } - if !seenModels[pricing.Model] { - seenModels[pricing.Model] = true - providerModels = append(providerModels, pricing.Model) - } - } - for _, model := range modelData.Data { - parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - providerModels = append(providerModels, parsedModel) - } - } - mc.unfilteredModelPool[provider] = providerModels -} - -// RefineModelForProvider refines the model for a given provider by performing a lookup -// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts. -// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b" -// -// Behavior: -// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error -// - When exactly one match is found, returns the fully-qualified model (provider/model format) -// - When the provider is not handled or no refinement is needed, returns the original model unchanged -func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) { - switch provider { - case schemas.Groq: - if strings.Contains(model, "gpt-") { - return "openai/" + model, nil - } - return mc.refineNestedProviderModel(provider, model) - case schemas.Replicate: - return mc.refineNestedProviderModel(provider, model) - } - return model, nil -} - -// refineNestedProviderModel resolves provider-native model slugs such as -// "openai/gpt-5-nano" from a base model request like "gpt-5-nano". -// It only considers catalog entries whose leading segment is a known Bifrost provider, -// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched. -func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) { - mc.mu.RLock() - models, ok := mc.modelPool[provider] - mc.mu.RUnlock() - if !ok { - return model, nil - } - - candidateModels := make([]string, 0) - seenCandidates := make(map[string]struct{}) - for _, poolModel := range models { - providerPart, modelPart := schemas.ParseModelString(poolModel, "") - if providerPart == "" || model != modelPart { - continue - } - - candidate := string(providerPart) + "/" + modelPart - if _, seen := seenCandidates[candidate]; seen { - continue - } - seenCandidates[candidate] = struct{}{} - candidateModels = append(candidateModels, candidate) - } - - switch len(candidateModels) { - case 0: - return model, nil - case 1: - return candidateModels[0], nil - default: - return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels) - } -} - -// SetPricingOverrides replaces the full in-memory pricing override set. -func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error { - seen := make(map[string]int, len(rows)) - overrides := make([]PricingOverride, 0, len(rows)) - for i := range rows { - o, err := convertTablePricingOverrideToPricingOverride(&rows[i]) - if err != nil { - return err - } - if idx, exists := seen[o.ID]; exists { - overrides[idx] = o // last entry wins for duplicate IDs - } else { - seen[o.ID] = len(overrides) - overrides = append(overrides, o) - } - } - mc.overridesMu.Lock() - mc.rawOverrides = overrides - mc.customPricing = buildCustomPricingData(overrides) - mc.overridesMu.Unlock() - return nil -} - -// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single -// operation, rebuilding the lookup map only once at the end. -func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error { - // Deduplicate the input batch by ID (last entry wins) and build the - // incoming set for O(1) lookup when filtering existing rawOverrides. - seenIncoming := make(map[string]int, len(rows)) - overrides := make([]PricingOverride, 0, len(rows)) - for _, row := range rows { - o, err := convertTablePricingOverrideToPricingOverride(row) - if err != nil { - return err - } - if idx, exists := seenIncoming[o.ID]; exists { - overrides[idx] = o // last entry wins for duplicate IDs - } else { - seenIncoming[o.ID] = len(overrides) - overrides = append(overrides, o) - } - } - - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - - updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides)) - for _, o := range mc.rawOverrides { - if _, replacing := seenIncoming[o.ID]; !replacing { - updated = append(updated, o) - } - } - updated = append(updated, overrides...) - mc.rawOverrides = updated - mc.customPricing = buildCustomPricingData(updated) - return nil -} - -// DeletePricingOverride removes a pricing override by ID. -func (mc *ModelCatalog) DeletePricingOverride(id string) { - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - - updated := make([]PricingOverride, 0, len(mc.rawOverrides)) - for _, o := range mc.rawOverrides { - if o.ID != id { - updated = append(updated, o) - } - } - mc.rawOverrides = updated - mc.customPricing = buildCustomPricingData(updated) -} - // IsRequestTypeSupported checks if a model supports chat completion. // It checks the supportedResponseTypes index. func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool { @@ -1069,4 +431,4 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { supportedParams: make(map[string][]string), done: make(chan struct{}), } -} +} \ No newline at end of file diff --git a/helm-charts/bifrost/templates/_helpers.tpl b/helm-charts/bifrost/templates/_helpers.tpl index 8dc0606658..97e7b7e4f4 100644 --- a/helm-charts/bifrost/templates/_helpers.tpl +++ b/helm-charts/bifrost/templates/_helpers.tpl @@ -227,8 +227,21 @@ false {{- if .Values.bifrost.client.maxRequestBodySizeMb }} {{- $_ := set $client "max_request_body_size_mb" .Values.bifrost.client.maxRequestBodySizeMb }} {{- end }} -{{- if hasKey .Values.bifrost.client "enableLitellmFallbacks" }} -{{- $_ := set $client "enable_litellm_fallbacks" .Values.bifrost.client.enableLitellmFallbacks }} +{{- if .Values.bifrost.client.compat }} +{{- $compat := dict }} +{{- if hasKey .Values.bifrost.client.compat "convertTextToChat" }} +{{- $_ := set $compat "convert_text_to_chat" .Values.bifrost.client.compat.convertTextToChat }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "convertChatToResponses" }} +{{- $_ := set $compat "convert_chat_to_responses" .Values.bifrost.client.compat.convertChatToResponses }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldDropParams" }} +{{- $_ := set $compat "should_drop_params" .Values.bifrost.client.compat.shouldDropParams }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldConvertParams" }} +{{- $_ := set $compat "should_convert_params" .Values.bifrost.client.compat.shouldConvertParams }} +{{- end }} +{{- $_ := set $client "compat" $compat }} {{- end }} {{- if .Values.bifrost.client.prometheusLabels }} {{- $_ := set $client "prometheus_labels" .Values.bifrost.client.prometheusLabels }} diff --git a/helm-charts/bifrost/values.schema.json b/helm-charts/bifrost/values.schema.json index 495e9c8e79..dac746716a 100644 --- a/helm-charts/bifrost/values.schema.json +++ b/helm-charts/bifrost/values.schema.json @@ -293,8 +293,15 @@ "type": "integer", "minimum": 1 }, - "enableLitellmFallbacks": { - "type": "boolean" + "compat": { + "type": "object", + "additionalProperties": false, + "properties": { + "convertTextToChat": { "type": "boolean" }, + "convertChatToResponses": { "type": "boolean" }, + "shouldDropParams": { "type": "boolean" }, + "shouldConvertParams": { "type": "boolean", "default": true } + } }, "prometheusLabels": { "type": "array", diff --git a/helm-charts/bifrost/values.yaml b/helm-charts/bifrost/values.yaml index 8509859f01..710aeea732 100644 --- a/helm-charts/bifrost/values.yaml +++ b/helm-charts/bifrost/values.yaml @@ -188,7 +188,11 @@ bifrost: enforceGovernanceHeader: false allowDirectKeys: false maxRequestBodySizeMb: 100 - enableLitellmFallbacks: false + compat: + convertTextToChat: false + convertChatToResponses: false + shouldDropParams: false + shouldConvertParams: true prometheusLabels: [] # Header filtering configuration for x-bf-eh-* headers forwarded to LLM providers headerFilterConfig: diff --git a/plugins/compat/main.go b/plugins/compat/main.go index 0c64b7b6ca..1e9450713b 100644 --- a/plugins/compat/main.go +++ b/plugins/compat/main.go @@ -13,7 +13,15 @@ const PluginName = "compat" // Config defines the configuration for the compat plugin. type Config struct { - Enabled bool `json:"enabled"` + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + +// IsEnabled returns true if any compat feature is enabled +func (c Config) IsEnabled() bool { + return c.ConvertTextToChat || c.ConvertChatToResponses || c.ShouldDropParams || c.ShouldConvertParams } // CompatPlugin provides LiteLLM-compatible request/response transformations. @@ -67,20 +75,26 @@ func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr return req, nil, nil } - // text completion → chat conversion - if (req.RequestType == schemas.TextCompletionRequest || req.RequestType == schemas.TextCompletionStreamRequest) && req.TextCompletionRequest != nil { - p.markForConversion(ctx, req.TextCompletionRequest.Provider, req.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + modifiedReq := cloneBifrostReq(req) + p.droppedParams = nil + + // Text completion → chat conversion + if p.config.ConvertTextToChat { + if (modifiedReq.RequestType == schemas.TextCompletionRequest || modifiedReq.RequestType == schemas.TextCompletionStreamRequest) && modifiedReq.TextCompletionRequest != nil { + p.markForConversion(ctx, modifiedReq.TextCompletionRequest.Provider, modifiedReq.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + } } - // chat completion → responses conversion - if (req.RequestType == schemas.ChatCompletionRequest || req.RequestType == schemas.ChatCompletionStreamRequest) && req.ChatRequest != nil { - p.markForConversion(ctx, req.ChatRequest.Provider, req.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + // Chat completion → responses conversion + if p.config.ConvertChatToResponses { + if (modifiedReq.RequestType == schemas.ChatCompletionRequest || modifiedReq.RequestType == schemas.ChatCompletionStreamRequest) && modifiedReq.ChatRequest != nil { + p.markForConversion(ctx, modifiedReq.ChatRequest.Provider, modifiedReq.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + } } - modifiedReq := cloneBifrostReq(req) - p.droppedParams = nil - if p.modelCatalog != nil { - _, model, _ := req.GetRequestFields() + // Compute unsupported parameters to drop based on model catalog allowlist + if p.config.ShouldDropParams && p.modelCatalog != nil { + _, model, _ := modifiedReq.GetRequestFields() if model != "" { if supportedParams := p.modelCatalog.GetSupportedParameters(model); supportedParams != nil { droppedParams := dropUnsupportedParams(modifiedReq, supportedParams) @@ -91,7 +105,9 @@ func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr } } - applyParameterConversion(modifiedReq) + if p.config.ShouldConvertParams { + applyParameterConversion(modifiedReq) + } return modifiedReq, nil, nil } @@ -143,4 +159,4 @@ func (p *CompatPlugin) markForConversion(ctx *schemas.BifrostContext, provider s ctx.SetValue(schemas.BifrostContextKeyChangeRequestType, targetType) p.logger.Debug("compat: marked %v for core conversion to %v for model %s", currentType, targetType, model) } -} \ No newline at end of file +} diff --git a/tests/governance/config.json b/tests/governance/config.json index bd9080a064..b8cedf9a3e 100644 --- a/tests/governance/config.json +++ b/tests/governance/config.json @@ -62,7 +62,6 @@ "enable_logging": true, "enforce_auth_on_inference": true, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/python/config.json b/tests/integrations/python/config.json index 00b89b5bdb..866469cc1d 100644 --- a/tests/integrations/python/config.json +++ b/tests/integrations/python/config.json @@ -343,7 +343,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json index cf49dba281..46bc65af6b 100644 --- a/tests/integrations/typescript/config.json +++ b/tests/integrations/typescript/config.json @@ -220,7 +220,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index ec28c9d3ad..dee74227c7 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -343,21 +343,32 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { } // Handle compat plugin toggle - if payload.ClientConfig.EnableLiteLLMFallbacks != currentConfig.EnableLiteLLMFallbacks { - if payload.ClientConfig.EnableLiteLLMFallbacks { - // Load and register the compat plugin - if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, &compat.Config{Enabled: true}, nil, nil); err != nil { + newCompat := payload.ClientConfig.Compat + oldCompat := currentConfig.Compat + if newCompat != oldCompat { + newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams + if newEnabled { + compatCfg := &compat.Config{ + ConvertTextToChat: newCompat.ConvertTextToChat, + ConvertChatToResponses: newCompat.ConvertChatToResponses, + ShouldDropParams: newCompat.ShouldDropParams, + ShouldConvertParams: newCompat.ShouldConvertParams, + } + if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil { logger.Warn("failed to load compat plugin: %v", err) + SendError(ctx, 400, "Failed to load compat plugin") + return } } else { - // Remove the compat plugin disabledCtx := context.WithValue(ctx, PluginDisabledKey, true) if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil { logger.Warn("failed to remove compat plugin: %v", err) + SendError(ctx, 400, "Failed to remove compat plugin") + return } } } - updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks + updatedConfig.Compat = newCompat // Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values if payload.ClientConfig.MCPAgentDepth > 0 { updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index c320086bb8..bc7ddd2641 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -36,8 +36,8 @@ import ( "github.com/maximhq/bifrost/framework/oauth2" plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" - "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/compat" + "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -309,7 +309,7 @@ var DefaultClientConfig = configstore.ClientConfig{ MCPAgentDepth: 10, MCPToolExecutionTimeout: 30, MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), - EnableLiteLLMFallbacks: false, + Compat: configstore.CompatConfig{ShouldConvertParams: true}, HideDeletedVirtualKeysInFilters: false, RoutingChainMaxDepth: governance.DefaultRoutingChainMaxDepth, } diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index ce1c2432a8..fed8d075d6 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -400,6 +400,7 @@ func (m *MockConfigStore) DB() *gorm.DB { retu func (m *MockConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { return fn(nil) } + func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { return nil } @@ -1130,18 +1131,23 @@ func (m *MockConfigStore) DeleteOauthToken(ctx context.Context, id string) error func (m *MockConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } + func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } @@ -1150,18 +1156,23 @@ func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *t func (m *MockConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { return nil } @@ -1170,33 +1181,43 @@ func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, func (m *MockConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } @@ -1204,24 +1225,31 @@ func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tabl func (m *MockConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { return 1, nil } + func (m *MockConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { return nil } + func (m *MockConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { return 1, nil } @@ -1263,12 +1291,15 @@ func (m *MockConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx . func (m *MockConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } + func (m *MockConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } @@ -1278,12 +1309,15 @@ func (m *MockConfigStore) DeleteFolder(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } + func (m *MockConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } @@ -1293,15 +1327,19 @@ func (m *MockConfigStore) DeletePrompt(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error { return nil } @@ -1311,15 +1349,19 @@ func (m *MockConfigStore) DeletePromptVersion(ctx context.Context, id uint) erro func (m *MockConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error { return nil } @@ -11981,6 +12023,7 @@ type mockLLMPlugin struct { func (p *mockLLMPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { return req, nil, nil } + func (p *mockLLMPlugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { return resp, bifrostErr, nil } @@ -12324,7 +12367,6 @@ func TestGenerateClientConfigHash(t *testing.T) { AllowDirectKeys: true, AllowedOrigins: []string{"http://localhost:3000"}, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } hash1, err := cc1.GenerateClientConfigHash() @@ -12421,12 +12463,12 @@ func TestGenerateClientConfigHash(t *testing.T) { t.Error("Different MaxRequestBodySizeMB should produce different hash") } - // Different EnableLiteLLMFallbacks should produce different hash + // Different Compat should produce different hash cc13 := cc1 - cc13.EnableLiteLLMFallbacks = true + cc13.Compat.ConvertTextToChat = true hash13, _ := cc13.GenerateClientConfigHash() if hash1 == hash13 { - t.Error("Different EnableLiteLLMFallbacks should produce different hash") + t.Error("Different Compat.ConvertTextToChat should produce different hash") } // PrometheusLabels order should not matter (sorted) @@ -13459,7 +13501,6 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: false, AllowDirectKeys: true, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } // Generate hash from config @@ -13473,7 +13514,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccToSave.EnforceAuthOnInference, AllowDirectKeys: ccToSave.AllowDirectKeys, MaxRequestBodySizeMB: ccToSave.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccToSave.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ConvertTextToChat: ccToSave.CompatConvertTextToChat, ConvertChatToResponses: ccToSave.CompatConvertChatToResponses, ShouldDropParams: ccToSave.CompatShouldDropParams}, } hashBeforeSave, _ := clientConfig.GenerateClientConfigHash() @@ -13492,7 +13533,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccFromDB.EnforceAuthOnInference, AllowDirectKeys: ccFromDB.AllowDirectKeys, MaxRequestBodySizeMB: ccFromDB.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccFromDB.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ConvertTextToChat: ccFromDB.CompatConvertTextToChat, ConvertChatToResponses: ccFromDB.CompatConvertChatToResponses, ShouldDropParams: ccFromDB.CompatShouldDropParams}, } hashAfterLoad, _ := clientConfigFromDB.GenerateClientConfigHash() @@ -15649,13 +15690,13 @@ func TestConfigSchemaSyncTopLevel(t *testing.T) { // Enterprise-only features: These fields exist in the JSON schema for documentation // and validation purposes, but are only available in the enterprise version. enterpriseSchemaFields := map[string]bool{ - "$schema": true, - "audit_logs": true, - "cluster_config": true, - "saml_config": true, - "load_balancer_config": true, - "guardrails_config": true, - "large_payload_optimization": true, + "$schema": true, + "audit_logs": true, + "cluster_config": true, + "saml_config": true, + "load_balancer_config": true, + "guardrails_config": true, + "large_payload_optimization": true, } schema := loadJSONSchema(t) @@ -16602,7 +16643,10 @@ func assertDefaultClientConfigValues(t *testing.T, cc configstore.ClientConfig) require.Equal(t, 100, cc.MaxRequestBodySizeMB, "MaxRequestBodySizeMB should default to 100") require.Equal(t, 10, cc.MCPAgentDepth, "MCPAgentDepth should default to 10") require.Equal(t, 30, cc.MCPToolExecutionTimeout, "MCPToolExecutionTimeout should default to 30") - require.Equal(t, false, cc.EnableLiteLLMFallbacks, "EnableLiteLLMFallbacks should default to false") + require.Equal(t, false, cc.Compat.ConvertTextToChat, "Compat.ConvertTextToChat should default to false") + require.Equal(t, false, cc.Compat.ConvertChatToResponses, "Compat.ConvertChatToResponses should default to false") + require.Equal(t, false, cc.Compat.ShouldDropParams, "Compat.ShouldDropParams should default to false") + require.Equal(t, true, cc.Compat.ShouldConvertParams, "Compat.ShouldConvertParams should default to true") require.Equal(t, false, cc.HideDeletedVirtualKeysInFilters, "HideDeletedVirtualKeysInFilters should default to false") } diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 39e07cc201..25f9597e2b 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -215,10 +215,16 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 7. Compat (if configured in PluginConfigs) - compatConfig := s.getPluginConfig(compat.PluginName) - if compatConfig != nil && compatConfig.Enabled { - s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatConfig.Config, false) + // 7. Compat (if any compat feature is enabled in ClientConfig) + cc := s.Config.ClientConfig.Compat + if cc.ConvertTextToChat || cc.ConvertChatToResponses || cc.ShouldDropParams || cc.ShouldConvertParams { + compatCfg := &compat.Config{ + ConvertTextToChat: cc.ConvertTextToChat, + ConvertChatToResponses: cc.ConvertChatToResponses, + ShouldDropParams: cc.ShouldDropParams, + ShouldConvertParams: cc.ShouldConvertParams, + } + s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatCfg, false) } else { s.markPluginDisabled(compat.PluginName) } diff --git a/transports/config.schema.json b/transports/config.schema.json index 1e5070d76f..a8994a2f91 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -94,9 +94,29 @@ "minimum": 1, "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Enable litellm-specific fallbacks for text completion for Groq" + "compat": { + "type": "object", + "description": "Compat plugin configuration for request type conversion, parameter dropping, and parameter value conversion", + "properties": { + "convert_text_to_chat": { + "type": "boolean", + "description": "Convert text completion requests to chat for models that only support chat" + }, + "convert_chat_to_responses": { + "type": "boolean", + "description": "Convert chat completion requests to responses for models that only support responses" + }, + "should_drop_params": { + "type": "boolean", + "description": "Drop unsupported parameters based on model catalog allowlist" + }, + "should_convert_params": { + "type": "boolean", + "description": "Converts model parameter values that are not supported by the model.", + "default": true + } + }, + "additionalProperties": false }, "header_filter_config": { "type": "object", @@ -4129,4 +4149,4 @@ "additionalProperties": false } } -} \ No newline at end of file +} diff --git a/transports/go.mod b/transports/go.mod index e33ac0aeb7..9d63d23e2b 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -16,6 +16,7 @@ require ( github.com/maximhq/bifrost/framework v1.3.1 github.com/maximhq/bifrost/plugins/compat v0.1.0 github.com/maximhq/bifrost/plugins/governance v1.5.0 + github.com/maximhq/bifrost/plugins/compat v0.1.0 github.com/maximhq/bifrost/plugins/logging v1.5.0 github.com/maximhq/bifrost/plugins/maxim v1.6.0 github.com/maximhq/bifrost/plugins/otel v1.2.0 @@ -184,6 +185,4 @@ require ( google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/postgres v1.6.0 // indirect -) - -replace github.com/maximhq/bifrost/plugins/compat => ../plugins/compat +) \ No newline at end of file diff --git a/ui/app/workspace/config/views/clientSettingsView.tsx b/ui/app/workspace/config/views/clientSettingsView.tsx index 1550b6566c..724b2b067e 100644 --- a/ui/app/workspace/config/views/clientSettingsView.tsx +++ b/ui/app/workspace/config/views/clientSettingsView.tsx @@ -6,13 +6,14 @@ import { Input } from "@/components/ui/input"; import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { getErrorMessage, useGetCoreConfigQuery, useGetDroppedRequestsQuery, useUpdateCoreConfigMutation } from "@/lib/store"; -import { CoreConfig, DefaultCoreConfig, DefaultGlobalHeaderFilterConfig, GlobalHeaderFilterConfig } from "@/lib/types/config"; +import { CompatConfig, CoreConfig, DefaultCoreConfig, DefaultGlobalHeaderFilterConfig, GlobalHeaderFilterConfig } from "@/lib/types/config"; import { cn } from "@/lib/utils"; import LargePayloadSettingsFragment from "@enterprise/components/large-payload/largePayloadSettingsFragment"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { useGetLargePayloadConfigQuery, useUpdateLargePayloadConfigMutation } from "@enterprise/lib/store/apis/largePayloadApi"; import { DefaultLargePayloadConfig, LargePayloadConfig } from "@enterprise/lib/types/largePayload"; import { Info, Plus, X } from "lucide-react"; +import Link from "next/link"; import { useCallback, useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; @@ -105,9 +106,14 @@ export default function ClientSettingsView() { const hasCoreConfigChanges = useMemo(() => { if (!config) return false; + const hasCompatConfigChanges = + (localConfig.compat?.convert_text_to_chat ?? false) !== (config.compat?.convert_text_to_chat ?? false) || + (localConfig.compat?.convert_chat_to_responses ?? false) !== (config.compat?.convert_chat_to_responses ?? false) || + (localConfig.compat?.should_drop_params ?? false) !== (config.compat?.should_drop_params ?? false) || + (localConfig.compat?.should_convert_params ?? true) !== (config.compat?.should_convert_params ?? true); return ( + hasCompatConfigChanges || localConfig.drop_excess_requests !== config.drop_excess_requests || - localConfig.enable_litellm_fallbacks !== config.enable_litellm_fallbacks || localConfig.disable_db_pings_in_health !== config.disable_db_pings_in_health || localConfig.async_job_result_ttl !== config.async_job_result_ttl || !headerFilterConfigEqual(localConfig.header_filter_config, config.header_filter_config) @@ -136,6 +142,10 @@ export default function ClientSettingsView() { setLocalConfig((prev) => ({ ...prev, [field]: value })); }, []); + const handleCompatChange = useCallback((field: keyof CompatConfig, value: boolean) => { + setLocalConfig((prev) => ({ ...prev, compat: { ...prev.compat, [field]: value } })); + }, []); + const handleLargePayloadConfigChange = useCallback((newConfig: LargePayloadConfig) => { setLocalLargePayloadConfig(newConfig); }, []); @@ -320,33 +330,95 @@ export default function ClientSettingsView() { /> - {/* Enable LiteLLM Fallbacks */} -
-
- -

- Enable litellm-specific fallbacks.{" "} - - Learn more - -

-
- handleConfigChange("enable_litellm_fallbacks", checked)} - disabled={!hasSettingsUpdateAccess} - /> -
+ {/* Compat Settings */} + + + +
+ LiteLLM Compat +

+ Request type conversion and parameter dropping.{" "} + e.stopPropagation()} + > + Learn more + +

+
+
+ +
+
+ +

+ Convert text completion requests to chat for models that only support chat. +

+
+ handleCompatChange("convert_text_to_chat", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
+
+
+ +

+ Convert chat completion requests to responses for models that only support responses. +

+
+ handleCompatChange("convert_chat_to_responses", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
+
+
+ +

Drop unsupported parameters based on model catalog allowlist.

+
+ handleCompatChange("should_drop_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
+
+
+ +

Converts model parameter values that are not supported by the model.

+
+ handleCompatChange("should_convert_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
+
+
+
{/* Disable DB Pings in Health */}
@@ -438,9 +510,8 @@ export default function ClientSettingsView() {
  • Wildcards: Use{" "} - * at the end of a pattern to match - prefixes (e.g.,{" "} - anthropic-* matches all headers starting + * at the end of a pattern to match prefixes + (e.g., anthropic-* matches all headers starting with anthropic-). Use{" "} * alone to match all headers.
  • diff --git a/ui/components/ui/accordion.tsx b/ui/components/ui/accordion.tsx index 2978b5e54f..6ee186762e 100644 --- a/ui/components/ui/accordion.tsx +++ b/ui/components/ui/accordion.tsx @@ -26,7 +26,7 @@ function AccordionTrigger({ className, children, ...props }: React.ComponentProp {...props} > {children} - + ); @@ -44,4 +44,4 @@ function AccordionContent({ className, children, ...props }: React.ComponentProp ); } -export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; +export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; \ No newline at end of file diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index b3929e54b9..be489e53b3 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -454,6 +454,13 @@ export interface BifrostConfig { auth_token?: string; } +export interface CompatConfig { + convert_text_to_chat: boolean; + convert_chat_to_responses: boolean; + should_drop_params: boolean; + should_convert_params: boolean; +} + // Core Bifrost configuration types export interface CoreConfig { drop_excess_requests: boolean; @@ -468,7 +475,7 @@ export interface CoreConfig { allowed_origins: string[]; allowed_headers: string[]; max_request_body_size_mb: number; - enable_litellm_fallbacks: boolean; + compat: CompatConfig; mcp_agent_depth: number; mcp_tool_execution_timeout: number; mcp_code_mode_binding_level?: string; @@ -495,7 +502,7 @@ export const DefaultCoreConfig: CoreConfig = { allow_direct_keys: false, allowed_origins: [], max_request_body_size_mb: 100, - enable_litellm_fallbacks: false, + compat: { convert_text_to_chat: false, convert_chat_to_responses: false, should_drop_params: false, should_convert_params: true }, mcp_agent_depth: 10, mcp_tool_execution_timeout: 30, mcp_code_mode_binding_level: "server",