Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 242 additions & 34 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,22 @@ type ChannelMessage struct {
// It handles request routing, provider management, and response processing.
type Bifrost struct {
ctx context.Context
account schemas.Account // account interface
plugins atomic.Pointer[[]schemas.Plugin] // list of plugins
requestQueues sync.Map // provider request queues (thread-safe)
waitGroups sync.Map // wait groups for each provider (thread-safe)
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
logger schemas.Logger // logger instance, default logger is used if not provided
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
keySelector schemas.KeySelector // Custom key selector function
account schemas.Account // account interface
plugins atomic.Pointer[[]schemas.Plugin] // list of plugins
providers atomic.Pointer[[]schemas.Provider] // list of providers
requestQueues sync.Map // provider request queues (thread-safe)
waitGroups sync.Map // wait groups for each provider (thread-safe)
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
logger schemas.Logger // logger instance, default logger is used if not provided
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
keySelector schemas.KeySelector // Custom key selector function
}

// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
Expand Down Expand Up @@ -91,6 +92,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
keySelector: config.KeySelector,
}
bifrost.plugins.Store(&config.Plugins)

// Initialize providers slice
bifrost.providers.Store(&[]schemas.Provider{})

bifrost.dropExcessRequests.Store(config.DropExcessRequests)

if bifrost.keySelector == nil {
Expand Down Expand Up @@ -203,6 +208,169 @@ func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error {

// PUBLIC API METHODS

// ListModelsRequest sends a list models request to the specified provider.
func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "list models request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for list models request",
},
}
}

request := &schemas.BifrostListModelsRequest{
Provider: req.Provider,
PageSize: req.PageSize,
PageToken: req.PageToken,
ExtraParams: req.ExtraParams,
}

provider := bifrost.getProviderByKey(req.Provider)
if provider == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider not found for list models request",
},
}
}

// Determine the base provider type for key requirement checks
baseProvider := req.Provider
providerConfig, err := bifrost.account.GetConfigForProvider(req.Provider)
if err == nil && providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
}

// Get API key for the provider if required
key := schemas.Key{}
if providerRequiresKey(baseProvider) {
key, err = bifrost.selectKeyFromProviderForModel(&ctx, schemas.ListModelsRequest, req.Provider, "", baseProvider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
}
}
}

response, bifrostErr := provider.ListModels(ctx, key, request)
if bifrostErr != nil {
return nil, bifrostErr
}
Comment thread
TejasGhatte marked this conversation as resolved.
return response, nil
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// ListAllModels lists all models from all configured providers.
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
startTime := time.Now()

if request == nil {
request = &schemas.BifrostListModelsRequest{}
}

providerKeys, err := bifrost.account.GetConfiguredProviders()
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "failed to get configured providers",
Error: err,
},
}
}

// Accumulate all models from all providers
allModels := make([]schemas.Model, 0)
var firstError *schemas.BifrostError

for _, providerKey := range providerKeys {
if strings.TrimSpace(string(providerKey)) == "" {
continue
}

// Create request for this provider with limit of 1000
providerRequest := &schemas.BifrostListModelsRequest{
Provider: providerKey,
PageSize: schemas.DefaultPageSize,
}

iterations := 0
Comment thread
TejasGhatte marked this conversation as resolved.
for {
iterations++
if iterations > schemas.MaxPaginationRequests {
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s", schemas.MaxPaginationRequests, providerKey))
break
}

response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
if bifrostErr != nil {
// Log the error but continue with other providers
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
if firstError == nil {
firstError = bifrostErr
}
break
}

if response == nil {
break
}

if len(response.Data) > 0 {
allModels = append(allModels, response.Data...)
}

// Check if there are more pages
if response.NextPageToken == "" {
break
}

// Set the page token for the next request
providerRequest.PageToken = response.NextPageToken
}
}

// If we couldn't get any models from any provider, return the first error
if len(allModels) == 0 && firstError != nil {
return nil, firstError
}

// Sort models alphabetically by ID
sort.Slice(allModels, func(i, j int) bool {
return allModels[i].ID < allModels[j].ID
})

// Calculate total elapsed time
elapsedTime := time.Since(startTime).Milliseconds()

