From 340185f236ec263f9c13cb5ffe4c40f8cdc4cff3 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Mon, 27 Apr 2026 02:02:21 +0530 Subject: [PATCH] feat: add default provider selection in integration paths --- core/providers/utils/utils.go | 3 +- .../handlers/webrtc_realtime_test.go | 6 +- .../bifrost-http/handlers/wsresponses_test.go | 2 +- .../bifrost-http/integrations/anthropic.go | 15 +++++ .../bifrost-http/integrations/bedrock.go | 22 +++++++ .../bifrost-http/integrations/bedrock_test.go | 2 +- .../bifrost-http/integrations/cohere.go | 20 ++++++ transports/bifrost-http/integrations/genai.go | 60 +++++++++++++++--- .../bifrost-http/integrations/openai.go | 59 +++++++++++++++--- .../bifrost-http/integrations/router.go | 61 +++++++++++++++++-- transports/bifrost-http/lib/config.go | 29 +++++---- transports/bifrost-http/lib/ctx_test.go | 22 ++++--- transports/changelog.md | 2 +- 13 files changed, 252 insertions(+), 51 deletions(-) diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index fc66309189..e96b3061c0 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -2787,7 +2787,8 @@ func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider sch if slices.Contains(availableProviders, defaultProvider) { return defaultProvider } - return "" + // Return the first available provider + return availableProviders[0] } return defaultProvider } diff --git a/transports/bifrost-http/handlers/webrtc_realtime_test.go b/transports/bifrost-http/handlers/webrtc_realtime_test.go index f5b2143749..f0eef3fc78 100644 --- a/transports/bifrost-http/handlers/webrtc_realtime_test.go +++ b/transports/bifrost-http/handlers/webrtc_realtime_test.go @@ -18,9 +18,9 @@ type testHandlerStore struct { kv *kvstore.Store } -func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } -func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } -func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil } +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } +func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } +func (s testHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return nil } func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { return nil } diff --git a/transports/bifrost-http/handlers/wsresponses_test.go b/transports/bifrost-http/handlers/wsresponses_test.go index c87f0b889b..9b8fd78e76 100644 --- a/transports/bifrost-http/handlers/wsresponses_test.go +++ b/transports/bifrost-http/handlers/wsresponses_test.go @@ -26,7 +26,7 @@ func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } -func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider { +func (s testWSHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return nil } diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index e21e54d2f5..eb5dfdc450 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -23,6 +23,18 @@ type AnthropicRouter struct { *GenericRouter } +// anthropicModelGetter extracts the model field from any Anthropic integration request type. +// It is called after body parsing, so req is fully populated. +func anthropicModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *anthropic.AnthropicTextRequest: + return r.Model, nil + case *anthropic.AnthropicMessageRequest: + return r.Model, nil + } + return "", nil +} + // createAnthropicCompleteRouteConfig creates a route configuration for the `/v1/complete` endpoint. func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { return RouteConfig{ @@ -35,6 +47,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicTextRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicTextRequest); ok { return &schemas.BifrostRequest{ @@ -75,6 +88,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicMessageRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx) @@ -394,6 +408,7 @@ func CreateAnthropicCountTokensRouteConfigs(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func(ctx context.Context) interface{} { return &anthropic.AnthropicMessageRequest{} }, + GetRequestModel: anthropicModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx) diff --git a/transports/bifrost-http/integrations/bedrock.go b/transports/bifrost-http/integrations/bedrock.go index 14bc4d9911..14c2bd826d 100644 --- a/transports/bifrost-http/integrations/bedrock.go +++ b/transports/bifrost-http/integrations/bedrock.go @@ -21,6 +21,23 @@ type BedrockRouter struct { *GenericRouter } +// bedrockModelGetter extracts the model ID from any Bedrock integration request type. +// It is called after PreCallback, so req.ModelID is populated from the URL path param. +func bedrockModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *bedrock.BedrockConverseRequest: + return r.ModelID, nil + case *bedrock.BedrockInvokeRequest: + return r.ModelID, nil + case *bedrock.BedrockCountTokensRequest: + if r.Input.Converse != nil { + return r.Input.Converse.ModelID, nil + } + return "", nil + } + return "", nil +} + // S3 context keys for storing request parameters const ( @@ -42,6 +59,7 @@ func createBedrockConverseRouteConfig(pathPrefix string, handlerStore lib.Handle GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { return schemas.ResponsesRequest }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { bifrostReq, err := bedrockReq.ToBifrostResponsesRequest(ctx) @@ -77,6 +95,7 @@ func createBedrockConverseStreamRouteConfig(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockConverseRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { // Mark as streaming request @@ -127,6 +146,7 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if invokeReq, ok := req.(*bedrock.BedrockInvokeRequest); ok { requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType) @@ -201,6 +221,7 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { invokeReq, ok := req.(*bedrock.BedrockInvokeRequest) if !ok { @@ -317,6 +338,7 @@ func createBedrockCountTokensRouteConfig(pathPrefix string, handlerStore lib.Han GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { return schemas.CountTokensRequest }, + GetRequestModel: bedrockModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if countTokensReq, ok := req.(*bedrock.BedrockCountTokensRequest); ok { if countTokensReq.Input.Converse == nil { diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index 0c1af6583a..8c01edaa83 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -30,7 +30,7 @@ func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return m.headerMatcher } -func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider { +func (m *mockHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return m.availableProviders } diff --git a/transports/bifrost-http/integrations/cohere.go b/transports/bifrost-http/integrations/cohere.go index 37aad1c1a8..cf6b7ceaca 100644 --- a/transports/bifrost-http/integrations/cohere.go +++ b/transports/bifrost-http/integrations/cohere.go @@ -69,6 +69,22 @@ func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, log } } +// cohereModelGetter extracts the model field from any Cohere integration request type. +// It is called after body parsing, so req is fully populated. +func cohereModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *cohere.CohereChatRequest: + return r.Model, nil + case *cohere.CohereEmbeddingRequest: + return r.Model, nil + case *cohere.CohereRerankRequest: + return r.Model, nil + case *cohere.CohereCountTokensRequest: + return r.Model, nil + } + return "", nil +} + // CreateCohereRouteConfigs creates route configurations for Cohere API endpoints. func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { var routes []RouteConfig @@ -85,6 +101,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereChatRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereChatRequest); ok { return &schemas.BifrostRequest{ @@ -131,6 +148,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereEmbeddingRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok { return &schemas.BifrostRequest{ @@ -164,6 +182,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereRerankRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok { return &schemas.BifrostRequest{ @@ -197,6 +216,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &cohere.CohereCountTokensRequest{} }, + GetRequestModel: cohereModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok { return &schemas.BifrostRequest{ diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 012a5e078a..88229393fe 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -35,6 +35,33 @@ type GenAIRouter struct { *GenericRouter } +// genAIModelGetter extracts the model name for GenAI routes. +// For request types populated by extractAndSetModelAndRequestType (the PreCallback), +// the model is already clean on the struct. For BifrostVideoRetrieveRequest (which has +// no model field), the provider-scoped model is extracted from the operation_id suffix +// (format: "op123:openai/gpt-4o") since the route pins the provider via operation_id. +func genAIModelGetter(ctx *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *gemini.GeminiGenerationRequest: + return r.Model, nil + case *gemini.GeminiEmbeddingRequest: + return r.Model, nil + case *gemini.GeminiVideoGenerationRequest: + return r.Model, nil + case *gemini.GeminiBatchCreateRequest: + return r.Model, nil + case *schemas.BifrostVideoRetrieveRequest: + // operation_id encodes the full model string: "op123:gpt-4o" or "op123:openai/gpt-4o". + operationID, _ := ctx.UserValue("operation_id").(string) + parts := strings.Split(operationID, ":") + if len(parts) >= 2 && parts[len(parts)-1] != "" { + return parts[len(parts)-1], nil + } + return "", nil + } + return "", nil +} + // CreateGenAIRouteConfigs creates a route configurations for GenAI endpoints. func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { var routes []RouteConfig @@ -51,6 +78,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &schemas.BifrostVideoRetrieveRequest{} }, + GetRequestModel: genAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if videoRetrieveReq, ok := req.(*schemas.BifrostVideoRetrieveRequest); ok { return &schemas.BifrostRequest{ @@ -89,6 +117,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { } return &gemini.GeminiGenerationRequest{} }, + GetRequestModel: genAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if geminiReq, ok := req.(*gemini.GeminiGenerationRequest); ok { if geminiReq.IsCountTokens { @@ -790,6 +819,12 @@ func createGenAIRerankRouteConfig(pathPrefix string) RouteConfig { GetRequestTypeInstance: func(ctx context.Context) interface{} { return &vertex.VertexRankRequest{} }, + GetRequestModel: func(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + if r, ok := req.(*vertex.VertexRankRequest); ok && r.Model != nil { + return *r.Model, nil + } + return "", nil + }, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if vertexReq, ok := req.(*vertex.VertexRankRequest); ok { return &schemas.BifrostRequest{ @@ -1282,31 +1317,38 @@ func extractGeminiVideoOperationFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *s return errors.New("operation_id must be a non-empty string") } - // check provider from operation id suffix, id:provider, could be any provider + // operation_id encodes the raw model string as a suffix: "id:rawModel" + // rawModel is either "gpt-4o" (provider name or bare model) or "openai/gpt-4o" (provider/model). parts := strings.Split(operationIDStr, ":") if len(parts) < 2 || parts[len(parts)-1] == "" { - return errors.New("provider is required in operation_id format 'id:provider'") + return errors.New("raw model is required in operation_id format 'id:rawModel' or 'id:provider/model'") + } + rawModel := parts[len(parts)-1] + + // Parse provider from rawModel: "openai/gpt-4o" → provider="openai"; "gemini" → provider="gemini". + var provider schemas.ModelProvider + rawModelParts := strings.SplitN(rawModel, "/", 2) + if len(rawModelParts) == 2 { + provider = schemas.ModelProvider(rawModelParts[0]) + } else { + provider = schemas.ModelProvider(rawModel) } - provider := parts[len(parts)-1] modelStr, ok := model.(string) if !ok || modelStr == "" { - modelStr = provider + modelStr = rawModel } - // if its gemini, set r.ID in format models/model/operations/operation_id:provider - // else set r.ID in format operation_id:provider - switch r := req.(type) { case *schemas.BifrostVideoRetrieveRequest: - r.Provider = schemas.ModelProvider(provider) + r.Provider = provider if r.Provider == schemas.OpenAI || r.Provider == schemas.Azure { // set a context flag to have video download request after video retrieve request when incoming request is coming from genai integration bifrostCtx.SetValue(schemas.BifrostContextKeyVideoOutputRequested, true) } // Gemini provider expects an operation resource path (without /v1beta prefix). - if provider == string(schemas.Gemini) { + if provider == schemas.Gemini { r.ID = "models/" + modelStr + "/operations/" + operationIDStr } else { r.ID = operationIDStr diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 24c987f658..262119f925 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -135,6 +135,10 @@ func hydrateOpenAIRequestFromLargePayloadMetadata(ctx *fasthttp.RequestCtx, bifr if r.Model == "" { r.Model = metadata.Model } + case *openai.OpenAIVideoGenerationRequest: + if r.Model == "" { + r.Model = metadata.Model + } } } @@ -300,14 +304,43 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ } } +// openAIModelGetter extracts the model field from any OpenAI integration request type. +// It is called after body parsing and PreCallback, so req is fully populated. +func openAIModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) { + switch r := req.(type) { + case *openai.OpenAIChatRequest: + return r.Model, nil + case *openai.OpenAITextCompletionRequest: + return r.Model, nil + case *openai.OpenAIEmbeddingRequest: + return r.Model, nil + case *openai.OpenAIResponsesRequest: + return r.Model, nil + case *openai.OpenAISpeechRequest: + return r.Model, nil + case *openai.OpenAITranscriptionRequest: + return r.Model, nil + case *openai.OpenAIImageGenerationRequest: + return r.Model, nil + case *openai.OpenAIImageEditRequest: + return r.Model, nil + case *openai.OpenAIImageVariationRequest: + return r.Model, nil + case *openai.OpenAIVideoGenerationRequest: + return r.Model, nil + } + return "", nil +} + // CreateOpenAIRouteConfigs creates route configurations for OpenAI endpoints. func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []RouteConfig { var routes []RouteConfig routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, - Path: pathPrefix + "/openai/deployments/{deploymentPath:*}", - Method: "POST", + Type: RouteConfigTypeOpenAI, + Path: pathPrefix + "/openai/deployments/{deploymentPath:*}", + Method: "POST", + GetRequestModel: openAIModelGetter, GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { deploymentPathVal, ok := ctx.UserValue("deploymentPath").(string) if !ok { @@ -543,6 +576,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIChatRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIChatRequest); ok { br := &schemas.BifrostRequest{ @@ -639,6 +673,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAITextCompletionRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAITextCompletionRequest); ok { return &schemas.BifrostRequest{ @@ -690,6 +725,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIResponsesRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ @@ -772,6 +808,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIResponsesRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ @@ -810,6 +847,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIEmbeddingRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if embeddingReq, ok := req.(*openai.OpenAIEmbeddingRequest); ok { return &schemas.BifrostRequest{ @@ -848,6 +886,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAISpeechRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if speechReq, ok := req.(*openai.OpenAISpeechRequest); ok { return &schemas.BifrostRequest{ @@ -891,7 +930,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAITranscriptionRequest{} }, - RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing + GetRequestModel: openAIModelGetter, + RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if transcriptionReq, ok := req.(*openai.OpenAITranscriptionRequest); ok { return &schemas.BifrostRequest{ @@ -946,6 +986,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageGenerationRequest{} }, + GetRequestModel: openAIModelGetter, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageGenReq, ok := req.(*openai.OpenAIImageGenerationRequest); ok { return &schemas.BifrostRequest{ @@ -996,7 +1037,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageEditRequest{} }, - RequestParser: parseOpenAIImageEditMultipartRequest, // Handle multipart form parsing + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIImageEditMultipartRequest, // Handle multipart form parsing RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageEditReq, ok := req.(*openai.OpenAIImageEditRequest); ok { return &schemas.BifrostRequest{ @@ -1046,7 +1088,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIImageVariationRequest{} }, - RequestParser: parseOpenAIImageVariationMultipartRequest, + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIImageVariationMultipartRequest, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if imageVariationReq, ok := req.(*openai.OpenAIImageVariationRequest); ok { return &schemas.BifrostRequest{ @@ -1098,7 +1141,8 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func(ctx context.Context) interface{} { return &openai.OpenAIVideoGenerationRequest{} }, - RequestParser: parseOpenAIVideoGenerationMultipartRequest, + GetRequestModel: openAIModelGetter, + RequestParser: parseOpenAIVideoGenerationMultipartRequest, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if videoGenerationReq, ok := req.(*openai.OpenAIVideoGenerationRequest); ok { return &schemas.BifrostRequest{ @@ -1114,6 +1158,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return err }, PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { + hydrateOpenAIRequestFromLargePayloadMetadata(ctx, bifrostCtx, req) if isAzureSDKRequest(ctx) { bifrostCtx.SetValue(schemas.BifrostContextKeyIsAzureUserAgent, true) } diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 950574471e..5eb9276b2f 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -55,6 +55,7 @@ import ( "io" "mime" "mime/multipart" + "slices" "strconv" "strings" @@ -354,6 +355,10 @@ type PostRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}, resp in // returns a schemas.RequestType indicating the HTTP request type derived from the context. type HTTPRequestTypeGetter func(ctx *fasthttp.RequestCtx) schemas.RequestType +// RequestModelGetter is a function type that accepts only a *fasthttp.RequestCtx and +// returns a string indicating the model derived from the context. +type RequestModelGetter func(ctx *fasthttp.RequestCtx, req interface{}) (string, error) + // ShortCircuit is a function that determines if the request should be short-circuited. type ShortCircuit func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) (bool, error) @@ -397,6 +402,14 @@ const ( RouteConfigTypeCohere RouteConfigType = "cohere" ) +var RouteConfigTypeToProvider = map[RouteConfigType]schemas.ModelProvider{ + RouteConfigTypeOpenAI: schemas.OpenAI, + RouteConfigTypeAnthropic: schemas.Anthropic, + RouteConfigTypeGenAI: schemas.Gemini, + RouteConfigTypeBedrock: schemas.Bedrock, + RouteConfigTypeCohere: schemas.Cohere, +} + // RouteConfig defines the configuration for a single route in an integration. // It specifies the path, method, and handlers for request/response conversion. type RouteConfig struct { @@ -404,6 +417,7 @@ type RouteConfig struct { Path string // HTTP path pattern (e.g., "/openai/v1/chat/completions") Method string // HTTP method (POST, GET, PUT, DELETE) GetHTTPRequestType HTTPRequestTypeGetter // Function to get the HTTP request type from the context (SHOULD NOT BE NIL) + GetRequestModel RequestModelGetter // Function to get the model from the context (SHOULD NOT BE NIL) GetRequestTypeInstance func(ctx context.Context) interface{} // Factory function to create request instance (SHOULD NOT BE NIL) RequestParser RequestParser // Optional: custom request parsing (e.g., multipart/form-data) RequestConverter RequestConverter // Function to convert request to BifrostRequest (for inference requests) @@ -621,10 +635,6 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Set integration type to context bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, string(config.Type)) - // Set available providers to context - availableProviders := g.handlerStore.GetAvailableProviders() - bifrostCtx.SetValue(schemas.BifrostContextKeyAvailableProviders, availableProviders) - // Async retrieve: check x-bf-async-id header early (before body parsing) if asyncID := string(ctx.Request.Header.Peek(schemas.AsyncHeaderGetID)); asyncID != "" { defer cancel() @@ -725,6 +735,44 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } } + // Set available providers to context + if config.GetRequestModel != nil { + model, err := config.GetRequestModel(ctx, req) + if err != nil { + cancel() + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to get model from context")) + return + } + extractedProvider, extractedModel := schemas.ParseModelString(model, "") + if extractedProvider == "" { + availableProviders := g.handlerStore.GetAvailableProviders(extractedModel) + availableProvidersStrs := make([]string, len(availableProviders)) + for i, p := range availableProviders { + availableProvidersStrs[i] = string(p) + } + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "No provider specified for model %s, found %d options in model catalog: [%s]", + extractedModel, len(availableProviders), strings.Join(availableProvidersStrs, ", "), + )) + if len(availableProviders) > 0 { + if slices.Contains(availableProviders, RouteConfigTypeToProvider[config.Type]) { + availableProviders = []schemas.ModelProvider{RouteConfigTypeToProvider[config.Type]} + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "Integration route default provider %s is found in the available providers list, selecting it", + RouteConfigTypeToProvider[config.Type], + )) + } else { + bifrostCtx.AppendRoutingEngineLog(schemas.RoutingEngineModelCatalog, fmt.Sprintf( + "Integration route default provider %s is not found in the available providers list, selecting first: %s", + RouteConfigTypeToProvider[config.Type], availableProviders[0], + )) + } + bifrostCtx.SetValue(schemas.BifrostContextKeyAvailableProviders, availableProviders) + } + schemas.AppendToContextList(bifrostCtx, schemas.BifrostContextKeyRoutingEnginesUsed, schemas.RoutingEngineModelCatalog) + } + } + // Handle batch requests if BatchRequestConverter is set // GenAI has two cases: (1) Dedicated batch routes (list/retrieve) have only BatchRequestConverter — always use batch path. // (2) The models path has both BatchRequestConverter and RequestConverter — use batch path only for batch create. @@ -795,10 +843,12 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Convert the integration-specific request to Bifrost format (inference requests) bifrostReq, err := config.RequestConverter(bifrostCtx, req) if err != nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to convert request to Bifrost format")) return } if bifrostReq == nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "invalid request")) return } @@ -808,6 +858,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle // Extract and parse fallbacks from the request if present if err := g.extractAndParseFallbacks(req, bifrostReq); err != nil { + defer cancel() g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to parse fallbacks: "+err.Error())) return } @@ -842,9 +893,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // streaming requests (where we actively detect write errors), but still provides a mechanism // for providers to respect cancellation. var response interface{} - var err error - var providerResponseHeaders map[string]string switch { diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 3d231e6675..aa845ab0e1 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -64,8 +64,8 @@ type HandlerStore interface { ShouldAllowDirectKeys() bool // GetHeaderMatcher returns the precompiled header matcher for header filtering GetHeaderMatcher() *HeaderMatcher - // GetAvailableProviders returns the list of available providers - GetAvailableProviders() []schemas.ModelProvider + // GetAvailableProviders returns the list of available providers for the given model + GetAvailableProviders(model string) []schemas.ModelProvider // GetStreamChunkInterceptor returns the interceptor for streaming chunks. // Returns nil if no plugins are loaded or streaming interception is not needed. GetStreamChunkInterceptor() StreamChunkInterceptor @@ -4933,20 +4933,25 @@ func (c *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstoreTa return nil } -func (c *Config) GetAvailableProviders() []schemas.ModelProvider { +func (c *Config) GetAvailableProviders(model string) []schemas.ModelProvider { c.Mu.RLock() defer c.Mu.RUnlock() availableProviders := []schemas.ModelProvider{} - for provider, config := range c.Providers { - // Check if the provider has at least one key with a non-empty value. If so, add the provider to the list. - // If the provider allows empty keys, add the provider to the list. - for _, key := range config.Keys { - if key.Value.GetValue() != "" || bifrost.CanProviderKeyValueBeEmpty(provider) { - if key.Enabled != nil && !*key.Enabled { - continue + if c.ModelCatalog != nil { + availableProviders = c.ModelCatalog.GetProvidersForModel(model) + } else { + // Return all providers that have at least one key with a non-empty value. + for provider, config := range c.Providers { + // Check if the provider has at least one key with a non-empty value. If so, add the provider to the list. + // If the provider allows empty keys, add the provider to the list. + for _, key := range config.Keys { + if key.Value.GetValue() != "" || bifrost.CanProviderKeyValueBeEmpty(provider) { + if key.Enabled != nil && !*key.Enabled { + continue + } + availableProviders = append(availableProviders, provider) + break } - availableProviders = append(availableProviders, provider) - break } } } diff --git a/transports/bifrost-http/lib/ctx_test.go b/transports/bifrost-http/lib/ctx_test.go index a90c2fff69..924196249c 100644 --- a/transports/bifrost-http/lib/ctx_test.go +++ b/transports/bifrost-http/lib/ctx_test.go @@ -18,16 +18,18 @@ type testHandlerStore struct { matcher *HeaderMatcher } -func (s testHandlerStore) ShouldAllowDirectKeys() bool { return s.allowDirectKeys } -func (s testHandlerStore) GetHeaderMatcher() *HeaderMatcher { return s.matcher } -func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil } -func (s testHandlerStore) GetStreamChunkInterceptor() StreamChunkInterceptor { return nil } -func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil } -func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } -func (s testHandlerStore) GetKVStore() *kvstore.Store { return nil } -func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return schemas.WhiteList{} } -func (s testHandlerStore) ShouldAllowPerRequestStorageOverride() bool { return false } -func (s testHandlerStore) ShouldAllowPerRequestRawOverride() bool { return false } +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return s.allowDirectKeys } +func (s testHandlerStore) GetHeaderMatcher() *HeaderMatcher { return s.matcher } +func (s testHandlerStore) GetAvailableProviders(_ string) []schemas.ModelProvider { return nil } +func (s testHandlerStore) GetStreamChunkInterceptor() StreamChunkInterceptor { return nil } +func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil } +func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } +func (s testHandlerStore) GetKVStore() *kvstore.Store { return nil } +func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return schemas.WhiteList{} +} +func (s testHandlerStore) ShouldAllowPerRequestStorageOverride() bool { return false } +func (s testHandlerStore) ShouldAllowPerRequestRawOverride() bool { return false } func TestParseSessionIDFromBaggage(t *testing.T) { tests := []struct { diff --git a/transports/changelog.md b/transports/changelog.md index 624a35f66e..960d3d740b 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,3 +1,3 @@ - fix: response extra fields request type corruption for streaming requests on high concurrency - feat: added support for per-request content logging toggle via `x-bf-disable-content-logging` header -- feat: auto-resolve provider via model catalog when model string has no provider prefix; adds `model-catalog` routing engine with selection log +- feat: auto-resolve provider when model string has no provider prefix