diff --git a/core/bifrost.go b/core/bifrost.go index 9a2a49db0b..8f8f58b7e8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -7,11 +7,8 @@ import ( "context" "fmt" "math/rand" - "os" - "os/signal" "slices" "sync" - "syscall" "time" "github.com/maximhq/bifrost/core/providers" @@ -30,6 +27,7 @@ const ( // It contains the request, response and error channels, and the request type. type ChannelMessage struct { schemas.BifrostRequest + Context context.Context Response chan *schemas.BifrostResponse Err chan schemas.BifrostError Type RequestType @@ -357,7 +355,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan } break // Don't retry client errors } else { - result, bifrostError = provider.TextCompletion(req.Model, key, *req.Input.TextCompletionInput, req.Params) + result, bifrostError = provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params) } } else if req.Type == ChatCompletionRequest { if req.Input.ChatCompletionInput == nil { @@ -369,14 +367,17 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan } break // Don't retry client errors } else { - result, bifrostError = provider.ChatCompletion(req.Model, key, *req.Input.ChatCompletionInput, req.Params) + result, bifrostError = provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params) } } bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey())) // Check if successful or if we should retry - if bifrostError == nil || bifrostError.IsBifrostError || (bifrostError.StatusCode != nil && !retryableStatusCodes[*bifrostError.StatusCode]) { + if bifrostError == nil || + bifrostError.IsBifrostError || + (bifrostError.StatusCode != nil && !retryableStatusCodes[*bifrostError.StatusCode]) || + (bifrostError.Error.Type != nil && *bifrostError.Error.Type == schemas.RequestCancelled) { break } } @@ -466,11 +467,15 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. } // Try the primary provider first - primaryResult, primaryErr := bifrost.tryTextCompletion(req.Provider, req, ctx) + primaryResult, primaryErr := bifrost.tryTextCompletion(req, ctx) if primaryErr == nil { return primaryResult, nil } + if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + return nil, primaryErr + } + // If primary provider failed and we have fallbacks, try them in order if len(req.Fallbacks) > 0 { for _, fallback := range req.Fallbacks { @@ -486,11 +491,15 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. fallbackReq.Model = fallback.Model // Try the fallback provider - result, fallbackErr := bifrost.tryTextCompletion(fallback.Provider, &fallbackReq, ctx) + result, fallbackErr := bifrost.tryTextCompletion(&fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil } + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + return nil, fallbackErr + } + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) } } @@ -501,8 +510,8 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. // tryTextCompletion attempts a text completion request with a single provider. // This is a helper function used by TextCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) +func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + queue, err := bifrost.GetProviderQueue(req.Provider) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -535,6 +544,7 @@ func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req // Get a ChannelMessage from the pool msg := bifrost.getChannelMessage(*req, TextCompletionRequest) + msg.Context = ctx // Handle queue send with context and proper cleanup select { @@ -561,6 +571,7 @@ func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req }, } } + // If not dropping excess requests, wait with context if ctx == nil { ctx = bifrost.backgroundCtx @@ -638,7 +649,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. } // Try the primary provider first - primaryResult, primaryErr := bifrost.tryChatCompletion(req.Provider, req, ctx) + primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx) if primaryErr == nil { return primaryResult, nil } @@ -658,7 +669,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. fallbackReq.Model = fallback.Model // Try the fallback provider - result, fallbackErr := bifrost.tryChatCompletion(fallback.Provider, &fallbackReq, ctx) + result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil @@ -673,8 +684,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. // tryChatCompletion attempts a chat completion request with a single provider. // This is a helper function used by ChatCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) +func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + queue, err := bifrost.GetProviderQueue(req.Provider) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -707,6 +718,7 @@ func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req // Get a ChannelMessage from the pool msg := bifrost.getChannelMessage(*req, ChatCompletionRequest) + msg.Context = ctx // Handle queue send with context and proper cleanup select { @@ -778,10 +790,10 @@ func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req return result, nil } -// Shutdown gracefully stops all workers when triggered. +// Cleanup gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. -func (bifrost *Bifrost) Shutdown() { - bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") +func (bifrost *Bifrost) Cleanup() { + bifrost.logger.Info("[BIFROST] Graceful Cleanup Initiated - Closing all request channels...") // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { @@ -793,13 +805,3 @@ func (bifrost *Bifrost) Shutdown() { waitGroup.Wait() } } - -// Cleanup handles SIGINT (Ctrl+C) to exit cleanly. -// It sets up signal handling and calls Shutdown when interrupted. -func (bifrost *Bifrost) Cleanup() { - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) - - <-signalChan // Wait for interrupt signal - bifrost.Shutdown() // Gracefully shut down -} diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 987bc7be3b..a9078d1bb3 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "strings" "sync" @@ -168,7 +169,7 @@ func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string // completeRequest sends a request to Anthropic's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AnthropicProvider) completeRequest(requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { @@ -195,14 +196,9 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response @@ -227,7 +223,7 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params)) // Merge additional parameters @@ -236,7 +232,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/complete", key) + responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/complete", key) if err != nil { return nil, err } @@ -281,7 +277,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // ChatCompletion performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params) // Merge additional parameters @@ -290,7 +286,7 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] "messages": formattedMessages, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) + responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/messages", key) if err != nil { return nil, err } diff --git a/core/providers/azure.go b/core/providers/azure.go index 0788d0196f..983c2d3b21 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "sync" "time" @@ -141,7 +142,7 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Azure's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AzureProvider) completeRequest(requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { +func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { @@ -204,14 +205,9 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response @@ -236,7 +232,7 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := prepareParams(params) // Merge additional parameters @@ -245,7 +241,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s "prompt": text, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "completions", key, model) if err != nil { return nil, err } @@ -297,7 +293,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := prepareParams(params) // Format messages for Azure API @@ -315,7 +311,7 @@ func (provider *AzureProvider) ChatCompletion(model, key string, messages []sche "messages": formattedMessages, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "chat/completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "chat/completions", key, model) if err != nil { return nil, err } diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 3eb6bc1206..c1c2d102a2 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -7,6 +7,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "net/http" @@ -189,7 +190,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { // CompleteRequest sends a request to Bedrock's API and handles the response. // It constructs the API URL, sets up AWS authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *BedrockProvider) completeRequest(requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) { +func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) { if provider.meta == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -206,6 +207,16 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac jsonBody, err := json.Marshal(requestBody) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: StrPtr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -216,7 +227,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac } // Create the request with the JSON body - req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -558,14 +569,14 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model) requestBody := mergeConfig(map[string]interface{}{ "prompt": text, }, preparedParams) - body, err := provider.completeRequest(requestBody, fmt.Sprintf("%s/invoke", model), key) + body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key) if err != nil { return nil, err } @@ -595,7 +606,7 @@ func (provider *BedrockProvider) TextCompletion(model, key, text string, params // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { messageBody, err := provider.prepareChatCompletionMessages(messages, model) if err != nil { return nil, err @@ -623,7 +634,7 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc } // Create the signed request - responseBody, err := provider.completeRequest(requestBody, path, key) + responseBody, err := provider.completeRequest(ctx, requestBody, path, key) if err != nil { return nil, err } diff --git a/core/providers/cohere.go b/core/providers/cohere.go index c8e6f41b19..4d84d740df 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "slices" "sync" @@ -128,7 +129,7 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -140,7 +141,7 @@ func (provider *CohereProvider) TextCompletion(model, key, text string, params * // ChatCompletion performs a chat completion request to the Cohere API. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Get the last message and chat history lastMessage := messages[len(messages)-1] chatHistory := messages[:len(messages)-1] @@ -222,14 +223,9 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response diff --git a/core/providers/openai.go b/core/providers/openai.go index 9dff479d98..1420663be6 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "sync" "time" @@ -101,7 +102,7 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -113,7 +114,7 @@ func (provider *OpenAIProvider) TextCompletion(model, key, text string, params * // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(model, messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -145,14 +146,9 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response @@ -242,10 +238,16 @@ func prepareOpenAIChatRequest(model string, messages []schemas.Message, params * "content": content, }) } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ + message := map[string]interface{}{ "role": msg.Role, "content": msg.Content, - }) + } + + if msg.ToolCallID != nil { + message["tool_call_id"] = msg.ToolCallID + } + + formattedMessages = append(formattedMessages, message) } } diff --git a/core/providers/utils.go b/core/providers/utils.go index 9f1482ee64..7fa930216b 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "net/url" "reflect" @@ -104,6 +105,49 @@ func prepareParams(params *schemas.ModelParameters) map[string]interface{} { return flatParams } +// IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the +// context is done. The fasthttp client call will continue in its goroutine until it completes +// or times out based on its own settings. This function merely stops *waiting* for the +// fasthttp call and returns an error related to the context. +func makeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) *schemas.BifrostError { + errChan := make(chan error, 1) + + go func() { + // client.Do is a blocking call. + // It will send an error (or nil for success) to errChan when it completes. + errChan <- client.Do(req, resp) + }() + + select { + case <-ctx.Done(): + // Context was cancelled (e.g., deadline exceeded or manual cancellation). + // Return a BifrostError indicating this. + return &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Type: StrPtr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: ctx.Err(), + }, + } + case err := <-errChan: + // The fasthttp.Do call completed. + if err != nil { + // The HTTP request itself failed (e.g., connection error, fasthttp timeout). + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + // HTTP request was successful from fasthttp's perspective (err is nil). + // The caller should check resp.StatusCode() for HTTP-level errors (4xx, 5xx). + return nil + } +} + // configureProxy sets up a proxy for the fasthttp client based on the provided configuration. // It supports HTTP, SOCKS5, and environment-based proxy configurations. // Returns the configured client or the original client if proxy configuration is invalid. diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 3f04c09310..73c5036ffc 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -5,6 +5,7 @@ package providers import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -76,7 +77,7 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. -func (provider *VertexProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -88,7 +89,7 @@ func (provider *VertexProvider) TextCompletion(model, key, text string, params * // ChatCompletion performs a chat completion request to the Vertex API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *VertexProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Format messages for Vertex API var formattedMessages []map[string]interface{} var preparedParams map[string]interface{} @@ -152,7 +153,7 @@ func (provider *VertexProvider) ChatCompletion(model, key string, messages []sch } // Create request - req, err := http.NewRequest("POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -167,6 +168,16 @@ func (provider *VertexProvider) ChatCompletion(model, key string, messages []sch // Make request resp, err := provider.client.Do(req) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: StrPtr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 2dcd888268..df11742d01 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -143,6 +143,7 @@ type ToolChoice struct { type Message struct { Role ModelChatMessageRole `json:"role"` Content *string `json:"content,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` ImageContent *ImageContent `json:"image_content,omitempty"` ToolCalls *[]Tool `json:"tool_calls,omitempty"` } @@ -290,6 +291,10 @@ type BifrostResponseExtraFields struct { RawResponse interface{} `json:"raw_response"` } +const ( + RequestCancelled = "request_cancelled" +) + // BifrostError represents an error from the Bifrost system. type BifrostError struct { EventID *string `json:"event_id,omitempty"` diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 3ed487d043..b26cea1df9 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -1,7 +1,10 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "time" +import ( + "context" + "time" +) const ( DefaultMaxRetries = 0 @@ -140,7 +143,7 @@ type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // TextCompletion performs a text completion request - TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + TextCompletion(ctx context.Context, model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) } diff --git a/core/tests/anthropic_test.go b/core/tests/anthropic_test.go index 9522f76ba7..edd0144aa2 100644 --- a/core/tests/anthropic_test.go +++ b/core/tests/anthropic_test.go @@ -32,6 +32,4 @@ func TestAnthropic(t *testing.T) { } SetupAllRequests(bifrost, config) - - bifrost.Cleanup() } diff --git a/core/tests/azure_test.go b/core/tests/azure_test.go index 81f37b6308..5c2194364c 100644 --- a/core/tests/azure_test.go +++ b/core/tests/azure_test.go @@ -26,5 +26,4 @@ func TestAzure(t *testing.T) { } SetupAllRequests(bifrost, config) - bifrost.Cleanup() } diff --git a/core/tests/bedrock_test.go b/core/tests/bedrock_test.go index ed227629cc..6c9b6126ff 100644 --- a/core/tests/bedrock_test.go +++ b/core/tests/bedrock_test.go @@ -34,5 +34,4 @@ func TestBedrock(t *testing.T) { } SetupAllRequests(bifrost, config) - bifrost.Cleanup() } diff --git a/core/tests/cohere_test.go b/core/tests/cohere_test.go index 37a7bfb37b..93c9675c01 100644 --- a/core/tests/cohere_test.go +++ b/core/tests/cohere_test.go @@ -26,6 +26,4 @@ func TestCohere(t *testing.T) { } SetupAllRequests(bifrost, config) - - bifrost.Cleanup() } diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go index 6cf5e67086..f748e08ca6 100644 --- a/core/tests/openai_test.go +++ b/core/tests/openai_test.go @@ -27,11 +27,10 @@ func TestOpenAI(t *testing.T) { Fallbacks: []schemas.Fallback{ { Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20240620", + Model: "claude-3-7-sonnet-20250219", }, }, } SetupAllRequests(bifrost, config) - bifrost.Cleanup() } diff --git a/core/tests/tests.go b/core/tests/tests.go index 2c4a6d3ec8..561f7078a8 100644 --- a/core/tests/tests.go +++ b/core/tests/tests.go @@ -6,6 +6,10 @@ package tests import ( "context" "log" + "os" + "os/signal" + "sync" + "syscall" "time" bifrost "github.com/maximhq/bifrost/core" @@ -82,7 +86,8 @@ var WeatherToolParams = schemas.ModelParameters{ // - bifrost: The Bifrost instance to use for the request // - config: Test configuration containing model and parameters // - ctx: Context for the request -func setupTextCompletionRequest(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context) { +// - wg: WaitGroup for synchronization +func setupTextCompletionRequest(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context, wg *sync.WaitGroup) { text := "Hello world!" if config.CustomTextCompletion != nil { text = *config.CustomTextCompletion @@ -93,7 +98,9 @@ func setupTextCompletionRequest(bifrostClient *bifrost.Bifrost, config TestConfi params = *config.CustomParams } + wg.Add(1) go func() { + defer wg.Done() result, err := bifrostClient.TextCompletionRequest(ctx, &schemas.BifrostRequest{ Provider: config.Provider, Model: config.TextModel, @@ -118,7 +125,8 @@ func setupTextCompletionRequest(bifrostClient *bifrost.Bifrost, config TestConfi // - bifrost: The Bifrost instance to use for the requests // - config: Test configuration containing model and parameters // - ctx: Context for the requests -func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context) { +// - wg: WaitGroup for synchronization +func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context, wg *sync.WaitGroup) { messages := config.Messages if len(messages) == 0 { messages = CommonTestMessages @@ -131,7 +139,9 @@ func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConf for i, message := range messages { delay := time.Duration(100*(i+1)) * time.Millisecond + wg.Add(1) go func(msg string, delay time.Duration, index int) { + defer wg.Done() time.Sleep(delay) messages := []schemas.Message{ { @@ -164,7 +174,8 @@ func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConf // - bifrost: The Bifrost instance to use for the requests // - config: Test configuration containing model and parameters // - ctx: Context for the requests -func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context) { +// - wg: WaitGroup for synchronization +func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context, wg *sync.WaitGroup) { params := schemas.ModelParameters{} if config.CustomParams != nil { params = *config.CustomParams @@ -186,7 +197,9 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont urlImageMessages[0].ImageContent.Type = bifrost.Ptr("url") } + wg.Add(1) go func() { + defer wg.Done() result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ Provider: config.Provider, Model: config.ChatModel, @@ -217,7 +230,9 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont }, } + wg.Add(1) go func() { + defer wg.Done() result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ Provider: config.Provider, Model: config.ChatModel, @@ -243,7 +258,8 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont // - bifrost: The Bifrost instance to use for the requests // - config: Test configuration containing model and parameters // - ctx: Context for the requests -func setupToolCalls(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context) { +// - wg: WaitGroup for synchronization +func setupToolCalls(bifrostClient *bifrost.Bifrost, config TestConfig, ctx context.Context, wg *sync.WaitGroup) { messages := []string{"What's the weather like in Mumbai?"} params := WeatherToolParams @@ -259,7 +275,9 @@ func setupToolCalls(bifrostClient *bifrost.Bifrost, config TestConfig, ctx conte for i, message := range messages { delay := time.Duration(100*(i+1)) * time.Millisecond + wg.Add(1) go func(msg string, delay time.Duration, index int) { + defer wg.Done() time.Sleep(delay) messages := []schemas.Message{ { @@ -306,20 +324,48 @@ func setupToolCalls(bifrostClient *bifrost.Bifrost, config TestConfig, ctx conte // Parameters: // - bifrost: The Bifrost instance to use for the requests // - config: Test configuration specifying which tests to run -func SetupAllRequests(bifrost *bifrost.Bifrost, config TestConfig) { - ctx := context.Background() +func SetupAllRequests(bifrostClient *bifrost.Bifrost, config TestConfig) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + var wg sync.WaitGroup + + go func() { + <-sigChan + log.Println("\nšŸ›‘ Interrupt signal received, cancelling requests...") + cancel() + }() if config.SetupText { - setupTextCompletionRequest(bifrost, config, ctx) + setupTextCompletionRequest(bifrostClient, config, ctx, &wg) } - setupChatCompletionRequests(bifrost, config, ctx) + setupChatCompletionRequests(bifrostClient, config, ctx, &wg) if config.SetupImage { - setupImageTests(bifrost, config, ctx) + setupImageTests(bifrostClient, config, ctx, &wg) } if config.SetupToolCalls { - setupToolCalls(bifrost, config, ctx) + setupToolCalls(bifrostClient, config, ctx, &wg) + } + + allDoneChan := make(chan struct{}) + go func() { + wg.Wait() + close(allDoneChan) + }() + + select { + case <-ctx.Done(): + log.Println("Context cancelled, test setup winding down.") + time.Sleep(1 * time.Second) + case <-allDoneChan: + log.Println("All test goroutines completed.") } + log.Println("Test setup finished.") + bifrostClient.Cleanup() } diff --git a/core/tests/vertex_test.go b/core/tests/vertex_test.go index 7a777f7848..c580284442 100644 --- a/core/tests/vertex_test.go +++ b/core/tests/vertex_test.go @@ -26,5 +26,4 @@ func TestVertex(t *testing.T) { } SetupAllRequests(bifrostClient, config) - bifrostClient.Cleanup() }