// Return aggregated response with accumulated latency
response := &schemas.BifrostListModelsResponse{
Data: allModels,
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ListModelsRequest,
Latency: elapsedTime,
},
}

response = response.ApplyPagination(request.PageSize, request.PageToken)

return response, nil
}

// TextCompletionRequest sends a text completion request to the specified provider.
func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
if req == nil {
Expand Down Expand Up @@ -1040,6 +1208,21 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)

// Atomically append provider to the providers slice
for {
oldPtr := bifrost.providers.Load()
var oldSlice []schemas.Provider
if oldPtr != nil {
oldSlice = *oldPtr
}
newSlice := make([]schemas.Provider, len(oldSlice)+1)
copy(newSlice, oldSlice)
newSlice[len(oldSlice)] = provider
if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) {
break
}
}

Comment thread
coderabbitai[bot] marked this conversation as resolved.
for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
currentWaitGroup.Add(1)
go bifrost.requestWorker(provider, providerConfig, queue)
Expand Down Expand Up @@ -1092,6 +1275,23 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
return queue, nil
}

// getProviderByKey retrieves a provider instance from the providers array by its provider key.
// Returns the provider if found, or nil if no provider with the given key exists.
func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider {
providers := bifrost.providers.Load()
if providers == nil {
return nil
}

for _, provider := range *providers {
if provider.GetProviderKey() == providerKey {
return provider
}
}

return nil
}

// CORE INTERNAL LOGIC

// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately
Expand Down Expand Up @@ -1589,7 +1789,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
key := schemas.Key{}
if providerRequiresKey(baseProvider) {
// Use the custom provider name for actual key selection, but pass base provider type for key validation
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), model, baseProvider)
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider)
if err != nil {
bifrost.logger.Warn("error selecting key for model %s: %v", model, err)
req.Err <- schemas.BifrostError{
Expand Down Expand Up @@ -1954,7 +2154,7 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {

// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {
// Check if key has been set in the context explicitly
if ctx != nil {
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
Expand All @@ -1974,28 +2174,36 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov

// filter out keys which dont support the model, if the key has no models, it is supported for all models
var supportedKeys []schemas.Key
for _, key := range keys {
modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0

// Additional deployment checks for Azure and Bedrock
deploymentSupported := true
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
// For Azure, check if deployment exists for this model
if len(key.AzureKeyConfig.Deployments) > 0 {
_, deploymentSupported = key.AzureKeyConfig.Deployments[model]
}
} else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil {
// For Bedrock, check if deployment exists for this model
if len(key.BedrockKeyConfig.Deployments) > 0 {
_, deploymentSupported = key.BedrockKeyConfig.Deployments[model]
if requestType == schemas.ListModelsRequest {
// Skip deployment check but still check if the key has a value
for _, k := range keys {
if strings.TrimSpace(k.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType) {
supportedKeys = append(supportedKeys, k)
}
}
} else {
for _, key := range keys {
modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0

if modelSupported && deploymentSupported {
supportedKeys = append(supportedKeys, key)
// Additional deployment checks for Azure and Bedrock
deploymentSupported := true
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
// For Azure, check if deployment exists for this model
if len(key.AzureKeyConfig.Deployments) > 0 {
_, deploymentSupported = key.AzureKeyConfig.Deployments[model]
}
} else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil {
// For Bedrock, check if deployment exists for this model
if len(key.BedrockKeyConfig.Deployments) > 0 {
_, deploymentSupported = key.BedrockKeyConfig.Deployments[model]
}
}

if modelSupported && deploymentSupported {
supportedKeys = append(supportedKeys, key)
}
}
}

if len(supportedKeys) == 0 {
if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock {
return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model)
Expand Down
8 changes: 2 additions & 6 deletions core/changelog.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
<!-- The pattern we follow here is to keep the changelog for the latest version -->
<!-- Old changelogs are automatically attached to the GitHub releases -->

- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse`
- fix: added latency calculation for vertex native requests
- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions
- feat: added global region support for vertex API
- fix: added filter for extra fields in chat completions request for Mistral provider
- fix: fixed ResponsesComputerToolCallPendingSafetyCheck code field
- feat: added ListModels method to Provider interface
- feat: enabled provider tracking in Bifrost core for API exposure
Loading