Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

**The fastest way to build AI applications that never go down.**

Bifrost is a high-performance AI gateway that connects you to 8+ providers (OpenAI, Anthropic, Bedrock, and more) through a single API. Get automatic failover, load balancing, and zero-downtime deployments in under 30 seconds.
Bifrost is a high-performance AI gateway that connects you to 10+ providers (OpenAI, Anthropic, Bedrock, and more) through a single API. Get automatic failover, load balancing, and zero-downtime deployments in under 30 seconds.

![Bifrost](./docs/media/cover.png)

Expand Down Expand Up @@ -260,7 +260,7 @@ Choose higher settings (like the t3.xlarge profile above) for raw speed, or lowe
<details>
<summary><strong>🎯 I want to understand what Bifrost can do</strong></summary>

- **[🔗 Multi-Provider Support](./docs/usage/providers.md)** - Connect to 8+ AI providers with one API
- **[🔗 Multi-Provider Support](./docs/usage/providers.md)** - Connect to 10+ AI providers with one API
- **[🛡️ Fallback & Reliability](./docs/usage/providers.md#fallback-mechanisms)** - Never lose a request with automatic failover
- **[🛠️ MCP Tool Integration](./docs/usage/http-transport/configuration/mcp.md)** - Give your AI external capabilities
- **[🔌 Plugin Ecosystem](./docs/usage/http-transport/configuration/plugins.md)** - Extend Bifrost with custom middleware
Expand Down
226 changes: 205 additions & 21 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import (
type RequestType string

const (
TextCompletionRequest RequestType = "text_completion"
ChatCompletionRequest RequestType = "chat_completion"
EmbeddingRequest RequestType = "embedding"
TextCompletionRequest RequestType = "text_completion"
ChatCompletionRequest RequestType = "chat_completion"
ChatCompletionStreamRequest RequestType = "chat_completion_stream"
EmbeddingRequest RequestType = "embedding"
)

// executor is a function type that handles specific request types.
Expand All @@ -40,10 +41,11 @@ var messageExecutors = map[RequestType]executor{
// It contains the request, response and error channels, and the request type.
type ChannelMessage struct {
schemas.BifrostRequest
Context context.Context
Response chan *schemas.BifrostResponse
Err chan schemas.BifrostError
Type RequestType
Context context.Context
Response chan *schemas.BifrostResponse
ResponseStream chan chan *schemas.BifrostStream
Err chan schemas.BifrostError
Type RequestType
}

// Bifrost manages providers and maintains sepcified open channels for concurrent processing.
Expand Down Expand Up @@ -315,6 +317,70 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
return nil, primaryErr
}

// ChatCompletionStreamRequest sends a chat completion stream request to the specified provider.
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if err := validateRequest(req); err != nil {
err.Provider = req.Provider
return nil, err
}

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, ChatCompletionStreamRequest)
if primaryErr == nil {
return primaryResult, nil
}

if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled {
primaryErr.Provider = req.Provider
return nil, primaryErr
}

// Check if this is a short-circuit error that doesn't allow fallbacks
// Note: AllowFallbacks = nil is treated as true (allow fallbacks by default)
if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks {
primaryErr.Provider = req.Provider
return nil, primaryErr
}

// If primary provider failed and we have fallbacks, try them in order
// This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil
if len(req.Fallbacks) > 0 {
for _, fallback := range req.Fallbacks {
// Check if we have config for this fallback provider
_, err := bifrost.account.GetConfigForProvider(fallback.Provider)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err))
continue
}

// Create a new request with the fallback provider and model
fallbackReq := *req
fallbackReq.Provider = fallback.Provider
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryStreamRequest(&fallbackReq, ctx, ChatCompletionStreamRequest)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
fallbackErr.Provider = fallback.Provider
return nil, fallbackErr
}

bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

primaryErr.Provider = req.Provider

// All providers failed, return the original error
return nil, primaryErr
}

// EmbeddingRequest sends an embedding request to the specified provider.
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
Expand Down Expand Up @@ -752,6 +818,10 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP
return providers.NewMistralProvider(config, bifrost.logger), nil
case schemas.Ollama:
return providers.NewOllamaProvider(config, bifrost.logger)
case schemas.Groq:
return providers.NewGroqProvider(config, bifrost.logger)
case schemas.SGL:
return providers.NewSGLProvider(config, bifrost.logger)
default:
return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
Expand Down Expand Up @@ -930,6 +1000,82 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont
}
}

