Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ import (
"context"
"fmt"
"math/rand"
"os"
"os/signal"
"slices"
"sync"
"syscall"
"time"

"github.com/maximhq/bifrost/core/providers"
Expand All @@ -30,6 +27,7 @@ const (
// It contains the request, response and error channels, and the request type.
type ChannelMessage struct {
schemas.BifrostRequest
Context context.Context
Response chan *schemas.BifrostResponse
Err chan schemas.BifrostError
Type RequestType
Expand Down Expand Up @@ -357,7 +355,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
}
break // Don't retry client errors
} else {
result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params)
result, bifrostError = provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params)
}
} else if req.Type == ChatCompletionRequest {
if req.Input.ChatCompletionInput == nil {
Expand All @@ -369,14 +367,17 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
}
break // Don't retry client errors
} else {
result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params)
result, bifrostError = provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}
}

bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey()))

// Check if successful or if we should retry
if bifrostError == nil || bifrostError.IsBifrostError || (bifrostError.StatusCode != nil && !retryableStatusCodes[*bifrostError.StatusCode]) {
if bifrostError == nil ||
bifrostError.IsBifrostError ||
(bifrostError.StatusCode != nil && !retryableStatusCodes[*bifrostError.StatusCode]) ||
(bifrostError.Error.Type != nil && *bifrostError.Error.Type == schemas.RequestCancelled) {
break
}
}
Expand Down Expand Up @@ -466,11 +467,15 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
}

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryTextCompletion(req.Provider, req, ctx)
primaryResult, primaryErr := bifrost.tryTextCompletion(req, ctx)
if primaryErr == nil {
return primaryResult, nil
}

if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled {
return nil, primaryErr
}

// If primary provider failed and we have fallbacks, try them in order
if len(req.Fallbacks) > 0 {
for _, fallback := range req.Fallbacks {
Expand All @@ -486,11 +491,15 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryTextCompletion(fallback.Provider, &fallbackReq, ctx)
result, fallbackErr := bifrost.tryTextCompletion(&fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
return nil, fallbackErr
}

bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
}
}
Expand All @@ -501,8 +510,8 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.

// tryTextCompletion attempts a text completion request with a single provider.
// This is a helper function used by TextCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(providerKey)
func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down Expand Up @@ -535,6 +544,7 @@ func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req

// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, TextCompletionRequest)
msg.Context = ctx

// Handle queue send with context and proper cleanup
select {
Expand All @@ -561,6 +571,7 @@ func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req
},
}
}

// If not dropping excess requests, wait with context
if ctx == nil {
ctx = bifrost.backgroundCtx
Expand Down Expand Up @@ -638,7 +649,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
}

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryChatCompletion(req.Provider, req, ctx)
primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx)
if primaryErr == nil {
return primaryResult, nil
}
Expand All @@ -658,7 +669,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryChatCompletion(fallback.Provider, &fallbackReq, ctx)
result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
Expand All @@ -673,8 +684,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.

// tryChatCompletion attempts a chat completion request with a single provider.
// This is a helper function used by ChatCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(providerKey)
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down Expand Up @@ -707,6 +718,7 @@ func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req

// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, ChatCompletionRequest)
msg.Context = ctx

// Handle queue send with context and proper cleanup
select {
Expand Down Expand Up @@ -778,10 +790,10 @@ func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req
return result, nil
}

// Shutdown gracefully stops all workers when triggered.
// Cleanup gracefully stops all workers when triggered.
// It closes all request channels and waits for workers to exit.
func (bifrost *Bifrost) Shutdown() {
bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...")
func (bifrost *Bifrost) Cleanup() {
bifrost.logger.Info("[BIFROST] Graceful Cleanup Initiated - Closing all request channels...")

// Close all provider queues to signal workers to stop
for _, queue := range bifrost.requestQueues {
Expand All @@ -793,13 +805,3 @@ func (bifrost *Bifrost) Shutdown() {
waitGroup.Wait()
}
}

// Cleanup handles SIGINT (Ctrl+C) to exit cleanly.
// It sets up signal handling and calls Shutdown when interrupted.
func (bifrost *Bifrost) Cleanup() {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)

<-signalChan // Wait for interrupt signal
bifrost.Shutdown() // Gracefully shut down
}
22 changes: 9 additions & 13 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package providers

import (
"context"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -168,7 +169,7 @@ func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string
// completeRequest sends a request to Anthropic's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *AnthropicProvider) completeRequest(requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) {
func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) {
// Marshal the request body
jsonData, err := json.Marshal(requestBody)
if err != nil {
Expand All @@ -195,14 +196,9 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf
req.SetBody(jsonData)

// Send the request
if err := provider.client.Do(req, resp); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: schemas.ErrProviderRequest,
Error: err,
},
}
bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
return nil, bifrostErr
}

