Skip to content
Merged
247 changes: 142 additions & 105 deletions core/bifrost.go
Comment thread
connyay marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,19 @@ type RequestType string
const (
TextCompletionRequest RequestType = "text_completion"
ChatCompletionRequest RequestType = "chat_completion"
EmbeddingRequest RequestType = "embedding"
)

// executor is a function type that handles specific request types.
type executor func(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError)

// messageExecutors is a factory map for handling different request types.
var messageExecutors = map[RequestType]executor{
TextCompletionRequest: handleTextCompletion,
ChatCompletionRequest: handleChatCompletion,
EmbeddingRequest: handleEmbedding,
}

// ChannelMessage represents a message passed through the request channel.
// It contains the request, response and error channels, and the request type.
type ChannelMessage struct {
Expand Down Expand Up @@ -380,6 +391,45 @@ func (bifrost *Bifrost) calculateBackoff(attempt int, config *schemas.ProviderCo
return time.Duration(jitter)
}

// handleTextCompletion executes a text completion request
func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.TextCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "text not provided for text completion request",
},
}
}
return provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params)
}

// handleChatCompletion executes a chat completion request
func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.ChatCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "chats not provided for chat completion request",
},
}
}
return provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}

// handleEmbedding executes an embedding request
func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.EmbeddingInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "input not provided for embedding request",
},
}
}
return provider.Embedding(req.Context, req.Model, key, req.Input.EmbeddingInput, req.Params)
}

// requestWorker handles incoming requests from the queue for a specific provider.
// It manages retries, error handling, and response processing.
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan ChannelMessage) {
Expand Down Expand Up @@ -439,31 +489,21 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan

bifrost.logger.Debug(fmt.Sprintf("Attempting request for provider %s", provider.GetProviderKey()))

// Attempt the request
if req.Type == TextCompletionRequest {
if req.Input.TextCompletionInput == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "text not provided for text completion request",
},
}
break // Don't retry client errors
} else {
result, bifrostError = provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params)
}
} else if req.Type == ChatCompletionRequest {
if req.Input.ChatCompletionInput == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "chats not provided for chat completion request",
},
}
break // Don't retry client errors
} else {
result, bifrostError = provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params)
// Attempt the request using factory pattern
executor := messageExecutors[req.Type]
if executor == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported request type: %s", req.Type),
},
}
break
}

result, bifrostError = executor(provider, &req, key)
if bifrostError != nil && !bifrostError.IsBifrostError {
break // Don't retry client errors
}

bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey()))
Expand Down Expand Up @@ -614,93 +654,81 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
// tryTextCompletion attempts a text completion request with a single provider.
// This is a helper function used by TextCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, newBifrostError(err)
}
return bifrost.tryRequest(req, ctx, TextCompletionRequest, true)
}

// Add MCP tools to request if MCP is configured
if bifrost.mcpManager != nil {
req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req)
// ChatCompletionRequest sends a chat completion request to the specified provider.
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req == nil {
return nil, newBifrostErrorFromMsg("bifrost request cannot be nil")
}

pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger)
preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req)
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
if req.Provider == "" {
return nil, newBifrostErrorFromMsg("provider is required")
}
if preReq == nil {
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")

if req.Model == "" {
return nil, newBifrostErrorFromMsg("model is required")
}

msg := bifrost.getChannelMessage(*preReq, TextCompletionRequest)
msg.Context = ctx
// Try the primary provider first
primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx)
if primaryErr == nil {
return primaryResult, nil
}

select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}
if ctx == nil {
ctx = bifrost.backgroundCtx
}
select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
}
// Check if this is a short-circuit error that doesn't allow fallbacks
// Note: AllowFallbacks = nil is treated as true (allow fallbacks by default)
if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks {
return nil, primaryErr
}

var result *schemas.BifrostResponse
var resp *schemas.BifrostResponse
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
}
bifrost.releaseChannelMessage(msg)
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
// If primary provider failed and we have fallbacks, try them in order
// This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil
if len(req.Fallbacks) > 0 {
for _, fallback := range req.Fallbacks {
// Check if we have config for this fallback provider
_, err := bifrost.account.GetConfigForProvider(fallback.Provider)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err))
continue
}

// Create a new request with the fallback provider and model
fallbackReq := *req
fallbackReq.Provider = fallback.Provider
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
return nil, fallbackErr
}

bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
}
return resp, nil
}

// All providers failed, return the original error
return nil, primaryErr
}

// ChatCompletionRequest sends a chat completion request to the specified provider.
// tryChatCompletion attempts a chat completion request with a single provider.
// This is a helper function used by ChatCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
return bifrost.tryRequest(req, ctx, ChatCompletionRequest, true)
}

// EmbeddingRequest sends an embedding request to the specified provider.
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req == nil {
return nil, newBifrostErrorFromMsg("bifrost request cannot be nil")
}
Expand All @@ -713,8 +741,12 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
return nil, newBifrostErrorFromMsg("model is required")
}

if req.Input.EmbeddingInput == nil {
return nil, newBifrostErrorFromMsg("embedding_input is required")
}

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx)
primaryResult, primaryErr := bifrost.tryEmbedding(req, ctx)
if primaryErr == nil {
return primaryResult, nil
}
Expand All @@ -726,7 +758,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
}

// If primary provider failed and we have fallbacks, try them in order
// This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil
if len(req.Fallbacks) > 0 {
for _, fallback := range req.Fallbacks {
// Check if we have config for this fallback provider
Expand All @@ -742,7 +773,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx)
result, fallbackErr := bifrost.tryEmbedding(&fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
Expand All @@ -759,16 +790,22 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
return nil, primaryErr
}

// tryChatCompletion attempts a chat completion request with a single provider.
// This is a helper function used by ChatCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
// tryEmbedding attempts an embedding request with a single provider.
// This is a helper function used by EmbeddingRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryEmbedding(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
return bifrost.tryRequest(req, ctx, EmbeddingRequest, false)
}

// tryRequest is a generic function that handles common request processing logic
// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling
func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context, requestType RequestType, includeMCP bool) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, newBifrostError(err)
}

// Add MCP tools to request if MCP is configured
if bifrost.mcpManager != nil {
// Add MCP tools to request if MCP is configured and requested
if includeMCP && bifrost.mcpManager != nil {
req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req)
}

Expand Down Expand Up @@ -796,7 +833,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

msg := bifrost.getChannelMessage(*preReq, ChatCompletionRequest)
msg := bifrost.getChannelMessage(*preReq, requestType)
msg.Context = ctx

select {
Expand Down
5 changes: 5 additions & 0 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,8 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc

return bifrostResponse, nil
}

// Embedding is not supported by the Anthropic provider.
func (provider *AnthropicProvider) Embedding(ctx context.Context, model, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
return nil, newUnsupportedOperationError("embedding", "anthropic")
}
Loading