// tryStreamRequest is a generic function that handles common request processing logic
// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling
func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx context.Context, requestType RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, newBifrostError(err)
}

// Add MCP tools to request if MCP is configured and requested
if requestType != EmbeddingRequest && bifrost.mcpManager != nil {
req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req)
}

pipeline := bifrost.getPluginPipeline()
defer bifrost.releasePluginPipeline(pipeline)

preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req)
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return newBifrostMessageChan(resp), nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return newBifrostMessageChan(resp), nil
}
}
if preReq == nil {
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

msg := bifrost.getChannelMessage(*preReq, requestType)
msg.Context = ctx

select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests.Load() {
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}
if ctx == nil {
ctx = bifrost.backgroundCtx
}
select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
}
}

select {
case stream := <-msg.ResponseStream:
bifrost.releaseChannelMessage(msg)
return stream, nil
case bifrostErrVal := <-msg.Err:
bifrost.releaseChannelMessage(msg)
return nil, &bifrostErrVal
}
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// requestWorker handles incoming requests from the queue for a specific provider.
// It manages retries, error handling, and response processing.
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan ChannelMessage) {
Expand All @@ -942,6 +1088,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan

for req := range queue {
var result *schemas.BifrostResponse
var stream chan *schemas.BifrostStream
var bifrostError *schemas.BifrostError
var err error

Expand Down Expand Up @@ -994,21 +1141,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan

bifrost.logger.Debug(fmt.Sprintf("Attempting request for provider %s", provider.GetProviderKey()))

// Attempt the request using factory pattern
executor := messageExecutors[req.Type]
if executor == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported request type: %s", req.Type),
},
// Attempt the request
if req.Type == ChatCompletionStreamRequest {
pipeline := bifrost.getPluginPipeline()
defer bifrost.releasePluginPipeline(pipeline)

postHookRunner := func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
resp, bifrostErr := pipeline.RunPostHooks(ctx, result, err, len(bifrost.plugins))
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
stream, bifrostError = handleChatCompletionStream(provider, &req, key, postHookRunner)
if bifrostError != nil && !bifrostError.IsBifrostError {
break // Don't retry client errors
}
} else {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
executor := messageExecutors[req.Type]
if executor == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported request type: %s", req.Type),
},
}
break
}
break
}

result, bifrostError = executor(provider, &req, key)
if bifrostError != nil && !bifrostError.IsBifrostError {
break // Don't retry client errors
result, bifrostError = executor(provider, &req, key)
if bifrostError != nil && !bifrostError.IsBifrostError {
break // Don't retry client errors
}
}

bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey()))
Expand All @@ -1031,7 +1195,11 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
}
req.Err <- *bifrostError
} else {
req.Response <- result
if req.Type == ChatCompletionStreamRequest {
req.ResponseStream <- stream
} else {
req.Response <- result
}
}
}

Expand Down Expand Up @@ -1077,6 +1245,20 @@ func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string)
return provider.Embedding(req.Context, req.Model, key, req.Input.EmbeddingInput, req.Params)
}

// handleChatCompletionStream executes a chat completion stream request
func handleChatCompletionStream(provider schemas.Provider, req *ChannelMessage, key string, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if req.Input.ChatCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "chats not provided for chat completion request",
},
}
}

return provider.ChatCompletionStream(req.Context, postHookRunner, req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}

// PLUGIN MANAGEMENT

// RunPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count.
Expand Down Expand Up @@ -1176,6 +1358,7 @@ func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType Re
msg := bifrost.channelMessagePool.Get().(*ChannelMessage)
msg.BifrostRequest = req
msg.Response = responseChan
msg.ResponseStream = make(chan chan *schemas.BifrostStream, 1) // Initialize the ResponseStream channel
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
msg.Err = errorChan
msg.Type = reqType

Expand All @@ -1190,6 +1373,7 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {

// Clear references and return to pool
msg.Response = nil
msg.ResponseStream = nil
msg.Err = nil
bifrost.channelMessagePool.Put(msg)
}
Expand Down
Loading