Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/providers/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2787,7 +2787,8 @@ func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider sch
if slices.Contains(availableProviders, defaultProvider) {
return defaultProvider
}
return ""
// Return the first available provider
return availableProviders[0]
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}
return defaultProvider
}
Expand Down
6 changes: 3 additions & 3 deletions transports/bifrost-http/handlers/webrtc_realtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ type testHandlerStore struct {
kv *kvstore.Store
}

func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true }
func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil }
func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil }
func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true }
func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil }
func (s testHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider { return nil }
func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion transports/bifrost-http/handlers/wsresponses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
return nil
}

func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
func (s testWSHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider {
return nil
}

Expand Down
15 changes: 15 additions & 0 deletions transports/bifrost-http/integrations/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ type AnthropicRouter struct {
*GenericRouter
}

// anthropicModelGetter extracts the model field from any Anthropic integration request type.
// It is called after body parsing, so req is fully populated.
func anthropicModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) {
switch r := req.(type) {
case *anthropic.AnthropicTextRequest:
return r.Model, nil
case *anthropic.AnthropicMessageRequest:
return r.Model, nil
}
return "", nil
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

// createAnthropicCompleteRouteConfig creates a route configuration for the `/v1/complete` endpoint.
func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig {
return RouteConfig{
Expand All @@ -35,6 +47,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &anthropic.AnthropicTextRequest{}
},
GetRequestModel: anthropicModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if anthropicReq, ok := req.(*anthropic.AnthropicTextRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -75,6 +88,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &anthropic.AnthropicMessageRequest{}
},
GetRequestModel: anthropicModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok {
bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx)
Expand Down Expand Up @@ -394,6 +408,7 @@ func CreateAnthropicCountTokensRouteConfigs(pathPrefix string, handlerStore lib.
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &anthropic.AnthropicMessageRequest{}
},
GetRequestModel: anthropicModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok {
bifrostReq := anthropicReq.ToBifrostResponsesRequest(ctx)
Expand Down
22 changes: 22 additions & 0 deletions transports/bifrost-http/integrations/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ type BedrockRouter struct {
*GenericRouter
}

// bedrockModelGetter extracts the model ID from any Bedrock integration request type.
// It is called after PreCallback, so req.ModelID is populated from the URL path param.
func bedrockModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) {
switch r := req.(type) {
case *bedrock.BedrockConverseRequest:
return r.ModelID, nil
case *bedrock.BedrockInvokeRequest:
return r.ModelID, nil
case *bedrock.BedrockCountTokensRequest:
if r.Input.Converse != nil {
return r.Input.Converse.ModelID, nil
}
return "", nil
}
return "", nil
}

// S3 context keys for storing request parameters

const (
Expand All @@ -42,6 +59,7 @@ func createBedrockConverseRouteConfig(pathPrefix string, handlerStore lib.Handle
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.ResponsesRequest
},
GetRequestModel: bedrockModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok {
bifrostReq, err := bedrockReq.ToBifrostResponsesRequest(ctx)
Expand Down Expand Up @@ -77,6 +95,7 @@ func createBedrockConverseStreamRouteConfig(pathPrefix string, handlerStore lib.
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &bedrock.BedrockConverseRequest{}
},
GetRequestModel: bedrockModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok {
// Mark as streaming request
Expand Down Expand Up @@ -127,6 +146,7 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &bedrock.BedrockInvokeRequest{}
},
GetRequestModel: bedrockModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if invokeReq, ok := req.(*bedrock.BedrockInvokeRequest); ok {
requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType)
Expand Down Expand Up @@ -201,6 +221,7 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &bedrock.BedrockInvokeRequest{}
},
GetRequestModel: bedrockModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
invokeReq, ok := req.(*bedrock.BedrockInvokeRequest)
if !ok {
Expand Down Expand Up @@ -317,6 +338,7 @@ func createBedrockCountTokensRouteConfig(pathPrefix string, handlerStore lib.Han
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.CountTokensRequest
},
GetRequestModel: bedrockModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if countTokensReq, ok := req.(*bedrock.BedrockCountTokensRequest); ok {
if countTokensReq.Input.Converse == nil {
Expand Down
2 changes: 1 addition & 1 deletion transports/bifrost-http/integrations/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
return m.headerMatcher
}

func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
func (m *mockHandlerStore) GetAvailableProviders(model string) []schemas.ModelProvider {
return m.availableProviders
}

Expand Down
20 changes: 20 additions & 0 deletions transports/bifrost-http/integrations/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, log
}
}

// cohereModelGetter extracts the model field from any Cohere integration request type.
// It is called after body parsing, so req is fully populated.
func cohereModelGetter(_ *fasthttp.RequestCtx, req interface{}) (string, error) {
switch r := req.(type) {
case *cohere.CohereChatRequest:
return r.Model, nil
case *cohere.CohereEmbeddingRequest:
return r.Model, nil
case *cohere.CohereRerankRequest:
return r.Model, nil
case *cohere.CohereCountTokensRequest:
return r.Model, nil
}
return "", nil
}

// CreateCohereRouteConfigs creates route configurations for Cohere API endpoints.
func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
var routes []RouteConfig
Expand All @@ -85,6 +101,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereChatRequest{}
},
GetRequestModel: cohereModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereChatRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -131,6 +148,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereEmbeddingRequest{}
},
GetRequestModel: cohereModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -164,6 +182,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereRerankRequest{}
},
GetRequestModel: cohereModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -197,6 +216,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereCountTokensRequest{}
},
GetRequestModel: cohereModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok {
return &schemas.BifrostRequest{
Expand Down
60 changes: 51 additions & 9 deletions transports/bifrost-http/integrations/genai.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ type GenAIRouter struct {
*GenericRouter
}

// genAIModelGetter extracts the model name for GenAI routes.
// For request types populated by extractAndSetModelAndRequestType (the PreCallback),
// the model is already clean on the struct. For BifrostVideoRetrieveRequest (which has
// no model field), the provider-scoped model is extracted from the operation_id suffix
// (format: "op123:openai/gpt-4o") since the route pins the provider via operation_id.
func genAIModelGetter(ctx *fasthttp.RequestCtx, req interface{}) (string, error) {
switch r := req.(type) {
case *gemini.GeminiGenerationRequest:
return r.Model, nil
case *gemini.GeminiEmbeddingRequest:
return r.Model, nil
case *gemini.GeminiVideoGenerationRequest:
return r.Model, nil
case *gemini.GeminiBatchCreateRequest:
return r.Model, nil
case *schemas.BifrostVideoRetrieveRequest:
// operation_id encodes the full model string: "op123:gpt-4o" or "op123:openai/gpt-4o".
operationID, _ := ctx.UserValue("operation_id").(string)
parts := strings.Split(operationID, ":")
if len(parts) >= 2 && parts[len(parts)-1] != "" {
return parts[len(parts)-1], nil
}
return "", nil
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
return "", nil
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// CreateGenAIRouteConfigs creates a route configurations for GenAI endpoints.
func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig {
var routes []RouteConfig
Expand All @@ -51,6 +78,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &schemas.BifrostVideoRetrieveRequest{}
},
GetRequestModel: genAIModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if videoRetrieveReq, ok := req.(*schemas.BifrostVideoRetrieveRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -89,6 +117,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig {
}
return &gemini.GeminiGenerationRequest{}
},
GetRequestModel: genAIModelGetter,
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if geminiReq, ok := req.(*gemini.GeminiGenerationRequest); ok {
if geminiReq.IsCountTokens {
Expand Down Expand Up @@ -790,6 +819,12 @@ func createGenAIRerankRouteConfig(pathPrefix string) RouteConfig {
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &vertex.VertexRankRequest{}
},
GetRequestModel: func(_ *fasthttp.RequestCtx, req interface{}) (string, error) {
if r, ok := req.(*vertex.VertexRankRequest); ok && r.Model != nil {
return *r.Model, nil
}
return "", nil
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if vertexReq, ok := req.(*vertex.VertexRankRequest); ok {
return &schemas.BifrostRequest{
Expand Down Expand Up @@ -1282,31 +1317,38 @@ func extractGeminiVideoOperationFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *s
return errors.New("operation_id must be a non-empty string")
}

// check provider from operation id suffix, id:provider, could be any provider
// operation_id encodes the raw model string as a suffix: "id:rawModel"
// rawModel is either "gpt-4o" (provider name or bare model) or "openai/gpt-4o" (provider/model).
parts := strings.Split(operationIDStr, ":")
if len(parts) < 2 || parts[len(parts)-1] == "" {
return errors.New("provider is required in operation_id format 'id:provider'")
return errors.New("raw model is required in operation_id format 'id:rawModel' or 'id:provider/model'")
}
rawModel := parts[len(parts)-1]

// Parse provider from rawModel: "openai/gpt-4o" → provider="openai"; "gemini" → provider="gemini".
var provider schemas.ModelProvider
rawModelParts := strings.SplitN(rawModel, "/", 2)
if len(rawModelParts) == 2 {
provider = schemas.ModelProvider(rawModelParts[0])
} else {
provider = schemas.ModelProvider(rawModel)
}
provider := parts[len(parts)-1]

modelStr, ok := model.(string)
if !ok || modelStr == "" {
modelStr = provider
modelStr = rawModel
}

// if its gemini, set r.ID in format models/model/operations/operation_id:provider
// else set r.ID in format operation_id:provider

switch r := req.(type) {
case *schemas.BifrostVideoRetrieveRequest:
r.Provider = schemas.ModelProvider(provider)
r.Provider = provider

if r.Provider == schemas.OpenAI || r.Provider == schemas.Azure {
// set a context flag to have video download request after video retrieve request when incoming request is coming from genai integration
bifrostCtx.SetValue(schemas.BifrostContextKeyVideoOutputRequested, true)
}
// Gemini provider expects an operation resource path (without /v1beta prefix).
if provider == string(schemas.Gemini) {
if provider == schemas.Gemini {
r.ID = "models/" + modelStr + "/operations/" + operationIDStr
} else {
r.ID = operationIDStr
Expand Down
Loading
Loading