From 5bf47420270a9882a3e7d3130835ade67ba15056 Mon Sep 17 00:00:00 2001 From: Samyabrata Maji <116789799+sammaji@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:52:54 +0530 Subject: [PATCH 1/4] fix: conversions for litellm compat to happen at the provider level --- core/bifrost.go | 60 ++++++++++-- core/schemas/bifrost.go | 1 + core/schemas/chatcompletions.go | 142 +++------------------------ core/schemas/mux.go | 118 +++++++++++++++++++++++ plugins/litellmcompat/context.go | 20 ---- plugins/litellmcompat/main.go | 50 +++------- plugins/litellmcompat/texttochat.go | 144 ++++++++++------------------ 7 files changed, 244 insertions(+), 291 deletions(-) delete mode 100644 plugins/litellmcompat/context.go diff --git a/core/bifrost.go b/core/bifrost.go index 96c206eb8b..b8b3d2d673 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -612,7 +612,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TextCompletionResponse, nil } @@ -934,7 +934,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.EmbeddingResponse, nil } @@ -1042,7 +1042,7 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.SpeechResponse, nil } @@ -1117,7 +1117,7 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TranscriptionResponse, nil } @@ -1158,7 +1158,8 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, // ImageGenerationRequest sends an image generation request to the specified provider. func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1213,7 +1214,8 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, // ImageGenerationStreamRequest sends an image generation stream request to the specified provider. func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1434,7 +1436,8 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * // VideoGenerationRequest sends a video generation request to the specified provider. func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostVideoGenerationRequest, +) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -4692,7 +4695,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Send the processed message to the output stream outputStream <- streamResponse - //TODO: Release the processed response immediately after use + // TODO: Release the processed response immediately after use } }() @@ -5252,6 +5255,30 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // bifrost.logger.Debug("worker for provider %s exiting...", provider.GetProviderKey()) } +func shouldConvertTextToChat(ctx *schemas.BifrostContext, requestType schemas.RequestType, request *schemas.BifrostTextCompletionRequest) bool { + if ctx == nil || request == nil { + return false + } + if requestType != schemas.TextCompletionRequest && requestType != schemas.TextCompletionStreamRequest { + return false + } + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) + return ok && shouldConvert +} + +func wrapTextToChatStreamPostHookRunner(postHookRunner schemas.PostHookRunner) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil && result.ChatResponse != nil { + if convertedResponse := result.ChatResponse.ToBifrostTextCompletionResponse(); convertedResponse != nil { + result = &schemas.BifrostResponse{ + TextCompletionResponse: convertedResponse, + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} + // handleProviderRequest handles the request to the provider based on the request type // key is used for single-key operations, keys is used for batch/file operations that need multiple keys func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, keys []schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -5264,6 +5291,17 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: + if shouldConvertTextToChat(req.Context, req.RequestType, req.BifrostRequest.TextCompletionRequest) { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse() + break + } + } textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) if bifrostError != nil { return nil, bifrostError @@ -5519,6 +5557,12 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: + if shouldConvertTextToChat(req.Context, req.RequestType, req.BifrostRequest.TextCompletionRequest) { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + return provider.ChatCompletionStream(req.Context, wrapTextToChatStreamPostHookRunner(postHookRunner), key, chatRequest) + } + } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 8f68d2fbad..cd8640880e 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -196,6 +196,7 @@ const ( BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" + BifrostContextKeyShouldConvertTextToChat BifrostContextKey = "bifrost-should-convert-text-to-chat" // bool (set by plugins to trigger text->chat provider conversion in core) BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index b864349213..f54002c9ba 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -29,16 +29,16 @@ func (cr *BifrostChatRequest) GetExtraParams() map[string]interface{} { // BifrostChatResponse represents the complete result from a chat completion request. type BifrostChatResponse struct { - ID string `json:"id"` - Choices []BifrostResponseChoice `json:"choices"` - Created int `json:"created"` // The Unix timestamp (in seconds). - Model string `json:"model"` - Object string `json:"object"` // "chat.completion" or "chat.completion.chunk" - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint string `json:"system_fingerprint"` - Usage *BifrostLLMUsage `json:"usage"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` - ExtraParams map[string]interface{} `json:"-"` + ID string `json:"id"` + Choices []BifrostResponseChoice `json:"choices"` + Created int `json:"created"` // The Unix timestamp (in seconds). + Model string `json:"model"` + Object string `json:"object"` // "chat.completion" or "chat.completion.chunk" + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint string `json:"system_fingerprint"` + Usage *BifrostLLMUsage `json:"usage"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + ExtraParams map[string]interface{} `json:"-"` // Perplexity-specific fields SearchResults []SearchResult `json:"search_results,omitempty"` @@ -46,125 +46,6 @@ type BifrostChatResponse struct { Citations []string `json:"citations,omitempty"` } -// ToTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse -func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletionResponse { - if cr == nil { - return nil - } - - if len(cr.Choices) == 0 { - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - choice := cr.Choices[0] - - // Handle streaming response choice - if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Choices: []BifrostResponseChoice{ - { - Index: 0, - TextCompletionResponseChoice: &TextCompletionResponseChoice{ - Text: choice.ChatStreamResponseChoice.Delta.Content, - }, - FinishReason: choice.FinishReason, - LogProbs: choice.LogProbs, - }, - }, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - // Handle non-streaming response choice - if choice.ChatNonStreamResponseChoice != nil { - msg := choice.ChatNonStreamResponseChoice.Message - var textContent *string - if msg != nil && msg.Content != nil && msg.Content.ContentStr != nil { - textContent = msg.Content.ContentStr - } - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Choices: []BifrostResponseChoice{ - { - Index: 0, - TextCompletionResponseChoice: &TextCompletionResponseChoice{ - Text: textContent, - }, - FinishReason: choice.FinishReason, - LogProbs: choice.LogProbs, - }, - }, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } - } - - // Fallback case - return basic response structure - return &BifrostTextCompletionResponse{ - ID: cr.ID, - Model: cr.Model, - Object: "text_completion", - SystemFingerprint: cr.SystemFingerprint, - Usage: cr.Usage, - ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, - ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, - ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, - }, - } -} - // ChatParameters represents the parameters for a chat completion. type ChatParameters struct { Audio *ChatAudioParameters `json:"audio,omitempty"` // Audio parameters @@ -531,7 +412,6 @@ type AdditionalPropertiesStruct struct { // MarshalJSON implements custom JSON marshalling for AdditionalPropertiesStruct. // It marshals either AdditionalPropertiesBool or AdditionalPropertiesMap based on which is set. func (a AdditionalPropertiesStruct) MarshalJSON() ([]byte, error) { - // if both are set, return an error if a.AdditionalPropertiesBool != nil && a.AdditionalPropertiesMap != nil { return nil, fmt.Errorf("both AdditionalPropertiesBool and AdditionalPropertiesMap are set; only one should be non-nil") @@ -1198,7 +1078,7 @@ type BifrostLLMUsage struct { CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"` TotalTokens int `json:"total_tokens"` - Cost *BifrostCost `json:"cost,omitempty"` //Only for the providers which support cost calculation + Cost *BifrostCost `json:"cost,omitempty"` // Only for the providers which support cost calculation } type ChatPromptTokensDetails struct { diff --git a/core/schemas/mux.go b/core/schemas/mux.go index f899f41739..7b0ffac048 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -2013,3 +2013,121 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes return responses } + +// ============================================================================= +// RESPONSE CONVERSION METHODS +// ============================================================================= + +// ToBifrostTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse +func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCompletionResponse { + if cr == nil { + return nil + } + + if len(cr.Choices) == 0 { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + ModelRequested: cr.ExtraFields.ModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + choice := cr.Choices[0] + + // Handle streaming response choice + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: choice.ChatStreamResponseChoice.Delta.Content, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + ModelRequested: cr.ExtraFields.ModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Handle non-streaming response choice + if choice.ChatNonStreamResponseChoice != nil { + msg := choice.ChatNonStreamResponseChoice.Message + var textContent *string + if msg != nil && msg.Content != nil && msg.Content.ContentStr != nil { + textContent = msg.Content.ContentStr + } + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: textContent, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + ModelRequested: cr.ExtraFields.ModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Fallback case - return basic response structure + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + ModelRequested: cr.ExtraFields.ModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + }, + } +} \ No newline at end of file diff --git a/plugins/litellmcompat/context.go b/plugins/litellmcompat/context.go deleted file mode 100644 index 1bcc79b99c..0000000000 --- a/plugins/litellmcompat/context.go +++ /dev/null @@ -1,20 +0,0 @@ -package litellmcompat - -import "github.com/maximhq/bifrost/core/schemas" - -// TransformContextKey is the key used to store TransformContext in BifrostContext -const TransformContextKey schemas.BifrostContextKey = "litellmcompat-transform-context" - -// TransformContext tracks what transformations were applied to a request -// so they can be reversed on the response -type TransformContext struct { - // Text-to-chat transform state - // TextToChatApplied indicates that a text completion request was converted to chat - TextToChatApplied bool - // OriginalRequestType stores the original request type before transformation - OriginalRequestType schemas.RequestType - // OriginalModel preserves the original model string for response metadata - OriginalModel string - // IsStreaming indicates if the original request was a streaming request - IsStreaming bool -} diff --git a/plugins/litellmcompat/main.go b/plugins/litellmcompat/main.go index 730d56beef..8983eb5ad3 100644 --- a/plugins/litellmcompat/main.go +++ b/plugins/litellmcompat/main.go @@ -1,13 +1,10 @@ -// Package litellmcompat provides LiteLLM-compatible request/response transformations -// for the Bifrost gateway. It enables automatic conversion of text completion requests -// to chat completion requests for models that only support chat completions, matching -// LiteLLM's behavior. +// Package litellmcompat provides LiteLLM-compatible text-to-chat conversion decisions +// for the Bifrost gateway. It marks text completion requests that should be converted +// by core provider dispatch for models that only support chat completions. // // When enabled, this plugin: -// - Silently converts text_completion() calls to chat completion format -// - Routes to the chat completion endpoint -// - Transforms the response back to text completion format -// - Places content in choices[0].text instead of choices[0].message.content +// - Decides whether text_completion() should be converted to chat +// - Stores the decision in context for core request dispatch package litellmcompat import ( @@ -78,47 +75,24 @@ func (p *LiteLLMCompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostC return chunk, nil } -// PreLLMHook intercepts requests and applies LiteLLM-compatible transformations. +// PreLLMHook intercepts requests and applies LiteLLM-compatible transformation intent. // For text completion requests on models that don't support text completion, -// it converts them to chat completion requests. +// it marks the request so core can convert at provider dispatch time. func (p *LiteLLMCompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { - tc := &TransformContext{} - // Apply request transforms in sequence - req = transformTextToChatRequest(ctx, req, tc, p.modelCatalog, p.logger) - - // Store the transform context for use in PostHook - ctx.SetValue(TransformContextKey, tc) - + req = transformTextToChatRequest(ctx, req, p.modelCatalog, p.logger) return req, nil, nil } -// PostLLMHook processes responses and applies LiteLLM-compatible transformations. -// If a text completion request was converted to chat, this converts the -// chat response back to text completion format. +// PostLLMHook normalizes metadata on converted responses/errors +// when this plugin requested text->chat conversion in PreLLMHook. func (p *LiteLLMCompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - // Retrieve the transform context - transformCtxValue := ctx.Value(TransformContextKey) - if transformCtxValue == nil { - return result, bifrostErr, nil - } - tc, ok := transformCtxValue.(*TransformContext) - if !ok || tc == nil { - return result, bifrostErr, nil - } - - // Apply response transforms in sequence - // Note: tool-call content runs before text-to-chat because text-to-chat may convert - // the response type, and tool-call content needs to operate on chat responses if result != nil { - result = transformTextToChatResponse(ctx, result, tc, p.logger) + result = transformTextToChatResponse(ctx, result, p.logger) } - - // Transform error metadata if there's an error if bifrostErr != nil { - bifrostErr = transformTextToChatError(ctx, bifrostErr, tc) + bifrostErr = transformTextToChatError(ctx, bifrostErr) } - return result, bifrostErr, nil } diff --git a/plugins/litellmcompat/texttochat.go b/plugins/litellmcompat/texttochat.go index b0c1b0a309..9c78a1473f 100644 --- a/plugins/litellmcompat/texttochat.go +++ b/plugins/litellmcompat/texttochat.go @@ -5,17 +5,21 @@ import ( "github.com/maximhq/bifrost/framework/modelcatalog" ) -// transformTextToChatRequest converts a text completion request to a chat completion request -// if the model doesn't support text completion natively. -// It updates the TransformContext with the transformation state. -func transformTextToChatRequest(_ *schemas.BifrostContext, req *schemas.BifrostRequest, tc *TransformContext, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { +const ( + OriginalRequestTypeContextKey schemas.BifrostContextKey = "litellmcompat-original-request-type" + OriginalModelContextKey schemas.BifrostContextKey = "litellmcompat-original-model" +) + +// transformTextToChatRequest determines whether a text request should be converted by core. +// It stores conversion intent in context; core performs the actual conversion. +func transformTextToChatRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { // Only process text completion requests if req.RequestType != schemas.TextCompletionRequest && req.RequestType != schemas.TextCompletionStreamRequest { return req } // Check if text completion request is present - if req.TextCompletionRequest == nil || tc == nil { + if req.TextCompletionRequest == nil { return req } @@ -24,6 +28,9 @@ func transformTextToChatRequest(_ *schemas.BifrostContext, req *schemas.BifrostR provider := req.TextCompletionRequest.Provider model := req.TextCompletionRequest.Model if mc.IsTextCompletionSupported(model, provider) { + if ctx != nil { + ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, false) + } if logger != nil { logger.Debug("litellmcompat: model %s/%s supports text completion, skipping conversion", provider, model) } @@ -31,121 +38,70 @@ func transformTextToChatRequest(_ *schemas.BifrostContext, req *schemas.BifrostR } } - // Convert text completion to chat completion - chatRequest := req.TextCompletionRequest.ToBifrostChatRequest() - if chatRequest == nil { - return req - } - - // Track the transformation - tc.TextToChatApplied = true - tc.OriginalRequestType = req.RequestType - tc.OriginalModel = req.TextCompletionRequest.Model - tc.IsStreaming = req.RequestType == schemas.TextCompletionStreamRequest - - // Create a new request with the chat completion - transformedReq := &schemas.BifrostRequest{ - ChatRequest: chatRequest, - } - - // Set the appropriate request type - if tc.IsStreaming { - transformedReq.RequestType = schemas.ChatCompletionStreamRequest - } else { - transformedReq.RequestType = schemas.ChatCompletionRequest + // Track conversion intent. Core will do the actual conversion during provider dispatch. + if ctx != nil { + ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, true) + ctx.SetValue(OriginalRequestTypeContextKey, req.RequestType) + ctx.SetValue(OriginalModelContextKey, req.TextCompletionRequest.Model) } if logger != nil { - logger.Debug("litellmcompat: converted text completion to chat completion for model %s (text completion not supported)", tc.OriginalModel) + logger.Debug("litellmcompat: marked text completion for core text->chat conversion for model %s (text completion not supported)", req.TextCompletionRequest.Model) } - return transformedReq + return req } -// transformTextToChatResponse converts a chat response back to text completion format -// if the original request was a text completion that was converted. -func transformTextToChatResponse(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, tc *TransformContext, logger schemas.Logger) *schemas.BifrostResponse { - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { - return resp +func getOriginalTextRequestMetadata(ctx *schemas.BifrostContext) (schemas.RequestType, string) { + requestType := schemas.TextCompletionRequest + if ctx == nil { + return requestType, "" + } + if value, ok := ctx.Value(OriginalRequestTypeContextKey).(schemas.RequestType); ok { + requestType = value } + model, _ := ctx.Value(OriginalModelContextKey).(string) + return requestType, model +} - // Check if we have a chat response to transform - if resp == nil || resp.ChatResponse == nil { +// transformTextToChatResponse normalizes metadata on converted text-completion responses. +// Core performs the actual stream/non-stream payload conversion. +func transformTextToChatResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, logger schemas.Logger) *schemas.BifrostResponse { + if resp == nil || resp.TextCompletionResponse == nil || ctx == nil { return resp } - // Convert chat response to text completion response - textCompletionResponse := resp.ChatResponse.ToTextCompletionResponse() - if textCompletionResponse == nil { + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) + if !ok || !shouldConvert { return resp } - // Restore original request type metadata - textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel - textCompletionResponse.ExtraFields.LiteLLMCompat = true + originalRequestType, originalModel := getOriginalTextRequestMetadata(ctx) + resp.TextCompletionResponse.ExtraFields.RequestType = originalRequestType + resp.TextCompletionResponse.ExtraFields.ModelRequested = originalModel + resp.TextCompletionResponse.ExtraFields.LiteLLMCompat = true if logger != nil { - logger.Debug("litellmcompat: converted chat response back to text completion for model %s", tc.OriginalModel) + logger.Debug("litellmcompat: normalized converted text completion metadata for model %s", originalModel) } - // Return a new response with the text completion - return &schemas.BifrostResponse{ - TextCompletionResponse: textCompletionResponse, - } + return resp } -// transformTextToChatError ensures error metadata reflects the original request type -// if a text completion request was converted to chat. -func transformTextToChatError(_ *schemas.BifrostContext, err *schemas.BifrostError, tc *TransformContext) *schemas.BifrostError { - if tc == nil || err == nil { +// transformTextToChatError restores original text-completion metadata on errors +// generated from chat fallback execution. +func transformTextToChatError(ctx *schemas.BifrostContext, err *schemas.BifrostError) *schemas.BifrostError { + if err == nil || ctx == nil { return err } - - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) + if !ok || !shouldConvert { return err } - // Restore original request type in error metadata - err.ExtraFields.RequestType = tc.OriginalRequestType - err.ExtraFields.OriginalModelRequested = tc.OriginalModel + originalRequestType, originalModel := getOriginalTextRequestMetadata(ctx) + err.ExtraFields.RequestType = originalRequestType + err.ExtraFields.ModelRequested = originalModel err.ExtraFields.LiteLLMCompat = true - return err } - -// TransformTextToChatStreamResponse transforms a streaming chat response back to text completion format. -// This is exported for use by streaming handlers. -func TransformTextToChatStreamResponse(ctx *schemas.BifrostContext, stream *schemas.BifrostStreamChunk, tc *TransformContext) *schemas.BifrostStreamChunk { - if tc == nil { - return stream - } - - // Only transform if we converted text completion to chat - if !tc.TextToChatApplied { - return stream - } - - // Check if we have a chat response in the stream to transform - if stream == nil || stream.BifrostChatResponse == nil { - return stream - } - - // Convert chat response to text completion response - textCompletionResponse := stream.BifrostChatResponse.ToTextCompletionResponse() - if textCompletionResponse == nil { - return stream - } - - // Restore original request type metadata - textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel - textCompletionResponse.ExtraFields.LiteLLMCompat = true - - // Return a new stream with the text completion response - return &schemas.BifrostStreamChunk{ - BifrostTextCompletionResponse: textCompletionResponse, - } -} From 99772bd821f2be8296df064abf2decaa88385175 Mon Sep 17 00:00:00 2001 From: Samyabrata Maji <116789799+sammaji@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:27:02 +0530 Subject: [PATCH 2/4] feat: litellmcompat chat to responses --- core/bifrost.go | 43 +- core/schemas/bifrost.go | 1 + core/schemas/mux.go | 260 +++++++- core/utils.go | 4 +- framework/modelcatalog/main.go | 711 +++++++++++++++++++++- framework/modelcatalog/sync.go | 1 - framework/modelcatalog/utils.go | 15 + plugins/litellmcompat/chattoresponses.go | 108 ++++ plugins/litellmcompat/main.go | 39 +- transports/bifrost-http/server/plugins.go | 4 +- transports/bifrost-http/server/server.go | 2 +- 11 files changed, 1162 insertions(+), 26 deletions(-) create mode 100644 plugins/litellmcompat/chattoresponses.go diff --git a/core/bifrost.go b/core/bifrost.go index b8b3d2d673..722bd61e29 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5266,6 +5266,17 @@ func shouldConvertTextToChat(ctx *schemas.BifrostContext, requestType schemas.Re return ok && shouldConvert } +func shouldConvertChatToResponses(ctx *schemas.BifrostContext, requestType schemas.RequestType, request *schemas.BifrostChatRequest) bool { + if ctx == nil || request == nil { + return false + } + if requestType != schemas.ChatCompletionRequest && requestType != schemas.ChatCompletionStreamRequest { + return false + } + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) + return ok && shouldConvert +} + func wrapTextToChatStreamPostHookRunner(postHookRunner schemas.PostHookRunner) schemas.PostHookRunner { return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { if result != nil && result.ChatResponse != nil { @@ -5279,6 +5290,19 @@ func wrapTextToChatStreamPostHookRunner(postHookRunner schemas.PostHookRunner) s } } +func wrapChatToResponsesStreamPostHookRunner(postHookRunner schemas.PostHookRunner) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil && result.ResponsesStreamResponse != nil { + if convertedResponse := result.ResponsesStreamResponse.ToBifrostChatResponse(); convertedResponse != nil { + result = &schemas.BifrostResponse{ + ChatResponse: convertedResponse, + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} + // handleProviderRequest handles the request to the provider based on the request type // key is used for single-key operations, keys is used for batch/file operations that need multiple keys func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, keys []schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -5308,6 +5332,17 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: + if shouldConvertChatToResponses(req.Context, req.RequestType, req.BifrostRequest.ChatRequest) { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.ChatResponse = responsesResponse.ToBifrostChatResponse() + break + } + } chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) if bifrostError != nil { return nil, bifrostError @@ -5565,6 +5600,12 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: + if shouldConvertChatToResponses(req.Context, req.RequestType, req.BifrostRequest.ChatRequest) { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + return provider.ResponsesStream(req.Context, wrapChatToResponsesStreamPostHookRunner(postHookRunner), key, responsesRequest) + } + } return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) @@ -6716,4 +6757,4 @@ func (bifrost *Bifrost) Shutdown() { } } bifrost.logger.Info("all request channels closed") -} +} \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index cd8640880e..310871ab96 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -197,6 +197,7 @@ const ( BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" BifrostContextKeyShouldConvertTextToChat BifrostContextKey = "bifrost-should-convert-text-to-chat" // bool (set by plugins to trigger text->chat provider conversion in core) + BifrostContextKeyShouldConvertChatToResponses BifrostContextKey = "bifrost-should-convert-chat-to-responses" // bool (set by plugins to trigger chat->responses provider conversion in core) BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) diff --git a/core/schemas/mux.go b/core/schemas/mux.go index 7b0ffac048..e719311539 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -1258,6 +1258,10 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC Videos: responsesResp.Videos, } + if responsesResp.ID != nil { + chatResp.ID = *responsesResp.ID + } + // Create Choices from ResponsesResponse if len(responsesResp.Output) > 0 { // Convert ResponsesMessages back to ChatMessages @@ -1991,6 +1995,34 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes response.Output = allOutput } + // Append finalized function call items so the terminal response carries them in Output. + for toolCallID, args := range state.ToolArgumentBuffers { + if args == "" { + continue + } + statusFinal := terminalStatus + messageType := ResponsesMessageTypeFunctionCall + callName := state.ToolCallNames[toolCallID] + var callNamePtr *string + if callName != "" { + callNamePtr = &callName + } + argsValue := args + fcMsg := ResponsesMessage{ + Type: &messageType, + Status: &statusFinal, + ResponsesToolMessage: &ResponsesToolMessage{ + CallID: &toolCallID, + Name: callNamePtr, + Arguments: &argsValue, + }, + } + if itemID := state.ItemIDs[toolCallID]; itemID != "" { + fcMsg.ID = &itemID + } + response.Output = append(response.Output, fcMsg) + } + responses = append(responses, &BifrostResponsesStreamResponse{ Type: terminalEventType, SequenceNumber: state.SequenceNumber, @@ -2014,6 +2046,232 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes return responses } +// ToBifrostChatResponse converts a BifrostResponsesStreamResponse chunk to a BifrostChatResponse (chat.completion.chunk). +// Returns nil for events that have no meaningful chat completion equivalent (lifecycle events, etc.). +func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatResponse { + if rsr == nil { + return nil + } + + extraFields := rsr.ExtraFields + extraFields.RequestType = ChatCompletionStreamRequest + + resp := &BifrostChatResponse{ + Object: "chat.completion.chunk", + ExtraFields: extraFields, + SearchResults: rsr.SearchResults, + Videos: rsr.Videos, + Citations: rsr.Citations, + } + + if rsr.Response != nil { + if rsr.Response.ID != nil { + resp.ID = *rsr.Response.ID + } + resp.Created = rsr.Response.CreatedAt + resp.Model = rsr.Response.Model + } + + switch rsr.Type { + case ResponsesStreamResponseTypeOutputTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Content: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeReasoningSummaryTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Reasoning: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeRefusalDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Refusal: rsr.Refusal, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeOutputItemAdded: + if rsr.Item == nil || rsr.Item.Type == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + switch *rsr.Item.Type { + case ResponsesMessageTypeFunctionCall: + if rsr.Item.ResponsesToolMessage == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + funcType := "function" + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Type: &funcType, + ID: rsr.Item.ResponsesToolMessage.CallID, + Function: ChatAssistantMessageToolCallFunction{ + Name: rsr.Item.ResponsesToolMessage.Name, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesMessageTypeMessage: + role := "assistant" + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Role: &role, + }, + }, + }, + } + return resp + + default: + // reasoning, file_search_call, web_search_call, etc. — no chat equivalent, + // actual content arrives via separate delta events. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + case ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + if rsr.Delta == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Function: ChatAssistantMessageToolCallFunction{ + Arguments: *rsr.Delta, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeCompleted, ResponsesStreamResponseTypeIncomplete: + finishReason := string(BifrostFinishReasonStop) + if rsr.Type == ResponsesStreamResponseTypeIncomplete { + finishReason = string(BifrostFinishReasonLength) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + FinishReason: &finishReason, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + if rsr.Response != nil { + if rsr.Response.Usage != nil { + resp.Usage = rsr.Response.Usage.ToBifrostLLMUsage() + } + // Check for tool_calls finish reason + for _, output := range rsr.Response.Output { + if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { + finishReason = string(BifrostFinishReasonToolCalls) + resp.Choices[0].FinishReason = &finishReason + break + } + } + } + return resp + + default: + // Lifecycle events (created, in_progress, content_part.added/done, output_text.done, + // output_item.done, function_call_arguments.done, etc.) → empty chat chunk with no content. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } +} + // ============================================================================= // RESPONSE CONVERSION METHODS // ============================================================================= @@ -2130,4 +2388,4 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom CacheDebug: cr.ExtraFields.CacheDebug, }, } -} \ No newline at end of file +} diff --git a/core/utils.go b/core/utils.go index 12d86e2508..75b491a9c1 100644 --- a/core/utils.go +++ b/core/utils.go @@ -272,6 +272,8 @@ func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) + ctx.ClearValue(schemas.BifrostContextKeyShouldConvertTextToChat) + ctx.ClearValue(schemas.BifrostContextKeyShouldConvertChatToResponses) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -579,7 +581,7 @@ func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model if discriminator == "" { discriminator = "__modelless__" } - return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) + return "session:" + string(provierKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) } // isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt. diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 3a797e8250..8aea69bb7b 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -3,12 +3,13 @@ package modelcatalog import ( "context" + "encoding/json" "fmt" + "slices" + "strings" "sync" "time" - "encoding/json" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" @@ -46,6 +47,10 @@ type ModelCatalog struct { unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering baseModelIndex map[string]string // model string → canonical base model name + // Pre-parsed supported outputs index (keyed by model name, populated from model parameters supported_endpoints) + // Values are normalized output types: "chat_completion", "responses", "text_completion" + supportedOutputs map[string][]string + // Background sync worker syncTicker *time.Ticker done chan struct{} @@ -75,6 +80,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto modelPool: make(map[schemas.ModelProvider][]string), unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), + supportedOutputs: make(map[string][]string), done: make(chan struct{}), distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)), } @@ -312,6 +318,704 @@ func (mc *ModelCatalog) getPricingURL() string { return mc.pricingURL } +// getPricingSyncInterval returns a copy of the pricing sync interval under mutex protection +func (mc *ModelCatalog) getPricingSyncInterval() time.Duration { + mc.pricingMu.RLock() + defer mc.pricingMu.RUnlock() + return mc.pricingSyncInterval +} + +// GetPricingEntryForModel returns the pricing data +func (mc *ModelCatalog) GetPricingEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { + mc.mu.RLock() + defer mc.mu.RUnlock() + // Check all modes + for _, mode := range []schemas.RequestType{ + schemas.TextCompletionRequest, + schemas.ChatCompletionRequest, + schemas.ResponsesRequest, + schemas.EmbeddingRequest, + schemas.RerankRequest, + schemas.SpeechRequest, + schemas.TranscriptionRequest, + schemas.ImageGenerationRequest, + schemas.ImageEditRequest, + schemas.ImageVariationRequest, + schemas.VideoGenerationRequest, + } { + key := makeKey(model, string(provider), normalizeRequestType(mode)) + pricing, ok := mc.pricingData[key] + if ok { + return convertTableModelPricingToPricingData(&pricing) + } + } + return nil +} + +// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair. +// It prefers chat, then responses, then text-completion entries; if none exist, +// it falls back to the lexicographically first available mode for deterministic behavior. +func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { + mc.mu.RLock() + defer mc.mu.RUnlock() + + if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil { + return entry + } + + baseModel := mc.getBaseModelNameUnsafe(model) + if baseModel != model { + if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil { + return entry + } + } + + if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil { + return entry + } + + return nil +} + +func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry { + preferredModes := []schemas.RequestType{ + schemas.ChatCompletionRequest, + schemas.ResponsesRequest, + schemas.TextCompletionRequest, + } + + for _, mode := range preferredModes { + key := makeKey(model, string(provider), normalizeRequestType(mode)) + pricing, ok := mc.pricingData[key] + if ok { + return convertTableModelPricingToPricingData(&pricing) + } + } + + prefix := model + "|" + string(provider) + "|" + matchingKeys := make([]string, 0) + for key := range mc.pricingData { + if strings.HasPrefix(key, prefix) { + matchingKeys = append(matchingKeys, key) + } + } + return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) +} + +func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry { + if baseModel == "" { + return nil + } + + matchingKeys := make([]string, 0) + for key, pricing := range mc.pricingData { + if normalizeProvider(pricing.Provider) != string(provider) { + continue + } + if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel { + continue + } + matchingKeys = append(matchingKeys, key) + } + return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) +} + +func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry { + if len(matchingKeys) == 0 { + return nil + } + + preferredModes := []string{ + normalizeRequestType(schemas.ChatCompletionRequest), + normalizeRequestType(schemas.ResponsesRequest), + normalizeRequestType(schemas.TextCompletionRequest), + } + + for _, mode := range preferredModes { + modeMatches := make([]string, 0) + for _, key := range matchingKeys { + parts := strings.SplitN(key, "|", 3) + if len(parts) != 3 || parts[2] != mode { + continue + } + modeMatches = append(modeMatches, key) + } + if len(modeMatches) == 0 { + continue + } + slices.Sort(modeMatches) + pricing := mc.pricingData[modeMatches[0]] + return convertTableModelPricingToPricingData(&pricing) + } + + slices.Sort(matchingKeys) + pricing := mc.pricingData[matchingKeys[0]] + return convertTableModelPricingToPricingData(&pricing) +} + +// GetModelsForProvider returns all available models for a given provider (thread-safe) +func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string { + mc.mu.RLock() + defer mc.mu.RUnlock() + + models, exists := mc.modelPool[provider] + if !exists { + return []string{} + } + + // Return a copy to prevent external modification + result := make([]string, len(models)) + copy(result, models) + return result +} + +// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe) +func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string { + mc.mu.RLock() + defer mc.mu.RUnlock() + + models, exists := mc.unfilteredModelPool[provider] + if !exists { + return []string{} + } + + // Return a copy to prevent external modification + result := make([]string, len(models)) + copy(result, models) + return result +} + +// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe). +// This is used for governance model selection when no specific provider is chosen. +func (mc *ModelCatalog) GetDistinctBaseModelNames() []string { + mc.mu.RLock() + defer mc.mu.RUnlock() + + seen := make(map[string]bool) + for _, baseName := range mc.baseModelIndex { + seen[baseName] = true + } + + result := make([]string, 0, len(seen)) + for name := range seen { + result = append(result, name) + } + return result +} + +// GetProvidersForModel returns all providers for a given model (thread-safe) +func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider { + mc.mu.RLock() + defer mc.mu.RUnlock() + + providers := make([]schemas.ModelProvider, 0) + for provider, models := range mc.modelPool { + isModelMatch := false + for _, m := range models { + if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) { + isModelMatch = true + break + } + } + if isModelMatch { + providers = append(providers, provider) + } + } + + // Handler special provider cases + // 1. Handler openrouter models + if !slices.Contains(providers, schemas.OpenRouter) { + for _, provider := range providers { + if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok { + if slices.Contains(openRouterModels, string(provider)+"/"+model) { + providers = append(providers, schemas.OpenRouter) + } + } + } + } + + // 2. Handle vertex models + if !slices.Contains(providers, schemas.Vertex) { + for _, provider := range providers { + if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok { + if slices.Contains(vertexModels, string(provider)+"/"+model) { + providers = append(providers, schemas.Vertex) + } + } + } + } + + // 3. Handle openai models for groq + if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") { + if groqModels, ok := mc.modelPool[schemas.Groq]; ok { + if slices.Contains(groqModels, "openai/"+model) { + providers = append(providers, schemas.Groq) + } + } + } + + // 4. Handle anthropic models for bedrock + if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") { + if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok { + for _, bedrockModel := range bedrockModels { + if strings.Contains(bedrockModel, model) { + providers = append(providers, schemas.Bedrock) + break + } + } + } + } + + return providers +} + +// IsModelAllowedForProvider checks if a model is allowed for a specific provider +// based on the allowed models list and catalog data. It handles all cross-provider +// logic including provider-prefixed models and special routing rules. +// +// Parameters: +// - provider: The provider to check against +// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet") +// - allowedModels: List of allowed model names (can be empty, can include provider prefixes) +// +// Behavior: +// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model +// (delegates to GetProvidersForModel which handles all cross-provider logic) +// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair +// - If allowedModels is not empty: Checks if model matches any entry in the list +// Provider-specific validation: +// - Direct matches: "gpt-4o" in allowedModels for any provider +// - Prefixed matches: Only if the prefixed model exists in provider's catalog +// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog +// contains "openai/gpt-4o" AND the model part matches the request) +// +// Returns: +// - bool: true if the model is allowed for the provider, false otherwise +// +// Examples: +// +// // Wildcard allowedModels - uses catalog to check provider support +// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"}) +// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet") +// +// // Empty allowedModels - deny all (deny-by-default) +// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{}) +// // Returns: false (no models are permitted) +// +// // Explicit allowedModels with prefix - validates against catalog +// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"}) +// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o") +// +// // Explicit allowedModels with prefix - wrong model +// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"}) +// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet") +// +// // Explicit allowedModels without prefix +// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) +// // Returns: true (direct match) +func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels schemas.WhiteList) bool { + // Case 1: ["*"] = allow all models; use catalog to determine support + // Empty allowedModels = deny all (fail-safe deny-by-default) + if allowedModels.IsUnrestricted() { + supportedProviders := mc.GetProvidersForModel(model) + return slices.Contains(supportedProviders, provider) + } + if allowedModels.IsEmpty() { + return false + } + + // Case 2: Explicit allowedModels = check if model matches any entry + // Get provider's catalog models for validation of prefixed entries + providerCatalogModels := mc.GetModelsForProvider(provider) + + for _, allowedModel := range allowedModels { + // Direct match: "gpt-4o" == "gpt-4o" + if allowedModel == model { + return true + } + + // Provider-prefixed match: verify it exists in provider's catalog first + // This ensures we only allow provider-specific model combinations that are actually supported + if strings.Contains(allowedModel, "/") { + // Check if this exact prefixed model exists in the provider's catalog + // e.g., for openrouter, check if "openai/gpt-4o" is in its catalog + if slices.Contains(providerCatalogModels, allowedModel) { + // Extract the model part and compare with request + _, modelPart := schemas.ParseModelString(allowedModel, "") + if modelPart == model { + return true + } + } + } + } + + return false +} + +// GetBaseModelName returns the canonical base model name for a given model string. +// It uses the pre-computed base_model from the pricing catalog when available, +// falling back to algorithmic date/version stripping for models not in the catalog. +// +// Examples: +// +// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o" +// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o" +// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback) +func (mc *ModelCatalog) GetBaseModelName(model string) string { + mc.mu.RLock() + defer mc.mu.RUnlock() + return mc.getBaseModelNameUnsafe(model) +} + +// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking. +// This is used to avoid locking overhead when getting the base model name for many models. +// Make sure the caller function is holding the read lock before calling this function. +// It is not safe to use this function when the model pool is being updated. +func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string { + // Step 1: Direct lookup in base model index + if base, ok := mc.baseModelIndex[model]; ok { + return base + } + + // Step 2: Strip provider prefix and try again + _, baseName := schemas.ParseModelString(model, "") + if baseName != model { + if base, ok := mc.baseModelIndex[baseName]; ok { + return base + } + } + + // Step 3: Fallback to algorithmic date/version stripping + // (for models not in the catalog, e.g., user-configured custom models) + return schemas.BaseModelName(baseName) +} + +// IsSameModel checks if two model strings refer to the same underlying model. +// It compares the canonical base model names derived from the pricing catalog +// (or algorithmic fallback for models not in the catalog). +// +// Examples: +// +// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match) +// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model) +// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models) +// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false +func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool { + if model1 == model2 { + return true + } + return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2) +} + +// DeleteModelDataForProvider deletes all model data from the pool for a given provider +func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) { + mc.mu.Lock() + defer mc.mu.Unlock() + + delete(mc.modelPool, provider) + delete(mc.unfilteredModelPool, provider) +} + +// UpsertModelDataForProvider upserts model data for a given provider +func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { + if modelData == nil { + return + } + mc.mu.Lock() + defer mc.mu.Unlock() + + // Populating models from pricing data for the given provider + // Provider models map + providerModels := []string{} + // Iterate through all pricing data to collect models per provider + for _, pricing := range mc.pricingData { + // Normalize provider before adding to model pool + normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) + // We will only add models for the given provider + if normalizedProvider != provider { + continue + } + // Add model to the provider's model set (using map for deduplication) + if slices.Contains(providerModels, pricing.Model) { + continue + } + providerModels = append(providerModels, pricing.Model) + // Build base model index from pre-computed base_model field + if pricing.BaseModel != "" { + mc.baseModelIndex[pricing.Model] = pricing.BaseModel + } + } + // If modelData is empty, then we allow all models + if len(modelData.Data) == 0 && len(allowedModels) == 0 { + mc.modelPool[provider] = providerModels + return + } + // Here we make sure that we still keep the backup for model catalog intact + // So we start with a existing model pool and add the new models from incoming data + finalModelList := make([]string, 0) + seenModels := make(map[string]bool) + // Case where list models failed but we have allowed models from keys + if len(modelData.Data) == 0 && len(allowedModels) > 0 { + for _, allowedModel := range allowedModels { + parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "") + if parsedProvider != provider { + continue + } + if !seenModels[parsedModel] { + seenModels[parsedModel] = true + finalModelList = append(finalModelList, parsedModel) + } + } + } + for _, model := range modelData.Data { + parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") + if parsedProvider != provider { + continue + } + if !seenModels[parsedModel] { + seenModels[parsedModel] = true + finalModelList = append(finalModelList, parsedModel) + } + } + + if len(allowedModels) == 0 { + for _, model := range providerModels { + if !seenModels[model] { + seenModels[model] = true + finalModelList = append(finalModelList, model) + } + } + } + mc.modelPool[provider] = finalModelList +} + +// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider +func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) { + if modelData == nil { + return + } + mc.mu.Lock() + defer mc.mu.Unlock() + + // Populating models from pricing data for the given provider + providerModels := []string{} + seenModels := make(map[string]bool) + for _, pricing := range mc.pricingData { + normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) + if normalizedProvider != provider { + continue + } + if !seenModels[pricing.Model] { + seenModels[pricing.Model] = true + providerModels = append(providerModels, pricing.Model) + } + } + for _, model := range modelData.Data { + parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") + if parsedProvider != provider { + continue + } + if !seenModels[parsedModel] { + seenModels[parsedModel] = true + providerModels = append(providerModels, parsedModel) + } + } + mc.unfilteredModelPool[provider] = providerModels +} + +// RefineModelForProvider refines the model for a given provider by performing a lookup +// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts. +// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b" +// +// Behavior: +// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error +// - When exactly one match is found, returns the fully-qualified model (provider/model format) +// - When the provider is not handled or no refinement is needed, returns the original model unchanged +func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) { + switch provider { + case schemas.Groq: + if strings.Contains(model, "gpt-") { + return "openai/" + model, nil + } + return mc.refineNestedProviderModel(provider, model) + case schemas.Replicate: + return mc.refineNestedProviderModel(provider, model) + } + return model, nil +} + +// refineNestedProviderModel resolves provider-native model slugs such as +// "openai/gpt-5-nano" from a base model request like "gpt-5-nano". +// It only considers catalog entries whose leading segment is a known Bifrost provider, +// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched. +func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) { + mc.mu.RLock() + models, ok := mc.modelPool[provider] + mc.mu.RUnlock() + if !ok { + return model, nil + } + + candidateModels := make([]string, 0) + seenCandidates := make(map[string]struct{}) + for _, poolModel := range models { + providerPart, modelPart := schemas.ParseModelString(poolModel, "") + if providerPart == "" || model != modelPart { + continue + } + + candidate := string(providerPart) + "/" + modelPart + if _, seen := seenCandidates[candidate]; seen { + continue + } + seenCandidates[candidate] = struct{}{} + candidateModels = append(candidateModels, candidate) + } + + switch len(candidateModels) { + case 0: + return model, nil + case 1: + return candidateModels[0], nil + default: + return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels) + } +} + +// SetPricingOverrides replaces the full in-memory pricing override set. +func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error { + seen := make(map[string]int, len(rows)) + overrides := make([]PricingOverride, 0, len(rows)) + for i := range rows { + o, err := convertTablePricingOverrideToPricingOverride(&rows[i]) + if err != nil { + return err + } + if idx, exists := seen[o.ID]; exists { + overrides[idx] = o // last entry wins for duplicate IDs + } else { + seen[o.ID] = len(overrides) + overrides = append(overrides, o) + } + } + mc.overridesMu.Lock() + mc.rawOverrides = overrides + mc.customPricing = buildCustomPricingData(overrides) + mc.overridesMu.Unlock() + return nil +} + +// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single +// operation, rebuilding the lookup map only once at the end. +func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error { + // Deduplicate the input batch by ID (last entry wins) and build the + // incoming set for O(1) lookup when filtering existing rawOverrides. + seenIncoming := make(map[string]int, len(rows)) + overrides := make([]PricingOverride, 0, len(rows)) + for _, row := range rows { + o, err := convertTablePricingOverrideToPricingOverride(row) + if err != nil { + return err + } + if idx, exists := seenIncoming[o.ID]; exists { + overrides[idx] = o // last entry wins for duplicate IDs + } else { + seenIncoming[o.ID] = len(overrides) + overrides = append(overrides, o) + } + } + + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides)) + for _, o := range mc.rawOverrides { + if _, replacing := seenIncoming[o.ID]; !replacing { + updated = append(updated, o) + } + } + updated = append(updated, overrides...) + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) + return nil +} + +// DeletePricingOverride removes a pricing override by ID. +func (mc *ModelCatalog) DeletePricingOverride(id string) { + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)) + for _, o := range mc.rawOverrides { + if o.ID != id { + updated = append(updated, o) + } + } + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) +} + +// IsTextCompletionSupported checks if a model supports text completion for the given provider. +// Returns true if the model has pricing data for text completion ("text_completion"), +// false otherwise. This is used by the litellmcompat plugin to determine whether to +// convert text completion requests to chat completion requests. +func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool { + mc.mu.RLock() + defer mc.mu.RUnlock() + // Check for text completion mode in pricing data + key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest)) + _, ok := mc.pricingData[key] + return ok +} + +// IsChatCompletionSupported checks if a model supports chat completion. +// It checks the supportedOutputs index (derived from supported_endpoints in the datasheet). +func (mc *ModelCatalog) IsChatCompletionSupported(model string, provider schemas.ModelProvider) bool { + mc.mu.RLock() + outputs, ok := mc.supportedOutputs[model] + mc.mu.RUnlock() + return ok && slices.Contains(outputs, "chat_completion") +} + +// IsResponsesSupported checks if a model supports the responses endpoint. +// It checks the supportedOutputs index (derived from supported_endpoints in the datasheet). +func (mc *ModelCatalog) IsResponsesSupported(model string, provider schemas.ModelProvider) bool { + mc.mu.RLock() + outputs, ok := mc.supportedOutputs[model] + mc.mu.RUnlock() + return ok && slices.Contains(outputs, "responses") +} + +// buildSupportedOutputsIndex parses supported_endpoints from model parameters data +// and rebuilds the supportedOutputs index with normalized output type names. +func (mc *ModelCatalog) buildSupportedOutputsIndex(paramsData map[string]json.RawMessage) { + newIndex := make(map[string][]string, len(paramsData)) + + for model, data := range paramsData { + var params struct { + SupportedEndpoints []string `json:"supported_endpoints"` + } + if err := json.Unmarshal(data, ¶ms); err != nil || len(params.SupportedEndpoints) == 0 { + continue + } + outputs := make([]string, 0, len(params.SupportedEndpoints)) + for _, endpoint := range params.SupportedEndpoints { + if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" { + if !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } + } + } + if len(outputs) > 0 { + newIndex[model] = outputs + } + } + + mc.mu.Lock() + mc.supportedOutputs = newIndex + mc.mu.Unlock() +} + // populateModelPool populates the model pool with all available models per provider (thread-safe) func (mc *ModelCatalog) populateModelPoolFromPricingData() { // Acquire write lock for the entire rebuild operation @@ -393,6 +1097,7 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), + supportedOutputs: make(map[string][]string), done: make(chan struct{}), } -} +} \ No newline at end of file diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 6de795936c..589dac05c9 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -70,7 +70,6 @@ func (mc *ModelCatalog) syncPricing(ctx context.Context) error { return nil }) - if err != nil { return fmt.Errorf("failed to sync pricing data to database: %w", err) } diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index 85c9977234..af7e569d2f 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -310,3 +310,18 @@ func convertTablePricingOverrideToPricingOverride(override *configstoreTables.Ta Options: options, }, nil } + +// normalizeEndpointToOutputType converts a supported_endpoints URL path to a normalized output type. +// Returns empty string for unrecognized endpoints. +func normalizeEndpointToOutputType(endpoint string) string { + switch { + case strings.Contains(endpoint, "/chat/completions"): + return "chat_completion" + case strings.Contains(endpoint, "/responses"): + return "responses" + case strings.Contains(endpoint, "/completions"): + return "text_completion" + default: + return "" + } +} diff --git a/plugins/litellmcompat/chattoresponses.go b/plugins/litellmcompat/chattoresponses.go new file mode 100644 index 0000000000..c8438a2b17 --- /dev/null +++ b/plugins/litellmcompat/chattoresponses.go @@ -0,0 +1,108 @@ +package litellmcompat + +import ( + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" +) + +const ( + ChatToResponsesOriginalRequestTypeContextKey schemas.BifrostContextKey = "litellmcompat-chat-to-responses-original-request-type" + ChatToResponsesOriginalModelContextKey schemas.BifrostContextKey = "litellmcompat-chat-to-responses-original-model" +) + +// transformChatToResponsesRequest determines whether a chat request should be converted +// to a responses request by core. It stores conversion intent in context; core performs +// the actual conversion. +func transformChatToResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { + // Only process chat completion requests + if req.RequestType != schemas.ChatCompletionRequest && req.RequestType != schemas.ChatCompletionStreamRequest { + return req + } + + // Check if chat completion request is present + if req.ChatRequest == nil { + return req + } + + // Check if the model supports chat completion via model catalog + if mc != nil { + provider := req.ChatRequest.Provider + model := req.ChatRequest.Model + if mc.IsChatCompletionSupported(model, provider) { + if ctx != nil { + ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, false) + } + if logger != nil { + logger.Debug("litellmcompat: model %s/%s supports chat completion, skipping conversion", provider, model) + } + return req + } + } + + // Track conversion intent. Core will do the actual conversion during provider dispatch. + if ctx != nil { + ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, true) + ctx.SetValue(ChatToResponsesOriginalRequestTypeContextKey, req.RequestType) + ctx.SetValue(ChatToResponsesOriginalModelContextKey, req.ChatRequest.Model) + } + + if logger != nil { + logger.Debug("litellmcompat: marked chat completion for core chat->responses conversion for model %s (chat completion not supported, responses supported)", req.ChatRequest.Model) + } + + return req +} + +func getOriginalChatRequestMetadata(ctx *schemas.BifrostContext) (schemas.RequestType, string) { + requestType := schemas.ChatCompletionRequest + if ctx == nil { + return requestType, "" + } + if value, ok := ctx.Value(ChatToResponsesOriginalRequestTypeContextKey).(schemas.RequestType); ok { + requestType = value + } + model, _ := ctx.Value(ChatToResponsesOriginalModelContextKey).(string) + return requestType, model +} + +// transformChatToResponsesResponse normalizes metadata on converted chat-completion responses. +// Core performs the actual stream/non-stream payload conversion. +func transformChatToResponsesResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, logger schemas.Logger) *schemas.BifrostResponse { + if resp == nil || resp.ChatResponse == nil || ctx == nil { + return resp + } + + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) + if !ok || !shouldConvert { + return resp + } + + originalRequestType, originalModel := getOriginalChatRequestMetadata(ctx) + resp.ChatResponse.ExtraFields.RequestType = originalRequestType + resp.ChatResponse.ExtraFields.ModelRequested = originalModel + resp.ChatResponse.ExtraFields.LiteLLMCompat = true + + if logger != nil { + logger.Debug("litellmcompat: normalized converted chat completion metadata for model %s", originalModel) + } + + return resp +} + +// transformChatToResponsesError restores original chat-completion metadata on errors +// generated from responses fallback execution. +func transformChatToResponsesError(ctx *schemas.BifrostContext, err *schemas.BifrostError) *schemas.BifrostError { + if err == nil || ctx == nil { + return err + } + shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) + if !ok || !shouldConvert { + return err + } + + originalRequestType, originalModel := getOriginalChatRequestMetadata(ctx) + err.ExtraFields.RequestType = originalRequestType + err.ExtraFields.ModelRequested = originalModel + err.ExtraFields.LiteLLMCompat = true + return err +} diff --git a/plugins/litellmcompat/main.go b/plugins/litellmcompat/main.go index 8983eb5ad3..59de6fedb8 100644 --- a/plugins/litellmcompat/main.go +++ b/plugins/litellmcompat/main.go @@ -1,9 +1,10 @@ -// Package litellmcompat provides LiteLLM-compatible text-to-chat conversion decisions -// for the Bifrost gateway. It marks text completion requests that should be converted -// by core provider dispatch for models that only support chat completions. +// Package litellmcompat provides LiteLLM-compatible request type conversion decisions +// for the Bifrost gateway. It marks requests that should be converted by core provider +// dispatch for models that don't natively support the requested endpoint type. // // When enabled, this plugin: // - Decides whether text_completion() should be converted to chat +// - Decides whether chat_completion() should be converted to responses // - Stores the decision in context for core request dispatch package litellmcompat @@ -24,24 +25,19 @@ type Config struct { // LiteLLMCompatPlugin provides LiteLLM-compatible request/response transformations. // When enabled, it automatically converts text completion requests to chat completion // requests for models that only support chat completions, matching LiteLLM's behavior. +// It also converts chat completion requests to responses for models that only support +// the responses endpoint. type LiteLLMCompatPlugin struct { config Config logger schemas.Logger modelCatalog *modelcatalog.ModelCatalog } -// Init creates a new litellmcompat plugin instance -func Init(config Config, logger schemas.Logger) (*LiteLLMCompatPlugin, error) { - return &LiteLLMCompatPlugin{ - config: config, - logger: logger, - }, nil -} - -// InitWithModelCatalog creates a new litellmcompat plugin instance with model catalog support. -// The model catalog is used to determine if a model supports text completion natively. -// If the model catalog is nil, the plugin will convert ALL text completion requests. -func InitWithModelCatalog(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*LiteLLMCompatPlugin, error) { +// Init creates a new litellmcompat plugin instance with model catalog support. +// The model catalog is used to determine if a model supports text completion or chat completion natively. +// If the model catalog is nil, the plugin will convert ALL text completion requests to chat completion +// and ALL chat completion requests to responses. +func Init(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*LiteLLMCompatPlugin, error) { return &LiteLLMCompatPlugin{ config: config, logger: logger, @@ -78,20 +74,31 @@ func (p *LiteLLMCompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostC // PreLLMHook intercepts requests and applies LiteLLM-compatible transformation intent. // For text completion requests on models that don't support text completion, // it marks the request so core can convert at provider dispatch time. +// For chat completion requests on models that don't support chat completion, +// it marks the request so core can convert at provider dispatch time. func (p *LiteLLMCompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + // Reset context keys + if ctx != nil { + ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, false) + ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, false) + } + // Apply request transforms in sequence req = transformTextToChatRequest(ctx, req, p.modelCatalog, p.logger) + req = transformChatToResponsesRequest(ctx, req, p.modelCatalog, p.logger) return req, nil, nil } // PostLLMHook normalizes metadata on converted responses/errors -// when this plugin requested text->chat conversion in PreLLMHook. +// when this plugin requested type conversion in PreLLMHook. func (p *LiteLLMCompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { if result != nil { result = transformTextToChatResponse(ctx, result, p.logger) + result = transformChatToResponsesResponse(ctx, result, p.logger) } if bifrostErr != nil { bifrostErr = transformTextToChatError(ctx, bifrostErr) + bifrostErr = transformChatToResponsesError(ctx, bifrostErr) } return result, bifrostErr, nil } diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 3cdf2f31fa..2d9a0ace81 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -110,7 +110,7 @@ func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifro if err != nil { return nil, fmt.Errorf("failed to marshal litellmcompat plugin config: %w", err) } - return litellmcompat.Init(*litellmConfig, logger) + return litellmcompat.Init(*litellmConfig, logger, bifrostConfig.ModelCatalog) default: return nil, fmt.Errorf("unknown built-in plugin: %s", name) @@ -293,4 +293,4 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin)) } return nil -} +} \ No newline at end of file diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 7f8378d417..c24ea67555 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -1554,4 +1554,4 @@ func (s *BifrostHTTPServer) Start() error { return err } return nil -} +} \ No newline at end of file From 3af404e95238a9a7c6cfead14d565d5ef1271438 Mon Sep 17 00:00:00 2001 From: Samyabrata Maji <116789799+sammaji@users.noreply.github.com> Date: Sat, 21 Mar 2026 11:56:12 +0530 Subject: [PATCH 3/4] feat: drops unsupported openai params --- AGENTS.md | 4 +- core/bifrost.go | 62 +-- core/providers/anthropic/types.go | 2 +- core/providers/bedrock/images.go | 1 - core/providers/bedrock/models.go | 3 +- core/providers/cohere/types.go | 7 +- core/providers/gemini/types.go | 21 +- core/providers/openai/types.go | 4 +- core/providers/perplexity/types.go | 78 ++-- core/providers/replicate/types.go | 36 +- core/providers/utils/utils.go | 21 +- core/providers/vertex/models.go | 2 +- core/providers/vertex/types.go | 28 +- core/providers/vertex/vertex.go | 2 +- core/schemas/bifrost.go | 47 +-- core/schemas/mux.go | 79 ++-- core/utils.go | 32 +- framework/go.mod | 2 +- framework/modelcatalog/main.go | 97 ++--- framework/modelcatalog/sync.go | 70 +++- framework/modelcatalog/utils.go | 78 ++++ nix/packages/bifrost-http.nix | 2 +- plugins/compat/changelog.md | 2 + plugins/compat/conversion.go | 25 ++ plugins/compat/dropparams.go | 218 +++++++++++ plugins/{litellmcompat => compat}/go.mod | 2 +- plugins/{litellmcompat => compat}/go.sum | 0 plugins/compat/main.go | 146 ++++++++ plugins/compat/requestcopy.go | 416 +++++++++++++++++++++ plugins/compat/version | 1 + plugins/litellmcompat/chattoresponses.go | 108 ------ plugins/litellmcompat/main.go | 109 ------ plugins/litellmcompat/texttochat.go | 107 ------ transports/bifrost-http/handlers/config.go | 18 +- transports/bifrost-http/lib/config.go | 4 +- transports/bifrost-http/server/plugins.go | 24 +- transports/go.mod | 16 +- transports/go.sum | 26 +- 38 files changed, 1227 insertions(+), 673 deletions(-) create mode 100644 plugins/compat/changelog.md create mode 100644 plugins/compat/conversion.go create mode 100644 plugins/compat/dropparams.go rename plugins/{litellmcompat => compat}/go.mod (99%) rename plugins/{litellmcompat => compat}/go.sum (100%) create mode 100644 plugins/compat/main.go create mode 100644 plugins/compat/requestcopy.go create mode 100644 plugins/compat/version delete mode 100644 plugins/litellmcompat/chattoresponses.go delete mode 100644 plugins/litellmcompat/main.go delete mode 100644 plugins/litellmcompat/texttochat.go diff --git a/AGENTS.md b/AGENTS.md index 03fdd812ee..bd41feaa78 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,7 +103,7 @@ bifrost/ │ ├── mocker/ # Mock responses for testing │ ├── jsonparser/ # JSON extraction utilities │ ├── maxim/ # Maxim observability -│ └── litellmcompat/ # LiteLLM SDK compatibility (HTTP transport) +│ └── compat/ # LiteLLM SDK compatibility (HTTP transport) │ ├── ui/ # Next.js web interface │ ├── app/workspace/ # Feature pages (20+ workspace sections) @@ -647,4 +647,4 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso - **Provider types**: Prefixed with provider name in PascalCase (`AnthropicChatRequest`, `GeminiEmbeddingResponse`). - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). -- **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. +- **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. \ No newline at end of file diff --git a/core/bifrost.go b/core/bifrost.go index 722bd61e29..9d8fb98c9a 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5255,54 +5255,6 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // bifrost.logger.Debug("worker for provider %s exiting...", provider.GetProviderKey()) } -func shouldConvertTextToChat(ctx *schemas.BifrostContext, requestType schemas.RequestType, request *schemas.BifrostTextCompletionRequest) bool { - if ctx == nil || request == nil { - return false - } - if requestType != schemas.TextCompletionRequest && requestType != schemas.TextCompletionStreamRequest { - return false - } - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) - return ok && shouldConvert -} - -func shouldConvertChatToResponses(ctx *schemas.BifrostContext, requestType schemas.RequestType, request *schemas.BifrostChatRequest) bool { - if ctx == nil || request == nil { - return false - } - if requestType != schemas.ChatCompletionRequest && requestType != schemas.ChatCompletionStreamRequest { - return false - } - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) - return ok && shouldConvert -} - -func wrapTextToChatStreamPostHookRunner(postHookRunner schemas.PostHookRunner) schemas.PostHookRunner { - return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - if result != nil && result.ChatResponse != nil { - if convertedResponse := result.ChatResponse.ToBifrostTextCompletionResponse(); convertedResponse != nil { - result = &schemas.BifrostResponse{ - TextCompletionResponse: convertedResponse, - } - } - } - return postHookRunner(ctx, result, bifrostErr) - } -} - -func wrapChatToResponsesStreamPostHookRunner(postHookRunner schemas.PostHookRunner) schemas.PostHookRunner { - return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - if result != nil && result.ResponsesStreamResponse != nil { - if convertedResponse := result.ResponsesStreamResponse.ToBifrostChatResponse(); convertedResponse != nil { - result = &schemas.BifrostResponse{ - ChatResponse: convertedResponse, - } - } - } - return postHookRunner(ctx, result, bifrostErr) - } -} - // handleProviderRequest handles the request to the provider based on the request type // key is used for single-key operations, keys is used for batch/file operations that need multiple keys func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, keys []schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -5315,7 +5267,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: - if shouldConvertTextToChat(req.Context, req.RequestType, req.BifrostRequest.TextCompletionRequest) { + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() if chatRequest != nil { chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) @@ -5332,7 +5284,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: - if shouldConvertChatToResponses(req.Context, req.RequestType, req.BifrostRequest.ChatRequest) { + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() if responsesRequest != nil { responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) @@ -5592,18 +5544,18 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: - if shouldConvertTextToChat(req.Context, req.RequestType, req.BifrostRequest.TextCompletionRequest) { + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() if chatRequest != nil { - return provider.ChatCompletionStream(req.Context, wrapTextToChatStreamPostHookRunner(postHookRunner), key, chatRequest) + return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), key, chatRequest) } } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: - if shouldConvertChatToResponses(req.Context, req.RequestType, req.BifrostRequest.ChatRequest) { + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() if responsesRequest != nil { - return provider.ResponsesStream(req.Context, wrapChatToResponsesStreamPostHookRunner(postHookRunner), key, responsesRequest) + return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), key, responsesRequest) } } return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) @@ -6757,4 +6709,4 @@ func (bifrost *Bifrost) Shutdown() { } } bifrost.logger.Info("all request channels closed") -} \ No newline at end of file +} diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 867dcf7e97..f3c45370cd 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -1316,4 +1316,4 @@ func parseAnthropicFileTimestamp(timestamp string) int64 { // AnthropicCountTokensResponse models the payload returned by Anthropic's count tokens endpoint. type AnthropicCountTokensResponse struct { InputTokens int `json:"input_tokens"` -} +} \ No newline at end of file diff --git a/core/providers/bedrock/images.go b/core/providers/bedrock/images.go index b0ac35dc01..dc3c76edd4 100644 --- a/core/providers/bedrock/images.go +++ b/core/providers/bedrock/images.go @@ -153,7 +153,6 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } return bedrockReq, nil - } // ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index 6d2f9006f2..549db2e3bd 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -81,7 +81,6 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } - func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil @@ -128,4 +127,4 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK pipeline.BackfillModels(included)...) return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index a4d78f9a48..8e5aa31402 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -9,8 +9,11 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinimumReasoningMaxTokens = 1 -const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +const ( + MinimumReasoningMaxTokens = 1 + DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +) + // Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request const ( cohereTokenizeMinTextLength = 1 diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 75cf9f504f..44755ee761 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -17,11 +17,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level -const DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation -const DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini -const DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini -const skipThoughtSignatureValidator = "skip_thought_signature_validator" +const ( + MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level + DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation + DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini + DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini + skipThoughtSignatureValidator = "skip_thought_signature_validator" +) type thinkingBudgetRange struct { Min int @@ -509,8 +511,7 @@ type GoogleMaps struct { } // URLContext is a tool to support URL context retrieval. -type URLContext struct { -} +type URLContext struct{} // ToolComputerUse is a tool to support computer use. type ToolComputerUse struct { @@ -555,8 +556,7 @@ type ExternalAPIElasticSearchParams struct { } // ExternalAPISimpleSearchParams represents the search parameters to use for SIMPLE_SEARCH spec. -type ExternalAPISimpleSearchParams struct { -} +type ExternalAPISimpleSearchParams struct{} // ExternalAPI retrieves from data source powered by external API for grounding. The external API // is not owned by Google, but needs to follow the pre-defined API spec. @@ -714,8 +714,7 @@ type Retrieval struct { // ToolCodeExecution is a tool that executes code generated by the model, and automatically returns the result // to the model. See also [ExecutableCode]and [CodeExecutionResult] which are input // and output to this tool. -type ToolCodeExecution struct { -} +type ToolCodeExecution struct{} // Tool details of a tool that the model may use to generate a response. type Tool struct { diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 89de4e1e66..39e25990d8 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -6,8 +6,8 @@ import ( "fmt" "github.com/bytedance/sonic" - "github.com/maximhq/bifrost/core/schemas" providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" ) const MinMaxCompletionTokens = 16 @@ -82,7 +82,7 @@ type OpenAIChatRequest struct { // PromptCacheIsolationKey is the Fireworks chat-completions field for cache isolation. PromptCacheIsolationKey *string `json:"prompt_cache_isolation_key,omitempty"` - //NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. + // NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. // This Field is populated only for such providers and is NOT to be used externally. MaxTokens *int `json:"max_tokens,omitempty"` diff --git a/core/providers/perplexity/types.go b/core/providers/perplexity/types.go index feef9e0ccb..d5ad5c65f6 100644 --- a/core/providers/perplexity/types.go +++ b/core/providers/perplexity/types.go @@ -4,45 +4,45 @@ import "github.com/maximhq/bifrost/core/schemas" // PerplexityChatRequest represents a Perplexity chat completion request type PerplexityChatRequest struct { - Model string `json:"model"` // Required: Model to use for chat completion - Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects - SearchMode *string `json:"search_mode"` // Required: Search mode - ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) - MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate - Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature - TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling - LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference - SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter - ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images - ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions - SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter - SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter - SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter - LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter - LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter - TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling - Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty - ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response - DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search - EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier - WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options - MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response - Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model - ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls - Stop []string `json:"stop,omitempty"` // Optional: Stop sequences - LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities - TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities - NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results - NumImages *int `json:"num_images,omitempty"` // Optional: Number of images - SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter - ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter - ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter - SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search - StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode - ExtraParams map[string]interface{} `json:"-"` + Model string `json:"model"` // Required: Model to use for chat completion + Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects + SearchMode *string `json:"search_mode"` // Required: Search mode + ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) + MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature + TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling + LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference + SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter + ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images + ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions + SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter + SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter + SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter + LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter + LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter + TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty + ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response + DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search + EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier + WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options + MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response + Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model + ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls + Stop []string `json:"stop,omitempty"` // Optional: Stop sequences + LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities + TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities + NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results + NumImages *int `json:"num_images,omitempty"` // Optional: Number of images + SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter + ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter + ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter + SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search + StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode + ExtraParams map[string]interface{} `json:"-"` } // GetExtraParams implements the RequestBodyWithExtraParams interface diff --git a/core/providers/replicate/types.go b/core/providers/replicate/types.go index 98f84e613e..3ae88c0095 100644 --- a/core/providers/replicate/types.go +++ b/core/providers/replicate/types.go @@ -313,28 +313,28 @@ type ReplicatePredictionListResponse struct { // ReplicateModelResponse represents a model response type ReplicateModelResponse struct { - URL string `json:"url"` // Model API URL - Owner string `json:"owner"` // Owner username or org name - Name string `json:"name"` // Model name - Description *string `json:"description,omitempty"` // Model description - Visibility string `json:"visibility"` // "public" or "private" - GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL - PaperURL *string `json:"paper_url,omitempty"` // Research paper URL - LicenseURL *string `json:"license_url,omitempty"` // License URL - RunCount *int `json:"run_count,omitempty"` // Number of times run - CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL - DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) - LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details - FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details + URL string `json:"url"` // Model API URL + Owner string `json:"owner"` // Owner username or org name + Name string `json:"name"` // Model name + Description *string `json:"description,omitempty"` // Model description + Visibility string `json:"visibility"` // "public" or "private" + GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL + PaperURL *string `json:"paper_url,omitempty"` // Research paper URL + LicenseURL *string `json:"license_url,omitempty"` // License URL + RunCount *int `json:"run_count,omitempty"` // Number of times run + CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL + DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) + LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details + FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details } // ReplicateModelVersion represents a model version type ReplicateModelVersion struct { - ID string `json:"id"` // Version ID - CreatedAt string `json:"created_at"` // ISO 8601 timestamp - CogVersion *string `json:"cog_version,omitempty"` // Cog version used - OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) - DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID + ID string `json:"id"` // Version ID + CreatedAt string `json:"created_at"` // ISO 8601 timestamp + CogVersion *string `json:"cog_version,omitempty"` // Cog version used + OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) + DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID } // ReplicateModelListResponse represents a paginated list of models diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 6cadc5a62c..189b4809d4 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -1680,7 +1680,7 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } @@ -1698,7 +1698,7 @@ func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunn Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } @@ -2051,14 +2051,13 @@ func ProcessAndSendError( logger schemas.Logger, ) { // Send scanner error through channel - bifrostError := - &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: fmt.Sprintf("Error reading stream: %v", err), - Error: err, - }, - } + bifrostError := &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { @@ -2220,7 +2219,7 @@ func GetBifrostResponseForStreamResponse( transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} switch { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 48837563eb..2fbe83979d 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -193,4 +193,4 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index 97d6de7fa2..bbdb89d17f 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -192,23 +192,23 @@ type VertexModelLabels struct { // These types are for the publishers.models.list endpoint (Model Garden) type VertexPublisherModel struct { - Name string `json:"name"` - VersionID string `json:"versionId"` - OpenSourceCategory string `json:"openSourceCategory"` - LaunchStage string `json:"launchStage"` - VersionState string `json:"versionState"` - PublisherModelTemplate string `json:"publisherModelTemplate"` - SupportedActions *VertexPublisherModelActions `json:"supportedActions"` + Name string `json:"name"` + VersionID string `json:"versionId"` + OpenSourceCategory string `json:"openSourceCategory"` + LaunchStage string `json:"launchStage"` + VersionState string `json:"versionState"` + PublisherModelTemplate string `json:"publisherModelTemplate"` + SupportedActions *VertexPublisherModelActions `json:"supportedActions"` } type VertexPublisherModelActions struct { - OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` - OpenGenie *VertexPublisherModelURI `json:"openGenie"` - OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` - OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` - OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` - Deploy *VertexPublisherModelDeploy `json:"deploy"` - OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` + OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` + OpenGenie *VertexPublisherModelURI `json:"openGenie"` + OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` + OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` + OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` + Deploy *VertexPublisherModelDeploy `json:"deploy"` + OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` } type VertexPublisherModelURI struct { diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 626cf00207..91acb476af 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -3153,4 +3153,4 @@ func (provider *VertexProvider) PassthroughStream( } }() return ch, nil -} +} \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 310871ab96..7d7c28ea47 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -196,8 +196,7 @@ const ( BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" - BifrostContextKeyShouldConvertTextToChat BifrostContextKey = "bifrost-should-convert-text-to-chat" // bool (set by plugins to trigger text->chat provider conversion in core) - BifrostContextKeyShouldConvertChatToResponses BifrostContextKey = "bifrost-should-convert-chat-to-responses" // bool (set by plugins to trigger chat->responses provider conversion in core) + BifrostContextKeyChangeRequestType BifrostContextKey = "bifrost-change-request-type" // RequestType (set by plugins to trigger request type conversion in core, e.g. text->chat or chat->responses) BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) @@ -1034,18 +1033,19 @@ type BifrostMCPResponse struct { // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - RequestType RequestType `json:"request_type"` - Provider ModelProvider `json:"provider,omitempty"` - OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request - ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` - ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` + ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` // params dropped by the compat plugin based on model catalog + ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) } type BifrostMCPResponseExtraFields struct { @@ -1220,13 +1220,14 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - OriginalModelRequested string `json:"original_model_requested,omitempty"` - ResolvedModelUsed string `json:"resolved_model_used,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest any `json:"raw_request,omitempty"` - RawResponse any `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` - MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication } diff --git a/core/schemas/mux.go b/core/schemas/mux.go index e719311539..24943d3fbd 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -1995,34 +1995,6 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes response.Output = allOutput } - // Append finalized function call items so the terminal response carries them in Output. - for toolCallID, args := range state.ToolArgumentBuffers { - if args == "" { - continue - } - statusFinal := terminalStatus - messageType := ResponsesMessageTypeFunctionCall - callName := state.ToolCallNames[toolCallID] - var callNamePtr *string - if callName != "" { - callNamePtr = &callName - } - argsValue := args - fcMsg := ResponsesMessage{ - Type: &messageType, - Status: &statusFinal, - ResponsesToolMessage: &ResponsesToolMessage{ - CallID: &toolCallID, - Name: callNamePtr, - Arguments: &argsValue, - }, - } - if itemID := state.ItemIDs[toolCallID]; itemID != "" { - fcMsg.ID = &itemID - } - response.Output = append(response.Output, fcMsg) - } - responses = append(responses, &BifrostResponsesStreamResponse{ Type: terminalEventType, SequenceNumber: state.SequenceNumber, @@ -2047,7 +2019,6 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes } // ToBifrostChatResponse converts a BifrostResponsesStreamResponse chunk to a BifrostChatResponse (chat.completion.chunk). -// Returns nil for events that have no meaningful chat completion equivalent (lifecycle events, etc.). func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatResponse { if rsr == nil { return nil @@ -2247,11 +2218,13 @@ func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatR resp.Usage = rsr.Response.Usage.ToBifrostLLMUsage() } // Check for tool_calls finish reason - for _, output := range rsr.Response.Output { - if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { - finishReason = string(BifrostFinishReasonToolCalls) - resp.Choices[0].FinishReason = &finishReason - break + if rsr.Type == ResponsesStreamResponseTypeCompleted { + for _, output := range rsr.Response.Output { + if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { + finishReason = string(BifrostFinishReasonToolCalls) + resp.Choices[0].FinishReason = &finishReason + break + } } } } @@ -2293,7 +2266,7 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -2326,7 +2299,7 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -2339,8 +2312,21 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom if choice.ChatNonStreamResponseChoice != nil { msg := choice.ChatNonStreamResponseChoice.Message var textContent *string - if msg != nil && msg.Content != nil && msg.Content.ContentStr != nil { - textContent = msg.Content.ContentStr + if msg != nil && msg.Content != nil { + if msg.Content.ContentStr != nil { + textContent = msg.Content.ContentStr + } else if len(msg.Content.ContentBlocks) > 0 { + var sb strings.Builder + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + sb.WriteString(*block.Text) + } + } + if sb.Len() > 0 { + s := sb.String() + textContent = &s + } + } } return &BifrostTextCompletionResponse{ ID: cr.ID, @@ -2362,7 +2348,7 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -2379,13 +2365,14 @@ func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCom SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } diff --git a/core/utils.go b/core/utils.go index 75b491a9c1..f013687489 100644 --- a/core/utils.go +++ b/core/utils.go @@ -272,8 +272,7 @@ func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) - ctx.ClearValue(schemas.BifrostContextKeyShouldConvertTextToChat) - ctx.ClearValue(schemas.BifrostContextKeyShouldConvertChatToResponses) + ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -581,7 +580,7 @@ func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model if discriminator == "" { discriminator = "__modelless__" } - return "session:" + string(provierKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) + return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) } // isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt. @@ -597,3 +596,30 @@ func isPromptOptionalImageEditType(t *string) bool { normalized, ) } + +// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming +// responses produced by a type-converted request are converted back to the +// caller's original type before the post-hook runs. +func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil { + switch targetType { + case schemas.ChatCompletionRequest: + // text→chat: convert chat stream chunk back to text completion + if result.ChatResponse != nil { + if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil { + result = &schemas.BifrostResponse{TextCompletionResponse: converted} + } + } + case schemas.ResponsesRequest: + // chat→responses: convert responses stream chunk back to chat + if result.ResponsesStreamResponse != nil { + if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil { + result = &schemas.BifrostResponse{ChatResponse: converted} + } + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} \ No newline at end of file diff --git a/framework/go.mod b/framework/go.mod index b4a2a0d0c3..e2c1149c95 100644 --- a/framework/go.mod +++ b/framework/go.mod @@ -9,6 +9,7 @@ require ( github.com/qdrant/go-client v1.16.2 github.com/redis/go-redis/v9 v9.17.2 github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 github.com/weaviate/weaviate v1.36.5 github.com/weaviate/weaviate-go-client/v5 v5.7.1 golang.org/x/crypto v0.49.0 @@ -54,7 +55,6 @@ require ( github.com/kylelemons/godebug v1.1.0 // indirect github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect - github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 8aea69bb7b..2122f339d5 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -47,9 +47,13 @@ type ModelCatalog struct { unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering baseModelIndex map[string]string // model string → canonical base model name - // Pre-parsed supported outputs index (keyed by model name, populated from model parameters supported_endpoints) - // Values are normalized output types: "chat_completion", "responses", "text_completion" - supportedOutputs map[string][]string + // Pre-parsed supported response types index (keyed by model name) + // Values are normalized response types: "chat_completion", "responses", "text_completion" + supportedResponseTypes map[string][]string + + // Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters) + // Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools") + supportedParams map[string][]string // Background sync worker syncTicker *time.Ticker @@ -80,7 +84,8 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto modelPool: make(map[schemas.ModelProvider][]string), unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), - supportedOutputs: make(map[string][]string), + supportedResponseTypes: make(map[string][]string), + supportedParams: make(map[string][]string), done: make(chan struct{}), distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)), } @@ -955,65 +960,28 @@ func (mc *ModelCatalog) DeletePricingOverride(id string) { mc.customPricing = buildCustomPricingData(updated) } -// IsTextCompletionSupported checks if a model supports text completion for the given provider. -// Returns true if the model has pricing data for text completion ("text_completion"), -// false otherwise. This is used by the litellmcompat plugin to determine whether to -// convert text completion requests to chat completion requests. -func (mc *ModelCatalog) IsTextCompletionSupported(model string, provider schemas.ModelProvider) bool { +// IsRequestTypeSupported checks if a model supports chat completion. +// It checks the supportedResponseTypes index. +func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool { mc.mu.RLock() - defer mc.mu.RUnlock() - // Check for text completion mode in pricing data - key := makeKey(model, normalizeProvider(string(provider)), normalizeRequestType(schemas.TextCompletionRequest)) - _, ok := mc.pricingData[key] - return ok -} - -// IsChatCompletionSupported checks if a model supports chat completion. -// It checks the supportedOutputs index (derived from supported_endpoints in the datasheet). -func (mc *ModelCatalog) IsChatCompletionSupported(model string, provider schemas.ModelProvider) bool { - mc.mu.RLock() - outputs, ok := mc.supportedOutputs[model] + outputs, ok := mc.supportedResponseTypes[model] mc.mu.RUnlock() - return ok && slices.Contains(outputs, "chat_completion") + return ok && slices.Contains(outputs, string(requestType)) } -// IsResponsesSupported checks if a model supports the responses endpoint. -// It checks the supportedOutputs index (derived from supported_endpoints in the datasheet). -func (mc *ModelCatalog) IsResponsesSupported(model string, provider schemas.ModelProvider) bool { +// GetSupportedParameters returns the list of supported parameter names for a model. +// Returns nil if the model is not found in the catalog. +func (mc *ModelCatalog) GetSupportedParameters(model string) []string { mc.mu.RLock() - outputs, ok := mc.supportedOutputs[model] + params, ok := mc.supportedParams[model] mc.mu.RUnlock() - return ok && slices.Contains(outputs, "responses") -} - -// buildSupportedOutputsIndex parses supported_endpoints from model parameters data -// and rebuilds the supportedOutputs index with normalized output type names. -func (mc *ModelCatalog) buildSupportedOutputsIndex(paramsData map[string]json.RawMessage) { - newIndex := make(map[string][]string, len(paramsData)) - - for model, data := range paramsData { - var params struct { - SupportedEndpoints []string `json:"supported_endpoints"` - } - if err := json.Unmarshal(data, ¶ms); err != nil || len(params.SupportedEndpoints) == 0 { - continue - } - outputs := make([]string, 0, len(params.SupportedEndpoints)) - for _, endpoint := range params.SupportedEndpoints { - if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" { - if !slices.Contains(outputs, normalized) { - outputs = append(outputs, normalized) - } - } - } - if len(outputs) > 0 { - newIndex[model] = outputs - } + if !ok { + return nil } - - mc.mu.Lock() - mc.supportedOutputs = newIndex - mc.mu.Unlock() + // Return a copy to prevent external modification + result := make([]string, len(params)) + copy(result, params) + return result } // populateModelPool populates the model pool with all available models per provider (thread-safe) @@ -1093,11 +1061,12 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { baseModelIndex = make(map[string]string) } return &ModelCatalog{ - modelPool: make(map[schemas.ModelProvider][]string), - unfilteredModelPool: make(map[schemas.ModelProvider][]string), - baseModelIndex: baseModelIndex, - pricingData: make(map[string]configstoreTables.TableModelPricing), - supportedOutputs: make(map[string][]string), - done: make(chan struct{}), - } -} \ No newline at end of file + modelPool: make(map[schemas.ModelProvider][]string), + unfilteredModelPool: make(map[schemas.ModelProvider][]string), + baseModelIndex: baseModelIndex, + pricingData: make(map[string]configstoreTables.TableModelPricing), + supportedResponseTypes: make(map[string][]string), + supportedParams: make(map[string][]string), + done: make(chan struct{}), + } +} diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 589dac05c9..4f0186ba40 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -6,11 +6,14 @@ import ( "fmt" "io" "net/http" + "slices" "sync" "time" providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/tidwall/gjson" "gorm.io/gorm" ) @@ -213,7 +216,7 @@ func (mc *ModelCatalog) loadModelParametersFromDatabase(ctx context.Context) (in for _, row := range rows { paramsData[row.Model] = json.RawMessage(row.Data) } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) mc.logger.Debug("loaded %d model parameters records from database into cache", len(rows)) return len(rows), nil } @@ -295,16 +298,71 @@ func (mc *ModelCatalog) syncWorker(ctx context.Context) { // --- Model Parameters sync --- -func applyModelParametersToProviderCache(paramsData map[string]json.RawMessage) { +func (mc *ModelCatalog) applyModelParameters(paramsData map[string]json.RawMessage) { modelParamsEntries := make(map[string]providerUtils.ModelParams, len(paramsData)) + newResponseTypes := make(map[string][]string, len(paramsData)) + newParamsIndex := make(map[string][]string, len(paramsData)) + for model, rawData := range paramsData { + var parsed modelParametersParseResult + if err := json.Unmarshal(rawData, &parsed); err != nil { + mc.logger.Warn("model-parameters-sync: skipping malformed parameters for model %s: %v", model, err) + continue + } + + outputs := make([]string, 0, len(parsed.SupportedEndpoints)) + for _, endpoint := range parsed.SupportedEndpoints { + if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" && !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } + } + + if parsed.Mode != nil { + if normalized := normalizeModeToOutputType(*parsed.Mode); normalized != "" && !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } + } + + if !slices.Contains(outputs, "text_completion") { + provider := gjson.GetBytes(rawData, "provider") + if provider.Exists() { + key := makeKey(model, normalizeProvider(provider.String()), normalizeRequestType(schemas.TextCompletionRequest)) + + mc.mu.RLock() + _, ok := mc.pricingData[key] + mc.mu.RUnlock() + if ok { + outputs = append(outputs, "text_completion") + } + } + } + + if len(outputs) > 0 { + newResponseTypes[model] = outputs + } + + supported := extractSupportedParams(&parsed) + if len(supported) > 0 { + newParamsIndex[model] = supported + } + var p struct { MaxOutputTokens *int `json:"max_output_tokens"` } - if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil { + if p.MaxOutputTokens == nil { + if err := json.Unmarshal(rawData, &p); err == nil && p.MaxOutputTokens != nil { + modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens} + } + } else { modelParamsEntries[model] = providerUtils.ModelParams{MaxOutputTokens: p.MaxOutputTokens} } } + + mc.mu.Lock() + mc.supportedResponseTypes = newResponseTypes + mc.supportedParams = newParamsIndex + mc.mu.Unlock() + if len(modelParamsEntries) > 0 { providerUtils.BulkSetModelParams(modelParamsEntries) } @@ -319,7 +377,7 @@ func (mc *ModelCatalog) loadModelParametersIntoMemoryFromURL(ctx context.Context if err != nil { return fmt.Errorf("failed to load model parameters from URL: %w", err) } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) return nil } @@ -366,7 +424,7 @@ func (mc *ModelCatalog) syncModelParameters(ctx context.Context) error { } } - applyModelParametersToProviderCache(paramsData) + mc.applyModelParameters(paramsData) mc.logger.Info("successfully synced %d model parameters records", len(paramsData)) return nil @@ -403,4 +461,4 @@ func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[str mc.logger.Debug("successfully downloaded and parsed %d model parameters records", len(paramsData)) return paramsData, nil -} +} \ No newline at end of file diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index af7e569d2f..3ffe956b4c 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -2,6 +2,7 @@ package modelcatalog import ( "context" + "slices" "strings" "time" @@ -325,3 +326,80 @@ func normalizeEndpointToOutputType(endpoint string) string { return "" } } + +// normalizeModeToOutputType converts mode to a normalized output type. +func normalizeModeToOutputType(mode string) string { + switch mode { + case "chat": + return "chat_completion" + case "completion": + return "text_completion" + case "responses": + return "responses" + default: + return "" + } +} + +// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex. +type modelParametersParseResult struct { + Mode *string `json:"mode,omitempty"` + SupportedEndpoints []string `json:"supported_endpoints,omitempty"` + ModelParameters []struct { + ID string `json:"id"` + } `json:"model_parameters,omitempty"` + SupportsFunctionCalling *bool `json:"supports_function_calling,omitempty"` + SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling,omitempty"` + SupportsToolChoice *bool `json:"supports_tool_choice,omitempty"` + SupportsReasoning *bool `json:"supports_reasoning,omitempty"` + SupportsServiceTier *bool `json:"supports_service_tier,omitempty"` + SupportsPromptCaching *bool `json:"supports_prompt_caching,omitempty"` +} + +// extractSupportedParams builds a list of supported OpenAI-compatible parameter +// names from model_parameters[].id values and supports_* boolean flags. +func extractSupportedParams(parsed *modelParametersParseResult) []string { + var supported []string + addParam := func(name string) { + if !slices.Contains(supported, name) { + supported = append(supported, name) + } + } + + // From model_parameters[].id — map IDs to request param names + for _, mp := range parsed.ModelParameters { + switch mp.ID { + case "reasoning_effort", "reasoning_summary": + addParam("reasoning") + case "web_search": + addParam("web_search_options") + case "promptTools", "image_detail", "stream": + // skip — not top-level request parameters + default: + addParam(mp.ID) + } + } + + // From supports_* boolean flags + if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling { + addParam("tools") + } + if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling { + addParam("parallel_tool_calls") + } + if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice { + addParam("tool_choice") + } + if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning { + addParam("reasoning") + } + if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier { + addParam("service_tier") + } + if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching { + addParam("prompt_cache_key") + addParam("prompt_cache_retention") + } + + return supported +} diff --git a/nix/packages/bifrost-http.nix b/nix/packages/bifrost-http.nix index f0f3b16ea1..0d05dd1e59 100644 --- a/nix/packages/bifrost-http.nix +++ b/nix/packages/bifrost-http.nix @@ -20,7 +20,7 @@ let replace github.com/maximhq/bifrost/core => ../core replace github.com/maximhq/bifrost/framework => ../framework replace github.com/maximhq/bifrost/plugins/governance => ../plugins/governance - replace github.com/maximhq/bifrost/plugins/litellmcompat => ../plugins/litellmcompat + replace github.com/maximhq/bifrost/plugins/compat => ../plugins/compat replace github.com/maximhq/bifrost/plugins/logging => ../plugins/logging replace github.com/maximhq/bifrost/plugins/maxim => ../plugins/maxim replace github.com/maximhq/bifrost/plugins/otel => ../plugins/otel diff --git a/plugins/compat/changelog.md b/plugins/compat/changelog.md new file mode 100644 index 0000000000..ad3d633b71 --- /dev/null +++ b/plugins/compat/changelog.md @@ -0,0 +1,2 @@ +- feat: Adds option for converting chat completions to responses for models that support it +- feat: Adds option for dropping unsupported model parameters \ No newline at end of file diff --git a/plugins/compat/conversion.go b/plugins/compat/conversion.go new file mode 100644 index 0000000000..d51ca4d730 --- /dev/null +++ b/plugins/compat/conversion.go @@ -0,0 +1,25 @@ +package compat + +import "github.com/maximhq/bifrost/core/schemas" + +// applyParameterConversion rewrites request fields in place for provider compatibility. +func applyParameterConversion(req *schemas.BifrostRequest) { + if req == nil { + return + } + normalizeDeveloperRoleForChatRequest(req) +} + +func normalizeDeveloperRoleForChatRequest(req *schemas.BifrostRequest) { + if req.ChatRequest == nil { + return + } + if req.ChatRequest.Provider != schemas.Bedrock && req.ChatRequest.Provider != schemas.Vertex && req.ChatRequest.Provider != schemas.Gemini { + return + } + for i := range req.ChatRequest.Input { + if req.ChatRequest.Input[i].Role == schemas.ChatMessageRoleDeveloper { + req.ChatRequest.Input[i].Role = schemas.ChatMessageRoleSystem + } + } +} diff --git a/plugins/compat/dropparams.go b/plugins/compat/dropparams.go new file mode 100644 index 0000000000..fcf79a2df1 --- /dev/null +++ b/plugins/compat/dropparams.go @@ -0,0 +1,218 @@ +package compat + +import "github.com/maximhq/bifrost/core/schemas" + +// dropUnsupportedParams removes unsupported model parameters from a request in place. +func dropUnsupportedParams(req *schemas.BifrostRequest, supportedParams []string) []string { + if req == nil { + return nil + } + + isSupported := make(map[string]bool, len(supportedParams)) + for _, param := range supportedParams { + isSupported[param] = true + } + + var dropped []string + + if req.ChatRequest != nil && req.ChatRequest.Params != nil { + params := req.ChatRequest.Params + + if params.Audio != nil && !isSupported["audio"] { + params.Audio = nil + dropped = append(dropped, "audio") + } + if params.FrequencyPenalty != nil && !isSupported["frequency_penalty"] { + params.FrequencyPenalty = nil + dropped = append(dropped, "frequency_penalty") + } + if params.LogitBias != nil && !isSupported["logit_bias"] { + params.LogitBias = nil + dropped = append(dropped, "logit_bias") + } + if params.LogProbs != nil && !isSupported["logprobs"] { + params.LogProbs = nil + dropped = append(dropped, "logprobs") + } + if params.MaxCompletionTokens != nil && !isSupported["max_completion_tokens"] { + params.MaxCompletionTokens = nil + dropped = append(dropped, "max_completion_tokens") + } + if params.Metadata != nil && !isSupported["metadata"] { + params.Metadata = nil + dropped = append(dropped, "metadata") + } + if params.ParallelToolCalls != nil && !isSupported["parallel_tool_calls"] { + params.ParallelToolCalls = nil + dropped = append(dropped, "parallel_tool_calls") + } + if params.Prediction != nil && !isSupported["prediction"] { + params.Prediction = nil + dropped = append(dropped, "prediction") + } + if params.PresencePenalty != nil && !isSupported["presence_penalty"] { + params.PresencePenalty = nil + dropped = append(dropped, "presence_penalty") + } + if params.PromptCacheKey != nil && !isSupported["prompt_cache_key"] { + params.PromptCacheKey = nil + dropped = append(dropped, "prompt_cache_key") + } + if params.PromptCacheRetention != nil && !isSupported["prompt_cache_retention"] { + params.PromptCacheRetention = nil + dropped = append(dropped, "prompt_cache_retention") + } + if params.Reasoning != nil && !isSupported["reasoning"] { + params.Reasoning = nil + dropped = append(dropped, "reasoning") + } + if params.ResponseFormat != nil && !isSupported["response_format"] { + params.ResponseFormat = nil + dropped = append(dropped, "response_format") + } + if params.Seed != nil && !isSupported["seed"] { + params.Seed = nil + dropped = append(dropped, "seed") + } + if params.ServiceTier != nil && !isSupported["service_tier"] { + params.ServiceTier = nil + dropped = append(dropped, "service_tier") + } + if len(params.Stop) > 0 && !isSupported["stop"] { + params.Stop = nil + dropped = append(dropped, "stop") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.TopLogProbs != nil && !isSupported["top_logprobs"] { + params.TopLogProbs = nil + dropped = append(dropped, "top_logprobs") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + if params.ToolChoice != nil && !isSupported["tool_choice"] { + params.ToolChoice = nil + dropped = append(dropped, "tool_choice") + } + if len(params.Tools) > 0 && !isSupported["tools"] { + params.Tools = nil + dropped = append(dropped, "tools") + } + if params.Verbosity != nil && !isSupported["verbosity"] { + params.Verbosity = nil + dropped = append(dropped, "verbosity") + } + if params.WebSearchOptions != nil && !isSupported["web_search_options"] { + params.WebSearchOptions = nil + dropped = append(dropped, "web_search_options") + } + } + + if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { + params := req.ResponsesRequest.Params + + if params.MaxOutputTokens != nil && !isSupported["max_output_tokens"] { + params.MaxOutputTokens = nil + dropped = append(dropped, "max_output_tokens") + } + if params.MaxToolCalls != nil && !isSupported["max_tool_calls"] { + params.MaxToolCalls = nil + dropped = append(dropped, "max_tool_calls") + } + if params.Metadata != nil && !isSupported["metadata"] { + params.Metadata = nil + dropped = append(dropped, "metadata") + } + if params.ParallelToolCalls != nil && !isSupported["parallel_tool_calls"] { + params.ParallelToolCalls = nil + dropped = append(dropped, "parallel_tool_calls") + } + if params.PromptCacheKey != nil && !isSupported["prompt_cache_key"] { + params.PromptCacheKey = nil + dropped = append(dropped, "prompt_cache_key") + } + if params.Reasoning != nil && !isSupported["reasoning"] { + params.Reasoning = nil + dropped = append(dropped, "reasoning") + } + if params.ServiceTier != nil && !isSupported["service_tier"] { + params.ServiceTier = nil + dropped = append(dropped, "service_tier") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.Text != nil && !isSupported["text"] { + params.Text = nil + dropped = append(dropped, "text") + } + if params.TopLogProbs != nil && !isSupported["top_logprobs"] { + params.TopLogProbs = nil + dropped = append(dropped, "top_logprobs") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + if params.ToolChoice != nil && !isSupported["tool_choice"] { + params.ToolChoice = nil + dropped = append(dropped, "tool_choice") + } + if len(params.Tools) > 0 && !isSupported["tools"] { + params.Tools = nil + dropped = append(dropped, "tools") + } + } + + if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { + params := req.TextCompletionRequest.Params + + if params.FrequencyPenalty != nil && !isSupported["frequency_penalty"] { + params.FrequencyPenalty = nil + dropped = append(dropped, "frequency_penalty") + } + if params.LogitBias != nil && !isSupported["logit_bias"] { + params.LogitBias = nil + dropped = append(dropped, "logit_bias") + } + if params.LogProbs != nil && !isSupported["logprobs"] { + params.LogProbs = nil + dropped = append(dropped, "logprobs") + } + if params.MaxTokens != nil && !isSupported["max_tokens"] { + params.MaxTokens = nil + dropped = append(dropped, "max_tokens") + } + if params.N != nil && !isSupported["n"] { + params.N = nil + dropped = append(dropped, "n") + } + if params.PresencePenalty != nil && !isSupported["presence_penalty"] { + params.PresencePenalty = nil + dropped = append(dropped, "presence_penalty") + } + if params.Seed != nil && !isSupported["seed"] { + params.Seed = nil + dropped = append(dropped, "seed") + } + if len(params.Stop) > 0 && !isSupported["stop"] { + params.Stop = nil + dropped = append(dropped, "stop") + } + if params.Temperature != nil && !isSupported["temperature"] { + params.Temperature = nil + dropped = append(dropped, "temperature") + } + if params.TopP != nil && !isSupported["top_p"] { + params.TopP = nil + dropped = append(dropped, "top_p") + } + } + + return dropped +} diff --git a/plugins/litellmcompat/go.mod b/plugins/compat/go.mod similarity index 99% rename from plugins/litellmcompat/go.mod rename to plugins/compat/go.mod index f4afc7b154..7c282948fc 100644 --- a/plugins/litellmcompat/go.mod +++ b/plugins/compat/go.mod @@ -1,4 +1,4 @@ -module github.com/maximhq/bifrost/plugins/litellmcompat +module github.com/maximhq/bifrost/plugins/compat go 1.26.2 diff --git a/plugins/litellmcompat/go.sum b/plugins/compat/go.sum similarity index 100% rename from plugins/litellmcompat/go.sum rename to plugins/compat/go.sum diff --git a/plugins/compat/main.go b/plugins/compat/main.go new file mode 100644 index 0000000000..0c64b7b6ca --- /dev/null +++ b/plugins/compat/main.go @@ -0,0 +1,146 @@ +// Package compat provides LiteLLM-compatible request normalization for the +// Bifrost gateway. It drops unsupported model params first, then rewrites +// requests to a compatible endpoint type when the target model does not support +// the caller's original request type. +package compat + +import ( + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" +) + +const PluginName = "compat" + +// Config defines the configuration for the compat plugin. +type Config struct { + Enabled bool `json:"enabled"` +} + +// CompatPlugin provides LiteLLM-compatible request/response transformations. +// When enabled, it automatically converts text completion requests to chat +// completion requests for models that only support chat completions, matching +// LiteLLM's behavior. It also converts chat completion requests to responses +// for models that only support the responses endpoint. +type CompatPlugin struct { + config Config + logger schemas.Logger + modelCatalog *modelcatalog.ModelCatalog + droppedParams []string +} + +// Init creates a new compat plugin instance with model catalog support. +// The model catalog is used to determine if a model supports text completion or +// chat completion natively. If the model catalog is nil, the plugin will +// convert all text completion requests to chat completion and all chat +// completion requests to responses. +func Init(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*CompatPlugin, error) { + return &CompatPlugin{ + config: config, + logger: logger, + modelCatalog: mc, + }, nil +} + +// GetName returns the plugin name +func (p *CompatPlugin) GetName() string { + return PluginName +} + +// HTTPTransportPreHook is not used for this plugin +func (p *CompatPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil +} + +// HTTPTransportPostHook is not used for this plugin +func (p *CompatPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { + return nil +} + +// HTTPTransportStreamChunkHook passes through streaming chunks unchanged. +func (p *CompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { + return chunk, nil +} + +// PreLLMHook intercepts requests and applies LiteLLM-compatible request normalization. +func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + if ctx == nil || req == nil { + return req, nil, nil + } + + // text completion → chat conversion + if (req.RequestType == schemas.TextCompletionRequest || req.RequestType == schemas.TextCompletionStreamRequest) && req.TextCompletionRequest != nil { + p.markForConversion(ctx, req.TextCompletionRequest.Provider, req.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + } + + // chat completion → responses conversion + if (req.RequestType == schemas.ChatCompletionRequest || req.RequestType == schemas.ChatCompletionStreamRequest) && req.ChatRequest != nil { + p.markForConversion(ctx, req.ChatRequest.Provider, req.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + } + + modifiedReq := cloneBifrostReq(req) + p.droppedParams = nil + if p.modelCatalog != nil { + _, model, _ := req.GetRequestFields() + if model != "" { + if supportedParams := p.modelCatalog.GetSupportedParameters(model); supportedParams != nil { + droppedParams := dropUnsupportedParams(modifiedReq, supportedParams) + if len(droppedParams) > 0 { + p.droppedParams = droppedParams + } + } + } + } + + applyParameterConversion(modifiedReq) + + return modifiedReq, nil, nil +} + +// PostLLMHook converts provider responses back to the caller-facing shape +func (p *CompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if ctx == nil { + return result, bifrostErr, nil + } + + if changeType, ok := ctx.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok { + if result != nil { + extraFields := result.GetExtraFields() + if extraFields != nil { + extraFields.ConvertedRequestType = changeType + } + } + if bifrostErr != nil { + bifrostErr.ExtraFields.ConvertedRequestType = changeType + } + } + + if result != nil { + if extraFields := result.GetExtraFields(); extraFields != nil { + extraFields.DroppedCompatPluginParams = p.droppedParams + } + } + + return result, bifrostErr, nil +} + +// Cleanup performs plugin cleanup. +func (p *CompatPlugin) Cleanup() error { + return nil +} + +// markForConversion checks if the model supports the current request type; if not, mark for conversion +func (p *CompatPlugin) markForConversion(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, currentType schemas.RequestType, targetType schemas.RequestType) { + shouldConvert := true + + if p.modelCatalog != nil { + if p.modelCatalog.IsRequestTypeSupported(model, provider, currentType) { + p.logger.Debug("compat: model %s/%s supports %v, skipping conversion", provider, model, currentType) + shouldConvert = false + } + } + + if shouldConvert { + ctx.SetValue(schemas.BifrostContextKeyChangeRequestType, targetType) + p.logger.Debug("compat: marked %v for core conversion to %v for model %s", currentType, targetType, model) + } +} \ No newline at end of file diff --git a/plugins/compat/requestcopy.go b/plugins/compat/requestcopy.go new file mode 100644 index 0000000000..f7a3937f92 --- /dev/null +++ b/plugins/compat/requestcopy.go @@ -0,0 +1,416 @@ +package compat + +import ( + "bytes" + "maps" + "slices" + + "github.com/maximhq/bifrost/core/schemas" +) + +func cloneBifrostReq(req *schemas.BifrostRequest) *schemas.BifrostRequest { + if req == nil { + return nil + } + + cloned := *req + + if req.TextCompletionRequest != nil { + cloned.TextCompletionRequest = cloneTextCompletionRequest(req.TextCompletionRequest) + } + if req.ChatRequest != nil { + cloned.ChatRequest = cloneChatRequest(req.ChatRequest) + } + if req.ResponsesRequest != nil { + cloned.ResponsesRequest = cloneResponsesRequest(req.ResponsesRequest) + } + + return &cloned +} + +func cloneTextCompletionRequest(req *schemas.BifrostTextCompletionRequest) *schemas.BifrostTextCompletionRequest { + if req == nil { + return nil + } + + cloned := *req + cloned.Input = cloneTextCompletionInput(req.Input) + cloned.Params = cloneTextCompletionParameters(req.Params) + cloned.Fallbacks = slices.Clone(req.Fallbacks) + cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) + return &cloned +} + +func cloneTextCompletionInput(input *schemas.TextCompletionInput) *schemas.TextCompletionInput { + if input == nil { + return nil + } + cloned := &schemas.TextCompletionInput{ + PromptArray: slices.Clone(input.PromptArray), + } + if input.PromptStr != nil { + prompt := *input.PromptStr + cloned.PromptStr = &prompt + } + return cloned +} + +func cloneTextCompletionParameters(params *schemas.TextCompletionParameters) *schemas.TextCompletionParameters { + if params == nil { + return nil + } + cloned := *params + if params.LogitBias != nil { + logitBias := cloneStringFloat64Map(*params.LogitBias) + cloned.LogitBias = &logitBias + } + if params.Stop != nil { + cloned.Stop = slices.Clone(params.Stop) + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneChatRequest(req *schemas.BifrostChatRequest) *schemas.BifrostChatRequest { + if req == nil { + return nil + } + + cloned := *req + if req.Input != nil { + cloned.Input = make([]schemas.ChatMessage, len(req.Input)) + for i, message := range req.Input { + cloned.Input[i] = schemas.DeepCopyChatMessage(message) + } + } + cloned.Params = cloneChatParameters(req.Params) + cloned.Fallbacks = slices.Clone(req.Fallbacks) + cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) + return &cloned +} + +func cloneChatParameters(params *schemas.ChatParameters) *schemas.ChatParameters { + if params == nil { + return nil + } + + cloned := *params + if params.Audio != nil { + audio := *params.Audio + cloned.Audio = &audio + } + if params.LogitBias != nil { + logitBias := cloneStringFloat64Map(*params.LogitBias) + cloned.LogitBias = &logitBias + } + if params.Metadata != nil { + metadata := cloneAnyMap(*params.Metadata) + cloned.Metadata = &metadata + } + if params.Modalities != nil { + cloned.Modalities = slices.Clone(params.Modalities) + } + if params.Prediction != nil { + prediction := *params.Prediction + prediction.Content = cloneAnyValue(params.Prediction.Content) + cloned.Prediction = &prediction + } + if params.Reasoning != nil { + reasoning := *params.Reasoning + cloned.Reasoning = &reasoning + } + if params.ResponseFormat != nil { + responseFormat := cloneAnyValue(*params.ResponseFormat) + cloned.ResponseFormat = &responseFormat + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.Stop != nil { + cloned.Stop = slices.Clone(params.Stop) + } + if params.ToolChoice != nil { + cloned.ToolChoice = cloneChatToolChoice(params.ToolChoice) + } + if params.Tools != nil { + cloned.Tools = make([]schemas.ChatTool, len(params.Tools)) + for i, tool := range params.Tools { + cloned.Tools[i] = schemas.DeepCopyChatTool(tool) + } + } + if params.WebSearchOptions != nil { + cloned.WebSearchOptions = cloneChatWebSearchOptions(params.WebSearchOptions) + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneChatToolChoice(choice *schemas.ChatToolChoice) *schemas.ChatToolChoice { + if choice == nil { + return nil + } + + cloned := &schemas.ChatToolChoice{} + if choice.ChatToolChoiceStr != nil { + value := *choice.ChatToolChoiceStr + cloned.ChatToolChoiceStr = &value + } + if choice.ChatToolChoiceStruct != nil { + choiceStruct := *choice.ChatToolChoiceStruct + if choice.ChatToolChoiceStruct.Function != nil { + function := *choice.ChatToolChoiceStruct.Function + choiceStruct.Function = &function + } + if choice.ChatToolChoiceStruct.Custom != nil { + custom := *choice.ChatToolChoiceStruct.Custom + choiceStruct.Custom = &custom + } + if choice.ChatToolChoiceStruct.AllowedTools != nil { + allowedTools := *choice.ChatToolChoiceStruct.AllowedTools + allowedTools.Tools = slices.Clone(choice.ChatToolChoiceStruct.AllowedTools.Tools) + choiceStruct.AllowedTools = &allowedTools + } + cloned.ChatToolChoiceStruct = &choiceStruct + } + return cloned +} + +func cloneChatWebSearchOptions(options *schemas.ChatWebSearchOptions) *schemas.ChatWebSearchOptions { + if options == nil { + return nil + } + + cloned := *options + if options.UserLocation != nil { + userLocation := *options.UserLocation + if options.UserLocation.Approximate != nil { + approximate := *options.UserLocation.Approximate + userLocation.Approximate = &approximate + } + cloned.UserLocation = &userLocation + } + return &cloned +} + +func cloneResponsesRequest(req *schemas.BifrostResponsesRequest) *schemas.BifrostResponsesRequest { + if req == nil { + return nil + } + + cloned := *req + if req.Input != nil { + cloned.Input = make([]schemas.ResponsesMessage, len(req.Input)) + for i, message := range req.Input { + cloned.Input[i] = schemas.DeepCopyResponsesMessage(message) + } + } + cloned.Params = cloneResponsesParameters(req.Params) + cloned.Fallbacks = slices.Clone(req.Fallbacks) + cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) + return &cloned +} + +func cloneResponsesParameters(params *schemas.ResponsesParameters) *schemas.ResponsesParameters { + if params == nil { + return nil + } + + cloned := *params + if params.Include != nil { + cloned.Include = slices.Clone(params.Include) + } + if params.Metadata != nil { + metadata := cloneAnyMap(*params.Metadata) + cloned.Metadata = &metadata + } + if params.Reasoning != nil { + reasoning := *params.Reasoning + cloned.Reasoning = &reasoning + } + if params.StreamOptions != nil { + streamOptions := *params.StreamOptions + cloned.StreamOptions = &streamOptions + } + if params.Text != nil { + cloned.Text = cloneResponsesTextConfig(params.Text) + } + if params.ToolChoice != nil { + cloned.ToolChoice = cloneResponsesToolChoice(params.ToolChoice) + } + if params.Tools != nil { + cloned.Tools = make([]schemas.ResponsesTool, len(params.Tools)) + for i, tool := range params.Tools { + cloned.Tools[i] = cloneResponsesTool(tool) + } + } + if params.ExtraParams != nil { + cloned.ExtraParams = cloneAnyMap(params.ExtraParams) + } + return &cloned +} + +func cloneResponsesTextConfig(text *schemas.ResponsesTextConfig) *schemas.ResponsesTextConfig { + if text == nil { + return nil + } + + cloned := *text + if text.Format != nil { + format := *text.Format + if text.Format.JSONSchema != nil { + jsonSchema := *text.Format.JSONSchema + if text.Format.JSONSchema.Schema != nil { + schema := cloneAnyValue(*text.Format.JSONSchema.Schema) + jsonSchema.Schema = &schema + } + if text.Format.JSONSchema.Properties != nil { + properties := cloneAnyMap(*text.Format.JSONSchema.Properties) + jsonSchema.Properties = &properties + } + if text.Format.JSONSchema.Required != nil { + jsonSchema.Required = slices.Clone(text.Format.JSONSchema.Required) + } + if text.Format.JSONSchema.Defs != nil { + defs := cloneAnyMap(*text.Format.JSONSchema.Defs) + jsonSchema.Defs = &defs + } + if text.Format.JSONSchema.Definitions != nil { + definitions := cloneAnyMap(*text.Format.JSONSchema.Definitions) + jsonSchema.Definitions = &definitions + } + if text.Format.JSONSchema.Items != nil { + items := cloneAnyMap(*text.Format.JSONSchema.Items) + jsonSchema.Items = &items + } + if text.Format.JSONSchema.AnyOf != nil { + jsonSchema.AnyOf = cloneAnyMapSlice(text.Format.JSONSchema.AnyOf) + } + if text.Format.JSONSchema.OneOf != nil { + jsonSchema.OneOf = cloneAnyMapSlice(text.Format.JSONSchema.OneOf) + } + if text.Format.JSONSchema.AllOf != nil { + jsonSchema.AllOf = cloneAnyMapSlice(text.Format.JSONSchema.AllOf) + } + if text.Format.JSONSchema.Default != nil { + jsonSchema.Default = cloneAnyValue(text.Format.JSONSchema.Default) + } + if text.Format.JSONSchema.Enum != nil { + jsonSchema.Enum = slices.Clone(text.Format.JSONSchema.Enum) + } + if text.Format.JSONSchema.PropertyOrdering != nil { + jsonSchema.PropertyOrdering = slices.Clone(text.Format.JSONSchema.PropertyOrdering) + } + format.JSONSchema = &jsonSchema + } + cloned.Format = &format + } + return &cloned +} + +func cloneResponsesToolChoice(choice *schemas.ResponsesToolChoice) *schemas.ResponsesToolChoice { + if choice == nil { + return nil + } + + cloned := &schemas.ResponsesToolChoice{} + if choice.ResponsesToolChoiceStr != nil { + value := *choice.ResponsesToolChoiceStr + cloned.ResponsesToolChoiceStr = &value + } + if choice.ResponsesToolChoiceStruct != nil { + choiceStruct := *choice.ResponsesToolChoiceStruct + if choice.ResponsesToolChoiceStruct.Tools != nil { + choiceStruct.Tools = slices.Clone(choice.ResponsesToolChoiceStruct.Tools) + } + cloned.ResponsesToolChoiceStruct = &choiceStruct + } + return cloned +} + +func cloneResponsesTool(tool schemas.ResponsesTool) schemas.ResponsesTool { + data, err := schemas.MarshalSorted(tool) + if err != nil { + return tool + } + + var cloned schemas.ResponsesTool + if err := schemas.Unmarshal(data, &cloned); err != nil { + return tool + } + + return cloned +} + +func cloneStringFloat64Map(input map[string]float64) map[string]float64 { + if input == nil { + return nil + } + + cloned := make(map[string]float64, len(input)) + maps.Copy(cloned, input) + return cloned +} + +func cloneAnyMap(input map[string]any) map[string]any { + if input == nil { + return nil + } + + cloned := make(map[string]any, len(input)) + for key, value := range input { + cloned[key] = cloneAnyValue(value) + } + return cloned +} + +func cloneAnyMapSlice(input []map[string]any) []map[string]any { + if input == nil { + return nil + } + + cloned := make([]map[string]any, len(input)) + for i, value := range input { + cloned[i] = cloneAnyMap(value) + } + return cloned +} + +func cloneAnySlice(input []any) []any { + if input == nil { + return nil + } + + cloned := make([]any, len(input)) + for i, value := range input { + cloned[i] = cloneAnyValue(value) + } + return cloned +} + +func cloneAnyValue(value any) any { + switch typed := value.(type) { + case nil: + return nil + case map[string]any: + return cloneAnyMap(typed) + case []any: + return cloneAnySlice(typed) + case []string: + return slices.Clone(typed) + case map[string]string: + cloned := make(map[string]string, len(typed)) + maps.Copy(cloned, typed) + return cloned + default: + return typed + } +} diff --git a/plugins/compat/version b/plugins/compat/version new file mode 100644 index 0000000000..6c6aa7cb09 --- /dev/null +++ b/plugins/compat/version @@ -0,0 +1 @@ +0.1.0 \ No newline at end of file diff --git a/plugins/litellmcompat/chattoresponses.go b/plugins/litellmcompat/chattoresponses.go deleted file mode 100644 index c8438a2b17..0000000000 --- a/plugins/litellmcompat/chattoresponses.go +++ /dev/null @@ -1,108 +0,0 @@ -package litellmcompat - -import ( - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" -) - -const ( - ChatToResponsesOriginalRequestTypeContextKey schemas.BifrostContextKey = "litellmcompat-chat-to-responses-original-request-type" - ChatToResponsesOriginalModelContextKey schemas.BifrostContextKey = "litellmcompat-chat-to-responses-original-model" -) - -// transformChatToResponsesRequest determines whether a chat request should be converted -// to a responses request by core. It stores conversion intent in context; core performs -// the actual conversion. -func transformChatToResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { - // Only process chat completion requests - if req.RequestType != schemas.ChatCompletionRequest && req.RequestType != schemas.ChatCompletionStreamRequest { - return req - } - - // Check if chat completion request is present - if req.ChatRequest == nil { - return req - } - - // Check if the model supports chat completion via model catalog - if mc != nil { - provider := req.ChatRequest.Provider - model := req.ChatRequest.Model - if mc.IsChatCompletionSupported(model, provider) { - if ctx != nil { - ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, false) - } - if logger != nil { - logger.Debug("litellmcompat: model %s/%s supports chat completion, skipping conversion", provider, model) - } - return req - } - } - - // Track conversion intent. Core will do the actual conversion during provider dispatch. - if ctx != nil { - ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, true) - ctx.SetValue(ChatToResponsesOriginalRequestTypeContextKey, req.RequestType) - ctx.SetValue(ChatToResponsesOriginalModelContextKey, req.ChatRequest.Model) - } - - if logger != nil { - logger.Debug("litellmcompat: marked chat completion for core chat->responses conversion for model %s (chat completion not supported, responses supported)", req.ChatRequest.Model) - } - - return req -} - -func getOriginalChatRequestMetadata(ctx *schemas.BifrostContext) (schemas.RequestType, string) { - requestType := schemas.ChatCompletionRequest - if ctx == nil { - return requestType, "" - } - if value, ok := ctx.Value(ChatToResponsesOriginalRequestTypeContextKey).(schemas.RequestType); ok { - requestType = value - } - model, _ := ctx.Value(ChatToResponsesOriginalModelContextKey).(string) - return requestType, model -} - -// transformChatToResponsesResponse normalizes metadata on converted chat-completion responses. -// Core performs the actual stream/non-stream payload conversion. -func transformChatToResponsesResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, logger schemas.Logger) *schemas.BifrostResponse { - if resp == nil || resp.ChatResponse == nil || ctx == nil { - return resp - } - - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) - if !ok || !shouldConvert { - return resp - } - - originalRequestType, originalModel := getOriginalChatRequestMetadata(ctx) - resp.ChatResponse.ExtraFields.RequestType = originalRequestType - resp.ChatResponse.ExtraFields.ModelRequested = originalModel - resp.ChatResponse.ExtraFields.LiteLLMCompat = true - - if logger != nil { - logger.Debug("litellmcompat: normalized converted chat completion metadata for model %s", originalModel) - } - - return resp -} - -// transformChatToResponsesError restores original chat-completion metadata on errors -// generated from responses fallback execution. -func transformChatToResponsesError(ctx *schemas.BifrostContext, err *schemas.BifrostError) *schemas.BifrostError { - if err == nil || ctx == nil { - return err - } - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertChatToResponses).(bool) - if !ok || !shouldConvert { - return err - } - - originalRequestType, originalModel := getOriginalChatRequestMetadata(ctx) - err.ExtraFields.RequestType = originalRequestType - err.ExtraFields.ModelRequested = originalModel - err.ExtraFields.LiteLLMCompat = true - return err -} diff --git a/plugins/litellmcompat/main.go b/plugins/litellmcompat/main.go deleted file mode 100644 index 59de6fedb8..0000000000 --- a/plugins/litellmcompat/main.go +++ /dev/null @@ -1,109 +0,0 @@ -// Package litellmcompat provides LiteLLM-compatible request type conversion decisions -// for the Bifrost gateway. It marks requests that should be converted by core provider -// dispatch for models that don't natively support the requested endpoint type. -// -// When enabled, this plugin: -// - Decides whether text_completion() should be converted to chat -// - Decides whether chat_completion() should be converted to responses -// - Stores the decision in context for core request dispatch -package litellmcompat - -import ( - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" -) - -const ( - PluginName = "litellmcompat" -) - -// Config defines the configuration for the litellmcompat plugin -type Config struct { - Enabled bool `json:"enabled"` -} - -// LiteLLMCompatPlugin provides LiteLLM-compatible request/response transformations. -// When enabled, it automatically converts text completion requests to chat completion -// requests for models that only support chat completions, matching LiteLLM's behavior. -// It also converts chat completion requests to responses for models that only support -// the responses endpoint. -type LiteLLMCompatPlugin struct { - config Config - logger schemas.Logger - modelCatalog *modelcatalog.ModelCatalog -} - -// Init creates a new litellmcompat plugin instance with model catalog support. -// The model catalog is used to determine if a model supports text completion or chat completion natively. -// If the model catalog is nil, the plugin will convert ALL text completion requests to chat completion -// and ALL chat completion requests to responses. -func Init(config Config, logger schemas.Logger, mc *modelcatalog.ModelCatalog) (*LiteLLMCompatPlugin, error) { - return &LiteLLMCompatPlugin{ - config: config, - logger: logger, - modelCatalog: mc, - }, nil -} - -// SetModelCatalog sets the model catalog for checking text completion support. -// This can be called after initialization to add model catalog support. -func (p *LiteLLMCompatPlugin) SetModelCatalog(mc *modelcatalog.ModelCatalog) { - p.modelCatalog = mc -} - -// GetName returns the plugin name -func (p *LiteLLMCompatPlugin) GetName() string { - return PluginName -} - -// HTTPTransportPreHook is not used for this plugin -func (p *LiteLLMCompatPlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { - return nil, nil -} - -// HTTPTransportPostHook is not used for this plugin -func (p *LiteLLMCompatPlugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { - return nil -} - -// HTTPTransportStreamChunkHook passes through streaming chunks unchanged -func (p *LiteLLMCompatPlugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { - return chunk, nil -} - -// PreLLMHook intercepts requests and applies LiteLLM-compatible transformation intent. -// For text completion requests on models that don't support text completion, -// it marks the request so core can convert at provider dispatch time. -// For chat completion requests on models that don't support chat completion, -// it marks the request so core can convert at provider dispatch time. -func (p *LiteLLMCompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { - // Reset context keys - if ctx != nil { - ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, false) - ctx.SetValue(schemas.BifrostContextKeyShouldConvertChatToResponses, false) - } - - // Apply request transforms in sequence - req = transformTextToChatRequest(ctx, req, p.modelCatalog, p.logger) - req = transformChatToResponsesRequest(ctx, req, p.modelCatalog, p.logger) - return req, nil, nil -} - -// PostLLMHook normalizes metadata on converted responses/errors -// when this plugin requested type conversion in PreLLMHook. -func (p *LiteLLMCompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - if result != nil { - result = transformTextToChatResponse(ctx, result, p.logger) - result = transformChatToResponsesResponse(ctx, result, p.logger) - } - if bifrostErr != nil { - bifrostErr = transformTextToChatError(ctx, bifrostErr) - bifrostErr = transformChatToResponsesError(ctx, bifrostErr) - } - return result, bifrostErr, nil -} - -// Cleanup performs plugin cleanup -func (p *LiteLLMCompatPlugin) Cleanup() error { - return nil -} diff --git a/plugins/litellmcompat/texttochat.go b/plugins/litellmcompat/texttochat.go deleted file mode 100644 index 9c78a1473f..0000000000 --- a/plugins/litellmcompat/texttochat.go +++ /dev/null @@ -1,107 +0,0 @@ -package litellmcompat - -import ( - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" -) - -const ( - OriginalRequestTypeContextKey schemas.BifrostContextKey = "litellmcompat-original-request-type" - OriginalModelContextKey schemas.BifrostContextKey = "litellmcompat-original-model" -) - -// transformTextToChatRequest determines whether a text request should be converted by core. -// It stores conversion intent in context; core performs the actual conversion. -func transformTextToChatRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, mc *modelcatalog.ModelCatalog, logger schemas.Logger) *schemas.BifrostRequest { - // Only process text completion requests - if req.RequestType != schemas.TextCompletionRequest && req.RequestType != schemas.TextCompletionStreamRequest { - return req - } - - // Check if text completion request is present - if req.TextCompletionRequest == nil { - return req - } - - // Check if the model supports text completion via model catalog - if mc != nil { - provider := req.TextCompletionRequest.Provider - model := req.TextCompletionRequest.Model - if mc.IsTextCompletionSupported(model, provider) { - if ctx != nil { - ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, false) - } - if logger != nil { - logger.Debug("litellmcompat: model %s/%s supports text completion, skipping conversion", provider, model) - } - return req - } - } - - // Track conversion intent. Core will do the actual conversion during provider dispatch. - if ctx != nil { - ctx.SetValue(schemas.BifrostContextKeyShouldConvertTextToChat, true) - ctx.SetValue(OriginalRequestTypeContextKey, req.RequestType) - ctx.SetValue(OriginalModelContextKey, req.TextCompletionRequest.Model) - } - - if logger != nil { - logger.Debug("litellmcompat: marked text completion for core text->chat conversion for model %s (text completion not supported)", req.TextCompletionRequest.Model) - } - - return req -} - -func getOriginalTextRequestMetadata(ctx *schemas.BifrostContext) (schemas.RequestType, string) { - requestType := schemas.TextCompletionRequest - if ctx == nil { - return requestType, "" - } - if value, ok := ctx.Value(OriginalRequestTypeContextKey).(schemas.RequestType); ok { - requestType = value - } - model, _ := ctx.Value(OriginalModelContextKey).(string) - return requestType, model -} - -// transformTextToChatResponse normalizes metadata on converted text-completion responses. -// Core performs the actual stream/non-stream payload conversion. -func transformTextToChatResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, logger schemas.Logger) *schemas.BifrostResponse { - if resp == nil || resp.TextCompletionResponse == nil || ctx == nil { - return resp - } - - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) - if !ok || !shouldConvert { - return resp - } - - originalRequestType, originalModel := getOriginalTextRequestMetadata(ctx) - resp.TextCompletionResponse.ExtraFields.RequestType = originalRequestType - resp.TextCompletionResponse.ExtraFields.ModelRequested = originalModel - resp.TextCompletionResponse.ExtraFields.LiteLLMCompat = true - - if logger != nil { - logger.Debug("litellmcompat: normalized converted text completion metadata for model %s", originalModel) - } - - return resp -} - -// transformTextToChatError restores original text-completion metadata on errors -// generated from chat fallback execution. -func transformTextToChatError(ctx *schemas.BifrostContext, err *schemas.BifrostError) *schemas.BifrostError { - if err == nil || ctx == nil { - return err - } - shouldConvert, ok := ctx.Value(schemas.BifrostContextKeyShouldConvertTextToChat).(bool) - if !ok || !shouldConvert { - return err - } - - originalRequestType, originalModel := getOriginalTextRequestMetadata(ctx) - err.ExtraFields.RequestType = originalRequestType - err.ExtraFields.ModelRequested = originalModel - err.ExtraFields.LiteLLMCompat = true - return err -} diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 4f08f31b12..ec28c9d3ad 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -19,7 +19,7 @@ import ( configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/plugins/litellmcompat" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -88,7 +88,7 @@ func (h *ConfigHandler) getVersion(ctx *fasthttp.RequestCtx) { // getConfig handles GET /config - Get the current configuration func (h *ConfigHandler) getConfig(ctx *fasthttp.RequestCtx) { - var mapConfig = make(map[string]any) + mapConfig := make(map[string]any) if query := string(ctx.QueryArgs().Peek("from_db")); query == "true" { if h.store.ConfigStore == nil { @@ -342,18 +342,18 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB } - // Handle LiteLLM compat plugin toggle + // Handle compat plugin toggle if payload.ClientConfig.EnableLiteLLMFallbacks != currentConfig.EnableLiteLLMFallbacks { if payload.ClientConfig.EnableLiteLLMFallbacks { - // Load and register the litellmcompat plugin - if err := h.configManager.ReloadPlugin(ctx, "litellmcompat", nil, &litellmcompat.Config{Enabled: true}, nil, nil); err != nil { - logger.Warn(fmt.Sprintf("failed to load litellmcompat plugin: %v", err)) + // Load and register the compat plugin + if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, &compat.Config{Enabled: true}, nil, nil); err != nil { + logger.Warn("failed to load compat plugin: %v", err) } } else { - // Remove the litellmcompat plugin + // Remove the compat plugin disabledCtx := context.WithValue(ctx, PluginDisabledKey, true) - if err := h.configManager.RemovePlugin(disabledCtx, "litellmcompat"); err != nil { - logger.Warn("failed to remove litellmcompat plugin: %v", err) + if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil { + logger.Warn("failed to remove compat plugin: %v", err) } } } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 5bdba4930f..c320086bb8 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -37,7 +37,7 @@ import ( plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" "github.com/maximhq/bifrost/plugins/governance" - "github.com/maximhq/bifrost/plugins/litellmcompat" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -107,7 +107,7 @@ func IsBuiltinPlugin(name string) bool { name == prompts.PluginName || name == logging.PluginName || name == governance.PluginName || - name == litellmcompat.PluginName || + name == compat.PluginName || name == maxim.PluginName || name == semanticcache.PluginName || name == otel.PluginName diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 2d9a0ace81..39e07cc201 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -6,8 +6,8 @@ import ( "slices" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/compat" "github.com/maximhq/bifrost/plugins/governance" - "github.com/maximhq/bifrost/plugins/litellmcompat" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -105,12 +105,12 @@ func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifro } return otel.Init(ctx, otelConfig, logger, bifrostConfig.ModelCatalog, handlers.GetVersion()) - case litellmcompat.PluginName: - litellmConfig, err := MarshalPluginConfig[litellmcompat.Config](pluginConfig) + case compat.PluginName: + compatConfig, err := MarshalPluginConfig[compat.Config](pluginConfig) if err != nil { - return nil, fmt.Errorf("failed to marshal litellmcompat plugin config: %w", err) + return nil, fmt.Errorf("failed to marshal compat plugin config: %w", err) } - return litellmcompat.Init(*litellmConfig, logger, bifrostConfig.ModelCatalog) + return compat.Init(*compatConfig, logger, bifrostConfig.ModelCatalog) default: return nil, fmt.Errorf("unknown built-in plugin: %s", name) @@ -215,14 +215,14 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 7. Litellmcompat (if configured in PluginConfigs) - litellmcompatConfig := s.getPluginConfig(litellmcompat.PluginName) - if litellmcompatConfig != nil && litellmcompatConfig.Enabled { - s.registerPluginWithStatus(ctx, litellmcompat.PluginName, nil, litellmcompatConfig.Config, false) + // 7. Compat (if configured in PluginConfigs) + compatConfig := s.getPluginConfig(compat.PluginName) + if compatConfig != nil && compatConfig.Enabled { + s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatConfig.Config, false) } else { - s.markPluginDisabled(litellmcompat.PluginName) + s.markPluginDisabled(compat.PluginName) } - s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(7)) + s.Config.SetPluginOrderInfo(compat.PluginName, builtinPlacement, schemas.Ptr(7)) // 8. Maxim (if configured in PluginConfigs) maximConfig := s.getPluginConfig(maxim.PluginName) @@ -293,4 +293,4 @@ func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { []string{fmt.Sprintf("plugin %s initialized successfully", cfg.Name)}, InferPluginTypes(plugin)) } return nil -} \ No newline at end of file +} diff --git a/transports/go.mod b/transports/go.mod index c182319a96..3a40f0d309 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -14,14 +14,14 @@ require ( github.com/mark3labs/mcp-go v0.43.2 github.com/maximhq/bifrost/core v1.5.1 github.com/maximhq/bifrost/framework v1.3.1 - github.com/maximhq/bifrost/plugins/governance v1.5.1 - github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1 - github.com/maximhq/bifrost/plugins/logging v1.5.1 - github.com/maximhq/bifrost/plugins/maxim v1.6.1 - github.com/maximhq/bifrost/plugins/otel v1.2.1 + github.com/maximhq/bifrost/plugins/compat v0.1.0 + github.com/maximhq/bifrost/plugins/governance v1.5.0 + github.com/maximhq/bifrost/plugins/logging v1.5.0 + github.com/maximhq/bifrost/plugins/maxim v1.6.0 + github.com/maximhq/bifrost/plugins/otel v1.2.0 github.com/maximhq/bifrost/plugins/prompts v1.0.1 - github.com/maximhq/bifrost/plugins/semanticcache v1.5.1 - github.com/maximhq/bifrost/plugins/telemetry v1.5.1 + github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 + github.com/maximhq/bifrost/plugins/telemetry v1.5.0 github.com/pion/rtcp v1.2.16 github.com/pion/webrtc/v4 v4.2.9 github.com/prometheus/client_golang v1.23.2 @@ -185,3 +185,5 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/postgres v1.6.0 // indirect ) + +replace github.com/maximhq/bifrost/plugins/compat => ../plugins/compat diff --git a/transports/go.sum b/transports/go.sum index 80dea86d19..c8b2bde0ec 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -215,24 +215,22 @@ github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtN github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= -github.com/maximhq/bifrost/plugins/governance v1.5.1 h1:zc7TY5Xb4HsEqKfL7mdkIushgAbD1a0MSoQpjYFEhtY= -github.com/maximhq/bifrost/plugins/governance v1.5.1/go.mod h1:WosnY6eDKAufCZKJpNsqWiHt/fyZOx2THoDLzkqRTnM= -github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1 h1:90SzGOuPZjau6wQ1CJwB7f//XETKyf6yFZ/2jC/DMCU= -github.com/maximhq/bifrost/plugins/litellmcompat v0.1.1/go.mod h1:BC1dOa23dED8rSYi7ntrIwqZGHkm3nktuPtEFSMx2tE= -github.com/maximhq/bifrost/plugins/logging v1.5.1 h1:kNjmevWpt7nmsRyDmVTz8GPhnljtgCOtO52vjfTMvG8= -github.com/maximhq/bifrost/plugins/logging v1.5.1/go.mod h1:qcutU7X+Qt7zuNgT7m/zblLvMsI4/SAaoMwlDDBopvY= -github.com/maximhq/bifrost/plugins/maxim v1.6.1 h1:pwWflCaINS+6nPihSjezUpbCHdENqRFVSNiwiGzPyoI= -github.com/maximhq/bifrost/plugins/maxim v1.6.1/go.mod h1:t8xxjMGGqbXz2IRSYxQGvfKM27G2LlIAkWyFVIx8S54= +github.com/maximhq/bifrost/plugins/governance v1.5.0 h1:cT+QiIKqJNKjl6/q0W3HTuZSeql0MHx3UWTyZPMLag4= +github.com/maximhq/bifrost/plugins/governance v1.5.0/go.mod h1:hjC5TmTdk4bES89zPUwBTwWWteHNtTV8WytdkPZUWd8= +github.com/maximhq/bifrost/plugins/logging v1.5.0 h1:uGrernx8gENT84L7fXyEpgvJZgORsGZvyq5B4PkSj80= +github.com/maximhq/bifrost/plugins/logging v1.5.0/go.mod h1:uxdMIVHUG7u5Wc5HQzXY13UlExc3lDumRgC8M+kTQiw= +github.com/maximhq/bifrost/plugins/maxim v1.6.0 h1:F23T1qcMczcuauGCYO5p9qeZOAc48FPjFdaSK9TmVeY= +github.com/maximhq/bifrost/plugins/maxim v1.6.0/go.mod h1:V/ccWAfBiW6kVXGWLe9tXKoTgFSh9sYgaJRrtEwFTso= github.com/maximhq/bifrost/plugins/mocker v1.5.1 h1:tXB8WPH9J7MURk45PNjx0hh9TeZzyBXqAYFaKUWdQtM= github.com/maximhq/bifrost/plugins/mocker v1.5.1/go.mod h1:qbjCfskG6jN23rtrLYmaxFBvA5CzOTJ67UIEuyFkO90= -github.com/maximhq/bifrost/plugins/otel v1.2.1 h1:fSGOBTOMfsUzZ2Kk/C7CDkbxJ2JceUhrmtFlQ2S7xBs= -github.com/maximhq/bifrost/plugins/otel v1.2.1/go.mod h1:mw5DMoHxIms5L+QpSqN0ow97wM72CRsR4I0MAuFaBNM= +github.com/maximhq/bifrost/plugins/otel v1.2.0 h1:+aJnWdryDlhza7wc4KETosX9j3Mdad5uUFBuwhslNsk= +github.com/maximhq/bifrost/plugins/otel v1.2.0/go.mod h1:BwNVvRuEgdPlSlDLzANpGy2RugWQjtHkEUoBiwT5MNI= github.com/maximhq/bifrost/plugins/prompts v1.0.1 h1:JpM+uVkYmNLWEvg/hT8HN2Wpzax6TUsM/mdIyYzkx00= github.com/maximhq/bifrost/plugins/prompts v1.0.1/go.mod h1:379vljFVED/0L+odEmYQaaYDY/HFy4smb8tpXXCeBvA= -github.com/maximhq/bifrost/plugins/semanticcache v1.5.1 h1:rkXataDvgnE3HlkXCtraYVadeLHLWImtbuajhpUIOyU= -github.com/maximhq/bifrost/plugins/semanticcache v1.5.1/go.mod h1:YSjXwYxO0UvRWKnwqp9SdlgFjAajaMfzpjbtSNTnqnY= -github.com/maximhq/bifrost/plugins/telemetry v1.5.1 h1:bZC/MdVDr3zmvi686tqrQMCzDVPvwqxXScVSk404NqY= -github.com/maximhq/bifrost/plugins/telemetry v1.5.1/go.mod h1:t1DiP/jrfV9oGmpp/Jy1mb/5YYHSvOgGAQR2055xsHI= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 h1:tibnQ8lSnKXujnjL4mt84P/5Vxj9e9wbhvh1Tjr68JA= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.0/go.mod h1:+NfIRAlHpuh5ORv0MoOf5f8uY4WPx6v/8Kuk+8FEGnw= +github.com/maximhq/bifrost/plugins/telemetry v1.5.0 h1:hECZgcsqeJSmiLrWONTFFU6APzTyILQzZuVV96oql5Q= +github.com/maximhq/bifrost/plugins/telemetry v1.5.0/go.mod h1:dl/4mtQhxooqU+r42hXajhUaq04S1X3LaH+km5UJAy0= github.com/maximhq/maxim-go v0.2.1 h1:hCp8dQ4HsyyNC+y5HCUuY/HFD0sOnGkjL5MdYCHkgEQ= github.com/maximhq/maxim-go v0.2.1/go.mod h1:nwFznXy0Dn4mxXGU4X+BCnE3VP68L+FPEaW0yUgk96o= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= From 5fe51c62df8cc6c73bccc7bafb29e27b5d62f653 Mon Sep 17 00:00:00 2001 From: Samyabrata Maji <116789799+sammaji@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:02:17 +0530 Subject: [PATCH 4/4] feat: removes enable_litellm_fallback option for more granular settings --- .../config.json | 1 - .../workflows/scripts/run-migration-tests.sh | 6 +- .../workflows/scripts/test-docker-image.sh | 3 +- .../scripts/validate-helm-config-fields.sh | 13 +- core/schemas/bifrost.go | 6 +- docs/features/litellm-compat.mdx | 74 +- docs/openapi/openapi.json | 50 +- docs/openapi/schemas/management/config.yaml | 21 +- .../supported-providers/overview.mdx | 2 +- .../config.json | 1 - .../configs/withprompushgateway/config.json | 3 +- examples/configs/withvirtualkeys/config.json | 1 - examples/dockers/data/config.json | 6 +- framework/configstore/clientconfig.go | 25 +- framework/configstore/migrations.go | 156 ++++- framework/configstore/rdb.go | 42 +- framework/configstore/tables/clientconfig.go | 7 +- framework/go.sum | 10 + framework/modelcatalog/main.go | 642 +----------------- helm-charts/bifrost/templates/_helpers.tpl | 17 +- helm-charts/bifrost/values.schema.json | 13 +- helm-charts/bifrost/values.yaml | 6 +- plugins/compat/conversion.go | 18 +- plugins/compat/go.sum | 10 + plugins/compat/main.go | 61 +- plugins/compat/requestcopy.go | 76 +-- tests/governance/config.json | 3 +- tests/integrations/python/config.json | 3 +- tests/integrations/typescript/config.json | 3 +- transports/bifrost-http/handlers/config.go | 23 +- transports/bifrost-http/lib/config.go | 5 +- transports/bifrost-http/lib/config_test.go | 84 ++- transports/bifrost-http/lib/ctx.go | 44 +- transports/bifrost-http/server/plugins.go | 14 +- transports/config.schema.json | 28 +- transports/go.mod | 14 +- transports/go.sum | 40 +- .../workspace/config/compatibility/page.tsx | 11 + .../config/views/clientSettingsView.tsx | 34 +- .../config/views/compatibilityView.tsx | 158 +++++ ui/components/sidebar.tsx | 10 +- ui/components/ui/accordion.tsx | 4 +- ui/lib/types/config.ts | 11 +- 43 files changed, 840 insertions(+), 919 deletions(-) create mode 100644 ui/app/workspace/config/compatibility/page.tsx create mode 100644 ui/app/workspace/config/views/compatibilityView.tsx diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index a0122adfa2..42758ab9dd 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index 85f901cfd5..e0f80eca7f 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -542,8 +542,8 @@ VALUES ('migration-test-lock', 'holder-migration-test-001', $future, $now) ON CONFLICT DO NOTHING; -- config_client (global client configuration) -INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, enable_litellm_fallbacks, config_hash, created_at, updated_at) -VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, 'client-config-hash-001', $now, $now) +INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, compat_convert_text_to_chat, compat_convert_chat_to_responses, compat_should_drop_params, compat_should_convert_params, config_hash, created_at, updated_at) +VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, 100, 10, 30, 'server', 10, false, false, false, true, 'client-config-hash-001', $now, $now) ON CONFLICT DO NOTHING; -- governance_config (key-value config table) @@ -3509,4 +3509,4 @@ main() { exit $exit_code } -main "$@" +main "$@" \ No newline at end of file diff --git a/.github/workflows/scripts/test-docker-image.sh b/.github/workflows/scripts/test-docker-image.sh index 5d770fbd64..ac115394bf 100755 --- a/.github/workflows/scripts/test-docker-image.sh +++ b/.github/workflows/scripts/test-docker-image.sh @@ -212,8 +212,7 @@ cat > "$CONFIG_FILE" << 'CONFIGEOF' "enable_logging": true, "enforce_governance_header": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "encryption_key": "" } diff --git a/.github/workflows/scripts/validate-helm-config-fields.sh b/.github/workflows/scripts/validate-helm-config-fields.sh index 3b08dfffe9..8b38e9d717 100755 --- a/.github/workflows/scripts/validate-helm-config-fields.sh +++ b/.github/workflows/scripts/validate-helm-config-fields.sh @@ -164,7 +164,11 @@ bifrost: enforceGovernanceHeader: true allowDirectKeys: true maxRequestBodySizeMb: 50 - enableLitellmFallbacks: true + compat: + convertTextToChat: true + convertChatToResponses: true + shouldDropParams: true + shouldConvertParams: true prometheusLabels: - "team" - "env" @@ -200,7 +204,10 @@ assert_field_value 'client.log_retention_days' '.client.log_retention_days' '30' assert_field_value 'client.enforce_governance_header' '.client.enforce_governance_header' 'true' assert_field_value 'client.allow_direct_keys' '.client.allow_direct_keys' 'true' assert_field_value 'client.max_request_body_size_mb' '.client.max_request_body_size_mb' '50' -assert_field_value 'client.enable_litellm_fallbacks' '.client.enable_litellm_fallbacks' 'true' +assert_field_value 'client.compat.convert_text_to_chat' '.client.compat.convert_text_to_chat' 'true' +assert_field_value 'client.compat.convert_chat_to_responses' '.client.compat.convert_chat_to_responses' 'true' +assert_field_value 'client.compat.should_drop_params' '.client.compat.should_drop_params' 'true' +assert_field_value 'client.compat.should_convert_params' '.client.compat.should_convert_params' 'true' assert_field 'client.prometheus_labels' '.client.prometheus_labels' assert_field 'client.header_filter_config.allowlist' '.client.header_filter_config.allowlist' assert_field 'client.header_filter_config.denylist' '.client.header_filter_config.denylist' @@ -1194,4 +1201,4 @@ if [ "$TESTS_FAILED" -gt 0 ]; then else echo -e "${GREEN}✅ All config.json field validations passed!${NC}" exit 0 -fi +fi \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 7d7c28ea47..db187ac717 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -272,6 +272,10 @@ const ( BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) BifrostContextKeyMCPLogID BifrostContextKey = "bifrost-mcp-log-id" // string (unique UUID for each MCP tool log entry - set per goroutine by agent executor - DO NOT SET THIS MANUALLY) + BifrostContextKeyCompatConvertTextToChat BifrostContextKey = "bifrost-compat-convert-text-to-chat" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatConvertChatToResponses BifrostContextKey = "bifrost-compat-convert-chat-to-responses" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldDropParams BifrostContextKey = "bifrost-compat-should-drop-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldConvertParams BifrostContextKey = "bifrost-compat-should-convert-params" // bool (per-request override from x-bf-compat header) ) const ( @@ -1230,4 +1234,4 @@ type BifrostErrorExtraFields struct { DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication -} +} \ No newline at end of file diff --git a/docs/features/litellm-compat.mdx b/docs/features/litellm-compat.mdx index 51cd26dcd9..490a37efa4 100644 --- a/docs/features/litellm-compat.mdx +++ b/docs/features/litellm-compat.mdx @@ -9,8 +9,10 @@ icon: "train" The LiteLLM compatibility plugin provides two transformations: 1. **Text-to-Chat Conversion** - Automatically converts text completion requests to chat completion format for models that only support chat APIs +2. **Chat-to-Responses Conversion** - Automatically converts chat completion requests to responses format for models that only support responses APIs +3. **Drop Unsupported Params** - Automatically drops unsupported parameters if the model doesn't support them -When either transformation is applied, responses include `extra_fields.litellm_compat: true`. +When either transformation is applied, responses include `extra_fields.converted_request_type: `. If request parameters are dropped, the keys are added in `extra_fields.dropped_compat_plugin_params`. --- @@ -55,6 +57,36 @@ F --> G - `object: "chat.completion"` → `object: "text_completion"` - Usage statistics and metadata are preserved +## 2. Chat-to-Responses Conversion + +Some AI models (like OpenAI o1-pro) only support the responses API and don't support native chat completion endpoints. LiteLLM compatibility mode automatically handles this by: + +1. Checking if the model supports chat completion natively (using the model catalog) +2. If not supported, converting your chat message to responses API format +3. Calling the responses endpoint internally +4. Transforming the response back to chat completion format + + +**Smart Conversion**: The conversion only happens when the model doesn't support chat completions natively. If a model has native chat completion support (like OpenAI's gpt-4 models), Bifrost uses the chat completion endpoint directly without any conversion. + + +This allows you to use a unified chat completion interface across all providers, even those that only support responses API. + +## How It Works + +When LiteLLM compatibility is enabled and you make a chat completion request, Bifrost first checks if the model supports chat completion: + +```mermaid +flowchart LR +A[Chat Completion Request] --> B{Model Supports Chat Completion?} +B -->|Yes| C[Call Chat Completion API] +B -->|No| D[Convert to Responses Message] +D --> E[Call Responses API] +E --> F[Transform Response] +C --> G[Chat Completion Response] +F --> G +``` + ## Enabling LiteLLM Compatibility @@ -63,7 +95,10 @@ F --> G 1. Open the Bifrost dashboard 2. Navigate to **Settings** → **Client Configuration** -3. Enable **LiteLLM Fallbacks** +3. Expand **LiteLLM Compat** and enable the features you need: + - **Convert Text to Chat** — converts text completion requests to chat for models that only support chat + - **Convert Chat to Responses** — converts chat completion requests to responses for models that only support responses + - **Drop Unsupported Params** — drops unsupported parameters based on model catalog allowlist 4. Save your configuration @@ -73,7 +108,11 @@ F --> G ```json { "client_config": { - "enable_litellm_fallbacks": true + "compat": { + "convert_text_to_chat": true, + "convert_chat_to_responses": true, + "should_drop_params": true + } } } ``` @@ -84,9 +123,9 @@ F --> G ## Supported Providers -LiteLLM compatibility mode works with any provider that supports chat completions but lacks native text completion support: +Text completion to chat completion conversion works with any provider that supports chat completions but lacks native text completion support: -| Provider | Native Text Completion | LiteLLM Fallback | +| Provider | Native Text Completion | With Fallback | |----------|----------------------|------------------| | OpenAI (GPT-4, GPT-3.5-turbo) | No | Yes | | Anthropic (Claude) | No | Yes | @@ -95,6 +134,12 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Mistral | No | Yes | | Bedrock | Varies by model | Yes | +Chat completion to responses conversion works with any provider that supports responses but lacks native chat completion support: + +| Provider | Native Chat Completion | With Fallback | +|----------|----------------------|------------------| +| OpenAI (o1-pro) | No | Yes | + ## Behavior Details **Model Capability Detection:** @@ -117,13 +162,19 @@ LiteLLM compatibility mode works with any provider that supports chat completion | Response | `choices[0].message.content` | `choices[0].text` | | Response | `object: "chat.completion"` | `object: "text_completion"` | +### Transformation 2: Chat-to-Responses Conversion + +**Applies to:** Chat completion requests on responses-only models + +| Phase | Original | Transformed | +|-------|----------|-------------| +| Request | Chat message with `role: "user"` | Responses input with `role: "user"` | +| Request | `chat_completion` request type | `responses` request type | ### Metadata Set on Transformed Responses When either transformation is applied: -- `extra_fields.litellm_compat`: Set to `true` -- `extra_fields.provider`: The provider that handled the request - `extra_fields.request_type`: Reflects the original request type - `extra_fields.original_model_requested`: The originally requested model - `extra_fields.resolved_model_used`: The actual provider API identifier used (equals original_model_requested when no alias mapping exists) @@ -131,8 +182,11 @@ When either transformation is applied: ### Error Handling When errors occur on transformed requests: -- `extra_fields.litellm_compat` is set to `true` - Original request type and model are preserved in error metadata +- `extra_fields.converted_request_type`: Set to type of request that was converted to (i.e., `chat_completion` or `responses`) +- `extra_fields.provider`: The provider that handled the request +- `extra_fields.original_model_requested`: The originally requested model +- `extra_fields.dropped_compat_plugin_params`: If any unsupported parameters were dropped, the keys are added here ## What's Preserved @@ -145,7 +199,7 @@ When errors occur on transformed requests: **Good Use Cases:** - Migrating from LiteLLM to Bifrost without code changes -- Maintaining backward compatibility with text completion interfaces +- Maintaining backward compatibility with text completion interfaces or chat completion interfaces - Using a unified API across providers with different capabilities **Consider Alternatives When:** @@ -157,4 +211,4 @@ When errors occur on transformed requests: - [Fallbacks](/features/fallbacks) - Automatic provider failover - [Drop-in Replacement](/features/drop-in-replacement) - Use existing SDKs with Bifrost -- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost +- [LiteLLM Integration](/integrations/litellm-sdk) - Using LiteLLM SDK with Bifrost \ No newline at end of file diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 1043039a1f..e258d653b5 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -133221,9 +133221,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -133537,9 +133543,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -205784,9 +205796,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -205999,9 +206017,15 @@ "type": "integer", "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Whether LiteLLM fallbacks are enabled" + "compat": { + "type": "object", + "description": "Compat plugin configuration", + "properties": { + "convert_text_to_chat": { "type": "boolean", "description": "Convert text completion requests to chat" }, + "convert_chat_to_responses": { "type": "boolean", "description": "Convert chat completion requests to responses" }, + "should_drop_params": { "type": "boolean", "description": "Drop unsupported parameters based on model catalog" }, + "should_convert_params": { "type": "boolean", "description": "Converts model parameter values that are not supported by the model.", "default": false } + } }, "log_retention_days": { "type": "integer", @@ -224498,4 +224522,4 @@ } } } -} \ No newline at end of file +} diff --git a/docs/openapi/schemas/management/config.yaml b/docs/openapi/schemas/management/config.yaml index 2c54b3979d..eaafb3821f 100644 --- a/docs/openapi/schemas/management/config.yaml +++ b/docs/openapi/schemas/management/config.yaml @@ -44,9 +44,24 @@ ClientConfig: max_request_body_size_mb: type: integer description: Maximum request body size in MB - enable_litellm_fallbacks: - type: boolean - description: Whether LiteLLM fallbacks are enabled + compat: + type: object + description: Compat plugin configuration + properties: + convert_text_to_chat: + type: boolean + description: Convert text completion requests to chat + convert_chat_to_responses: + type: boolean + description: Convert chat completion requests to responses + should_drop_params: + type: boolean + description: Drop unsupported parameters based on model catalog + should_convert_params: + type: boolean + default: false + description: Converts model parameter values that are not supported by the model + additionalProperties: false log_retention_days: type: integer description: Number of days to retain logs diff --git a/docs/providers/supported-providers/overview.mdx b/docs/providers/supported-providers/overview.mdx index b3ae42f62f..98d13ffa73 100644 --- a/docs/providers/supported-providers/overview.mdx +++ b/docs/providers/supported-providers/overview.mdx @@ -48,7 +48,7 @@ The following table summarizes which operations are supported by each provider v Some operations are not supported by the downstream provider, and their internal implementation in Bifrost is optional. 🟡 -Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting the `enable_litellm_fallbacks` flag to `true` in the client configuration. +Like Text completions are not supported by Groq, but Bifrost can emulate them internally using the Chat Completions API. This feature is disabled by default, but it can be enabled by setting `compat.convert_text_to_chat` to `true` in the client configuration. We do not promote using such fallbacks, since text completions and chat completions are fundamentally different. However, this option is available to help users migrating from LiteLLM (which does support these fallbacks). diff --git a/examples/configs/withpostgresmcpclientsinconfig/config.json b/examples/configs/withpostgresmcpclientsinconfig/config.json index 8e03969988..068bc88012 100644 --- a/examples/configs/withpostgresmcpclientsinconfig/config.json +++ b/examples/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/configs/withprompushgateway/config.json b/examples/configs/withprompushgateway/config.json index f697041388..110557d797 100644 --- a/examples/configs/withprompushgateway/config.json +++ b/examples/configs/withprompushgateway/config.json @@ -183,8 +183,7 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "config_store": { "enabled": true, diff --git a/examples/configs/withvirtualkeys/config.json b/examples/configs/withvirtualkeys/config.json index a968bad65c..9d9ae2c87a 100644 --- a/examples/configs/withvirtualkeys/config.json +++ b/examples/configs/withvirtualkeys/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, diff --git a/examples/dockers/data/config.json b/examples/dockers/data/config.json index 46cbfd8e68..072691c2ea 100644 --- a/examples/dockers/data/config.json +++ b/examples/dockers/data/config.json @@ -27,7 +27,9 @@ "*" ], "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "compat": { + "should_convert_params": false + } }, "framework": { "pricing": { @@ -35,4 +37,4 @@ "pricing_sync_interval": 86400 } } -} \ No newline at end of file +} diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 698437c328..90a27db25e 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -34,6 +34,14 @@ type EnvKeyInfo struct { KeyID string // The key ID this env var belongs to (empty for non-key configs like bedrock_config, connection_string) } +// CompatConfig holds the compat plugin feature flags. +type CompatConfig struct { + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + // ClientConfig represents the core configuration for Bifrost HTTP transport and the Bifrost Client. // It includes settings for excess request handling, Prometheus metrics, and initial pool size. type ClientConfig struct { @@ -51,7 +59,7 @@ type ClientConfig struct { AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) AllowedHeaders []string `json:"allowed_headers,omitempty"` // Additional allowed headers for CORS and WebSocket MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB - EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + Compat CompatConfig `json:"compat"` // Compat plugin configuration MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" @@ -110,10 +118,17 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("allowDirectKeys:false")) } - if c.EnableLiteLLMFallbacks { - hash.Write([]byte("enableLiteLLMFallbacks:true")) - } else { - hash.Write([]byte("enableLiteLLMFallbacks:false")) + if c.Compat.ConvertTextToChat { + hash.Write([]byte("compatConvertTextToChat:true")) + } + if c.Compat.ConvertChatToResponses { + hash.Write([]byte("compatConvertChatToResponses:true")) + } + if c.Compat.ShouldDropParams { + hash.Write([]byte("compatShouldDropParams:true")) + } + if c.Compat.ShouldConvertParams { + hash.Write([]byte("compatShouldConvertParams:true")) } // Only hash non-default value to avoid legacy config hash churn. diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index e64351eeaa..7f0b34fa87 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -376,6 +376,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddWhitelistedRoutesJSONColumn(ctx, db); err != nil { return err } + if err := migrationReplaceEnableLiteLLMWithCompatColumns(ctx, db); err != nil { + return err + } + if err := migrationDefaultCompatShouldConvertParamsFalse(ctx, db); err != nil { + return err + } return nil } @@ -785,9 +791,10 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) ID: "add_enable_litellm_fallbacks_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { - if err := migrator.AddColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + // Use raw SQL since the struct field was removed in a later migration. + // This column is subsequently dropped by migrationReplaceEnableLiteLLMWithCompatColumns. + if !tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT FALSE").Error; err != nil { return err } } @@ -795,9 +802,7 @@ func migrationAddEnableLiteLLMFallbacksColumn(ctx context.Context, db *gorm.DB) }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - - if err := migrator.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + if err := tx.Exec("ALTER TABLE config_client DROP COLUMN IF EXISTS enable_litellm_fallbacks").Error; err != nil { return err } return nil @@ -2162,7 +2167,6 @@ func migrationAddAdditionalConfigHashColumns(ctx context.Context, db *gorm.DB) e AllowDirectKeys: cc.AllowDirectKeys, AllowedOrigins: cc.AllowedOrigins, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, } hash, err := clientConfig.GenerateClientConfigHash() if err != nil { @@ -5611,7 +5615,6 @@ func migrationAddRoutingChainMaxDepthColumn(ctx context.Context, db *gorm.DB) er AllowedOrigins: cc.AllowedOrigins, AllowedHeaders: cc.AllowedHeaders, MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, HideDeletedVirtualKeysInFilters: cc.HideDeletedVirtualKeysInFilters, MCPAgentDepth: cc.MCPAgentDepth, MCPToolExecutionTimeout: cc.MCPToolExecutionTimeout, @@ -5907,7 +5910,6 @@ func migrationAddMultiBudgetTables(ctx context.Context, db *gorm.DB) error { if mg.HasColumn(&tables.TableBudget{}, "provider_config_id") { if err := mg.DropColumn(&tables.TableBudget{}, "provider_config_id"); err != nil { return err - } } return nil @@ -6063,21 +6065,155 @@ func migrationAddWhitelistedRoutesJSONColumn(ctx context.Context, db *gorm.DB) e return fmt.Errorf("failed to add whitelisted_routes_json column: %w", err) } } + return nil }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableClientConfig{}, "whitelisted_routes_json") { if err := migrator.DropColumn(&tables.TableClientConfig{}, "whitelisted_routes_json"); err != nil { return fmt.Errorf("failed to drop whitelisted_routes_json column: %w", err) } } + + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running whitelisted_routes_json migration: %s", err.Error()) + } + return nil +} + +// migrationReplaceEnableLiteLLMWithCompatColumns replaces the single enable_litellm_fallbacks +// boolean with compat feature columns. If enable_litellm_fallbacks was true, +// only convert_text_to_chat is set to true (preserving the original behavior). +func migrationReplaceEnableLiteLLMWithCompatColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "replace_enable_litellm_with_compat_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + // Add new columns + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_drop_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_drop_params"); err != nil { + return err + } + } + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + if err := mig.AddColumn(&tables.TableClientConfig{}, "compat_should_convert_params"); err != nil { + return err + } + } + + if err := tx.Exec("UPDATE config_client SET compat_should_convert_params = FALSE").Error; err != nil { + return err + } + + // Migrate data: if enable_litellm_fallbacks was true, set convert_text_to_chat = true + if mig.HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("UPDATE config_client SET compat_convert_text_to_chat = enable_litellm_fallbacks").Error; err != nil { + return err + } + if err := mig.DropColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks"); err != nil { + return err + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if tx.Migrator().HasColumn(&tables.TableClientConfig{}, "enable_litellm_fallbacks") { + if err := tx.Exec("ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT FALSE").Error; err != nil { + return err + } + } + if mig.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat") { + if err := tx.Exec("UPDATE config_client SET enable_litellm_fallbacks = COALESCE(compat_convert_text_to_chat, FALSE)").Error; err != nil { + return err + } + } + for _, col := range []string{ + "compat_convert_text_to_chat", + "compat_convert_chat_to_responses", + "compat_should_drop_params", + "compat_should_convert_params", + } { + if mig.HasColumn(&tables.TableClientConfig{}, col) { + if err := mig.DropColumn(&tables.TableClientConfig{}, col); err != nil { + return err + } + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running replace_enable_litellm_with_compat_columns migration: %s", err.Error()) + } + return nil +} + +// migrationDefaultCompatShouldConvertParamsFalse ensures existing deployments +// converge to the new default for compat_should_convert_params. The earlier +// compat migration may already be marked as applied, so changing its body is not +// sufficient for installed databases. +func migrationDefaultCompatShouldConvertParamsFalse(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "default_compat_should_convert_params_false", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + return nil + } + + if err := tx.Exec("UPDATE config_client SET compat_should_convert_params = FALSE").Error; err != nil { + return err + } + + if err := mig.AlterColumn(&tables.TableClientConfig{}, "CompatShouldConvertParams"); err != nil { + return fmt.Errorf("failed to alter compat_should_convert_params default: %w", err) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + + if !mig.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params") { + return nil + } + + switch tx.Dialector.Name() { + case "postgres": + if err := tx.Exec("ALTER TABLE config_client ALTER COLUMN compat_should_convert_params SET DEFAULT FALSE").Error; err != nil { + return err + } + } + return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running add_whitelisted_routes_json_column migration: %s", err.Error()) + return fmt.Errorf("error running default_compat_should_convert_params_false migration: %s", err.Error()) } return nil } diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index c2740a594f..a21a5a4e62 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -137,7 +137,10 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, AllowedHeaders: config.AllowedHeaders, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + CompatConvertTextToChat: config.Compat.ConvertTextToChat, + CompatConvertChatToResponses: config.Compat.ConvertChatToResponses, + CompatShouldDropParams: config.Compat.ShouldDropParams, + CompatShouldConvertParams: config.Compat.ShouldConvertParams, MCPAgentDepth: config.MCPAgentDepth, MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, @@ -289,21 +292,26 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er return nil, err } return &ClientConfig{ - DropExcessRequests: dbConfig.DropExcessRequests, - InitialPoolSize: dbConfig.InitialPoolSize, - PrometheusLabels: dbConfig.PrometheusLabels, - EnableLogging: dbConfig.EnableLogging, - DisableContentLogging: dbConfig.DisableContentLogging, - DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, - LogRetentionDays: dbConfig.LogRetentionDays, - EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, - EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, - EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, - AllowDirectKeys: dbConfig.AllowDirectKeys, - AllowedOrigins: dbConfig.AllowedOrigins, - AllowedHeaders: dbConfig.AllowedHeaders, - MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + DropExcessRequests: dbConfig.DropExcessRequests, + InitialPoolSize: dbConfig.InitialPoolSize, + PrometheusLabels: dbConfig.PrometheusLabels, + EnableLogging: dbConfig.EnableLogging, + DisableContentLogging: dbConfig.DisableContentLogging, + DisableDBPingsInHealth: dbConfig.DisableDBPingsInHealth, + LogRetentionDays: dbConfig.LogRetentionDays, + EnforceAuthOnInference: dbConfig.EnforceAuthOnInference, + EnforceGovernanceHeader: dbConfig.EnforceGovernanceHeader, + EnforceSCIMAuth: dbConfig.EnforceSCIMAuth, + AllowDirectKeys: dbConfig.AllowDirectKeys, + AllowedOrigins: dbConfig.AllowedOrigins, + AllowedHeaders: dbConfig.AllowedHeaders, + MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, + Compat: CompatConfig{ + ConvertTextToChat: dbConfig.CompatConvertTextToChat, + ConvertChatToResponses: dbConfig.CompatConvertChatToResponses, + ShouldDropParams: dbConfig.CompatShouldDropParams, + ShouldConvertParams: dbConfig.CompatShouldConvertParams, + }, MCPAgentDepth: dbConfig.MCPAgentDepth, MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, @@ -4481,4 +4489,4 @@ func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.C } s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) return nil -} +} \ No newline at end of file diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index a9ff7fc7f6..7dafc96f8e 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -37,8 +37,11 @@ type TableClientConfig struct { RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - // LiteLLM fallback flag - EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` + // Compat plugin feature flags + CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat;default:false" json:"-"` + CompatConvertChatToResponses bool `gorm:"column:compat_convert_chat_to_responses;default:false" json:"-"` + CompatShouldDropParams bool `gorm:"column:compat_should_drop_params;default:false" json:"-"` + CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params;default:false" json:"-"` // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/go.sum b/framework/go.sum index ed13db0550..e75ab122ec 100644 --- a/framework/go.sum +++ b/framework/go.sum @@ -22,6 +22,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -265,10 +266,15 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -292,9 +298,13 @@ golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 2122f339d5..0125118ea2 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "slices" - "strings" "sync" "time" @@ -323,649 +322,12 @@ func (mc *ModelCatalog) getPricingURL() string { return mc.pricingURL } -// getPricingSyncInterval returns a copy of the pricing sync interval under mutex protection -func (mc *ModelCatalog) getPricingSyncInterval() time.Duration { - mc.pricingMu.RLock() - defer mc.pricingMu.RUnlock() - return mc.pricingSyncInterval -} - -// GetPricingEntryForModel returns the pricing data -func (mc *ModelCatalog) GetPricingEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { - mc.mu.RLock() - defer mc.mu.RUnlock() - // Check all modes - for _, mode := range []schemas.RequestType{ - schemas.TextCompletionRequest, - schemas.ChatCompletionRequest, - schemas.ResponsesRequest, - schemas.EmbeddingRequest, - schemas.RerankRequest, - schemas.SpeechRequest, - schemas.TranscriptionRequest, - schemas.ImageGenerationRequest, - schemas.ImageEditRequest, - schemas.ImageVariationRequest, - schemas.VideoGenerationRequest, - } { - key := makeKey(model, string(provider), normalizeRequestType(mode)) - pricing, ok := mc.pricingData[key] - if ok { - return convertTableModelPricingToPricingData(&pricing) - } - } - return nil -} - -// GetModelCapabilityEntryForModel returns capability metadata for a model/provider pair. -// It prefers chat, then responses, then text-completion entries; if none exist, -// it falls back to the lexicographically first available mode for deterministic behavior. -func (mc *ModelCatalog) GetModelCapabilityEntryForModel(model string, provider schemas.ModelProvider) *PricingEntry { - mc.mu.RLock() - defer mc.mu.RUnlock() - - if entry := mc.getCapabilityEntryForExactModelUnsafe(model, provider); entry != nil { - return entry - } - - baseModel := mc.getBaseModelNameUnsafe(model) - if baseModel != model { - if entry := mc.getCapabilityEntryForExactModelUnsafe(baseModel, provider); entry != nil { - return entry - } - } - - if entry := mc.getCapabilityEntryForModelFamilyUnsafe(baseModel, provider); entry != nil { - return entry - } - - return nil -} - -func (mc *ModelCatalog) getCapabilityEntryForExactModelUnsafe(model string, provider schemas.ModelProvider) *PricingEntry { - preferredModes := []schemas.RequestType{ - schemas.ChatCompletionRequest, - schemas.ResponsesRequest, - schemas.TextCompletionRequest, - } - - for _, mode := range preferredModes { - key := makeKey(model, string(provider), normalizeRequestType(mode)) - pricing, ok := mc.pricingData[key] - if ok { - return convertTableModelPricingToPricingData(&pricing) - } - } - - prefix := model + "|" + string(provider) + "|" - matchingKeys := make([]string, 0) - for key := range mc.pricingData { - if strings.HasPrefix(key, prefix) { - matchingKeys = append(matchingKeys, key) - } - } - return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) -} - -func (mc *ModelCatalog) getCapabilityEntryForModelFamilyUnsafe(baseModel string, provider schemas.ModelProvider) *PricingEntry { - if baseModel == "" { - return nil - } - - matchingKeys := make([]string, 0) - for key, pricing := range mc.pricingData { - if normalizeProvider(pricing.Provider) != string(provider) { - continue - } - if mc.getBaseModelNameUnsafe(pricing.Model) != baseModel { - continue - } - matchingKeys = append(matchingKeys, key) - } - return mc.selectCapabilityEntryFromKeysUnsafe(matchingKeys) -} - -func (mc *ModelCatalog) selectCapabilityEntryFromKeysUnsafe(matchingKeys []string) *PricingEntry { - if len(matchingKeys) == 0 { - return nil - } - - preferredModes := []string{ - normalizeRequestType(schemas.ChatCompletionRequest), - normalizeRequestType(schemas.ResponsesRequest), - normalizeRequestType(schemas.TextCompletionRequest), - } - - for _, mode := range preferredModes { - modeMatches := make([]string, 0) - for _, key := range matchingKeys { - parts := strings.SplitN(key, "|", 3) - if len(parts) != 3 || parts[2] != mode { - continue - } - modeMatches = append(modeMatches, key) - } - if len(modeMatches) == 0 { - continue - } - slices.Sort(modeMatches) - pricing := mc.pricingData[modeMatches[0]] - return convertTableModelPricingToPricingData(&pricing) - } - - slices.Sort(matchingKeys) - pricing := mc.pricingData[matchingKeys[0]] - return convertTableModelPricingToPricingData(&pricing) -} - -// GetModelsForProvider returns all available models for a given provider (thread-safe) -func (mc *ModelCatalog) GetModelsForProvider(provider schemas.ModelProvider) []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - models, exists := mc.modelPool[provider] - if !exists { - return []string{} - } - - // Return a copy to prevent external modification - result := make([]string, len(models)) - copy(result, models) - return result -} - -// GetUnfilteredModelsForProvider returns all available models for a given provider (thread-safe) -func (mc *ModelCatalog) GetUnfilteredModelsForProvider(provider schemas.ModelProvider) []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - models, exists := mc.unfilteredModelPool[provider] - if !exists { - return []string{} - } - - // Return a copy to prevent external modification - result := make([]string, len(models)) - copy(result, models) - return result -} - -// GetDistinctBaseModelNames returns all unique base model names from the catalog (thread-safe). -// This is used for governance model selection when no specific provider is chosen. -func (mc *ModelCatalog) GetDistinctBaseModelNames() []string { - mc.mu.RLock() - defer mc.mu.RUnlock() - - seen := make(map[string]bool) - for _, baseName := range mc.baseModelIndex { - seen[baseName] = true - } - - result := make([]string, 0, len(seen)) - for name := range seen { - result = append(result, name) - } - return result -} - -// GetProvidersForModel returns all providers for a given model (thread-safe) -func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvider { - mc.mu.RLock() - defer mc.mu.RUnlock() - - providers := make([]schemas.ModelProvider, 0) - for provider, models := range mc.modelPool { - isModelMatch := false - for _, m := range models { - if m == model || mc.getBaseModelNameUnsafe(m) == mc.getBaseModelNameUnsafe(model) { - isModelMatch = true - break - } - } - if isModelMatch { - providers = append(providers, provider) - } - } - - // Handler special provider cases - // 1. Handler openrouter models - if !slices.Contains(providers, schemas.OpenRouter) { - for _, provider := range providers { - if openRouterModels, ok := mc.modelPool[schemas.OpenRouter]; ok { - if slices.Contains(openRouterModels, string(provider)+"/"+model) { - providers = append(providers, schemas.OpenRouter) - } - } - } - } - - // 2. Handle vertex models - if !slices.Contains(providers, schemas.Vertex) { - for _, provider := range providers { - if vertexModels, ok := mc.modelPool[schemas.Vertex]; ok { - if slices.Contains(vertexModels, string(provider)+"/"+model) { - providers = append(providers, schemas.Vertex) - } - } - } - } - - // 3. Handle openai models for groq - if !slices.Contains(providers, schemas.Groq) && strings.Contains(model, "gpt-") { - if groqModels, ok := mc.modelPool[schemas.Groq]; ok { - if slices.Contains(groqModels, "openai/"+model) { - providers = append(providers, schemas.Groq) - } - } - } - - // 4. Handle anthropic models for bedrock - if !slices.Contains(providers, schemas.Bedrock) && strings.Contains(model, "claude") { - if bedrockModels, ok := mc.modelPool[schemas.Bedrock]; ok { - for _, bedrockModel := range bedrockModels { - if strings.Contains(bedrockModel, model) { - providers = append(providers, schemas.Bedrock) - break - } - } - } - } - - return providers -} - -// IsModelAllowedForProvider checks if a model is allowed for a specific provider -// based on the allowed models list and catalog data. It handles all cross-provider -// logic including provider-prefixed models and special routing rules. -// -// Parameters: -// - provider: The provider to check against -// - model: The model name (without provider prefix, e.g., "gpt-4o" or "claude-3-5-sonnet") -// - allowedModels: List of allowed model names (can be empty, can include provider prefixes) -// -// Behavior: -// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model -// (delegates to GetProvidersForModel which handles all cross-provider logic) -// - If allowedModels is empty ([]): Deny-by-default — returns false for any provider/model pair -// - If allowedModels is not empty: Checks if model matches any entry in the list -// Provider-specific validation: -// - Direct matches: "gpt-4o" in allowedModels for any provider -// - Prefixed matches: Only if the prefixed model exists in provider's catalog -// (e.g., "openai/gpt-4o" in allowedModels only matches if openrouter's catalog -// contains "openai/gpt-4o" AND the model part matches the request) -// -// Returns: -// - bool: true if the model is allowed for the provider, false otherwise -// -// Examples: -// -// // Wildcard allowedModels - uses catalog to check provider support -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"}) -// // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet") -// -// // Empty allowedModels - deny all (deny-by-default) -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{}) -// // Returns: false (no models are permitted) -// -// // Explicit allowedModels with prefix - validates against catalog -// mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"}) -// // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o") -// -// // Explicit allowedModels with prefix - wrong model -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"openai/gpt-4o"}) -// // Returns: false (model part "gpt-4o" doesn't match request "claude-3-5-sonnet") -// -// // Explicit allowedModels without prefix -// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) -// // Returns: true (direct match) -func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels schemas.WhiteList) bool { - // Case 1: ["*"] = allow all models; use catalog to determine support - // Empty allowedModels = deny all (fail-safe deny-by-default) - if allowedModels.IsUnrestricted() { - supportedProviders := mc.GetProvidersForModel(model) - return slices.Contains(supportedProviders, provider) - } - if allowedModels.IsEmpty() { - return false - } - - // Case 2: Explicit allowedModels = check if model matches any entry - // Get provider's catalog models for validation of prefixed entries - providerCatalogModels := mc.GetModelsForProvider(provider) - - for _, allowedModel := range allowedModels { - // Direct match: "gpt-4o" == "gpt-4o" - if allowedModel == model { - return true - } - - // Provider-prefixed match: verify it exists in provider's catalog first - // This ensures we only allow provider-specific model combinations that are actually supported - if strings.Contains(allowedModel, "/") { - // Check if this exact prefixed model exists in the provider's catalog - // e.g., for openrouter, check if "openai/gpt-4o" is in its catalog - if slices.Contains(providerCatalogModels, allowedModel) { - // Extract the model part and compare with request - _, modelPart := schemas.ParseModelString(allowedModel, "") - if modelPart == model { - return true - } - } - } - } - - return false -} - -// GetBaseModelName returns the canonical base model name for a given model string. -// It uses the pre-computed base_model from the pricing catalog when available, -// falling back to algorithmic date/version stripping for models not in the catalog. -// -// Examples: -// -// mc.GetBaseModelName("gpt-4o") // Returns: "gpt-4o" -// mc.GetBaseModelName("openai/gpt-4o") // Returns: "gpt-4o" -// mc.GetBaseModelName("gpt-4o-2024-08-06") // Returns: "gpt-4o" (algorithmic fallback) -func (mc *ModelCatalog) GetBaseModelName(model string) string { - mc.mu.RLock() - defer mc.mu.RUnlock() - return mc.getBaseModelNameUnsafe(model) -} - -// getBaseModelNameUnsafe returns the canonical base model name for a given model string without locking. -// This is used to avoid locking overhead when getting the base model name for many models. -// Make sure the caller function is holding the read lock before calling this function. -// It is not safe to use this function when the model pool is being updated. -func (mc *ModelCatalog) getBaseModelNameUnsafe(model string) string { - // Step 1: Direct lookup in base model index - if base, ok := mc.baseModelIndex[model]; ok { - return base - } - - // Step 2: Strip provider prefix and try again - _, baseName := schemas.ParseModelString(model, "") - if baseName != model { - if base, ok := mc.baseModelIndex[baseName]; ok { - return base - } - } - - // Step 3: Fallback to algorithmic date/version stripping - // (for models not in the catalog, e.g., user-configured custom models) - return schemas.BaseModelName(baseName) -} - -// IsSameModel checks if two model strings refer to the same underlying model. -// It compares the canonical base model names derived from the pricing catalog -// (or algorithmic fallback for models not in the catalog). -// -// Examples: -// -// mc.IsSameModel("gpt-4o", "gpt-4o") // true (direct match) -// mc.IsSameModel("openai/gpt-4o", "gpt-4o") // true (same base model) -// mc.IsSameModel("gpt-4o", "claude-3-5-sonnet") // false (different models) -// mc.IsSameModel("openai/gpt-4o", "anthropic/claude-3-5-sonnet") // false -func (mc *ModelCatalog) IsSameModel(model1, model2 string) bool { - if model1 == model2 { - return true - } - return mc.GetBaseModelName(model1) == mc.GetBaseModelName(model2) -} - -// DeleteModelDataForProvider deletes all model data from the pool for a given provider -func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvider) { - mc.mu.Lock() - defer mc.mu.Unlock() - - delete(mc.modelPool, provider) - delete(mc.unfilteredModelPool, provider) -} - -// UpsertModelDataForProvider upserts model data for a given provider -func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { - if modelData == nil { - return - } - mc.mu.Lock() - defer mc.mu.Unlock() - - // Populating models from pricing data for the given provider - // Provider models map - providerModels := []string{} - // Iterate through all pricing data to collect models per provider - for _, pricing := range mc.pricingData { - // Normalize provider before adding to model pool - normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) - // We will only add models for the given provider - if normalizedProvider != provider { - continue - } - // Add model to the provider's model set (using map for deduplication) - if slices.Contains(providerModels, pricing.Model) { - continue - } - providerModels = append(providerModels, pricing.Model) - // Build base model index from pre-computed base_model field - if pricing.BaseModel != "" { - mc.baseModelIndex[pricing.Model] = pricing.BaseModel - } - } - // If modelData is empty, then we allow all models - if len(modelData.Data) == 0 && len(allowedModels) == 0 { - mc.modelPool[provider] = providerModels - return - } - // Here we make sure that we still keep the backup for model catalog intact - // So we start with a existing model pool and add the new models from incoming data - finalModelList := make([]string, 0) - seenModels := make(map[string]bool) - // Case where list models failed but we have allowed models from keys - if len(modelData.Data) == 0 && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - parsedProvider, parsedModel := schemas.ParseModelString(allowedModel.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - finalModelList = append(finalModelList, parsedModel) - } - } - } - for _, model := range modelData.Data { - parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - finalModelList = append(finalModelList, parsedModel) - } - } - - if len(allowedModels) == 0 { - for _, model := range providerModels { - if !seenModels[model] { - seenModels[model] = true - finalModelList = append(finalModelList, model) - } - } - } - mc.modelPool[provider] = finalModelList -} - -// UpsertUnfilteredModelDataForProvider upserts unfiltered model data for a given provider -func (mc *ModelCatalog) UpsertUnfilteredModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse) { - if modelData == nil { - return - } - mc.mu.Lock() - defer mc.mu.Unlock() - - // Populating models from pricing data for the given provider - providerModels := []string{} - seenModels := make(map[string]bool) - for _, pricing := range mc.pricingData { - normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider)) - if normalizedProvider != provider { - continue - } - if !seenModels[pricing.Model] { - seenModels[pricing.Model] = true - providerModels = append(providerModels, pricing.Model) - } - } - for _, model := range modelData.Data { - parsedProvider, parsedModel := schemas.ParseModelString(model.ID, "") - if parsedProvider != provider { - continue - } - if !seenModels[parsedModel] { - seenModels[parsedModel] = true - providerModels = append(providerModels, parsedModel) - } - } - mc.unfilteredModelPool[provider] = providerModels -} - -// RefineModelForProvider refines the model for a given provider by performing a lookup -// in mc.modelPool and using schemas.ParseModelString to extract provider and model parts. -// e.g. "gpt-oss-120b" for groq provider -> "openai/gpt-oss-120b" -// -// Behavior: -// - When the provider's catalog (mc.modelPool) yields multiple matching models, returns an error -// - When exactly one match is found, returns the fully-qualified model (provider/model format) -// - When the provider is not handled or no refinement is needed, returns the original model unchanged -func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, model string) (string, error) { - switch provider { - case schemas.Groq: - if strings.Contains(model, "gpt-") { - return "openai/" + model, nil - } - return mc.refineNestedProviderModel(provider, model) - case schemas.Replicate: - return mc.refineNestedProviderModel(provider, model) - } - return model, nil -} - -// refineNestedProviderModel resolves provider-native model slugs such as -// "openai/gpt-5-nano" from a base model request like "gpt-5-nano". -// It only considers catalog entries whose leading segment is a known Bifrost provider, -// so Replicate owner/model identifiers like "meta/llama-3-8b" are left untouched. -func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider, model string) (string, error) { - mc.mu.RLock() - models, ok := mc.modelPool[provider] - mc.mu.RUnlock() - if !ok { - return model, nil - } - - candidateModels := make([]string, 0) - seenCandidates := make(map[string]struct{}) - for _, poolModel := range models { - providerPart, modelPart := schemas.ParseModelString(poolModel, "") - if providerPart == "" || model != modelPart { - continue - } - - candidate := string(providerPart) + "/" + modelPart - if _, seen := seenCandidates[candidate]; seen { - continue - } - seenCandidates[candidate] = struct{}{} - candidateModels = append(candidateModels, candidate) - } - - switch len(candidateModels) { - case 0: - return model, nil - case 1: - return candidateModels[0], nil - default: - return "", fmt.Errorf("multiple compatible models found for model %s: %v", model, candidateModels) - } -} - -// SetPricingOverrides replaces the full in-memory pricing override set. -func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error { - seen := make(map[string]int, len(rows)) - overrides := make([]PricingOverride, 0, len(rows)) - for i := range rows { - o, err := convertTablePricingOverrideToPricingOverride(&rows[i]) - if err != nil { - return err - } - if idx, exists := seen[o.ID]; exists { - overrides[idx] = o // last entry wins for duplicate IDs - } else { - seen[o.ID] = len(overrides) - overrides = append(overrides, o) - } - } - mc.overridesMu.Lock() - mc.rawOverrides = overrides - mc.customPricing = buildCustomPricingData(overrides) - mc.overridesMu.Unlock() - return nil -} - -// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single -// operation, rebuilding the lookup map only once at the end. -func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error { - // Deduplicate the input batch by ID (last entry wins) and build the - // incoming set for O(1) lookup when filtering existing rawOverrides. - seenIncoming := make(map[string]int, len(rows)) - overrides := make([]PricingOverride, 0, len(rows)) - for _, row := range rows { - o, err := convertTablePricingOverrideToPricingOverride(row) - if err != nil { - return err - } - if idx, exists := seenIncoming[o.ID]; exists { - overrides[idx] = o // last entry wins for duplicate IDs - } else { - seenIncoming[o.ID] = len(overrides) - overrides = append(overrides, o) - } - } - - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - - updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides)) - for _, o := range mc.rawOverrides { - if _, replacing := seenIncoming[o.ID]; !replacing { - updated = append(updated, o) - } - } - updated = append(updated, overrides...) - mc.rawOverrides = updated - mc.customPricing = buildCustomPricingData(updated) - return nil -} - -// DeletePricingOverride removes a pricing override by ID. -func (mc *ModelCatalog) DeletePricingOverride(id string) { - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - - updated := make([]PricingOverride, 0, len(mc.rawOverrides)) - for _, o := range mc.rawOverrides { - if o.ID != id { - updated = append(updated, o) - } - } - mc.rawOverrides = updated - mc.customPricing = buildCustomPricingData(updated) -} - // IsRequestTypeSupported checks if a model supports chat completion. // It checks the supportedResponseTypes index. func (mc *ModelCatalog) IsRequestTypeSupported(model string, provider schemas.ModelProvider, requestType schemas.RequestType) bool { mc.mu.RLock() + defer mc.mu.RUnlock() outputs, ok := mc.supportedResponseTypes[model] - mc.mu.RUnlock() return ok && slices.Contains(outputs, string(requestType)) } @@ -1069,4 +431,4 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { supportedParams: make(map[string][]string), done: make(chan struct{}), } -} +} \ No newline at end of file diff --git a/helm-charts/bifrost/templates/_helpers.tpl b/helm-charts/bifrost/templates/_helpers.tpl index 8dc0606658..97e7b7e4f4 100644 --- a/helm-charts/bifrost/templates/_helpers.tpl +++ b/helm-charts/bifrost/templates/_helpers.tpl @@ -227,8 +227,21 @@ false {{- if .Values.bifrost.client.maxRequestBodySizeMb }} {{- $_ := set $client "max_request_body_size_mb" .Values.bifrost.client.maxRequestBodySizeMb }} {{- end }} -{{- if hasKey .Values.bifrost.client "enableLitellmFallbacks" }} -{{- $_ := set $client "enable_litellm_fallbacks" .Values.bifrost.client.enableLitellmFallbacks }} +{{- if .Values.bifrost.client.compat }} +{{- $compat := dict }} +{{- if hasKey .Values.bifrost.client.compat "convertTextToChat" }} +{{- $_ := set $compat "convert_text_to_chat" .Values.bifrost.client.compat.convertTextToChat }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "convertChatToResponses" }} +{{- $_ := set $compat "convert_chat_to_responses" .Values.bifrost.client.compat.convertChatToResponses }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldDropParams" }} +{{- $_ := set $compat "should_drop_params" .Values.bifrost.client.compat.shouldDropParams }} +{{- end }} +{{- if hasKey .Values.bifrost.client.compat "shouldConvertParams" }} +{{- $_ := set $compat "should_convert_params" .Values.bifrost.client.compat.shouldConvertParams }} +{{- end }} +{{- $_ := set $client "compat" $compat }} {{- end }} {{- if .Values.bifrost.client.prometheusLabels }} {{- $_ := set $client "prometheus_labels" .Values.bifrost.client.prometheusLabels }} diff --git a/helm-charts/bifrost/values.schema.json b/helm-charts/bifrost/values.schema.json index 495e9c8e79..62239bae54 100644 --- a/helm-charts/bifrost/values.schema.json +++ b/helm-charts/bifrost/values.schema.json @@ -293,8 +293,15 @@ "type": "integer", "minimum": 1 }, - "enableLitellmFallbacks": { - "type": "boolean" + "compat": { + "type": "object", + "additionalProperties": false, + "properties": { + "convertTextToChat": { "type": "boolean" }, + "convertChatToResponses": { "type": "boolean" }, + "shouldDropParams": { "type": "boolean" }, + "shouldConvertParams": { "type": "boolean" } + } }, "prometheusLabels": { "type": "array", @@ -3163,4 +3170,4 @@ "additionalProperties": false } } -} +} \ No newline at end of file diff --git a/helm-charts/bifrost/values.yaml b/helm-charts/bifrost/values.yaml index 8509859f01..10d1078b0d 100644 --- a/helm-charts/bifrost/values.yaml +++ b/helm-charts/bifrost/values.yaml @@ -188,7 +188,11 @@ bifrost: enforceGovernanceHeader: false allowDirectKeys: false maxRequestBodySizeMb: 100 - enableLitellmFallbacks: false + compat: + convertTextToChat: false + convertChatToResponses: false + shouldDropParams: false + shouldConvertParams: false prometheusLabels: [] # Header filtering configuration for x-bf-eh-* headers forwarded to LLM providers headerFilterConfig: diff --git a/plugins/compat/conversion.go b/plugins/compat/conversion.go index d51ca4d730..176f1b1389 100644 --- a/plugins/compat/conversion.go +++ b/plugins/compat/conversion.go @@ -7,19 +7,19 @@ func applyParameterConversion(req *schemas.BifrostRequest) { if req == nil { return } - normalizeDeveloperRoleForChatRequest(req) -} -func normalizeDeveloperRoleForChatRequest(req *schemas.BifrostRequest) { - if req.ChatRequest == nil { - return + if req.ChatRequest != nil { + normalizeDeveloperRoleForChatRequest(req.ChatRequest) } - if req.ChatRequest.Provider != schemas.Bedrock && req.ChatRequest.Provider != schemas.Vertex && req.ChatRequest.Provider != schemas.Gemini { +} + +func normalizeDeveloperRoleForChatRequest(req *schemas.BifrostChatRequest) { + if req.Provider != schemas.Bedrock && req.Provider != schemas.Vertex && req.Provider != schemas.Gemini { return } - for i := range req.ChatRequest.Input { - if req.ChatRequest.Input[i].Role == schemas.ChatMessageRoleDeveloper { - req.ChatRequest.Input[i].Role = schemas.ChatMessageRoleSystem + for i := range req.Input { + if req.Input[i].Role == schemas.ChatMessageRoleDeveloper { + req.Input[i].Role = schemas.ChatMessageRoleSystem } } } diff --git a/plugins/compat/go.sum b/plugins/compat/go.sum index d231e3c8df..2bc5801fc0 100644 --- a/plugins/compat/go.sum +++ b/plugins/compat/go.sum @@ -22,6 +22,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -267,10 +268,15 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -294,9 +300,13 @@ golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/plugins/compat/main.go b/plugins/compat/main.go index 0c64b7b6ca..6c536d6bd6 100644 --- a/plugins/compat/main.go +++ b/plugins/compat/main.go @@ -13,7 +13,15 @@ const PluginName = "compat" // Config defines the configuration for the compat plugin. type Config struct { - Enabled bool `json:"enabled"` + ConvertTextToChat bool `json:"convert_text_to_chat"` + ConvertChatToResponses bool `json:"convert_chat_to_responses"` + ShouldDropParams bool `json:"should_drop_params"` + ShouldConvertParams bool `json:"should_convert_params"` +} + +// IsEnabled returns true if any compat feature is enabled +func (c Config) IsEnabled() bool { + return c.ConvertTextToChat || c.ConvertChatToResponses || c.ShouldDropParams || c.ShouldConvertParams } // CompatPlugin provides LiteLLM-compatible request/response transformations. @@ -67,20 +75,34 @@ func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr return req, nil, nil } - // text completion → chat conversion - if (req.RequestType == schemas.TextCompletionRequest || req.RequestType == schemas.TextCompletionStreamRequest) && req.TextCompletionRequest != nil { - p.markForConversion(ctx, req.TextCompletionRequest.Provider, req.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + convertTextToChatOverride, convertTextToChatOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatConvertTextToChat).(bool) + convertChatToResponsesOverride, convertChatToResponsesOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatConvertChatToResponses).(bool) + shouldDropParamsOverride, shouldDropParamsOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatShouldDropParams).(bool) + shouldConvertParamsOverride, shouldConvertParamsOverrideEnabled := ctx.Value(schemas.BifrostContextKeyCompatShouldConvertParams).(bool) + + modifiedReq := req + if (shouldDropParamsOverrideEnabled && shouldDropParamsOverride) || (shouldConvertParamsOverrideEnabled && shouldDropParamsOverride) || p.config.ShouldConvertParams || p.config.ShouldDropParams { + modifiedReq = cloneBifrostReq(req) } + p.droppedParams = nil - // chat completion → responses conversion - if (req.RequestType == schemas.ChatCompletionRequest || req.RequestType == schemas.ChatCompletionStreamRequest) && req.ChatRequest != nil { - p.markForConversion(ctx, req.ChatRequest.Provider, req.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + // Text completion → chat conversion + if (convertTextToChatOverrideEnabled && convertTextToChatOverride) || p.config.ConvertTextToChat { + if (modifiedReq.RequestType == schemas.TextCompletionRequest || modifiedReq.RequestType == schemas.TextCompletionStreamRequest) && modifiedReq.TextCompletionRequest != nil { + p.markForConversion(ctx, modifiedReq.TextCompletionRequest.Provider, modifiedReq.TextCompletionRequest.Model, schemas.TextCompletionRequest, schemas.ChatCompletionRequest) + } } - modifiedReq := cloneBifrostReq(req) - p.droppedParams = nil - if p.modelCatalog != nil { - _, model, _ := req.GetRequestFields() + // Chat completion → responses conversion + if (convertChatToResponsesOverrideEnabled && convertChatToResponsesOverride) || p.config.ConvertChatToResponses { + if (modifiedReq.RequestType == schemas.ChatCompletionRequest || modifiedReq.RequestType == schemas.ChatCompletionStreamRequest) && modifiedReq.ChatRequest != nil { + p.markForConversion(ctx, modifiedReq.ChatRequest.Provider, modifiedReq.ChatRequest.Model, schemas.ChatCompletionRequest, schemas.ResponsesRequest) + } + } + + // Compute unsupported parameters to drop based on model catalog allowlist + if ((shouldDropParamsOverrideEnabled && shouldDropParamsOverride) || p.config.ShouldDropParams) && p.modelCatalog != nil { + _, model, _ := modifiedReq.GetRequestFields() if model != "" { if supportedParams := p.modelCatalog.GetSupportedParameters(model); supportedParams != nil { droppedParams := dropUnsupportedParams(modifiedReq, supportedParams) @@ -91,7 +113,9 @@ func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr } } - applyParameterConversion(modifiedReq) + if (shouldConvertParamsOverride && shouldConvertParamsOverrideEnabled) || p.config.ShouldConvertParams { + applyParameterConversion(modifiedReq) + } return modifiedReq, nil, nil } @@ -130,17 +154,16 @@ func (p *CompatPlugin) Cleanup() error { // markForConversion checks if the model supports the current request type; if not, mark for conversion func (p *CompatPlugin) markForConversion(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, currentType schemas.RequestType, targetType schemas.RequestType) { - shouldConvert := true - + shouldConvert := false if p.modelCatalog != nil { - if p.modelCatalog.IsRequestTypeSupported(model, provider, currentType) { - p.logger.Debug("compat: model %s/%s supports %v, skipping conversion", provider, model, currentType) - shouldConvert = false + if !p.modelCatalog.IsRequestTypeSupported(model, provider, currentType) && p.modelCatalog.IsRequestTypeSupported(model, provider, targetType) { + shouldConvert = true } + } else { + p.logger.Debug("compat: model calalog is nil") } if shouldConvert { ctx.SetValue(schemas.BifrostContextKeyChangeRequestType, targetType) - p.logger.Debug("compat: marked %v for core conversion to %v for model %s", currentType, targetType, model) } -} \ No newline at end of file +} diff --git a/plugins/compat/requestcopy.go b/plugins/compat/requestcopy.go index f7a3937f92..a92e9a81b8 100644 --- a/plugins/compat/requestcopy.go +++ b/plugins/compat/requestcopy.go @@ -1,7 +1,6 @@ package compat import ( - "bytes" "maps" "slices" @@ -15,46 +14,19 @@ func cloneBifrostReq(req *schemas.BifrostRequest) *schemas.BifrostRequest { cloned := *req - if req.TextCompletionRequest != nil { - cloned.TextCompletionRequest = cloneTextCompletionRequest(req.TextCompletionRequest) + if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { + cloned.TextCompletionRequest.Params = cloneTextCompletionParameters(req.TextCompletionRequest.Params) } - if req.ChatRequest != nil { - cloned.ChatRequest = cloneChatRequest(req.ChatRequest) + if req.ChatRequest != nil && req.ChatRequest.Params != nil { + cloned.ChatRequest.Params = cloneChatParameters(req.ChatRequest.Params) } - if req.ResponsesRequest != nil { - cloned.ResponsesRequest = cloneResponsesRequest(req.ResponsesRequest) + if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { + cloned.ResponsesRequest.Params = cloneResponsesParameters(req.ResponsesRequest.Params) } return &cloned } -func cloneTextCompletionRequest(req *schemas.BifrostTextCompletionRequest) *schemas.BifrostTextCompletionRequest { - if req == nil { - return nil - } - - cloned := *req - cloned.Input = cloneTextCompletionInput(req.Input) - cloned.Params = cloneTextCompletionParameters(req.Params) - cloned.Fallbacks = slices.Clone(req.Fallbacks) - cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) - return &cloned -} - -func cloneTextCompletionInput(input *schemas.TextCompletionInput) *schemas.TextCompletionInput { - if input == nil { - return nil - } - cloned := &schemas.TextCompletionInput{ - PromptArray: slices.Clone(input.PromptArray), - } - if input.PromptStr != nil { - prompt := *input.PromptStr - cloned.PromptStr = &prompt - } - return cloned -} - func cloneTextCompletionParameters(params *schemas.TextCompletionParameters) *schemas.TextCompletionParameters { if params == nil { return nil @@ -77,24 +49,6 @@ func cloneTextCompletionParameters(params *schemas.TextCompletionParameters) *sc return &cloned } -func cloneChatRequest(req *schemas.BifrostChatRequest) *schemas.BifrostChatRequest { - if req == nil { - return nil - } - - cloned := *req - if req.Input != nil { - cloned.Input = make([]schemas.ChatMessage, len(req.Input)) - for i, message := range req.Input { - cloned.Input[i] = schemas.DeepCopyChatMessage(message) - } - } - cloned.Params = cloneChatParameters(req.Params) - cloned.Fallbacks = slices.Clone(req.Fallbacks) - cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) - return &cloned -} - func cloneChatParameters(params *schemas.ChatParameters) *schemas.ChatParameters { if params == nil { return nil @@ -201,24 +155,6 @@ func cloneChatWebSearchOptions(options *schemas.ChatWebSearchOptions) *schemas.C return &cloned } -func cloneResponsesRequest(req *schemas.BifrostResponsesRequest) *schemas.BifrostResponsesRequest { - if req == nil { - return nil - } - - cloned := *req - if req.Input != nil { - cloned.Input = make([]schemas.ResponsesMessage, len(req.Input)) - for i, message := range req.Input { - cloned.Input[i] = schemas.DeepCopyResponsesMessage(message) - } - } - cloned.Params = cloneResponsesParameters(req.Params) - cloned.Fallbacks = slices.Clone(req.Fallbacks) - cloned.RawRequestBody = bytes.Clone(req.RawRequestBody) - return &cloned -} - func cloneResponsesParameters(params *schemas.ResponsesParameters) *schemas.ResponsesParameters { if params == nil { return nil diff --git a/tests/governance/config.json b/tests/governance/config.json index bd9080a064..b8cedf9a3e 100644 --- a/tests/governance/config.json +++ b/tests/governance/config.json @@ -62,7 +62,6 @@ "enable_logging": true, "enforce_auth_on_inference": true, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/python/config.json b/tests/integrations/python/config.json index 00b89b5bdb..866469cc1d 100644 --- a/tests/integrations/python/config.json +++ b/tests/integrations/python/config.json @@ -343,7 +343,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json index cf49dba281..46bc65af6b 100644 --- a/tests/integrations/typescript/config.json +++ b/tests/integrations/typescript/config.json @@ -220,7 +220,6 @@ "enable_logging": true, "enforce_auth_on_inference": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 } } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index ec28c9d3ad..dee74227c7 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -343,21 +343,32 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { } // Handle compat plugin toggle - if payload.ClientConfig.EnableLiteLLMFallbacks != currentConfig.EnableLiteLLMFallbacks { - if payload.ClientConfig.EnableLiteLLMFallbacks { - // Load and register the compat plugin - if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, &compat.Config{Enabled: true}, nil, nil); err != nil { + newCompat := payload.ClientConfig.Compat + oldCompat := currentConfig.Compat + if newCompat != oldCompat { + newEnabled := newCompat.ConvertTextToChat || newCompat.ConvertChatToResponses || newCompat.ShouldDropParams || newCompat.ShouldConvertParams + if newEnabled { + compatCfg := &compat.Config{ + ConvertTextToChat: newCompat.ConvertTextToChat, + ConvertChatToResponses: newCompat.ConvertChatToResponses, + ShouldDropParams: newCompat.ShouldDropParams, + ShouldConvertParams: newCompat.ShouldConvertParams, + } + if err := h.configManager.ReloadPlugin(ctx, compat.PluginName, nil, compatCfg, nil, nil); err != nil { logger.Warn("failed to load compat plugin: %v", err) + SendError(ctx, 400, "Failed to load compat plugin") + return } } else { - // Remove the compat plugin disabledCtx := context.WithValue(ctx, PluginDisabledKey, true) if err := h.configManager.RemovePlugin(disabledCtx, compat.PluginName); err != nil { logger.Warn("failed to remove compat plugin: %v", err) + SendError(ctx, 400, "Failed to remove compat plugin") + return } } } - updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks + updatedConfig.Compat = newCompat // Only update MCP fields if explicitly provided (non-zero) to avoid clearing stored values if payload.ClientConfig.MCPAgentDepth > 0 { updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index c320086bb8..94202c5b30 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -36,8 +36,8 @@ import ( "github.com/maximhq/bifrost/framework/oauth2" plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" - "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/compat" + "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" @@ -309,7 +309,6 @@ var DefaultClientConfig = configstore.ClientConfig{ MCPAgentDepth: 10, MCPToolExecutionTimeout: 30, MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), - EnableLiteLLMFallbacks: false, HideDeletedVirtualKeysInFilters: false, RoutingChainMaxDepth: governance.DefaultRoutingChainMaxDepth, } @@ -4052,4 +4051,4 @@ func DeepCopy[T any](in T) (T, error) { } err = sonic.Unmarshal(b, &out) return out, err -} +} \ No newline at end of file diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index ce1c2432a8..05cf7ade20 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -400,6 +400,7 @@ func (m *MockConfigStore) DB() *gorm.DB { retu func (m *MockConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { return fn(nil) } + func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator.Migration) error { return nil } @@ -1130,18 +1131,23 @@ func (m *MockConfigStore) DeleteOauthToken(ctx context.Context, id string) error func (m *MockConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } + func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { return nil } @@ -1150,18 +1156,23 @@ func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *t func (m *MockConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { return nil } + func (m *MockConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { return nil } @@ -1170,33 +1181,43 @@ func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, func (m *MockConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { return nil } @@ -1204,24 +1225,31 @@ func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tabl func (m *MockConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { return nil, nil } + func (m *MockConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { return nil } + func (m *MockConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { return nil } + func (m *MockConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { return 1, nil } + func (m *MockConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { return nil, nil } + func (m *MockConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { return nil } + func (m *MockConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { return 1, nil } @@ -1263,12 +1291,15 @@ func (m *MockConfigStore) DeleteRoutingRule(ctx context.Context, id string, tx . func (m *MockConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) GetFolderByID(ctx context.Context, id string) (*tables.TableFolder, error) { return nil, nil } + func (m *MockConfigStore) CreateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } + func (m *MockConfigStore) UpdateFolder(ctx context.Context, folder *tables.TableFolder) error { return nil } @@ -1278,12 +1309,15 @@ func (m *MockConfigStore) DeleteFolder(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) GetPromptByID(ctx context.Context, id string) (*tables.TablePrompt, error) { return nil, nil } + func (m *MockConfigStore) CreatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } + func (m *MockConfigStore) UpdatePrompt(ctx context.Context, prompt *tables.TablePrompt) error { return nil } @@ -1293,15 +1327,19 @@ func (m *MockConfigStore) DeletePrompt(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptVersion(ctx context.Context, version *tables.TablePromptVersion) error { return nil } @@ -1311,15 +1349,19 @@ func (m *MockConfigStore) DeletePromptVersion(ctx context.Context, id uint) erro func (m *MockConfigStore) GetPromptSessions(ctx context.Context, promptID string) ([]tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) GetPromptSessionByID(ctx context.Context, id uint) (*tables.TablePromptSession, error) { return nil, nil } + func (m *MockConfigStore) CreatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) UpdatePromptSession(ctx context.Context, session *tables.TablePromptSession) error { return nil } + func (m *MockConfigStore) RenamePromptSession(ctx context.Context, id uint, name string) error { return nil } @@ -11981,6 +12023,7 @@ type mockLLMPlugin struct { func (p *mockLLMPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { return req, nil, nil } + func (p *mockLLMPlugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { return resp, bifrostErr, nil } @@ -12324,7 +12367,6 @@ func TestGenerateClientConfigHash(t *testing.T) { AllowDirectKeys: true, AllowedOrigins: []string{"http://localhost:3000"}, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } hash1, err := cc1.GenerateClientConfigHash() @@ -12421,12 +12463,12 @@ func TestGenerateClientConfigHash(t *testing.T) { t.Error("Different MaxRequestBodySizeMB should produce different hash") } - // Different EnableLiteLLMFallbacks should produce different hash + // Different Compat should produce different hash cc13 := cc1 - cc13.EnableLiteLLMFallbacks = true + cc13.Compat.ConvertTextToChat = true hash13, _ := cc13.GenerateClientConfigHash() if hash1 == hash13 { - t.Error("Different EnableLiteLLMFallbacks should produce different hash") + t.Error("Different Compat.ConvertTextToChat should produce different hash") } // PrometheusLabels order should not matter (sorted) @@ -13459,7 +13501,6 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: false, AllowDirectKeys: true, MaxRequestBodySizeMB: 100, - EnableLiteLLMFallbacks: false, } // Generate hash from config @@ -13473,7 +13514,12 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccToSave.EnforceAuthOnInference, AllowDirectKeys: ccToSave.AllowDirectKeys, MaxRequestBodySizeMB: ccToSave.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccToSave.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ + ConvertTextToChat: ccToSave.CompatConvertTextToChat, + ConvertChatToResponses: ccToSave.CompatConvertChatToResponses, + ShouldDropParams: ccToSave.CompatShouldDropParams, + ShouldConvertParams: ccToSave.CompatShouldConvertParams, + }, } hashBeforeSave, _ := clientConfig.GenerateClientConfigHash() @@ -13492,7 +13538,12 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { EnforceAuthOnInference: ccFromDB.EnforceAuthOnInference, AllowDirectKeys: ccFromDB.AllowDirectKeys, MaxRequestBodySizeMB: ccFromDB.MaxRequestBodySizeMB, - EnableLiteLLMFallbacks: ccFromDB.EnableLiteLLMFallbacks, + Compat: configstore.CompatConfig{ + ConvertTextToChat: ccFromDB.CompatConvertTextToChat, + ConvertChatToResponses: ccFromDB.CompatConvertChatToResponses, + ShouldDropParams: ccFromDB.CompatShouldDropParams, + ShouldConvertParams: ccFromDB.CompatShouldConvertParams, + }, } hashAfterLoad, _ := clientConfigFromDB.GenerateClientConfigHash() @@ -15649,13 +15700,13 @@ func TestConfigSchemaSyncTopLevel(t *testing.T) { // Enterprise-only features: These fields exist in the JSON schema for documentation // and validation purposes, but are only available in the enterprise version. enterpriseSchemaFields := map[string]bool{ - "$schema": true, - "audit_logs": true, - "cluster_config": true, - "saml_config": true, - "load_balancer_config": true, - "guardrails_config": true, - "large_payload_optimization": true, + "$schema": true, + "audit_logs": true, + "cluster_config": true, + "saml_config": true, + "load_balancer_config": true, + "guardrails_config": true, + "large_payload_optimization": true, } schema := loadJSONSchema(t) @@ -16602,7 +16653,10 @@ func assertDefaultClientConfigValues(t *testing.T, cc configstore.ClientConfig) require.Equal(t, 100, cc.MaxRequestBodySizeMB, "MaxRequestBodySizeMB should default to 100") require.Equal(t, 10, cc.MCPAgentDepth, "MCPAgentDepth should default to 10") require.Equal(t, 30, cc.MCPToolExecutionTimeout, "MCPToolExecutionTimeout should default to 30") - require.Equal(t, false, cc.EnableLiteLLMFallbacks, "EnableLiteLLMFallbacks should default to false") + require.Equal(t, false, cc.Compat.ConvertTextToChat, "Compat.ConvertTextToChat should default to false") + require.Equal(t, false, cc.Compat.ConvertChatToResponses, "Compat.ConvertChatToResponses should default to false") + require.Equal(t, false, cc.Compat.ShouldDropParams, "Compat.ShouldDropParams should default to false") + require.Equal(t, false, cc.Compat.ShouldConvertParams, "Compat.ShouldConvertParams should default to false") require.Equal(t, false, cc.HideDeletedVirtualKeysInFilters, "HideDeletedVirtualKeysInFilters should default to false") } diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 36c56cc2e5..7f6127406d 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -8,6 +8,7 @@ package lib import ( "context" + "encoding/json" "fmt" "strconv" "strings" @@ -443,6 +444,47 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat } return true } + + // Compat header: per-request override of compat plugin settings. + // Accepts: "true" (enable all), JSON array of feature names, or ["*"] (enable all). + // An empty array [] or absent header means no overrides. + if keyStr == "x-bf-compat" { + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertTextToChat) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatConvertChatToResponses) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldDropParams) + bifrostCtx.ClearValue(schemas.BifrostContextKeyCompatShouldConvertParams) + valueStr := strings.TrimSpace(string(value)) + if valueStr == "true" { + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } else if strings.HasPrefix(valueStr, "[") { + var features []string + if err := json.Unmarshal([]byte(valueStr), &features); err == nil { + if len(features) == 1 && features[0] == "*" { + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } else { + for _, f := range features { + switch f { + case "convert_text_to_chat": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertTextToChat, true) + case "convert_chat_to_responses": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatConvertChatToResponses, true) + case "should_drop_params": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldDropParams, true) + case "should_convert_params": + bifrostCtx.SetValue(schemas.BifrostContextKeyCompatShouldConvertParams, true) + } + } + } + } + } + return true + } return true }) @@ -568,4 +610,4 @@ func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest // Note: Body not copied - for streaming, body was already consumed return req -} +} \ No newline at end of file diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 39e07cc201..031dadd73b 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -215,13 +215,15 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 7. Compat (if configured in PluginConfigs) - compatConfig := s.getPluginConfig(compat.PluginName) - if compatConfig != nil && compatConfig.Enabled { - s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatConfig.Config, false) - } else { - s.markPluginDisabled(compat.PluginName) + // 7. Compat (if any compat feature is enabled in ClientConfig) + cc := s.Config.ClientConfig.Compat + compatCfg := &compat.Config{ + ConvertTextToChat: cc.ConvertTextToChat, + ConvertChatToResponses: cc.ConvertChatToResponses, + ShouldDropParams: cc.ShouldDropParams, + ShouldConvertParams: cc.ShouldConvertParams, } + s.registerPluginWithStatus(ctx, compat.PluginName, nil, compatCfg, false) s.Config.SetPluginOrderInfo(compat.PluginName, builtinPlacement, schemas.Ptr(7)) // 8. Maxim (if configured in PluginConfigs) diff --git a/transports/config.schema.json b/transports/config.schema.json index 1e5070d76f..ebed8a4389 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -94,9 +94,29 @@ "minimum": 1, "description": "Maximum request body size in MB" }, - "enable_litellm_fallbacks": { - "type": "boolean", - "description": "Enable litellm-specific fallbacks for text completion for Groq" + "compat": { + "type": "object", + "description": "Compat plugin configuration for request type conversion, parameter dropping, and parameter value conversion", + "properties": { + "convert_text_to_chat": { + "type": "boolean", + "description": "Convert text completion requests to chat for models that only support chat" + }, + "convert_chat_to_responses": { + "type": "boolean", + "description": "Convert chat completion requests to responses for models that only support responses" + }, + "should_drop_params": { + "type": "boolean", + "description": "Drop unsupported parameters based on model catalog allowlist" + }, + "should_convert_params": { + "type": "boolean", + "description": "Converts model parameter values that are not supported by the model.", + "default": false + } + }, + "additionalProperties": false }, "header_filter_config": { "type": "object", @@ -4129,4 +4149,4 @@ "additionalProperties": false } } -} \ No newline at end of file +} diff --git a/transports/go.mod b/transports/go.mod index 3a40f0d309..76c6a7fb0a 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -15,13 +15,13 @@ require ( github.com/maximhq/bifrost/core v1.5.1 github.com/maximhq/bifrost/framework v1.3.1 github.com/maximhq/bifrost/plugins/compat v0.1.0 - github.com/maximhq/bifrost/plugins/governance v1.5.0 - github.com/maximhq/bifrost/plugins/logging v1.5.0 - github.com/maximhq/bifrost/plugins/maxim v1.6.0 - github.com/maximhq/bifrost/plugins/otel v1.2.0 + github.com/maximhq/bifrost/plugins/governance v1.5.1 + github.com/maximhq/bifrost/plugins/logging v1.5.1 + github.com/maximhq/bifrost/plugins/maxim v1.6.1 + github.com/maximhq/bifrost/plugins/otel v1.2.1 github.com/maximhq/bifrost/plugins/prompts v1.0.1 - github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 - github.com/maximhq/bifrost/plugins/telemetry v1.5.0 + github.com/maximhq/bifrost/plugins/semanticcache v1.5.1 + github.com/maximhq/bifrost/plugins/telemetry v1.5.1 github.com/pion/rtcp v1.2.16 github.com/pion/webrtc/v4 v4.2.9 github.com/prometheus/client_golang v1.23.2 @@ -185,5 +185,3 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/postgres v1.6.0 // indirect ) - -replace github.com/maximhq/bifrost/plugins/compat => ../plugins/compat diff --git a/transports/go.sum b/transports/go.sum index c8b2bde0ec..096621cfdb 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -26,6 +26,7 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= @@ -167,6 +168,7 @@ github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+K github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= @@ -215,22 +217,24 @@ github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtN github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= -github.com/maximhq/bifrost/plugins/governance v1.5.0 h1:cT+QiIKqJNKjl6/q0W3HTuZSeql0MHx3UWTyZPMLag4= -github.com/maximhq/bifrost/plugins/governance v1.5.0/go.mod h1:hjC5TmTdk4bES89zPUwBTwWWteHNtTV8WytdkPZUWd8= -github.com/maximhq/bifrost/plugins/logging v1.5.0 h1:uGrernx8gENT84L7fXyEpgvJZgORsGZvyq5B4PkSj80= -github.com/maximhq/bifrost/plugins/logging v1.5.0/go.mod h1:uxdMIVHUG7u5Wc5HQzXY13UlExc3lDumRgC8M+kTQiw= -github.com/maximhq/bifrost/plugins/maxim v1.6.0 h1:F23T1qcMczcuauGCYO5p9qeZOAc48FPjFdaSK9TmVeY= -github.com/maximhq/bifrost/plugins/maxim v1.6.0/go.mod h1:V/ccWAfBiW6kVXGWLe9tXKoTgFSh9sYgaJRrtEwFTso= +github.com/maximhq/bifrost/plugins/compat v0.1.0 h1:N9IVY4hmvQj/tCyppWu7zy41N8pyo0dZ+1W6Z+pQCKE= +github.com/maximhq/bifrost/plugins/compat v0.1.0/go.mod h1:PpVbCGimxQUiCHLzpHZRSjyNlSo+LgIbGzZFhtHcytI= +github.com/maximhq/bifrost/plugins/governance v1.5.1 h1:zc7TY5Xb4HsEqKfL7mdkIushgAbD1a0MSoQpjYFEhtY= +github.com/maximhq/bifrost/plugins/governance v1.5.1/go.mod h1:WosnY6eDKAufCZKJpNsqWiHt/fyZOx2THoDLzkqRTnM= +github.com/maximhq/bifrost/plugins/logging v1.5.1 h1:kNjmevWpt7nmsRyDmVTz8GPhnljtgCOtO52vjfTMvG8= +github.com/maximhq/bifrost/plugins/logging v1.5.1/go.mod h1:qcutU7X+Qt7zuNgT7m/zblLvMsI4/SAaoMwlDDBopvY= +github.com/maximhq/bifrost/plugins/maxim v1.6.1 h1:pwWflCaINS+6nPihSjezUpbCHdENqRFVSNiwiGzPyoI= +github.com/maximhq/bifrost/plugins/maxim v1.6.1/go.mod h1:t8xxjMGGqbXz2IRSYxQGvfKM27G2LlIAkWyFVIx8S54= github.com/maximhq/bifrost/plugins/mocker v1.5.1 h1:tXB8WPH9J7MURk45PNjx0hh9TeZzyBXqAYFaKUWdQtM= github.com/maximhq/bifrost/plugins/mocker v1.5.1/go.mod h1:qbjCfskG6jN23rtrLYmaxFBvA5CzOTJ67UIEuyFkO90= -github.com/maximhq/bifrost/plugins/otel v1.2.0 h1:+aJnWdryDlhza7wc4KETosX9j3Mdad5uUFBuwhslNsk= -github.com/maximhq/bifrost/plugins/otel v1.2.0/go.mod h1:BwNVvRuEgdPlSlDLzANpGy2RugWQjtHkEUoBiwT5MNI= +github.com/maximhq/bifrost/plugins/otel v1.2.1 h1:fSGOBTOMfsUzZ2Kk/C7CDkbxJ2JceUhrmtFlQ2S7xBs= +github.com/maximhq/bifrost/plugins/otel v1.2.1/go.mod h1:mw5DMoHxIms5L+QpSqN0ow97wM72CRsR4I0MAuFaBNM= github.com/maximhq/bifrost/plugins/prompts v1.0.1 h1:JpM+uVkYmNLWEvg/hT8HN2Wpzax6TUsM/mdIyYzkx00= github.com/maximhq/bifrost/plugins/prompts v1.0.1/go.mod h1:379vljFVED/0L+odEmYQaaYDY/HFy4smb8tpXXCeBvA= -github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 h1:tibnQ8lSnKXujnjL4mt84P/5Vxj9e9wbhvh1Tjr68JA= -github.com/maximhq/bifrost/plugins/semanticcache v1.5.0/go.mod h1:+NfIRAlHpuh5ORv0MoOf5f8uY4WPx6v/8Kuk+8FEGnw= -github.com/maximhq/bifrost/plugins/telemetry v1.5.0 h1:hECZgcsqeJSmiLrWONTFFU6APzTyILQzZuVV96oql5Q= -github.com/maximhq/bifrost/plugins/telemetry v1.5.0/go.mod h1:dl/4mtQhxooqU+r42hXajhUaq04S1X3LaH+km5UJAy0= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.1 h1:rkXataDvgnE3HlkXCtraYVadeLHLWImtbuajhpUIOyU= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.1/go.mod h1:YSjXwYxO0UvRWKnwqp9SdlgFjAajaMfzpjbtSNTnqnY= +github.com/maximhq/bifrost/plugins/telemetry v1.5.1 h1:bZC/MdVDr3zmvi686tqrQMCzDVPvwqxXScVSk404NqY= +github.com/maximhq/bifrost/plugins/telemetry v1.5.1/go.mod h1:t1DiP/jrfV9oGmpp/Jy1mb/5YYHSvOgGAQR2055xsHI= github.com/maximhq/maxim-go v0.2.1 h1:hCp8dQ4HsyyNC+y5HCUuY/HFD0sOnGkjL5MdYCHkgEQ= github.com/maximhq/maxim-go v0.2.1/go.mod h1:nwFznXy0Dn4mxXGU4X+BCnE3VP68L+FPEaW0yUgk96o= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -357,13 +361,21 @@ go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAc go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 h1:w1K+pCJoPpQifuVpsKamUdn9U0zM3xUziVOqsGksUrY= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0/go.mod h1:HBy4BjzgVE8139ieRI75oXm3EcDN+6GhD88JT1Kjvxg= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= @@ -397,9 +409,13 @@ golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/ui/app/workspace/config/compatibility/page.tsx b/ui/app/workspace/config/compatibility/page.tsx new file mode 100644 index 0000000000..d8193f35a7 --- /dev/null +++ b/ui/app/workspace/config/compatibility/page.tsx @@ -0,0 +1,11 @@ +"use client"; + +import CompatibilityView from "../views/compatibilityView"; + +export default function CompatibilityPage() { + return ( +
+ +
+ ); +} diff --git a/ui/app/workspace/config/views/clientSettingsView.tsx b/ui/app/workspace/config/views/clientSettingsView.tsx index 1550b6566c..0ae5b736ca 100644 --- a/ui/app/workspace/config/views/clientSettingsView.tsx +++ b/ui/app/workspace/config/views/clientSettingsView.tsx @@ -107,7 +107,6 @@ export default function ClientSettingsView() { if (!config) return false; return ( localConfig.drop_excess_requests !== config.drop_excess_requests || - localConfig.enable_litellm_fallbacks !== config.enable_litellm_fallbacks || localConfig.disable_db_pings_in_health !== config.disable_db_pings_in_health || localConfig.async_job_result_ttl !== config.async_job_result_ttl || !headerFilterConfigEqual(localConfig.header_filter_config, config.header_filter_config) @@ -320,34 +319,6 @@ export default function ClientSettingsView() { /> - {/* Enable LiteLLM Fallbacks */} -
-
- -

- Enable litellm-specific fallbacks.{" "} - - Learn more - -

-
- handleConfigChange("enable_litellm_fallbacks", checked)} - disabled={!hasSettingsUpdateAccess} - /> -
- {/* Disable DB Pings in Health */}
@@ -438,9 +409,8 @@ export default function ClientSettingsView() {
  • Wildcards: Use{" "} - * at the end of a pattern to match - prefixes (e.g.,{" "} - anthropic-* matches all headers starting + * at the end of a pattern to match prefixes + (e.g., anthropic-* matches all headers starting with anthropic-). Use{" "} * alone to match all headers.
  • diff --git a/ui/app/workspace/config/views/compatibilityView.tsx b/ui/app/workspace/config/views/compatibilityView.tsx new file mode 100644 index 0000000000..50d4273c97 --- /dev/null +++ b/ui/app/workspace/config/views/compatibilityView.tsx @@ -0,0 +1,158 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Switch } from "@/components/ui/switch"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CompatConfig, DefaultCoreConfig } from "@/lib/types/config"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import Link from "next/link"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +export default function CompatibilityView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config?.compat; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localCompatConfig, setLocalCompatConfig] = useState(DefaultCoreConfig.compat); + + useEffect(() => { + if (config) { + setLocalCompatConfig(config); + return; + } + setLocalCompatConfig(DefaultCoreConfig.compat); + }, [config]); + + const hasChanges = useMemo(() => { + const baseline = config ?? DefaultCoreConfig.compat; + return ( + localCompatConfig.convert_text_to_chat !== baseline.convert_text_to_chat || + localCompatConfig.convert_chat_to_responses !== baseline.convert_chat_to_responses || + localCompatConfig.should_drop_params !== baseline.should_drop_params || + localCompatConfig.should_convert_params !== baseline.should_convert_params + ); + }, [config, localCompatConfig]); + + const handleCompatChange = useCallback((field: keyof CompatConfig, value: boolean) => { + setLocalCompatConfig((prev) => ({ ...prev, [field]: value })); + }, []); + + const handleSave = useCallback(async () => { + if (!bifrostConfig) { + toast.error("Configuration not loaded"); + return; + } + + try { + await updateCoreConfig({ + ...bifrostConfig, + client_config: { + ...(bifrostConfig.client_config ?? DefaultCoreConfig), + compat: localCompatConfig, + }, + }).unwrap(); + toast.success("Compatibility settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localCompatConfig, updateCoreConfig]); + + return ( +
    +
    +

    Compatibility

    +

    + Configure request conversions and compatibility fallbacks.{" "} + + Learn more + +

    +
    + +
    +
    +
    + +

    Convert text completion requests to chat for models that only support chat.

    +
    + handleCompatChange("convert_text_to_chat", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    + Convert chat completion requests to responses for models that only support responses. +

    +
    + handleCompatChange("convert_chat_to_responses", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    Drop unsupported parameters based on model catalog allowlist.

    +
    + handleCompatChange("should_drop_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    + +
    +
    + +

    Converts model parameter values that are not supported by the model.

    +
    + handleCompatChange("should_convert_params", checked)} + disabled={!hasSettingsUpdateAccess} + /> +
    +
    + +
    + +
    +
    + ); +} \ No newline at end of file diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index f331a0d740..a95264ffe9 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -23,6 +23,7 @@ import { Logs, Network, PanelLeftClose, + Plug, Puzzle, Router, ScrollText, @@ -720,6 +721,13 @@ export default function AppSidebar() { description: "Client configuration settings", hasAccess: hasSettingsAccess, }, + { + title: "Compatibility", + url: "/workspace/config/compatibility", + icon: Plug, + description: "Compatibility conversion settings", + hasAccess: hasSettingsAccess, + }, { title: "Caching", url: "/workspace/config/caching", @@ -1251,4 +1259,4 @@ export default function AppSidebar() { ); -} +} \ No newline at end of file diff --git a/ui/components/ui/accordion.tsx b/ui/components/ui/accordion.tsx index 2978b5e54f..6ee186762e 100644 --- a/ui/components/ui/accordion.tsx +++ b/ui/components/ui/accordion.tsx @@ -26,7 +26,7 @@ function AccordionTrigger({ className, children, ...props }: React.ComponentProp {...props} > {children} - + ); @@ -44,4 +44,4 @@ function AccordionContent({ className, children, ...props }: React.ComponentProp ); } -export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; +export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; \ No newline at end of file diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index b3929e54b9..882585ded9 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -454,6 +454,13 @@ export interface BifrostConfig { auth_token?: string; } +export interface CompatConfig { + convert_text_to_chat: boolean; + convert_chat_to_responses: boolean; + should_drop_params: boolean; + should_convert_params: boolean; +} + // Core Bifrost configuration types export interface CoreConfig { drop_excess_requests: boolean; @@ -468,7 +475,7 @@ export interface CoreConfig { allowed_origins: string[]; allowed_headers: string[]; max_request_body_size_mb: number; - enable_litellm_fallbacks: boolean; + compat: CompatConfig; mcp_agent_depth: number; mcp_tool_execution_timeout: number; mcp_code_mode_binding_level?: string; @@ -495,7 +502,7 @@ export const DefaultCoreConfig: CoreConfig = { allow_direct_keys: false, allowed_origins: [], max_request_body_size_mb: 100, - enable_litellm_fallbacks: false, + compat: { convert_text_to_chat: false, convert_chat_to_responses: false, should_drop_params: false, should_convert_params: false }, mcp_agent_depth: 10, mcp_tool_execution_timeout: 30, mcp_code_mode_binding_level: "server",