diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index a0122adfa2..42758ab9dd 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index 85f901cfd5..e0f80eca7f 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -542,8 +542,8 @@ VALUES ('migration-test-lock', 'holder-migration-test-001', $future, $now) ON CONFLICT DO NOTHING; -- config_client (global client configuration) -INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, enable_litellm_fallbacks, config_hash, created_at, updated_at) -VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, 'client-config-hash-001', $now, $now) +INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, compat_convert_text_to_chat, compat_convert_chat_to_responses, compat_should_drop_params, compat_should_convert_params, config_hash, created_at, updated_at) +VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, 100, 10, 30, 'server', 10, false, false, false, true, 'client-config-hash-001', $now, $now) ON CONFLICT DO NOTHING; -- governance_config (key-value config table) @@ -3509,4 +3509,4 @@ main() { exit $exit_code } -main "$@" +main "$@" \ No newline at end of file diff --git a/.github/workflows/scripts/test-docker-image.sh b/.github/workflows/scripts/test-docker-image.sh index 5d770fbd64..ac115394bf 100755 --- a/.github/workflows/scripts/test-docker-image.sh +++ b/.github/workflows/scripts/test-docker-image.sh @@ -212,8 +212,7 @@ cat > "$CONFIG_FILE" << 'CONFIGEOF' "enable_logging": true, "enforce_governance_header": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "encryption_key": "" } diff --git a/.github/workflows/scripts/validate-helm-config-fields.sh b/.github/workflows/scripts/validate-helm-config-fields.sh index 3b08dfffe9..8b38e9d717 100755 --- a/.github/workflows/scripts/validate-helm-config-fields.sh +++ b/.github/workflows/scripts/validate-helm-config-fields.sh @@ -164,7 +164,11 @@ bifrost: enforceGovernanceHeader: true allowDirectKeys: true maxRequestBodySizeMb: 50 - enableLitellmFallbacks: true + compat: + convertTextToChat: true + convertChatToResponses: true + shouldDropParams: true + shouldConvertParams: true prometheusLabels: - "team" - "env" @@ -200,7 +204,10 @@ assert_field_value 'client.log_retention_days' '.client.log_retention_days' '30' assert_field_value 'client.enforce_governance_header' '.client.enforce_governance_header' 'true' assert_field_value 'client.allow_direct_keys' '.client.allow_direct_keys' 'true' assert_field_value 'client.max_request_body_size_mb' '.client.max_request_body_size_mb' '50' -assert_field_value 'client.enable_litellm_fallbacks' '.client.enable_litellm_fallbacks' 'true' +assert_field_value 'client.compat.convert_text_to_chat' '.client.compat.convert_text_to_chat' 'true' +assert_field_value 'client.compat.convert_chat_to_responses' '.client.compat.convert_chat_to_responses' 'true' +assert_field_value 'client.compat.should_drop_params' '.client.compat.should_drop_params' 'true' +assert_field_value 'client.compat.should_convert_params' '.client.compat.should_convert_params' 'true' assert_field 'client.prometheus_labels' '.client.prometheus_labels' assert_field 'client.header_filter_config.allowlist' '.client.header_filter_config.allowlist' assert_field 'client.header_filter_config.denylist' '.client.header_filter_config.denylist' @@ -1194,4 +1201,4 @@ if [ "$TESTS_FAILED" -gt 0 ]; then else echo -e "${GREEN}✅ All config.json field validations passed!${NC}" exit 0 -fi +fi \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 03fdd812ee..bd41feaa78 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,7 +103,7 @@ bifrost/ │ ├── mocker/ # Mock responses for testing │ ├── jsonparser/ # JSON extraction utilities │ ├── maxim/ # Maxim observability -│ └── litellmcompat/ # LiteLLM SDK compatibility (HTTP transport) +│ └── compat/ # LiteLLM SDK compatibility (HTTP transport) │ ├── ui/ # Next.js web interface │ ├── app/workspace/ # Feature pages (20+ workspace sections) @@ -647,4 +647,4 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso - **Provider types**: Prefixed with provider name in PascalCase (`AnthropicChatRequest`, `GeminiEmbeddingResponse`). - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). -- **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. +- **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. \ No newline at end of file diff --git a/core/bifrost.go b/core/bifrost.go index a514d1b936..b4015cca9f 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -612,7 +612,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TextCompletionResponse, nil } @@ -934,7 +934,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.EmbeddingResponse, nil } @@ -1042,7 +1042,7 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.SpeechResponse, nil } @@ -1117,7 +1117,7 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TranscriptionResponse, nil } @@ -1158,7 +1158,8 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, // ImageGenerationRequest sends an image generation request to the specified provider. func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1213,7 +1214,8 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, // ImageGenerationStreamRequest sends an image generation stream request to the specified provider. func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1434,7 +1436,8 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * // VideoGenerationRequest sends a video generation request to the specified provider. func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostVideoGenerationRequest, +) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -4692,7 +4695,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Send the processed message to the output stream outputStream <- streamResponse - //TODO: Release the processed response immediately after use + // TODO: Release the processed response immediately after use } }() @@ -5264,12 +5267,34 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse() + break + } + } textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) if bifrostError != nil { return nil, bifrostError } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.ChatResponse = responsesResponse.ToBifrostChatResponse() + break + } + } chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) if bifrostError != nil { return nil, bifrostError @@ -5519,8 +5544,20 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), key, chatRequest) + } + } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), key, responsesRequest) + } + } return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 867dcf7e97..f3c45370cd 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -1316,4 +1316,4 @@ func parseAnthropicFileTimestamp(timestamp string) int64 { // AnthropicCountTokensResponse models the payload returned by Anthropic's count tokens endpoint. type AnthropicCountTokensResponse struct { InputTokens int `json:"input_tokens"` -} +} \ No newline at end of file diff --git a/core/providers/bedrock/images.go b/core/providers/bedrock/images.go index b0ac35dc01..dc3c76edd4 100644 --- a/core/providers/bedrock/images.go +++ b/core/providers/bedrock/images.go @@ -153,7 +153,6 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } return bedrockReq, nil - } // ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index 6d2f9006f2..549db2e3bd 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -81,7 +81,6 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } - func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil @@ -128,4 +127,4 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK pipeline.BackfillModels(included)...) return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index a4d78f9a48..8e5aa31402 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -9,8 +9,11 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinimumReasoningMaxTokens = 1 -const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +const ( + MinimumReasoningMaxTokens = 1 + DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +) + // Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request const ( cohereTokenizeMinTextLength = 1 diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 75cf9f504f..44755ee761 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -17,11 +17,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level -const DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation -const DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini -const DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini -const skipThoughtSignatureValidator = "skip_thought_signature_validator" +const ( + MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level + DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation + DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini + DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini + skipThoughtSignatureValidator = "skip_thought_signature_validator" +) type thinkingBudgetRange struct { Min int @@ -509,8 +511,7 @@ type GoogleMaps struct { } // URLContext is a tool to support URL context retrieval. -type URLContext struct { -} +type URLContext struct{} // ToolComputerUse is a tool to support computer use. type ToolComputerUse struct { @@ -555,8 +556,7 @@ type ExternalAPIElasticSearchParams struct { } // ExternalAPISimpleSearchParams represents the search parameters to use for SIMPLE_SEARCH spec. -type ExternalAPISimpleSearchParams struct { -} +type ExternalAPISimpleSearchParams struct{} // ExternalAPI retrieves from data source powered by external API for grounding. The external API // is not owned by Google, but needs to follow the pre-defined API spec. @@ -714,8 +714,7 @@ type Retrieval struct { // ToolCodeExecution is a tool that executes code generated by the model, and automatically returns the result // to the model. See also [ExecutableCode]and [CodeExecutionResult] which are input // and output to this tool. -type ToolCodeExecution struct { -} +type ToolCodeExecution struct{} // Tool details of a tool that the model may use to generate a response. type Tool struct { diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 89de4e1e66..39e25990d8 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -6,8 +6,8 @@ import ( "fmt" "github.com/bytedance/sonic" - "github.com/maximhq/bifrost/core/schemas" providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" ) const MinMaxCompletionTokens = 16 @@ -82,7 +82,7 @@ type OpenAIChatRequest struct { // PromptCacheIsolationKey is the Fireworks chat-completions field for cache isolation. PromptCacheIsolationKey *string `json:"prompt_cache_isolation_key,omitempty"` - //NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. + // NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. // This Field is populated only for such providers and is NOT to be used externally. MaxTokens *int `json:"max_tokens,omitempty"` diff --git a/core/providers/perplexity/types.go b/core/providers/perplexity/types.go index feef9e0ccb..d5ad5c65f6 100644 --- a/core/providers/perplexity/types.go +++ b/core/providers/perplexity/types.go @@ -4,45 +4,45 @@ import "github.com/maximhq/bifrost/core/schemas" // PerplexityChatRequest represents a Perplexity chat completion request type PerplexityChatRequest struct { - Model string `json:"model"` // Required: Model to use for chat completion - Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects - SearchMode *string `json:"search_mode"` // Required: Search mode - ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) - MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate - Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature - TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling - LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference - SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter - ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images - ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions - SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter - SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter - SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter - LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter - LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter - TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling - Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty - ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response - DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search - EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier - WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options - MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response - Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model - ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls - Stop []string `json:"stop,omitempty"` // Optional: Stop sequences - LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities - TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities - NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results - NumImages *int `json:"num_images,omitempty"` // Optional: Number of images - SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter - ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter - ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter - SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search - StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode - ExtraParams map[string]interface{} `json:"-"` + Model string `json:"model"` // Required: Model to use for chat completion + Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects + SearchMode *string `json:"search_mode"` // Required: Search mode + ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) + MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature + TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling + LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference + SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter + ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images + ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions + SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter + SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter + SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter + LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter + LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter + TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty + ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response + DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search + EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier + WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options + MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response + Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model + ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls + Stop []string `json:"stop,omitempty"` // Optional: Stop sequences + LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities + TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities + NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results + NumImages *int `json:"num_images,omitempty"` // Optional: Number of images + SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter + ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter + ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter + SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search + StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode + ExtraParams map[string]interface{} `json:"-"` } // GetExtraParams implements the RequestBodyWithExtraParams interface diff --git a/core/providers/replicate/types.go b/core/providers/replicate/types.go index 98f84e613e..3ae88c0095 100644 --- a/core/providers/replicate/types.go +++ b/core/providers/replicate/types.go @@ -313,28 +313,28 @@ type ReplicatePredictionListResponse struct { // ReplicateModelResponse represents a model response type ReplicateModelResponse struct { - URL string `json:"url"` // Model API URL - Owner string `json:"owner"` // Owner username or org name - Name string `json:"name"` // Model name - Description *string `json:"description,omitempty"` // Model description - Visibility string `json:"visibility"` // "public" or "private" - GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL - PaperURL *string `json:"paper_url,omitempty"` // Research paper URL - LicenseURL *string `json:"license_url,omitempty"` // License URL - RunCount *int `json:"run_count,omitempty"` // Number of times run - CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL - DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) - LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details - FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details + URL string `json:"url"` // Model API URL + Owner string `json:"owner"` // Owner username or org name + Name string `json:"name"` // Model name + Description *string `json:"description,omitempty"` // Model description + Visibility string `json:"visibility"` // "public" or "private" + GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL + PaperURL *string `json:"paper_url,omitempty"` // Research paper URL + LicenseURL *string `json:"license_url,omitempty"` // License URL + RunCount *int `json:"run_count,omitempty"` // Number of times run + CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL + DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) + LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details + FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details } // ReplicateModelVersion represents a model version type ReplicateModelVersion struct { - ID string `json:"id"` // Version ID - CreatedAt string `json:"created_at"` // ISO 8601 timestamp - CogVersion *string `json:"cog_version,omitempty"` // Cog version used - OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) - DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID + ID string `json:"id"` // Version ID + CreatedAt string `json:"created_at"` // ISO 8601 timestamp + CogVersion *string `json:"cog_version,omitempty"` // Cog version used + OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) + DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID } // ReplicateModelListResponse represents a paginated list of models diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 00e12392f0..22ae9fa7d3 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -1680,7 +1680,7 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } @@ -1698,7 +1698,7 @@ func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunn Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } @@ -2051,14 +2051,13 @@ func ProcessAndSendError( logger schemas.Logger, ) { // Send scanner error through channel - bifrostError := - &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: fmt.Sprintf("Error reading stream: %v", err), - Error: err, - }, - } + bifrostError := &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { @@ -2220,7 +2219,7 @@ func GetBifrostResponseForStreamResponse( transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} switch { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 48837563eb..2fbe83979d 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -193,4 +193,4 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index 97d6de7fa2..bbdb89d17f 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -192,23 +192,23 @@ type VertexModelLabels struct { // These types are for the publishers.models.list endpoint (Model Garden) type VertexPublisherModel struct { - Name string `json:"name"` - VersionID string `json:"versionId"` - OpenSourceCategory string `json:"openSourceCategory"` - LaunchStage string `json:"launchStage"` - VersionState string `json:"versionState"` - PublisherModelTemplate string `json:"publisherModelTemplate"` - SupportedActions *VertexPublisherModelActions `json:"supportedActions"` + Name string `json:"name"` + VersionID string `json:"versionId"` + OpenSourceCategory string `json:"openSourceCategory"` + LaunchStage string `json:"launchStage"` + VersionState string `json:"versionState"` + PublisherModelTemplate string `json:"publisherModelTemplate"` + SupportedActions *VertexPublisherModelActions `json:"supportedActions"` } type VertexPublisherModelActions struct { - OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` - OpenGenie *VertexPublisherModelURI `json:"openGenie"` - OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` - OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` - OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` - Deploy *VertexPublisherModelDeploy `json:"deploy"` - OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` + OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` + OpenGenie *VertexPublisherModelURI `json:"openGenie"` + OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` + OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` + OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` + Deploy *VertexPublisherModelDeploy `json:"deploy"` + OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` } type VertexPublisherModelURI struct { diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 626cf00207..91acb476af 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -3153,4 +3153,4 @@ func (provider *VertexProvider) PassthroughStream( } }() return ch, nil -} +} \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 8f68d2fbad..db187ac717 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -196,6 +196,7 @@ const ( BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" + BifrostContextKeyChangeRequestType BifrostContextKey = "bifrost-change-request-type" // RequestType (set by plugins to trigger request type conversion in core, e.g. text->chat or chat->responses) BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) @@ -271,6 +272,10 @@ const ( BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) BifrostContextKeyMCPLogID BifrostContextKey = "bifrost-mcp-log-id" // string (unique UUID for each MCP tool log entry - set per goroutine by agent executor - DO NOT SET THIS MANUALLY) + BifrostContextKeyCompatConvertTextToChat BifrostContextKey = "bifrost-compat-convert-text-to-chat" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatConvertChatToResponses BifrostContextKey = "bifrost-compat-convert-chat-to-responses" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldDropParams BifrostContextKey = "bifrost-compat-should-drop-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldConvertParams BifrostContextKey = "bifrost-compat-should-convert-params" // bool (per-request override from x-bf-compat header) ) const ( @@ -1032,18 +1037,19 @@ type BifrostMCPResponse struct { // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - RequestType RequestType `json:"request_type"` - Provider ModelProvider `json:"provider,omitempty"` - OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request - ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` - ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` + ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` // params dropped by the compat plugin based on model catalog + ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) } type BifrostMCPResponseExtraFields struct { @@ -1218,13 +1224,14 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - OriginalModelRequested string `json:"original_model_requested,omitempty"` - ResolvedModelUsed string `json:"resolved_model_used,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest any `json:"raw_request,omitempty"` - RawResponse any `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` - MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication -} + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication +} \ No newline at end of file diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index b864349213..f54002c9ba 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -29,16 +29,16 @@ func (cr *BifrostChatRequest) GetExtraParams() map[string]interface{} { // BifrostChatResponse represents the complete result from a chat completion request. type BifrostChatResponse struct { - ID string `json:"id"` - Choices []BifrostResponseChoice `json:"choices"` - Created int `json:"created"` // The Unix timestamp (in seconds). - Model string `json:"model"` - Object string `json:"object"` // "chat.completion" or "chat.completion.chunk" - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint string `json:"system_fingerprint"` - Usage *BifrostLLMUsage `json:"usage"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` - ExtraParams map[string]interface{} `json:"-"` + ID string `json:"id"` + Choices []BifrostResponseChoice `json:"choices"` + Created int `json:"created"` // The Unix timestamp (in seconds). + Model string `json:"model"` + Object string `json:"object"` // "chat.completion" or "chat.completion.chunk" + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint string `json:"system_fingerprint"` + Usage *BifrostLLMUsage `json:"usage"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + ExtraParams map[string]interface{} `json:"-"` // Perplexity-specific fields SearchResults []SearchResult `json:"search_results,omitempty"` @@ -46,125 +46,6 @@ type BifrostChatResponse struct { Citations []string `json:"citations,omitempty"` } -// ToTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse -func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletionResponse { - if cr == nil { - return nil - } - - if len(cr.Choices) == 0 { - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - choice := cr.Choices[0] - - // Handle streaming response choice - if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Choices: []BifrostResponseChoice{ - { - Index: 0, - TextCompletionResponseChoice: &TextCompletionResponseChoice{ - Text: choice.ChatStreamResponseChoice.Delta.Content, - }, - FinishReason: choice.FinishReason, - LogProbs: choice.LogProbs, - }, - }, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - // Handle non-streaming response choice - if choice.ChatNonStreamResponseChoice != nil { - msg := choice.ChatNonStreamResponseChoice.Message - var textContent *string - if msg != nil && msg.Content != nil && msg.Content.ContentStr != nil { - textContent = msg.Content.ContentStr - } - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Choices: []BifrostResponseChoice{ - { - Index: 0, - TextCompletionResponseChoice: &TextCompletionResponseChoice{ - Text: textContent, - }, - FinishReason: choice.FinishReason, - LogProbs: choice.LogProbs, - }, - }, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - // Fallback case - return basic response structure - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } -} - // ChatParameters represents the parameters for a chat completion. type ChatParameters struct { Audio *ChatAudioParameters `json:"audio,omitempty"` // Audio parameters @@ -531,7 +412,6 @@ type AdditionalPropertiesStruct struct { // MarshalJSON implements custom JSON marshalling for AdditionalPropertiesStruct. // It marshals either AdditionalPropertiesBool or AdditionalPropertiesMap based on which is set. func (a AdditionalPropertiesStruct) MarshalJSON() ([]byte, error) { - // if both are set, return an error if a.AdditionalPropertiesBool != nil && a.AdditionalPropertiesMap != nil { return nil, fmt.Errorf("both AdditionalPropertiesBool and AdditionalPropertiesMap are set; only one should be non-nil") @@ -1198,7 +1078,7 @@ type BifrostLLMUsage struct { CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"` TotalTokens int `json:"total_tokens"` - Cost *BifrostCost `json:"cost,omitempty"` //Only for the providers which support cost calculation + Cost *BifrostCost `json:"cost,omitempty"` // Only for the providers which support cost calculation } type ChatPromptTokensDetails struct { diff --git a/core/schemas/mux.go b/core/schemas/mux.go index f899f41739..24943d3fbd 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -1258,6 +1258,10 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC Videos: responsesResp.Videos, } + if responsesResp.ID != nil { + chatResp.ID = *responsesResp.ID + } + // Create Choices from ResponsesResponse if len(responsesResp.Output) > 0 { // Convert ResponsesMessages back to ChatMessages @@ -2013,3 +2017,362 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes return responses } + +// ToBifrostChatResponse converts a BifrostResponsesStreamResponse chunk to a BifrostChatResponse (chat.completion.chunk). +func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatResponse { + if rsr == nil { + return nil + } + + extraFields := rsr.ExtraFields + extraFields.RequestType = ChatCompletionStreamRequest + + resp := &BifrostChatResponse{ + Object: "chat.completion.chunk", + ExtraFields: extraFields, + SearchResults: rsr.SearchResults, + Videos: rsr.Videos, + Citations: rsr.Citations, + } + + if rsr.Response != nil { + if rsr.Response.ID != nil { + resp.ID = *rsr.Response.ID + } + resp.Created = rsr.Response.CreatedAt + resp.Model = rsr.Response.Model + } + + switch rsr.Type { + case ResponsesStreamResponseTypeOutputTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Content: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeReasoningSummaryTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Reasoning: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeRefusalDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Refusal: rsr.Refusal, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeOutputItemAdded: + if rsr.Item == nil || rsr.Item.Type == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + switch *rsr.Item.Type { + case ResponsesMessageTypeFunctionCall: + if rsr.Item.ResponsesToolMessage == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + funcType := "function" + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Type: &funcType, + ID: rsr.Item.ResponsesToolMessage.CallID, + Function: ChatAssistantMessageToolCallFunction{ + Name: rsr.Item.ResponsesToolMessage.Name, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesMessageTypeMessage: + role := "assistant" + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Role: &role, + }, + }, + }, + } + return resp + + default: + // reasoning, file_search_call, web_search_call, etc. — no chat equivalent, + // actual content arrives via separate delta events. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + case ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + if rsr.Delta == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Function: ChatAssistantMessageToolCallFunction{ + Arguments: *rsr.Delta, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeCompleted, ResponsesStreamResponseTypeIncomplete: + finishReason := string(BifrostFinishReasonStop) + if rsr.Type == ResponsesStreamResponseTypeIncomplete { + finishReason = string(BifrostFinishReasonLength) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + FinishReason: &finishReason, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + if rsr.Response != nil { + if rsr.Response.Usage != nil { + resp.Usage = rsr.Response.Usage.ToBifrostLLMUsage() + } + // Check for tool_calls finish reason + if rsr.Type == ResponsesStreamResponseTypeCompleted { + for _, output := range rsr.Response.Output { + if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { + finishReason = string(BifrostFinishReasonToolCalls) + resp.Choices[0].FinishReason = &finishReason + break + } + } + } + } + return resp + + default: + // Lifecycle events (created, in_progress, content_part.added/done, output_text.done, + // output_item.done, function_call_arguments.done, etc.) → empty chat chunk with no content. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } +} + +// ============================================================================= +// RESPONSE CONVERSION METHODS +// ============================================================================= + +// ToBifrostTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse +func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCompletionResponse { + if cr == nil { + return nil + } + + if len(cr.Choices) == 0 { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + choice := cr.Choices[0] + + // Handle streaming response choice + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: choice.ChatStreamResponseChoice.Delta.Content, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Handle non-streaming response choice + if choice.ChatNonStreamResponseChoice != nil { + msg := choice.ChatNonStreamResponseChoice.Message + var textContent *string + if msg != nil && msg.Content != nil { + if msg.Content.ContentStr != nil { + textContent = msg.Content.ContentStr + } else if len(msg.Content.ContentBlocks) > 0 { + var sb strings.Builder + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + sb.WriteString(*block.Text) + } + } + if sb.Len() > 0 { + s := sb.String() + textContent = &s + } + } + } + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: textContent, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Fallback case - return basic response structure + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } +} diff --git a/core/utils.go b/core/utils.go index adf5a57066..444b3d9cb3 100644 --- a/core/utils.go +++ b/core/utils.go @@ -281,6 +281,7 @@ func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) + ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -604,3 +605,30 @@ func isPromptOptionalImageEditType(t *string) bool { normalized, ) } + +// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming +// responses produced by a type-converted request are converted back to the +// caller's original type before the post-hook runs. +func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil { + switch targetType { + case schemas.ChatCompletionRequest: + // text→chat: convert chat stream chunk back to text completion + if result.ChatResponse != nil { + if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil { + result = &schemas.BifrostResponse{TextCompletionResponse: converted} + } + } + case schemas.ResponsesRequest: + // chat→responses: convert responses stream chunk back to chat + if result.ResponsesStreamResponse != nil { + if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil { + result = &schemas.BifrostResponse{ChatResponse: converted} + } + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} \ No newline at end of file diff --git a/docs/features/litellm-compat.mdx b/docs/features/litellm-compat.mdx index 51cd26dcd9..490a37efa4 100644 --- a/docs/features/litellm-compat.mdx +++ b/docs/features/litellm-compat.mdx @@ -9,8 +9,10 @@ icon: "train" The LiteLLM compatibility plugin provides two transformations: 1. **Text-to-Chat Conversion** - Automatically converts text completion requests to chat completion format for models that only support chat APIs +2. **Chat-to-Responses Conversion** - Automatically converts chat completion requests to responses format for models that only support responses APIs +3. **Drop Unsupported Params** - Automatically drops unsupported parameters if the model doesn't support them -When either transformation is applied, responses include `extra_fields.litellm_compat: true`. +When either transformation is applied, responses include `extra_fields.converted_request_type: `. If request parameters are dropped, the keys are added in `extra_fields.dropped_compat_plugin_params`. --- @@ -55,6 +57,36 @@ F --> G - `object: "chat.completion"` → `object: "text_completion"` - Usage statistics and metadata are preserved +## 2. Chat-to-Responses Conversion + +Some AI models (like OpenAI o1-pro) only support the responses API and don't support native chat completion endpoints. LiteLLM compatibility mode automatically handles this by: + +1. Checking if the model supports chat completion natively (using the model catalog) +2. If not supported, converting your chat message to responses API format +3. Calling the responses endpoint internally +4. Transforming the response back to chat completion format + + +**Smart Conversion**: The conversion only happens when the model doesn't support chat completions natively. If a model has native chat completion support (like OpenAI's gpt-4 models), Bifrost uses the chat completion endpoint directly without any conversion. + + +This allows you to use a unified chat completion interface across all providers, even those that only support responses API. + +## How It Works + +When LiteLLM compatibility is enabled and you make a chat completion request, Bifrost first checks if the model supports chat completion: + +```mermaid +flowchart LR +A[Chat Completion Request] --> B{Model Supports Chat Completion?} +B -->|Yes| C[Call Chat Completion API] +B -->|No| D[Convert to Responses Message] +D --> E[Call Responses API] +E --> F[Transform Response] +C --> G[Chat Completion Response] +F --> G +``` + ## Enabling LiteLLM Compatibility @@ -63,7 +95,10 @@ F --> G 1. Open the Bifrost dashboard 2. Navigate to **Settings** → **Client Configuration** -3. Enable **LiteLLM Fallbacks** +3. Expand **LiteLLM Compat** and enable the features you need: + - **Convert Text to Chat** — converts text completion requests to chat for models that only support chat + - **Convert Chat to Responses** — converts chat completion requests to responses for models that only support responses + - **Drop Unsupported Params** — drops unsupported parameters based on model catalog allowlist 4. Save your configuration @@ -73,7 +108,11 @@ F --> G ```json { "client_config": { - "enable_litellm_fallbacks": true + "compat": { + "convert_text_to_chat": true, + "convert_chat_to_responses": true, + "should_drop_params": true + } } } ``` @@ -84,9 +123,9 @@ F --> G ## Supported Providers -LiteLLM compatibility mode works with any provider that supports chat completions but lacks native text completion support: +Text completion to chat completion conversion works with any provider that supports chat completions but lacks native text completion support: -| Provider | Native Text Completion | LiteLLM Fallback | +| Provider | Native Text Completion | With Fallback | |----------|----------------------|------------------| | OpenAI (GPT-4, GPT-3.5-turbo) | No | Yes | | Anthropic (Claude) | No | Yes | @@ -95,6 +134,12 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Mistral | No | Yes | | Bedrock | Varies by model | Yes | +Chat completion to responses conversion works with any provider that supports responses but lacks native chat completion support: + +| Provider | Native Chat Completion | With Fallback | +|----------|----------------------|------------------| +| OpenAI (o1-pro) | No | Yes | + ## Behavior Details **Model Capability Detection:** @@ -117,13 +162,19 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Response | `choices[0].message.content` | `choices[0].text` | | Response | `object: "chat.completion"` | `object: "text_completion"` | +### Transformation 2: Chat-to-Responses Conversion + +**Applies to:** Chat completion requests on responses-only models + +| Phase | Original | Transformed | +|-------|----------|-------------| +| Request | Chat message with `role: "user"` | Responses input with `role: "user"` | +| Request | `chat_completion` request type | `responses` request type | ### Metadata Set on Transformed Responses When either transformation is applied: -- `extra_fields.litellm_compat`: Set to `true` -- `extra_fields.provider`: The provider that handled the request - `extra_fields.request_type`: Reflects the original request type - `extra_fields.original_model_requested`: The originally requested model - `extra_fields.resolved_model_used`: The actual provider API identifier used (equals original_model_requested when no alias mapping exists) @@ -131,8 +182,11 @@ When either transformation is applied: ### Error Handling When errors occur on transformed requests: -- `extra_fields.litellm_compat` is set to `true` - Original request type and model are preserved in error metadata +- `extra_fields.converted_request_type`: Set to type of request that was converted to (i.e., `chat_completion` or `responses`) +- `extra_fields.provider`: The provider that handled the request +- `extra_fields.original_model_requested`: The originally requested model +- `extra_fields.dropped_compat_plugin_params`: If any unsupported parameters were dropped, the keys are added here ## What's Preserved @@ -145,7 +199,7 @@ When errors occur on transformed requests: **Good Use Cases:** - Migrating from LiteLLM to Bifrost without code changes -- Maintaining backward compatibility with text completion interfaces +- Maintaining backward compatibility with text completion interfaces or chat completion interfaces - Using a unified API across providers with different capabilities **Consider Alternatives When:** @@ -157,4 +211,4 @@ When errors occur on transformed requests: - [Fallbacks](/features/fallbacks) - Automatic provider failover - [Drop-in Replacement](/features/drop-in-replacement) - Use existing SDKs with Bifrost -- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost +- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost \ No newline at end of file diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 1043039a1f..e258d653b5 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -133221,9 +133221,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -133537,9 +133543,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -205784,9 +205796,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -205999,9 +206017,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -224498,4 +224522,4 @@ } } } -} \ No newline at end of file +} diff --git a/docs/openapi/schemas/management/config.yaml b/docs/openapi/schemas/management/config.yaml index 2c54b3979d..eaafb3821f 100644 --- a/docs/openapi/schemas/management/config.yaml +++ b/docs/openapi/schemas/management/config.yaml @@ -44,9 +44,24 @@ ClientConfig: max_request_body_size_mb: type: integer description: Maximum request body size in MB - enable_litellm_fallbacks: - type: boolean - description: Whether LiteLLM fallbacks are enabled + compat: + type: object + description: Compat plugin configuration + properties: + convert_text_to_chat: + type: boolean + description: Convert text completion requests to chat + convert_chat_to_responses: + type: boolean + description: Convert chat completion requests to responses + should_drop_params: + type: boolean + description: Drop unsupported parameters based on model catalog + should_convert_params: + type: boolean + default: false + description: Converts model parameter values that are not supported by the model + additionalProperties: false log_retention_days: type: integer description: Number of days to retain logs diff --git a/docs/providers/supported-providers/overview.mdx b/docs/providers/supported-providers/overview.mdx index b3ae42f62f..98d13ffa73 100644 --- a/docs/providers/supported-providers/overview.mdx +++ b/docs/providers/supported-providers/overview.mdx @@ -48,7 +48,7 @@ The following table summarizes which operations are supported by each provider v Some operations are not supported by the downstream provider, and their internal implementation in Bifrost is optional. 🟡 -Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting the `enable_litellm_fallbacks` flag to `true` in the client configuration. +Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting `compat.convert_text_to_chat` to `true` in the client configuration. We do not promote using such fallbacks, since text completions and chat completions are fundamentally different. However, this option is available to help users migrating from LiteLLM (which does support these fallbacks). diff --git a/examples/configs/withpostgresmcpclientsinconfig/config.json b/examples/configs/withpostgresmcpclientsinconfig/config.json index 8e03969988..068bc88012 100644 --- a/examples/configs/withpostgresmcpclientsinconfig/config.json +++ b/examples/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/configs/withprompushgateway/config.json b/examples/configs/withprompushgateway/config.json index f697041388..110557d797 100644 --- a/examples/configs/withprompushgateway/config.json +++ b/examples/configs/withprompushgateway/config.json @@ -183,8 +183,7 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "config_store": { "enabled": true, diff --git a/examples/configs/withvirtualkeys/config.json b/examples/configs/withvirtualkeys/config.json index a968bad65c..9d9ae2c87a 100644 --- a/examples/configs/withvirtualkeys/config.json +++ b/examples/configs/withvirtualkeys/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/dockers/data/config.json b/examples/dockers/data/config.json index 46cbfd8e68..072691c2ea 100644 --- a/examples/dockers/data/config.json +++ b/examples/dockers/data/config.json @@ -27,7 +27,9 @@ "*" ], "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "compat": { + "should_convert_params": false + } }, "framework": { "pricing": { @@ -35,4 +37,4 @@ "pricing_sync_interval": 86400 } } -} \ No newline at end of file +} diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 698437c328..90a27db25e 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -34,6 +34,14 @@ type EnvKeyInfo struct { KeyID string // The key ID this env var belongs to (empty for non-key configs like bedrock_config, connection_string) } +// CompatConfig holds the compat plugin feature flags. +type CompatConfig struct { + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + // ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. // It includes settings for excess request handling, Prometheus metrics, and initial pool size. type ClientConfig struct { @@ -51,7 +59,7 @@ type ClientConfig struct { AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) AllowedHeaders []string `json:"allowed_headers,omitempty"` // Additional allowed headers for CORS and WebSocket MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB - EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + Compat CompatConfig `json:"compat"` // Compat plugin configuration MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" @@ -110,10 +118,17 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("allowDirectKeys:false")) } - if c.EnableLiteLLMFallbacks { - hash.Write([]byte("enableLiteLLMFallbacks:true")) - } else { - hash.Write([]byte("enableLiteLLMFallbacks:false")) + if c.Compat.ConvertTextToChat { + hash.Write([]byte("compatConvertTextToChat:true")) + } + if c.Compat.ConvertChatToResponses { + hash.Write([]byte("compatConvertChatToResponses:true")) + } + if c.Compat.ShouldDropParams { + hash.Write([]byte("compatShouldDropParams:true")) + } + if c.Compat.ShouldConvertParams { + hash.Write([]byte("compatShouldConvertParams:true")) } // Only hash non-default value to avoid legacy config hash churn. diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 8ff803cc99..e7281c4d19 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -377,9 +377,15 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddWhitelistedRoutesJSONColumn(ctx, db); err != nil { return err } + if err := migrationReplaceEnableLiteLLMWithCompatColumns(ctx, db); err != nil { + return err + } if err := migrationAddModelPricingUniqueIndex(ctx, db); err != nil { return err } + if err := migrationDefaultCompatShouldConvertParamsFalse(ctx, db); err != nil { + return err + } return nil } @@ -789,9 +795,10 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) ID: "add_enable_litellm_fallbacks_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { - if err := migrator.AddColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + // Use raw SQL since the struct field was removed in a later migration. + // This column is subsequently dropped by migrationReplaceEnableLiteLLMWithCompatColumns. + if !tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT FALSE").Error; err != nil { return err } } @@ -799,9 +806,7 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - - if err := migrator.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + if err := tx.Exec("ALTER TABLE config_client DROP COLUMN IF EXISTS enable_litellm_fallbacks").Error; err != nil { return err } return nil @@ -2166,7 +2171,6 @@ func migrationAddAdditionalConfigHashColumns(ctx context.Context, db *gorm.DB) e AllowDirectKeys: cc.AllowDirectKeys, AllowedOrigins: cc.AllowedOrigins, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, } hash, err := clientConfig.GenerateClientConfigHash() if err != nil { @@ -5674,7 +5678,6 @@ func migrationAddRoutingChainMaxDepthColumn(ctx context.Context, db *gorm.DB) er AllowedOrigins: cc.AllowedOrigins, AllowedHeaders: cc.AllowedHeaders, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, HideDeletedVirtualKeysInFilters: cc.HideDeletedVirtualKeysInFilters, MCPAgentDepth: cc.MCPAgentDepth, MCPToolExecutionTimeout: cc.MCPToolExecutionTimeout, @@ -5970,7 +5973,6 @@ func migrationAddMultiBudgetTables(ctx context.Context, db *gorm.DB) error { if mg.HasColumn(&tables.TableBudget{}, "provider_config_id") { if err := mg.DropColumn(&tables.TableBudget{}, "provider_config_id"); err != nil { return err - } } return nil @@ -6126,21 +6128,155 @@ func migrationAddWhitelistedRoutesJSONColumn(ctx context.Context, db *gorm.DB) e return fmt.Errorf("failed to add whitelisted_routes_json column: %w", err) } } + return nil }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableClientConfig{}, "whitelisted_routes_json") { if err := migrator.DropColumn(&tables.TableClientConfig{}, "whitelisted_routes_json"); err != nil { return fmt.Errorf("failed to drop whitelisted_routes_json column: %w", err) } } + + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running whitelisted_routes_json migration: %s", err.Error()) + } + return nil +} + +// migrationReplaceEnableLiteLLMWithCompatColumns replaces the single enable_litellm_fallbacks +// boolean with compat feature columns. If enable_litellm_fallbacks was true, +// only convert_text_to_chat is set to true (preserving the original behavior). +func migrationReplaceEnableLiteLLMWithCompatColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "replace_enable_litellm_with_compat_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + // Add new columns + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_drop_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_drop_params"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_convert_params"); err != nil { + return err + } + } + + if err := tx.Exec("UPDATE config_client SET compat_should_convert_params = FALSE").Error; err != nil { + return err + } + + // Migrate data: if enable_litellm_fallbacks was true, set convert_text_to_chat = true + if mig.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("UPDATE config_client SET compat_convert_text_to_chat = enable_litellm_fallbacks").Error; err != nil { + return err + } + if err := mig.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + return err + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT FALSE").Error; err != nil { + return err + } + } + if mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat") { + if err := tx.Exec("UPDATE config_client SET enable_litellm_fallbacks = COALESCE(compat_convert_text_to_chat, FALSE)").Error; err != nil { + return err + } + } + for _, col := range []string{ + "compat_convert_text_to_chat", + "compat_convert_chat_to_responses", + "compat_should_drop_params", + "compat_should_convert_params", + } { + if mig.HasColumn(&tables.TableClientConfig{}, col) { + if err := mig.DropColumn(&tables.TableClientConfig{}, col); err != nil { + return err + } + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running replace_enable_litellm_with_compat_columns migration: %s", err.Error()) + } + return nil +} + +// migrationDefaultCompatShouldConvertParamsFalse ensures existing deployments +// converge to the new default for compat_should_convert_params. The earlier +// compat migration may already be marked as applied, so changing its body is not +// sufficient for installed databases. +func migrationDefaultCompatShouldConvertParamsFalse(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "default_compat_should_convert_params_false", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + return nil + } + + if err := tx.Exec("UPDATE config_client SET compat_should_convert_params = FALSE").Error; err != nil { + return err + } + + if err := mig.AlterColumn(&tables.TableClientConfig{}, "CompatShouldConvertParams"); err != nil { + return fmt.Errorf("failed to alter compat_should_convert_params default: %w", err) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + return nil + } + + switch tx.Dialector.Name() { + case "postgres": + if err := tx.Exec("ALTER TABLE config_client ALTER COLUMN compat_should_convert_params SET DEFAULT FALSE").Error; err != nil { + return err + } + } + return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running add_whitelisted_routes_json_column migration: %s", err.Error()) + return fmt.Errorf("error running default_compat_should_convert_params_false migration: %s", err.Error()) } return nil } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index a0a82912cb..6631da5b1c 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -137,7 +137,10 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, AllowedHeaders: config.AllowedHeaders, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + CompatConvertTextToChat: config.Compat.ConvertTextToChat, + CompatConvertChatToResponses: config.Compat.ConvertChatToResponses, + CompatShouldDropParams: config.Compat.ShouldDropParams, + CompatShouldConvertParams: config.Compat.ShouldConvertParams, MCPAgentDepth: config.MCPAgentDepth, MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, @@ -289,21 +292,26 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er return nil, err } return &ClientConfig{ - DropExcessRequests: dbConfig.DropExcessRequests, - InitialPoolSize: dbConfig.InitialPoolSize, - PrometheusLabels: dbConfig.PrometheusLabels, - EnableLogging: dbConfig.EnableLogging, - DisableContentLogging: dbConfig.DisableContentLogging, - DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, - LogRetentionDays: dbConfig.LogRetentionDays, - EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, - EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, - EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, - AllowDirectKeys: dbConfig.AllowDirectKeys, - AllowedOrigins: dbConfig.AllowedOrigins, - AllowedHeaders: dbConfig.AllowedHeaders, - MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + DropExcessRequests: dbConfig.DropExcessRequests, + InitialPoolSize: dbConfig.InitialPoolSize, + PrometheusLabels: dbConfig.PrometheusLabels, + EnableLogging: dbConfig.EnableLogging, + DisableContentLogging: dbConfig.DisableContentLogging, + DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, + LogRetentionDays: dbConfig.LogRetentionDays, + EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, + EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, + EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, + AllowDirectKeys: dbConfig.AllowDirectKeys, + AllowedOrigins: dbConfig.AllowedOrigins, + AllowedHeaders: dbConfig.AllowedHeaders, + MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, + Compat: CompatConfig{ + ConvertTextToChat: dbConfig.CompatConvertTextToChat, + ConvertChatToResponses: dbConfig.CompatConvertChatToResponses, + ShouldDropParams: dbConfig.CompatShouldDropParams, + ShouldConvertParams: dbConfig.CompatShouldConvertParams, + }, MCPAgentDepth: dbConfig.MCPAgentDepth, MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, @@ -4461,4 +4469,4 @@ func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.C } s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) return nil -} +} \ No newline at end of file diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index a9ff7fc7f6..7dafc96f8e 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -37,8 +37,11 @@ type TableClientConfig struct { RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - // LiteLLM fallback flag - EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` + // Compat plugin feature flags + CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"` + CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"` + CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"` + CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"` // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/go.mod b/framework/go.mod index b4a2a0d0c3..e2c1149c95 100644 --- a/framework/go.mod +++ b/framework/go.mod @@ -9,6 +9,7 @@ require ( github.com/qdrant/go-client v1.16.2 github.com/redis/go-redis/v9 v9.17.2 github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 github.com/weaviate/weaviate v1.36.5 github.com/weaviate/weaviate-go-client/v5 v5.7.1 golang.org/x/crypto v0.49.0 @@ -54,7 +55,6 @@ require ( github.com/kylelemons/godebug v1.1.0 // indirect github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect - github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect diff --git a/framework/go.sum b/framework/go.sum index ed13db0550..e75ab122ec 100644 --- a/framework/go.sum +++ b/framework/go.sum @@ -22,6 +22,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -265,10 +266,15 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -292,9 +298,13 @@ golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 64eb494b53..8fc0a8056b 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -3,12 +3,12 @@ package modelcatalog import ( "context" + "encoding/json" "fmt" + "slices" "sync" "time" - "encoding/json" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" @@ -46,6 +46,14 @@ type ModelCatalog struct { unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering baseModelIndex map[string]string // model string → canonical base model name + // Pre-parsed supported response types index (keyed by model name) + // Values are normalized response types: "chat_completion", "responses", "text_completion" + supportedResponseTypes map[string][]string + + // Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters) + // Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools") + supportedParams map[string][]string + // Background sync worker syncTicker *time.Ticker done chan struct{} @@ -75,6 +83,8 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto modelPool: make(map[schemas.ModelProvider][]string), unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), + supportedResponseTypes: make(map[string][]string), + supportedParams: make(map[string][]string), done: make(chan struct{}), distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)), } @@ -332,6 +342,30 @@ func (mc *ModelCatalog) getPricingURL() string { return mc.pricingURL } +// IsRequestTypeSupported checks if a model supports chat completion. +// It checks the supportedResponseTypes index. +func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool { + mc.mu.RLock() + defer mc.mu.RUnlock() + outputs, ok := mc.supportedResponseTypes[model] + return ok && slices.Contains(outputs, string(requestType)) +} + +// GetSupportedParameters returns the list of supported parameter names for a model. +// Returns nil if the model is not found in the catalog. +func (mc *ModelCatalog) GetSupportedParameters(model string) []string { + mc.mu.RLock() + params, ok := mc.supportedParams[model] + mc.mu.RUnlock() + if !ok { + return nil + } + // Return a copy to prevent external modification + result := make([]string, len(params)) + copy(result, params) + return result +} + // populateModelPool populates the model pool with all available models per provider (thread-safe) func (mc *ModelCatalog) populateModelPoolFromPricingData() { // Acquire write lock for the entire rebuild operation @@ -409,10 +443,12 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { baseModelIndex = make(map[string]string) } return &ModelCatalog{ - modelPool: make(map[schemas.ModelProvider][]string), - unfilteredModelPool: make(map[schemas.ModelProvider][]string), - baseModelIndex: baseModelIndex, - pricingData: make(map[string]configstoreTables.TableModelPricing), - done: make(chan struct{}), + modelPool: make(map[schemas.ModelProvider][]string), + unfilteredModelPool: make(map[schemas.ModelProvider][]string), + baseModelIndex: baseModelIndex, + pricingData: make(map[string]configstoreTables.TableModelPricing), + supportedResponseTypes: make(map[string][]string), + supportedParams: make(map[string][]string), + done: make(chan struct{}), } -} +} \ No newline at end of file diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 661d38066e..166178cd1c 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -6,11 +6,14 @@ import ( "fmt" "io" "net/http" + "slices" "sync" "time" providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/tidwall/gjson" "gorm.io/gorm" ) @@ -68,7 +71,6 @@ func (mc *ModelCatalog) syncPricing(ctx context.Context) error { return nil }) - if err != nil { return fmt.Errorf("failed to sync pricing data to database: %w", err) } @@ -212,7 +214,7 @@ func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (in for _, row := range rows { paramsData[row.Model] = json.RawMessage(row.Data) } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows)) return len(rows), nil } @@ -323,16 +325,71 @@ func (mc *ModelCatalog) syncWorker(ctx context.Context) { // --- Model Parameters sync --- -func applyModelParametersToProviderCache(paramsData map[string]json.RawMessage) { +func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) { modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData)) + newResponseTypes := make(map[string][]string, len(paramsData)) + newParamsIndex := make(map[string][]string, len(paramsData)) + for model, rawData := range paramsData { + var parsed modelParametersParseResult + if err := json.Unmarshal(rawData, &parsed); err != nil { + mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err) + continue + } + + outputs := make([]string, 0, len(parsed.SupportedEndpoints)) + for _, endpoint := range parsed.SupportedEndpoints { + if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } + } + + if parsed.Mode != nil { + if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } + } + + if !slices.Contains(outputs, "text_completion") { + provider := gjson.GetBytes(rawData, "provider") + if provider.Exists() { + key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest)) + + mc.mu.RLock() + _, ok := mc.pricingData[key] + mc.mu.RUnlock() + if ok { + outputs = append(outputs, "text_completion") + } + } + } + + if len(outputs) > 0 { + newResponseTypes[model] = outputs + } + + supported := extractSupportedParams(&parsed) + if len(supported) > 0 { + newParamsIndex[model] = supported + } + var p struct { MaxOutputTokens *int `json:"max_output_tokens"` } - if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil { + if p.MaxOutputTokens == nil { + if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil { + modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens} + } + } else { modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens} } } + + mc.mu.Lock() + mc.supportedResponseTypes = newResponseTypes + mc.supportedParams = newParamsIndex + mc.mu.Unlock() + if len(modelParamsEntries) > 0 { providerUtils.BulkSetModelParams(modelParamsEntries) } @@ -347,7 +404,7 @@ func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context if err != nil { return fmt.Errorf("failed to load model parameters from URL: %w", err) } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) return nil } @@ -394,7 +451,7 @@ func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error { } } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) mc.logger.Info("successfully synced %d model parameters records", len(paramsData)) return nil @@ -431,4 +488,4 @@ func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[str mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData)) return paramsData, nil -} +} \ No newline at end of file diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index 85c9977234..3ffe956b4c 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -2,6 +2,7 @@ package modelcatalog import ( "context" + "slices" "strings" "time" @@ -310,3 +311,95 @@ func convertTablePricingOverrideToPricingOverride(override *configstoreTables.Ta Options: options, }, nil } + +// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type. +// Returns empty string for unrecognized endpoints. +func normalizeEndpointToOutputType(endpoint string) string { + switch { + case strings.Contains(endpoint, "/chat/completions"): + return "chat_completion" + case strings.Contains(endpoint, "/responses"): + return "responses" + case strings.Contains(endpoint, "/completions"): + return "text_completion" + default: + return "" + } +} + +// normalizeModeToOutputType converts mode to a normalized output type. +func normalizeModeToOutputType(mode string) string { + switch mode { + case "chat": + return "chat_completion" + case "completion": + return "text_completion" + case "responses": + return "responses" + default: + return "" + } +} + +// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex. +type modelParametersParseResult struct { + Mode *string `json:"mode,omitempty"` + SupportedEndpoints []string `json:"supported_endpoints,omitempty"` + ModelParameters []struct { + ID string `json:"id"` + } `json:"model_parameters,omitempty"` + SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"` + SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"` + SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"` + SupportsReasoning *bool `json:"supports_reasoning,omitempty"` + SupportsServiceTier *bool `json:"supports_service_tier,omitempty"` + SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"` +} + +// extractSupportedParams builds a list of supported OpenAI-compatible parameter +// names from model_parameters[].id values and supports_* boolean flags. +func extractSupportedParams(parsed *modelParametersParseResult) []string { + var supported []string + addParam := func(name string) { + if !slices.Contains(supported, name) { + supported = append(supported, name) + } + } + + // From model_parameters[].id — map IDs to request param names + for _, mp := range parsed.ModelParameters { + switch mp.ID { + case "reasoning_effort", "reasoning_summary": + addParam("reasoning") + case "web_search": + addParam("web_search_options") + case "promptTools", "image_detail", "stream": + // skip — not top-level request parameters + default: + addParam(mp.ID) + } + } + + // From supports_* boolean flags + if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling { + addParam("tools") + } + if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling { + addParam("parallel_tool_calls") + } + if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice { + addParam("tool_choice") + } + if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning { + addParam("reasoning") + } + if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier { + addParam("service_tier") + } + if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching { + addParam("prompt_cache_key") + addParam("prompt_cache_retention") + } + + return supported +} diff --git a/helm-charts/bifrost/templates/_helpers.tpl b/helm-charts/bifrost/templates/_helpers.tpl index 8dc0606658..97e7b7e4f4 100644 --- a/helm-charts/bifrost/templates/_helpers.tpl +++ b/helm-charts/bifrost/templates/_helpers.tpl @@ -227,8 +227,21 @@ false {{- if .Values.bifrost.client.maxRequestBodySizeMb }} {{- $_ := set $client "max_request_body_size_mb" .Values.bifrost.client.maxRequestBodySizeMb }} {{- end }} -{{- if hasKey .Values.bifrost.client "enableLitellmFallbacks" }} -{{- $_ := set $client "enable_litellm_fallbacks" .Values.bifrost.client.enableLitellmFallbacks }} +{{- if .Values.bifrost.client.compat }} +{{- $compat := dict }} +{{- if hasKey .Values.bifrost.client.compat "convertTextToChat" }} +{{- $_ := set $compat "convert_text_to_chat" .Values.bifrost.client.compat.convertTextToChat }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "convertChatToResponses" }} +{{- $_ := set $compat "convert_chat_to_responses" .Values.bifrost.client.compat.convertChatToResponses }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldDropParams" }} +{{- $_ := set $compat "should_drop_params" .Values.bifrost.client.compat.shouldDropParams }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldConvertParams" }} +{{- $_ := set $compat "should_convert_params" .Values.bifrost.client.compat.shouldConvertParams }} +{{- end }} +{{- $_ := set $client "compat" $compat }} {{- end }} {{- if .Values.bifrost.client.prometheusLabels }} {{- $_ := set $client "prometheus_labels" .Values.bifrost.client.prometheusLabels }} diff --git a/helm-charts/bifrost/values.schema.json b/helm-charts/bifrost/values.schema.json index 495e9c8e79..62239bae54 100644 --- a/helm-charts/bifrost/values.schema.json +++ b/helm-charts/bifrost/values.schema.json @@ -293,8 +293,15 @@ "type": "integer", "minimum": 1 }, - "enableLitellmFallbacks": { - "type": "boolean" + "compat": { + "type": "object", + "additionalProperties": false, + "properties": { + "convertTextToChat": { "type": "boolean" }, + "convertChatToResponses": { "type": "boolean" }, + "shouldDropParams": { "type": "boolean" }, + "shouldConvertParams": { "type": "boolean" } + } }, "prometheusLabels": { "type": "array", @@ -3163,4 +3170,4 @@ "additionalProperties": false } } -} +} \ No newline at end of file diff --git a/helm-charts/bifrost/values.yaml b/helm-charts/bifrost/values.yaml index 8509859f01..10d1078b0d 100644 --- a/helm-charts/bifrost/values.yaml +++ b/helm-charts/bifrost/values.yaml @@ -188,7 +188,11 @@ bifrost: enforceGovernanceHeader: false allowDirectKeys: false maxRequestBodySizeMb: 100 - enableLitellmFallbacks: false + compat: + convertTextToChat: false + convertChatToResponses: false + shouldDropParams: false + shouldConvertParams: false prometheusLabels: [] # Header filtering configuration for x-bf-eh-* headers forwarded to LLM providers headerFilterConfig: diff --git a/nix/packages/bifrost-http.nix b/nix/packages/bifrost-http.nix index f0f3b16ea1..0d05dd1e59 100644 --- a/nix/packages/bifrost-http.nix +++ b/nix/packages/bifrost-http.nix @@ -20,7 +20,7 @@ let replace github.com/maximhq/bifrost/core => ../core replace github.com/maximhq/bifrost/framework => ../framework replace github.com/maximhq/bifrost/plugins/governance => ../plugins/governance - replace github.com/maximhq/bifrost/plugins/litellmcompat => ../plugins/litellmcompat + replace github.com/maximhq/bifrost/plugins/compat => ../plugins/compat replace github.com/maximhq/bifrost/plugins/logging => ../plugins/logging replace github.com/maximhq/bifrost/plugins/maxim => ../plugins/maxim replace github.com/maximhq/bifrost/plugins/otel => ../plugins/otel diff --git a/plugins/compat/changelog.md b/plugins/compat/changelog.md new file mode 100644 index 0000000000..ad3d633b71 --- /dev/null +++ b/plugins/compat/changelog.md @@ -0,0 +1,2 @@ +- feat: Adds option for converting chat completions to responses for models that support it +- feat: Adds option for dropping unsupported model parameters \ No newline at end of file diff --git a/plugins/compat/conversion.go b/plugins/compat/conversion.go new file mode 100644 index 0000000000..176f1b1389 --- /dev/null +++ b/plugins/compat/conversion.go @@ -0,0 +1,25 @@ +package compat + +import "github.com/maximhq/bifrost/core/schemas" + +// applyParameterConversion rewrites request fields in place for provider compatibility. +func applyParameterConversion(req *schemas.BifrostRequest) { + if req == nil { + return + } + + if req.ChatRequest != nil { + normalizeDeveloperRoleForChatRequest(req.ChatRequest) + } +} + +func normalizeDeveloperRoleForChatRequest(req *schemas.BifrostChatRequest) { + if req.Provider != schemas.Bedrock && req.Provider != schemas.Vertex && req.Provider != schemas.Gemini { + return + } + for i := range req.Input { + if req.Input[i].Role == schemas.ChatMessageRoleDeveloper { + req.Input[i].Role = schemas.ChatMessageRoleSystem + } + } +} diff --git a/plugins/compat/dropparams.go b/plugins/compat/dropparams.go new file mode 100644 index 0000000000..fcf79a2df1 --- /dev/null +++ b/plugins/compat/dropparams.go @@ -0,0 +1,218 @@ +package compat + +import "github.com/maximhq/bifrost/core/schemas" + +// dropUnsupportedParams removes unsupported model parameters from a request in place. +func dropUnsupportedParams(req *schemas.BifrostRequest, supportedParams []string) []string { + if req == nil { + return nil + } + + isSupported := make(map[string]bool, len(supportedParams)) + for _, param := range supportedParams { + isSupported[param] = true + } + + var dropped []string + + if req.ChatRequest != nil && req.ChatRequest.Params != nil { + params := req.ChatRequest.Params + + if params.Audio != nil && !isSupported["audio"] { + params.Audio = nil + dropped = append(dropped, "audio") + } + if params.FrequencyPenalty != nil && !isSupported["frequency_penalty"] { + params.FrequencyPenalty = nil + dropped = append(dropped, "frequency_penalty") + } + if params.LogitBias != nil && !isSupported["logit_bias"] { + params.LogitBias = nil + dropped = append(dropped, "logit_bias") + } + if params.LogProbs != nil && !isSupported["logprobs"] { + params.LogProbs = nil + dropped = append(dropped, "logprobs") + } + if params.MaxCompletionTokens != nil && !isSupported["max_completion_tokens"] { + params.MaxCompletionTokens = nil + dropped = append(dropped, "max_completion_tokens") + } + if params.Metadata != nil && !isSupported["metadata"] { + params.Metadata = nil + dropped = append(dropped, "metadata") + } + if params.ParallelToolCalls != nil && !isSupported["parallel_tool_calls"] { + params.ParallelToolCalls = nil + dropped = append(dropped, "parallel_tool_calls") + } + if params.Prediction != nil && !isSupported["prediction"] { + params.Prediction = nil + dropped = append(dropped, "prediction") + } + if params.PresencePenalty != nil && !isSupported["presence_penalty"] { + params.PresencePenalty = nil + dropped = append(dropped, "presence_penalty") + } + if params.PromptCacheKey != nil && !isSupported["prompt_cache_key"] { + params.PromptCacheKey = nil + dropped = append(dropped, "prompt_cache_key") + } + if params.PromptCacheRetention != nil && !isSupported["prompt_cache_retention"] { + params.PromptCacheRetention = nil + dropped = append(dropped, "prompt_cache_retention") + } + if params.Reasoning != nil && !isSupported["reasoning"] { + params.Reasoning = nil + dropped = append(dropped, "reasoning") + } + if params.ResponseFormat != nil && !isSupported["response_format"] { + params.ResponseFormat = nil + dropped = append(dropped, "response_format") + } + if params.Seed != nil && !isSupported["seed"] { + params.Seed = nil + dropped = append(dropped, "seed") + } + if params.ServiceTier != nil && !isSupported["service_tier"] { + params.ServiceTier = nil + dropped = append(dropped, "service_tier") + } + if len(params.Stop) > 0 && !isSupported["stop"] { + params.Stop = nil + dropped = append(dropped, "stop") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.TopLogProbs != nil && !isSupported["top_logprobs"] { + params.TopLogProbs = nil + dropped = append(dropped, "top_logprobs") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + if params.ToolChoice != nil && !isSupported["tool_choice"] { + params.ToolChoice = nil + dropped = append(dropped, "tool_choice") + } + if len(params.Tools) > 0 && !isSupported["tools"] { + params.Tools = nil + dropped = append(dropped, "tools") + } + if params.Verbosity != nil && !isSupported["verbosity"] { + params.Verbosity = nil + dropped = append(dropped, "verbosity") + } + if params.WebSearchOptions != nil && !isSupported["web_search_options"] { + params.WebSearchOptions = nil + dropped = append(dropped, "web_search_options") + } + } + + if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { + params := req.ResponsesRequest.Params + + if params.MaxOutputTokens != nil && !isSupported["max_output_tokens"] { + params.MaxOutputTokens = nil + dropped = append(dropped, "max_output_tokens") + } + if params.MaxToolCalls != nil && !isSupported["max_tool_calls"] { + params.MaxToolCalls = nil + dropped = append(dropped, "max_tool_calls") + } + if params.Metadata != nil && !isSupported["metadata"] { + params.Metadata = nil + dropped = append(dropped, "metadata") + } + if params.ParallelToolCalls != nil && !isSupported["parallel_tool_calls"] { + params.ParallelToolCalls = nil + dropped = append(dropped, "parallel_tool_calls") + } + if params.PromptCacheKey != nil && !isSupported["prompt_cache_key"] { + params.PromptCacheKey = nil + dropped = append(dropped, "prompt_cache_key") + } + if params.Reasoning != nil && !isSupported["reasoning"] { + params.Reasoning = nil + dropped = append(dropped, "reasoning") + } + if params.ServiceTier != nil && !isSupported["service_tier"] { + params.ServiceTier = nil + dropped = append(dropped, "service_tier") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.Text != nil && !isSupported["text"] { + params.Text = nil + dropped = append(dropped, "text") + } + if params.TopLogProbs != nil && !isSupported["top_logprobs"] { + params.TopLogProbs = nil + dropped = append(dropped, "top_logprobs") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + if params.ToolChoice != nil && !isSupported["tool_choice"] { + params.ToolChoice = nil + dropped = append(dropped, "tool_choice") + } + if len(params.Tools) > 0 && !isSupported["tools"] { + params.Tools = nil + dropped = append(dropped, "tools") + } + } + + if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { + params := req.TextCompletionRequest.Params + + if params.FrequencyPenalty != nil && !isSupported["frequency_penalty"] { + params.FrequencyPenalty = nil + dropped = append(dropped, "frequency_penalty") + } + if params.LogitBias != nil && !isSupported["logit_bias"] { + params.LogitBias = nil + dropped = append(dropped, "logit_bias") + } + if params.LogProbs != nil && !isSupported["logprobs"] { + params.LogProbs = nil + dropped = append(dropped, "logprobs") + } + if params.MaxTokens != nil && !isSupported["max_tokens"] { + params.MaxTokens = nil + dropped = append(dropped, "max_tokens") + } + if params.N != nil && !isSupported["n"] { + params.N = nil + dropped = append(dropped, "n") + } + if params.PresencePenalty != nil && !isSupported["presence_penalty"] { + params.PresencePenalty = nil + dropped = append(dropped, "presence_penalty") + } + if params.Seed != nil && !isSupported["seed"] { + params.Seed = nil + dropped = append(dropped, "seed") + } + if len(params.Stop) > 0 && !isSupported["stop"] { + params.Stop = nil + dropped = append(dropped, "stop") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + } + + return dropped +} diff --git a/plugins/litellmcompat/go.mod b/plugins/compat/go.mod similarity index 99% rename from plugins/litellmcompat/go.mod rename to plugins/compat/go.mod index f4afc7b154..7c282948fc 100644 --- a/plugins/litellmcompat/go.mod +++ b/plugins/compat/go.mod @@ -1,4 +1,4 @@ -module github.com/maximhq/bifrost/plugins/litellmcompat +module github.com/maximhq/bifrost/plugins/compat go 1.26.2 diff --git a/plugins/litellmcompat/go.sum b/plugins/compat/go.sum similarity index 96% rename from plugins/litellmcompat/go.sum rename to plugins/compat/go.sum index d231e3c8df..2bc5801fc0 100644 --- a/plugins/litellmcompat/go.sum +++ b/plugins/compat/go.sum @@ -22,6 +22,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -267,10 +268,15 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -294,9 +300,13 @@ golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/plugins/compat/main.go b/plugins/compat/main.go new file mode 100644 index 0000000000..6c536d6bd6 --- /dev/null +++ b/plugins/compat/main.go @@ -0,0 +1,169 @@ +// Package compat provides LiteLLM-compatible request normalization for the +// Bifrost gateway. It drops unsupported model params first, then rewrites +// requests to a compatible endpoint type when the target model does not support +// the caller's original request type. +package compat + +import ( + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" +) + +const PluginName = "compat" + +// Config defines the configuration for the compat plugin. +type Config struct { + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + +// IsEnabled returns true if any compat feature is enabled +func (c Config) IsEnabled() bool { + return c.ConvertTextToChat || c.ConvertChatToResponses || c.ShouldDropParams || c.ShouldConvertParams +} + +// CompatPlugin provides LiteLLM-compatible request/response transformations. +// When enabled, it automatically converts text completion requests to chat +// completion requests for models that only support chat completions, matching +// LiteLLM's behavior. It also converts chat completion requests to responses +// for models that only support the responses endpoint. +type CompatPlugin struct { + config Config + logger schemas.Logger + modelCatalog *modelcatalog.ModelCatalog + droppedParams []string +} + +// Init creates a new compat plugin instance with model catalog support. +// The model catalog is used to determine if a model supports text completion or +// chat completion natively. If the model catalog is nil, the plugin will +// convert all text completion requests to chat completion and all chat +// completion requests to responses. +func Init(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*CompatPlugin, error) { + return &CompatPlugin{ + config: config, + logger: logger, + modelCatalog: mc, + }, nil +} + +// GetName returns the plugin name +func (p *CompatPlugin) GetName() string { + return PluginName +} + +// HTTPTransportPreHook is not used for this plugin +func (p *CompatPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil +} + +// HTTPTransportPostHook is not used for this plugin +func (p *CompatPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { + return nil +} + +// HTTPTransportStreamChunkHook passes through streaming chunks unchanged. +func (p *CompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { + return chunk, nil +} + +// PreLLMHook intercepts requests and applies LiteLLM-compatible request normalization. +func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + if ctx == nil || req == nil { + return req, nil, nil + } + + convertTextToChatOverride, convertTextToChatOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatConvertTextToChat).(bool) + convertChatToResponsesOverride, convertChatToResponsesOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatConvertChatToResponses).(bool) + shouldDropParamsOverride, shouldDropParamsOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatShouldDropParams).(bool) + shouldConvertParamsOverride, shouldConvertParamsOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatShouldConvertParams).(bool) + + modifiedReq := req + if (shouldDropParamsOverrideEnabled && shouldDropParamsOverride) || (shouldConvertParamsOverrideEnabled && shouldDropParamsOverride) || p.config.ShouldConvertParams || p.config.ShouldDropParams { + modifiedReq = cloneBifrostReq(req) + } + p.droppedParams = nil + + // Text completion → chat conversion + if (convertTextToChatOverrideEnabled && convertTextToChatOverride) || p.config.ConvertTextToChat { + if (modifiedReq.RequestType == schemas.TextCompletionRequest || modifiedReq.RequestType == schemas.TextCompletionStreamRequest) && modifiedReq.TextCompletionRequest != nil { + p.markForConversion(ctx, modifiedReq.TextCompletionRequest.Provider, modifiedReq.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + } + } + + // Chat completion → responses conversion + if (convertChatToResponsesOverrideEnabled && convertChatToResponsesOverride) || p.config.ConvertChatToResponses { + if (modifiedReq.RequestType == schemas.ChatCompletionRequest || modifiedReq.RequestType == schemas.ChatCompletionStreamRequest) && modifiedReq.ChatRequest != nil { + p.markForConversion(ctx, modifiedReq.ChatRequest.Provider, modifiedReq.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + } + } + + // Compute unsupported parameters to drop based on model catalog allowlist + if ((shouldDropParamsOverrideEnabled && shouldDropParamsOverride) || p.config.ShouldDropParams) && p.modelCatalog != nil { + _, model, _ := modifiedReq.GetRequestFields() + if model != "" { + if supportedParams := p.modelCatalog.GetSupportedParameters(model); supportedParams != nil { + droppedParams := dropUnsupportedParams(modifiedReq, supportedParams) + if len(droppedParams) > 0 { + p.droppedParams = droppedParams + } + } + } + } + + if (shouldConvertParamsOverride && shouldConvertParamsOverrideEnabled) || p.config.ShouldConvertParams { + applyParameterConversion(modifiedReq) + } + + return modifiedReq, nil, nil +} + +// PostLLMHook converts provider responses back to the caller-facing shape +func (p *CompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if ctx == nil { + return result, bifrostErr, nil + } + + if changeType, ok := ctx.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok { + if result != nil { + extraFields := result.GetExtraFields() + if extraFields != nil { + extraFields.ConvertedRequestType = changeType + } + } + if bifrostErr != nil { + bifrostErr.ExtraFields.ConvertedRequestType = changeType + } + } + + if result != nil { + if extraFields := result.GetExtraFields(); extraFields != nil { + extraFields.DroppedCompatPluginParams = p.droppedParams + } + } + + return result, bifrostErr, nil +} + +// Cleanup performs plugin cleanup. +func (p *CompatPlugin) Cleanup() error { + return nil +} + +// markForConversion checks if the model supports the current request type; if not, mark for conversion +func (p *CompatPlugin) markForConversion(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, currentType schemas.RequestType, targetType schemas.RequestType) { + shouldConvert := false + if p.modelCatalog != nil { + if !p.modelCatalog.IsRequestTypeSupported(model, provider, currentType) && p.modelCatalog.IsRequestTypeSupported(model, provider, targetType) { + shouldConvert = true + } + } else { + p.logger.Debug("compat: model calalog is nil") + } + + if shouldConvert { + ctx.SetValue(schemas.BifrostContextKeyChangeRequestType, targetType) + } +} diff --git a/plugins/compat/requestcopy.go b/plugins/compat/requestcopy.go new file mode 100644 index 0000000000..a92e9a81b8 --- /dev/null +++ b/plugins/compat/requestcopy.go @@ -0,0 +1,352 @@ +package compat + +import ( + "maps" + "slices" + + "github.com/maximhq/bifrost/core/schemas" +) + +func cloneBifrostReq(req *schemas.BifrostRequest) *schemas.BifrostRequest { + if req == nil { + return nil + } + + cloned := *req + + if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { + cloned.TextCompletionRequest.Params = cloneTextCompletionParameters(req.TextCompletionRequest.Params) + } + if req.ChatRequest != nil && req.ChatRequest.Params != nil { + cloned.ChatRequest.Params = cloneChatParameters(req.ChatRequest.Params) + } + if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { + cloned.ResponsesRequest.Params = cloneResponsesParameters(req.ResponsesRequest.Params) + } + + return &cloned +} + +func cloneTextCompletionParameters(params *schemas.TextCompletionParameters) *schemas.TextCompletionParameters { + if params == nil { + return nil + } + cloned := *params + if params.LogitBias != nil { + logitBias := cloneStringFloat64Map(*params.LogitBias) + cloned.LogitBias = &logitBias + } + if params.Stop != nil { + cloned.Stop = slices.Clone(params.Stop) + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneChatParameters(params *schemas.ChatParameters) *schemas.ChatParameters { + if params == nil { + return nil + } + + cloned := *params + if params.Audio != nil { + audio := *params.Audio + cloned.Audio = &audio + } + if params.LogitBias != nil { + logitBias := cloneStringFloat64Map(*params.LogitBias) + cloned.LogitBias = &logitBias + } + if params.Metadata != nil { + metadata := cloneAnyMap(*params.Metadata) + cloned.Metadata = &metadata + } + if params.Modalities != nil { + cloned.Modalities = slices.Clone(params.Modalities) + } + if params.Prediction != nil { + prediction := *params.Prediction + prediction.Content = cloneAnyValue(params.Prediction.Content) + cloned.Prediction = &prediction + } + if params.Reasoning != nil { + reasoning := *params.Reasoning + cloned.Reasoning = &reasoning + } + if params.ResponseFormat != nil { + responseFormat := cloneAnyValue(*params.ResponseFormat) + cloned.ResponseFormat = &responseFormat + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.Stop != nil { + cloned.Stop = slices.Clone(params.Stop) + } + if params.ToolChoice != nil { + cloned.ToolChoice = cloneChatToolChoice(params.ToolChoice) + } + if params.Tools != nil { + cloned.Tools = make([]schemas.ChatTool, len(params.Tools)) + for i, tool := range params.Tools { + cloned.Tools[i] = schemas.DeepCopyChatTool(tool) + } + } + if params.WebSearchOptions != nil { + cloned.WebSearchOptions = cloneChatWebSearchOptions(params.WebSearchOptions) + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneChatToolChoice(choice *schemas.ChatToolChoice) *schemas.ChatToolChoice { + if choice == nil { + return nil + } + + cloned := &schemas.ChatToolChoice{} + if choice.ChatToolChoiceStr != nil { + value := *choice.ChatToolChoiceStr + cloned.ChatToolChoiceStr = &value + } + if choice.ChatToolChoiceStruct != nil { + choiceStruct := *choice.ChatToolChoiceStruct + if choice.ChatToolChoiceStruct.Function != nil { + function := *choice.ChatToolChoiceStruct.Function + choiceStruct.Function = &function + } + if choice.ChatToolChoiceStruct.Custom != nil { + custom := *choice.ChatToolChoiceStruct.Custom + choiceStruct.Custom = &custom + } + if choice.ChatToolChoiceStruct.AllowedTools != nil { + allowedTools := *choice.ChatToolChoiceStruct.AllowedTools + allowedTools.Tools = slices.Clone(choice.ChatToolChoiceStruct.AllowedTools.Tools) + choiceStruct.AllowedTools = &allowedTools + } + cloned.ChatToolChoiceStruct = &choiceStruct + } + return cloned +} + +func cloneChatWebSearchOptions(options *schemas.ChatWebSearchOptions) *schemas.ChatWebSearchOptions { + if options == nil { + return nil + } + + cloned := *options + if options.UserLocation != nil { + userLocation := *options.UserLocation + if options.UserLocation.Approximate != nil { + approximate := *options.UserLocation.Approximate + userLocation.Approximate = &approximate + } + cloned.UserLocation = &userLocation + } + return &cloned +} + +func cloneResponsesParameters(params *schemas.ResponsesParameters) *schemas.ResponsesParameters { + if params == nil { + return nil + } + + cloned := *params + if params.Include != nil { + cloned.Include = slices.Clone(params.Include) + } + if params.Metadata != nil { + metadata := cloneAnyMap(*params.Metadata) + cloned.Metadata = &metadata + } + if params.Reasoning != nil { + reasoning := *params.Reasoning + cloned.Reasoning = &reasoning + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.Text != nil { + cloned.Text = cloneResponsesTextConfig(params.Text) + } + if params.ToolChoice != nil { + cloned.ToolChoice = cloneResponsesToolChoice(params.ToolChoice) + } + if params.Tools != nil { + cloned.Tools = make([]schemas.ResponsesTool, len(params.Tools)) + for i, tool := range params.Tools { + cloned.Tools[i] = cloneResponsesTool(tool) + } + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneResponsesTextConfig(text *schemas.ResponsesTextConfig) *schemas.ResponsesTextConfig { + if text == nil { + return nil + } + + cloned := *text + if text.Format != nil { + format := *text.Format + if text.Format.JSONSchema != nil { + jsonSchema := *text.Format.JSONSchema + if text.Format.JSONSchema.Schema != nil { + schema := cloneAnyValue(*text.Format.JSONSchema.Schema) + jsonSchema.Schema = &schema + } + if text.Format.JSONSchema.Properties != nil { + properties := cloneAnyMap(*text.Format.JSONSchema.Properties) + jsonSchema.Properties = &properties + } + if text.Format.JSONSchema.Required != nil { + jsonSchema.Required = slices.Clone(text.Format.JSONSchema.Required) + } + if text.Format.JSONSchema.Defs != nil { + defs := cloneAnyMap(*text.Format.JSONSchema.Defs) + jsonSchema.Defs = &defs + } + if text.Format.JSONSchema.Definitions != nil { + definitions := cloneAnyMap(*text.Format.JSONSchema.Definitions) + jsonSchema.Definitions = &definitions + } + if text.Format.JSONSchema.Items != nil { + items := cloneAnyMap(*text.Format.JSONSchema.Items) + jsonSchema.Items = &items + } + if text.Format.JSONSchema.AnyOf != nil { + jsonSchema.AnyOf = cloneAnyMapSlice(text.Format.JSONSchema.AnyOf) + } + if text.Format.JSONSchema.OneOf != nil { + jsonSchema.OneOf = cloneAnyMapSlice(text.Format.JSONSchema.OneOf) + } + if text.Format.JSONSchema.AllOf != nil { + jsonSchema.AllOf = cloneAnyMapSlice(text.Format.JSONSchema.AllOf) + } + if text.Format.JSONSchema.Default != nil { + jsonSchema.Default = cloneAnyValue(text.Format.JSONSchema.Default) + } + if text.Format.JSONSchema.Enum != nil { + jsonSchema.Enum = slices.Clone(text.Format.JSONSchema.Enum) + } + if text.Format.JSONSchema.PropertyOrdering != nil { + jsonSchema.PropertyOrdering = slices.Clone(text.Format.JSONSchema.PropertyOrdering) + } + format.JSONSchema = &jsonSchema + } + cloned.Format = &format + } + return &cloned +} + +func cloneResponsesToolChoice(choice *schemas.ResponsesToolChoice) *schemas.ResponsesToolChoice { + if choice == nil { + return nil + } + + cloned := &schemas.ResponsesToolChoice{} + if choice.ResponsesToolChoiceStr != nil { + value := *choice.ResponsesToolChoiceStr + cloned.ResponsesToolChoiceStr = &value + } + if choice.ResponsesToolChoiceStruct != nil { + choiceStruct := *choice.ResponsesToolChoiceStruct + if choice.ResponsesToolChoiceStruct.Tools != nil { + choiceStruct.Tools = slices.Clone(choice.ResponsesToolChoiceStruct.Tools) + } + cloned.ResponsesToolChoiceStruct = &choiceStruct + } + return cloned +} + +func cloneResponsesTool(tool schemas.ResponsesTool) schemas.ResponsesTool { + data, err := schemas.MarshalSorted(tool) + if err != nil { + return tool + } + + var cloned schemas.ResponsesTool + if err := schemas.Unmarshal(data, &cloned); err != nil { + return tool + } + + return cloned +} + +func cloneStringFloat64Map(input map[string]float64) map[string]float64 { + if input == nil { + return nil + } + + cloned := make(map[string]float64, len(input)) + maps.Copy(cloned, input) + return cloned +} + +func cloneAnyMap(input map[string]any) map[string]any { + if input == nil { + return nil + } + + cloned := make(map[string]any, len(input)) + for key, value := range input { + cloned[key] = cloneAnyValue(value) + } + return cloned +} + +func cloneAnyMapSlice(input []map[string]any) []map[string]any { + if input == nil { + return nil + } + + cloned := make([]map[string]any, len(input)) + for i, value := range input { + cloned[i] = cloneAnyMap(value) + } + return cloned +} + +func cloneAnySlice(input []any) []any { + if input == nil { + return nil + } + + cloned := make([]any, len(input)) + for i, value := range input { + cloned[i] = cloneAnyValue(value) + } + return cloned +} + +func cloneAnyValue(value any) any { + switch typed := value.(type) { + case nil: + return nil + case map[string]any: + return cloneAnyMap(typed) + case []any: + return cloneAnySlice(typed) + case []string: + return slices.Clone(typed) + case map[string]string: + cloned := make(map[string]string, len(typed)) + maps.Copy(cloned, typed) + return cloned + default: + return typed + } +} diff --git a/plugins/compat/version b/plugins/compat/version new file mode 100644 index 0000000000..6c6aa7cb09 --- /dev/null +++ b/plugins/compat/version @@ -0,0 +1 @@ +0.1.0 \ No newline at end of file diff --git a/plugins/litellmcompat/context.go b/plugins/litellmcompat/context.go deleted file mode 100644 index 1bcc79b99c..0000000000 --- a/plugins/litellmcompat/context.go +++ /dev/null @@ -1,20 +0,0 @@ -package litellmcompat - -import "github.com/maximhq/bifrost/core/schemas" - -// TransformContextKey is the key used to store TransformContext in BifrostContext -const TransformContextKey schemas.BifrostContextKey = "litellmcompat-transform-context" - -// TransformContext tracks what transformations were applied to a request -// so they can be reversed on the response -type TransformContext struct { - // Text-to-chat transform state - // TextToChatApplied indicates that a text completion request was converted to chat - TextToChatApplied bool - // OriginalRequestType stores the original request type before transformation - OriginalRequestType schemas.RequestType - // OriginalModel preserves the original model string for response metadata - OriginalModel string - // IsStreaming indicates if the original request was a streaming request - IsStreaming bool -} diff --git a/plugins/litellmcompat/main.go b/plugins/litellmcompat/main.go deleted file mode 100644 index 730d56beef..0000000000 --- a/plugins/litellmcompat/main.go +++ /dev/null @@ -1,128 +0,0 @@ -// Package litellmcompat provides LiteLLM-compatible request/response transformations -// for the Bifrost gateway. It enables automatic conversion of text completion requests -// to chat completion requests for models that only support chat completions, matching -// LiteLLM's behavior. -// -// When enabled, this plugin: -// - Silently converts text_completion() calls to chat completion format -// - Routes to the chat completion endpoint -// - Transforms the response back to text completion format -// - Places content in choices[0].text instead of choices[0].message.content -package litellmcompat - -import ( - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" -) - -const ( - PluginName = "litellmcompat" -) - -// Config defines the configuration for the litellmcompat plugin -type Config struct { - Enabled bool `json:"enabled"` -} - -// LiteLLMCompatPlugin provides LiteLLM-compatible request/response transformations. -// When enabled, it automatically converts text completion requests to chat completion -// requests for models that only support chat completions, matching LiteLLM's behavior. -type LiteLLMCompatPlugin struct { - config Config - logger schemas.Logger - modelCatalog *modelcatalog.ModelCatalog -} - -// Init creates a new litellmcompat plugin instance -func Init(config Config, logger schemas.Logger) (*LiteLLMCompatPlugin, error) { - return &LiteLLMCompatPlugin{ - config: config, - logger: logger, - }, nil -} - -// InitWithModelCatalog creates a new litellmcompat plugin instance with model catalog support. -// The model catalog is used to determine if a model supports text completion natively. -// If the model catalog is nil, the plugin will convert ALL text completion requests. -func InitWithModelCatalog(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*LiteLLMCompatPlugin, error) { - return &LiteLLMCompatPlugin{ - config: config, - logger: logger, - modelCatalog: mc, - }, nil -} - -// SetModelCatalog sets the model catalog for checking text completion support. -// This can be called after initialization to add model catalog support. -func (p *LiteLLMCompatPlugin) SetModelCatalog(mc *modelcatalog.ModelCatalog) { - p.modelCatalog = mc -} - -// GetName returns the plugin name -func (p *LiteLLMCompatPlugin) GetName() string { - return PluginName -} - -// HTTPTransportPreHook is not used for this plugin -func (p *LiteLLMCompatPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { - return nil, nil -} - -// HTTPTransportPostHook is not used for this plugin -func (p *LiteLLMCompatPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { - return nil -} - -// HTTPTransportStreamChunkHook passes through streaming chunks unchanged -func (p *LiteLLMCompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { - return chunk, nil -} - -// PreLLMHook intercepts requests and applies LiteLLM-compatible transformations. -// For text completion requests on models that don't support text completion, -// it converts them to chat completion requests. -func (p *LiteLLMCompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { - tc := &TransformContext{} - - // Apply request transforms in sequence - req = transformTextToChatRequest(ctx, req, tc, p.modelCatalog, p.logger) - - // Store the transform context for use in PostHook - ctx.SetValue(TransformContextKey, tc) - - return req, nil, nil -} - -// PostLLMHook processes responses and applies LiteLLM-compatible transformations. -// If a text completion request was converted to chat, this converts the -// chat response back to text completion format. -func (p *LiteLLMCompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - // Retrieve the transform context - transformCtxValue := ctx.Value(TransformContextKey) - if transformCtxValue == nil { - return result, bifrostErr, nil - } - tc, ok := transformCtxValue.(*TransformContext) - if !ok || tc == nil { - return result, bifrostErr, nil - } - - // Apply response transforms in sequence - // Note: tool-call content runs before text-to-chat because text-to-chat may convert - // the response type, and tool-call content needs to operate on chat responses - if result != nil { - result = transformTextToChatResponse(ctx, result, tc, p.logger) - } - - // Transform error metadata if there's an error - if bifrostErr != nil { - bifrostErr = transformTextToChatError(ctx, bifrostErr, tc) - } - - return result, bifrostErr, nil -} - -// Cleanup performs plugin cleanup -func (p *LiteLLMCompatPlugin) Cleanup() error { - return nil -} diff --git a/plugins/litellmcompat/texttochat.go b/plugins/litellmcompat/texttochat.go deleted file mode 100644 index b0c1b0a309..0000000000 --- a/plugins/litellmcompat/texttochat.go +++ /dev/null @@ -1,151 +0,0 @@ -package litellmcompat - -import ( - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" -) - -// transformTextToChatRequest converts a text completion request to a chat completion request -// if the model doesn't support text completion natively. -// It updates the TransformContext with the transformation state. -func transformTextToChatRequest(_ *schemas.BifrostContext, req *schemas.BifrostRequest, tc *TransformContext, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { - // Only process text completion requests - if req.RequestType != schemas.TextCompletionRequest && req.RequestType != schemas.TextCompletionStreamRequest { - return req - } - - // Check if text completion request is present - if req.TextCompletionRequest == nil || tc == nil { - return req - } - - // Check if the model supports text completion via model catalog - if mc != nil { - provider := req.TextCompletionRequest.Provider - model := req.TextCompletionRequest.Model - if mc.IsTextCompletionSupported(model, provider) { - if logger != nil { - logger.Debug("litellmcompat: model %s/%s supports text completion, skipping conversion", provider, model) - } - return req - } - } - - // Convert text completion to chat completion - chatRequest := req.TextCompletionRequest.ToBifrostChatRequest() - if chatRequest == nil { - return req - } - - // Track the transformation - tc.TextToChatApplied = true - tc.OriginalRequestType = req.RequestType - tc.OriginalModel = req.TextCompletionRequest.Model - tc.IsStreaming = req.RequestType == schemas.TextCompletionStreamRequest - - // Create a new request with the chat completion - transformedReq := &schemas.BifrostRequest{ - ChatRequest: chatRequest, - } - - // Set the appropriate request type - if tc.IsStreaming { - transformedReq.RequestType = schemas.ChatCompletionStreamRequest - } else { - transformedReq.RequestType = schemas.ChatCompletionRequest - } - - if logger != nil { - logger.Debug("litellmcompat: converted text completion to chat completion for model %s (text completion not supported)", tc.OriginalModel) - } - - return transformedReq -} - -// transformTextToChatResponse converts a chat response back to text completion format -// if the original request was a text completion that was converted. -func transformTextToChatResponse(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, tc *TransformContext, logger schemas.Logger) *schemas.BifrostResponse { - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { - return resp - } - - // Check if we have a chat response to transform - if resp == nil || resp.ChatResponse == nil { - return resp - } - - // Convert chat response to text completion response - textCompletionResponse := resp.ChatResponse.ToTextCompletionResponse() - if textCompletionResponse == nil { - return resp - } - - // Restore original request type metadata - textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel - textCompletionResponse.ExtraFields.LiteLLMCompat = true - - if logger != nil { - logger.Debug("litellmcompat: converted chat response back to text completion for model %s", tc.OriginalModel) - } - - // Return a new response with the text completion - return &schemas.BifrostResponse{ - TextCompletionResponse: textCompletionResponse, - } -} - -// transformTextToChatError ensures error metadata reflects the original request type -// if a text completion request was converted to chat. -func transformTextToChatError(_ *schemas.BifrostContext, err *schemas.BifrostError, tc *TransformContext) *schemas.BifrostError { - if tc == nil || err == nil { - return err - } - - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { - return err - } - - // Restore original request type in error metadata - err.ExtraFields.RequestType = tc.OriginalRequestType - err.ExtraFields.OriginalModelRequested = tc.OriginalModel - err.ExtraFields.LiteLLMCompat = true - - return err -} - -// TransformTextToChatStreamResponse transforms a streaming chat response back to text completion format. -// This is exported for use by streaming handlers. -func TransformTextToChatStreamResponse(ctx *schemas.BifrostContext, stream *schemas.BifrostStreamChunk, tc *TransformContext) *schemas.BifrostStreamChunk { - if tc == nil { - return stream - } - - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { - return stream - } - - // Check if we have a chat response in the stream to transform - if stream == nil || stream.BifrostChatResponse == nil { - return stream - } - - // Convert chat response to text completion response - textCompletionResponse := stream.BifrostChatResponse.ToTextCompletionResponse() - if textCompletionResponse == nil { - return stream - } - - // Restore original request type metadata - textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel - textCompletionResponse.ExtraFields.LiteLLMCompat = true - - // Return a new stream with the text completion response - return &schemas.BifrostStreamChunk{ - BifrostTextCompletionResponse: textCompletionResponse, - } -} diff --git a/tests/governance/config.json b/tests/governance/config.json index bd9080a064..b8cedf9a3e 100644 --- a/tests/governance/config.json +++ b/tests/governance/config.json @@ -62,7 +62,6 @@ "enable_logging": true, "enforce_auth_on_inference": true, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/python/config.json b/tests/integrations/python/config.json index 00b89b5bdb..866469cc1d 100644 --- a/tests/integrations/python/config.json +++ b/tests/integrations/python/config.json @@ -343,7 +343,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json index cf49dba281..46bc65af6b 100644 --- a/tests/integrations/typescript/config.json +++ b/tests/integrations/typescript/config.json @@ -220,7 +220,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 4f08f31b12..dee74227c7 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -19,7 +19,7 @@ import ( configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/plugins/litellmcompat" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -88,7 +88,7 @@ func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) { // getConfig handles GET /config - Get the current configuration func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) { - var mapConfig = make(map[string]any) + mapConfig := make(map[string]any) if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" { if h.store.ConfigStore == nil { @@ -342,22 +342,33 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB } - // Handle LiteLLM compat plugin toggle - if payload.ClientConfig.EnableLiteLLMFallbacks != currentConfig.EnableLiteLLMFallbacks { - if payload.ClientConfig.EnableLiteLLMFallbacks { - // Load and register the litellmcompat plugin - if err := h.configManager.ReloadPlugin(ctx, "litellmcompat", nil, &litellmcompat.Config{Enabled: true}, nil, nil); err != nil { - logger.Warn(fmt.Sprintf("failed to load litellmcompat plugin: %v", err)) + // Handle compat plugin toggle + newCompat := payload.ClientConfig.Compat + oldCompat := currentConfig.Compat + if newCompat != oldCompat { + newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams + if newEnabled { + compatCfg := &compat.Config{ + ConvertTextToChat: newCompat.ConvertTextToChat, + ConvertChatToResponses: newCompat.ConvertChatToResponses, + ShouldDropParams: newCompat.ShouldDropParams, + ShouldConvertParams: newCompat.ShouldConvertParams, + } + if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil { + logger.Warn("failed to load compat plugin: %v", err) + SendError(ctx, 400, "Failed to load compat plugin") + return } } else { - // Remove the litellmcompat plugin disabledCtx := context.WithValue(ctx, PluginDisabledKey, true) - if err := h.configManager.RemovePlugin(disabledCtx, "litellmcompat"); err != nil { - logger.Warn("failed to remove litellmcompat plugin: %v", err) + if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil { + logger.Warn("failed to remove compat plugin: %v", err) + SendError(ctx, 400, "Failed to remove compat plugin") + return } } } - updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks + updatedConfig.Compat = newCompat // Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values if payload.ClientConfig.MCPAgentDepth > 0 { updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 2954520ce3..0d54489341 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -36,8 +36,8 @@ import ( "github.com/maximhq/bifrost/framework/oauth2" plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/plugins/governance" - "github.com/maximhq/bifrost/plugins/litellmcompat" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -107,7 +107,7 @@ func IsBuiltinPlugin(name string) bool { name == prompts.PluginName || name == logging.PluginName || name == governance.PluginName || - name == litellmcompat.PluginName || + name == compat.PluginName || name == maxim.PluginName || name == semanticcache.PluginName || name == otel.PluginName @@ -316,7 +316,6 @@ var DefaultClientConfig = configstore.ClientConfig{ MCPAgentDepth: 10, MCPToolExecutionTimeout: 30, MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), - EnableLiteLLMFallbacks: false, HideDeletedVirtualKeysInFilters: false, RoutingChainMaxDepth: governance.DefaultRoutingChainMaxDepth, } @@ -4143,4 +4142,4 @@ func DeepCopy[T any](in T) (T, error) { } err = sonic.Unmarshal(b, &out) return out, err -} +} \ No newline at end of file diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index e5006f8384..9d65e2f436 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -425,6 +425,7 @@ func (m *MockConfigStore) DB() *gorm.DB { retu func (m *MockConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { return fn(nil) } + func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { return nil } @@ -1155,18 +1156,23 @@ func (m *MockConfigStore) DeleteOauthToken(ctx context.Context, id string) error func (m *MockConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } + func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } @@ -1175,18 +1181,23 @@ func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *t func (m *MockConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { return nil } @@ -1195,33 +1206,43 @@ func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, func (m *MockConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } @@ -1229,24 +1250,31 @@ func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tabl func (m *MockConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { return 1, nil } + func (m *MockConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { return nil } + func (m *MockConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { return 1, nil } @@ -1288,12 +1316,15 @@ func (m *MockConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx . func (m *MockConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } + func (m *MockConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } @@ -1303,12 +1334,15 @@ func (m *MockConfigStore) DeleteFolder(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } + func (m *MockConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } @@ -1318,15 +1352,19 @@ func (m *MockConfigStore) DeletePrompt(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error { return nil } @@ -1336,15 +1374,19 @@ func (m *MockConfigStore) DeletePromptVersion(ctx context.Context, id uint) erro func (m *MockConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error { return nil } @@ -12006,6 +12048,7 @@ type mockLLMPlugin struct { func (p *mockLLMPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { return req, nil, nil } + func (p *mockLLMPlugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { return resp, bifrostErr, nil } @@ -12349,7 +12392,6 @@ func TestGenerateClientConfigHash(t *testing.T) { AllowDirectKeys: true, AllowedOrigins: []string{"http://localhost:3000"}, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } hash1, err := cc1.GenerateClientConfigHash() @@ -12446,12 +12488,12 @@ func TestGenerateClientConfigHash(t *testing.T) { t.Error("Different MaxRequestBodySizeMB should produce different hash") } - // Different EnableLiteLLMFallbacks should produce different hash + // Different Compat should produce different hash cc13 := cc1 - cc13.EnableLiteLLMFallbacks = true + cc13.Compat.ConvertTextToChat = true hash13, _ := cc13.GenerateClientConfigHash() if hash1 == hash13 { - t.Error("Different EnableLiteLLMFallbacks should produce different hash") + t.Error("Different Compat.ConvertTextToChat should produce different hash") } // PrometheusLabels order should not matter (sorted) @@ -13484,7 +13526,6 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: false, AllowDirectKeys: true, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } // Generate hash from config @@ -13498,7 +13539,12 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccToSave.EnforceAuthOnInference, AllowDirectKeys: ccToSave.AllowDirectKeys, MaxRequestBodySizeMB: ccToSave.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccToSave.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ + ConvertTextToChat: ccToSave.CompatConvertTextToChat, + ConvertChatToResponses: ccToSave.CompatConvertChatToResponses, + ShouldDropParams: ccToSave.CompatShouldDropParams, + ShouldConvertParams: ccToSave.CompatShouldConvertParams, + }, } hashBeforeSave, _ := clientConfig.GenerateClientConfigHash() @@ -13517,7 +13563,12 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccFromDB.EnforceAuthOnInference, AllowDirectKeys: ccFromDB.AllowDirectKeys, MaxRequestBodySizeMB: ccFromDB.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccFromDB.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ + ConvertTextToChat: ccFromDB.CompatConvertTextToChat, + ConvertChatToResponses: ccFromDB.CompatConvertChatToResponses, + ShouldDropParams: ccFromDB.CompatShouldDropParams, + ShouldConvertParams: ccFromDB.CompatShouldConvertParams, + }, } hashAfterLoad, _ := clientConfigFromDB.GenerateClientConfigHash() @@ -15674,13 +15725,13 @@ func TestConfigSchemaSyncTopLevel(t *testing.T) { // Enterprise-only features: These fields exist in the JSON schema for documentation // and validation purposes, but are only available in the enterprise version. enterpriseSchemaFields := map[string]bool{ - "$schema": true, - "audit_logs": true, - "cluster_config": true, - "saml_config": true, - "load_balancer_config": true, - "guardrails_config": true, - "large_payload_optimization": true, + "$schema": true, + "audit_logs": true, + "cluster_config": true, + "saml_config": true, + "load_balancer_config": true, + "guardrails_config": true, + "large_payload_optimization": true, } schema := loadJSONSchema(t) @@ -16627,7 +16678,10 @@ func assertDefaultClientConfigValues(t *testing.T, cc configstore.ClientConfig) require.Equal(t, 100, cc.MaxRequestBodySizeMB, "MaxRequestBodySizeMB should default to 100") require.Equal(t, 10, cc.MCPAgentDepth, "MCPAgentDepth should default to 10") require.Equal(t, 30, cc.MCPToolExecutionTimeout, "MCPToolExecutionTimeout should default to 30") - require.Equal(t, false, cc.EnableLiteLLMFallbacks, "EnableLiteLLMFallbacks should default to false") + require.Equal(t, false, cc.Compat.ConvertTextToChat, "Compat.ConvertTextToChat should default to false") + require.Equal(t, false, cc.Compat.ConvertChatToResponses, "Compat.ConvertChatToResponses should default to false") + require.Equal(t, false, cc.Compat.ShouldDropParams, "Compat.ShouldDropParams should default to false") + require.Equal(t, false, cc.Compat.ShouldConvertParams, "Compat.ShouldConvertParams should default to false") require.Equal(t, false, cc.HideDeletedVirtualKeysInFilters, "HideDeletedVirtualKeysInFilters should default to false") } diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 36c56cc2e5..7f6127406d 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -8,6 +8,7 @@ package lib import ( "context" + "encoding/json" "fmt" "strconv" "strings" @@ -443,6 +444,47 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat } return true } + + // Compat header: per-request override of compat plugin settings. + // Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all). + // An empty array [] or absent header means no overrides. + if keyStr == "x-bf-compat" { + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertTextToChat) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertChatToResponses) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldDropParams) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldConvertParams) + valueStr := strings.TrimSpace(string(value)) + if valueStr == "true" { + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } else if strings.HasPrefix(valueStr, "[") { + var features []string + if err := json.Unmarshal([]byte(valueStr), &features); err == nil { + if len(features) == 1 && features[0] == "*" { + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } else { + for _, f := range features { + switch f { + case "convert_text_to_chat": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + case "convert_chat_to_responses": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + case "should_drop_params": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + case "should_convert_params": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } + } + } + } + } + return true + } return true }) @@ -568,4 +610,4 @@ func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest // Note: Body not copied - for streaming, body was already consumed return req -} +} \ No newline at end of file diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 3cdf2f31fa..031dadd73b 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -6,8 +6,8 @@ import ( "slices" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/plugins/governance" - "github.com/maximhq/bifrost/plugins/litellmcompat" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -105,12 +105,12 @@ func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifro } return otel.Init(ctx, otelConfig, logger, bifrostConfig.ModelCatalog, handlers.GetVersion()) - case litellmcompat.PluginName: - litellmConfig, err := MarshalPluginConfig[litellmcompat.Config](pluginConfig) + case compat.PluginName: + compatConfig, err := MarshalPluginConfig[compat.Config](pluginConfig) if err != nil { - return nil, fmt.Errorf("failed to marshal litellmcompat plugin config: %w", err) + return nil, fmt.Errorf("failed to marshal compat plugin config: %w", err) } - return litellmcompat.Init(*litellmConfig, logger) + return compat.Init(*compatConfig, logger, bifrostConfig.ModelCatalog) default: return nil, fmt.Errorf("unknown built-in plugin: %s", name) @@ -215,14 +215,16 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 7. Litellmcompat (if configured in PluginConfigs) - litellmcompatConfig := s.getPluginConfig(litellmcompat.PluginName) - if litellmcompatConfig != nil && litellmcompatConfig.Enabled { - s.registerPluginWithStatus(ctx, litellmcompat.PluginName, nil, litellmcompatConfig.Config, false) - } else { - s.markPluginDisabled(litellmcompat.PluginName) + // 7. Compat (if any compat feature is enabled in ClientConfig) + cc := s.Config.ClientConfig.Compat + compatCfg := &compat.Config{ + ConvertTextToChat: cc.ConvertTextToChat, + ConvertChatToResponses: cc.ConvertChatToResponses, + ShouldDropParams: cc.ShouldDropParams, + ShouldConvertParams: cc.ShouldConvertParams, } - s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(7)) + s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatCfg, false) + s.Config.SetPluginOrderInfo(compat.PluginName, builtinPlacement, schemas.Ptr(7)) // 8. Maxim (if configured in PluginConfigs) maximConfig := s.getPluginConfig(maxim.PluginName) diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 7f8378d417..c24ea67555 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -1554,4 +1554,4 @@ func (s *BifrostHTTPServer) Start() error { return err } return nil -} +} \ No newline at end of file diff --git a/transports/config.schema.json b/transports/config.schema.json index ef73414d62..3db8b83322 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -100,9 +100,29 @@ "minimum": 1, "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Enable litellm-specific fallbacks for text completion for Groq" + "compat": { + "type": "object", + "description": "Compat plugin configuration for request type conversion, parameter dropping, and parameter value conversion", + "properties": { + "convert_text_to_chat": { + "type": "boolean", + "description": "Convert text completion requests to chat for models that only support chat" + }, + "convert_chat_to_responses": { + "type": "boolean", + "description": "Convert chat completion requests to responses for models that only support responses" + }, + "should_drop_params": { + "type": "boolean", + "description": "Drop unsupported parameters based on model catalog allowlist" + }, + "should_convert_params": { + "type": "boolean", + "description": "Converts model parameter values that are not supported by the model.", + "default": false + } + }, + "additionalProperties": false }, "header_filter_config": { "type": "object", @@ -4136,4 +4156,4 @@ "additionalProperties": false } } -} \ No newline at end of file +} diff --git a/transports/go.mod b/transports/go.mod index c182319a96..76c6a7fb0a 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -14,8 +14,8 @@ require ( github.com/mark3labs/mcp-go v0.43.2 github.com/maximhq/bifrost/core v1.5.1 github.com/maximhq/bifrost/framework v1.3.1 + github.com/maximhq/bifrost/plugins/compat v0.1.0 github.com/maximhq/bifrost/plugins/governance v1.5.1 - github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1 github.com/maximhq/bifrost/plugins/logging v1.5.1 github.com/maximhq/bifrost/plugins/maxim v1.6.1 github.com/maximhq/bifrost/plugins/otel v1.2.1 diff --git a/transports/go.sum b/transports/go.sum index 80dea86d19..096621cfdb 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -26,6 +26,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -167,6 +168,7 @@ github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+K github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= @@ -215,10 +217,10 @@ github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtN github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= +github.com/maximhq/bifrost/plugins/compat v0.1.0 h1:N9IVY4hmvQj/tCyppWu7zy41N8pyo0dZ+1W6Z+pQCKE= +github.com/maximhq/bifrost/plugins/compat v0.1.0/go.mod h1:PpVbCGimxQUiCHLzpHZRSjyNlSo+LgIbGzZFhtHcytI= github.com/maximhq/bifrost/plugins/governance v1.5.1 h1:zc7TY5Xb4HsEqKfL7mdkIushgAbD1a0MSoQpjYFEhtY= github.com/maximhq/bifrost/plugins/governance v1.5.1/go.mod h1:WosnY6eDKAufCZKJpNsqWiHt/fyZOx2THoDLzkqRTnM= -github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1 h1:90SzGOuPZjau6wQ1CJwB7f//XETKyf6yFZ/2jC/DMCU= -github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1/go.mod h1:BC1dOa23dED8rSYi7ntrIwqZGHkm3nktuPtEFSMx2tE= github.com/maximhq/bifrost/plugins/logging v1.5.1 h1:kNjmevWpt7nmsRyDmVTz8GPhnljtgCOtO52vjfTMvG8= github.com/maximhq/bifrost/plugins/logging v1.5.1/go.mod h1:qcutU7X+Qt7zuNgT7m/zblLvMsI4/SAaoMwlDDBopvY= github.com/maximhq/bifrost/plugins/maxim v1.6.1 h1:pwWflCaINS+6nPihSjezUpbCHdENqRFVSNiwiGzPyoI= @@ -359,13 +361,21 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 h1:w1K+pCJoPpQifuVpsKamUdn9U0zM3xUziVOqsGksUrY= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0/go.mod h1:HBy4BjzgVE8139ieRI75oXm3EcDN+6GhD88JT1Kjvxg= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= @@ -399,9 +409,13 @@ golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/ui/app/workspace/config/compatibility/page.tsx b/ui/app/workspace/config/compatibility/page.tsx new file mode 100644 index 0000000000..d8193f35a7 --- /dev/null +++ b/ui/app/workspace/config/compatibility/page.tsx @@ -0,0 +1,11 @@ +"use client"; + +import CompatibilityView from "../views/compatibilityView"; + +export default function CompatibilityPage() { + return ( +
+ +
+ ); +} diff --git a/ui/app/workspace/config/views/clientSettingsView.tsx b/ui/app/workspace/config/views/clientSettingsView.tsx index 1550b6566c..0ae5b736ca 100644 --- a/ui/app/workspace/config/views/clientSettingsView.tsx +++ b/ui/app/workspace/config/views/clientSettingsView.tsx @@ -107,7 +107,6 @@ export default function ClientSettingsView() { if (!config) return false; return ( localConfig.drop_excess_requests !== config.drop_excess_requests || - localConfig.enable_litellm_fallbacks !== config.enable_litellm_fallbacks || localConfig.disable_db_pings_in_health !== config.disable_db_pings_in_health || localConfig.async_job_result_ttl !== config.async_job_result_ttl || !headerFilterConfigEqual(localConfig.header_filter_config, config.header_filter_config) @@ -320,34 +319,6 @@ export default function ClientSettingsView() { /> - {/* Enable LiteLLM Fallbacks */} -
-
- -

- Enable litellm-specific fallbacks.{" "} - - Learn more - -

-
- handleConfigChange("enable_litellm_fallbacks", checked)} - disabled={!hasSettingsUpdateAccess} - /> -
- {/* Disable DB Pings in Health */}
@@ -438,9 +409,8 @@ export default function ClientSettingsView() {
  • Wildcards: Use{" "} - * at the end of a pattern to match - prefixes (e.g.,{" "} - anthropic-* matches all headers starting + * at the end of a pattern to match prefixes + (e.g., anthropic-* matches all headers starting with anthropic-). Use{" "} * alone to match all headers.
  • diff --git a/ui/app/workspace/config/views/compatibilityView.tsx b/ui/app/workspace/config/views/compatibilityView.tsx new file mode 100644 index 0000000000..50d4273c97 --- /dev/null +++ b/ui/app/workspace/config/views/compatibilityView.tsx @@ -0,0 +1,158 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Switch } from "@/components/ui/switch"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CompatConfig, DefaultCoreConfig } from "@/lib/types/config"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import Link from "next/link"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +export default function CompatibilityView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config?.compat; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localCompatConfig, setLocalCompatConfig] = useState(DefaultCoreConfig.compat); + + useEffect(() => { + if (config) { + setLocalCompatConfig(config); + return; + } + setLocalCompatConfig(DefaultCoreConfig.compat); + }, [config]); + + const hasChanges = useMemo(() => { + const baseline = config ?? DefaultCoreConfig.compat; + return ( + localCompatConfig.convert_text_to_chat !== baseline.convert_text_to_chat || + localCompatConfig.convert_chat_to_responses !== baseline.convert_chat_to_responses || + localCompatConfig.should_drop_params !== baseline.should_drop_params || + localCompatConfig.should_convert_params !== baseline.should_convert_params + ); + }, [config, localCompatConfig]); + + const handleCompatChange = useCallback((field: keyof CompatConfig, value: boolean) => { + setLocalCompatConfig((prev) => ({ ...prev, [field]: value })); + }, []); + + const handleSave = useCallback(async () => { + if (!bifrostConfig) { + toast.error("Configuration not loaded"); + return; + } + + try { + await updateCoreConfig({ + ...bifrostConfig, + client_config: { + ...(bifrostConfig.client_config ?? DefaultCoreConfig), + compat: localCompatConfig, + }, + }).unwrap(); + toast.success("Compatibility settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localCompatConfig, updateCoreConfig]); + + return ( +
    +
    +

    Compatibility

    +

    + Configure request conversions and compatibility fallbacks.{" "} + + Learn more + +

    +
    + +
    +
    +
    + +

    Convert text completion requests to chat for models that only support chat.

    +
    + handleCompatChange("convert_text_to_chat", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    + Convert chat completion requests to responses for models that only support responses. +

    +
    + handleCompatChange("convert_chat_to_responses", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    Drop unsupported parameters based on model catalog allowlist.

    +
    + handleCompatChange("should_drop_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    Converts model parameter values that are not supported by the model.

    +
    + handleCompatChange("should_convert_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    +
    + +
    + +
    +
    + ); +} \ No newline at end of file diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index f331a0d740..a95264ffe9 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -23,6 +23,7 @@ import { Logs, Network, PanelLeftClose, + Plug, Puzzle, Router, ScrollText, @@ -720,6 +721,13 @@ export default function AppSidebar() { description: "Client configuration settings", hasAccess: hasSettingsAccess, }, + { + title: "Compatibility", + url: "/workspace/config/compatibility", + icon: Plug, + description: "Compatibility conversion settings", + hasAccess: hasSettingsAccess, + }, { title: "Caching", url: "/workspace/config/caching", @@ -1251,4 +1259,4 @@ export default function AppSidebar() { ); -} +} \ No newline at end of file diff --git a/ui/components/ui/accordion.tsx b/ui/components/ui/accordion.tsx index 2978b5e54f..6ee186762e 100644 --- a/ui/components/ui/accordion.tsx +++ b/ui/components/ui/accordion.tsx @@ -26,7 +26,7 @@ function AccordionTrigger({ className, children, ...props }: React.ComponentProp {...props} > {children} - + ); @@ -44,4 +44,4 @@ function AccordionContent({ className, children, ...props }: React.ComponentProp ); } -export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; +export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; \ No newline at end of file diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index b3929e54b9..882585ded9 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -454,6 +454,13 @@ export interface BifrostConfig { auth_token?: string; } +export interface CompatConfig { + convert_text_to_chat: boolean; + convert_chat_to_responses: boolean; + should_drop_params: boolean; + should_convert_params: boolean; +} + // Core Bifrost configuration types export interface CoreConfig { drop_excess_requests: boolean; @@ -468,7 +475,7 @@ export interface CoreConfig { allowed_origins: string[]; allowed_headers: string[]; max_request_body_size_mb: number; - enable_litellm_fallbacks: boolean; + compat: CompatConfig; mcp_agent_depth: number; mcp_tool_execution_timeout: number; mcp_code_mode_binding_level?: string; @@ -495,7 +502,7 @@ export const DefaultCoreConfig: CoreConfig = { allow_direct_keys: false, allowed_origins: [], max_request_body_size_mb: 100, - enable_litellm_fallbacks: false, + compat: { convert_text_to_chat: false, convert_chat_to_responses: false, should_drop_params: false, should_convert_params: false }, mcp_agent_depth: 10, mcp_tool_execution_timeout: 30, mcp_code_mode_binding_level: "server",