// Handle error response
Expand All @@ -227,7 +223,7 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf
// TextCompletion performs a text completion request to Anthropic's API.
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := provider.prepareTextCompletionParams(prepareParams(params))

// Merge additional parameters
Expand All @@ -236,7 +232,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
"prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text),
}, preparedParams)

responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/complete", key)
responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/complete", key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -281,7 +277,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
// ChatCompletion performs a chat completion request to Anthropic's API.
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params)

// Merge additional parameters
Expand All @@ -290,7 +286,7 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
"messages": formattedMessages,
}, preparedParams)

responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key)
responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/messages", key)
if err != nil {
return nil, err
}
Expand Down
22 changes: 9 additions & 13 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package providers

import (
"context"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -141,7 +142,7 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider {
// completeRequest sends a request to Azure's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *AzureProvider) completeRequest(requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) {
func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) {
// Marshal the request body
jsonData, err := json.Marshal(requestBody)
if err != nil {
Expand Down Expand Up @@ -204,14 +205,9 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{
req.SetBody(jsonData)

// Send the request
if err := provider.client.Do(req, resp); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: schemas.ErrProviderRequest,
Error: err,
},
}
bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
return nil, bifrostErr
}

// Handle error response
Expand All @@ -236,7 +232,7 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{
// TextCompletion performs a text completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := prepareParams(params)

// Merge additional parameters
Expand All @@ -245,7 +241,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s
"prompt": text,
}, preparedParams)

responseBody, err := provider.completeRequest(requestBody, "completions", key, model)
responseBody, err := provider.completeRequest(ctx, requestBody, "completions", key, model)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -297,7 +293,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s
// ChatCompletion performs a chat completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := prepareParams(params)

// Format messages for Azure API
Expand All @@ -315,7 +311,7 @@ func (provider *AzureProvider) ChatCompletion(model, key string, messages []sche
"messages": formattedMessages,
}, preparedParams)

responseBody, err := provider.completeRequest(requestBody, "chat/completions", key, model)
responseBody, err := provider.completeRequest(ctx, requestBody, "chat/completions", key, model)
if err != nil {
return nil, err
}
Expand Down
23 changes: 17 additions & 6 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -189,7 +190,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
// CompleteRequest sends a request to Bedrock's API and handles the response.
// It constructs the API URL, sets up AWS authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *BedrockProvider) completeRequest(requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) {
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) {
if provider.meta == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -206,6 +207,16 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac

jsonBody, err := json.Marshal(requestBody)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Type: StrPtr(schemas.RequestCancelled),
Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()),
Error: err,
},
}
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Expand All @@ -216,7 +227,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac
}

// Create the request with the JSON body
req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody))
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Expand Down Expand Up @@ -558,14 +569,14 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i
// TextCompletion performs a text completion request to Bedrock's API.
// It formats the request, sends it to Bedrock, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *BedrockProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model)

requestBody := mergeConfig(map[string]interface{}{
"prompt": text,
}, preparedParams)

body, err := provider.completeRequest(requestBody, fmt.Sprintf("%s/invoke", model), key)
body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -595,7 +606,7 @@ func (provider *BedrockProvider) TextCompletion(model, key, text string, params
// ChatCompletion performs a chat completion request to Bedrock's API.
// It formats the request, sends it to Bedrock, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *BedrockProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
messageBody, err := provider.prepareChatCompletionMessages(messages, model)
if err != nil {
return nil, err
Expand Down Expand Up @@ -623,7 +634,7 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc
}

// Create the signed request
responseBody, err := provider.completeRequest(requestBody, path, key)
responseBody, err := provider.completeRequest(ctx, requestBody, path, key)
if err != nil {
return nil, err
}
Expand Down
Loading