From 94552817b9abe7300ec485c378d0dfb1f10ec44d Mon Sep 17 00:00:00 2001 From: Pratham Mishra <99235987+Pratham-Mishra04@users.noreply.github.com> Date: Tue, 22 Jul 2025 16:43:42 +0530 Subject: [PATCH] enhancement: core leaks fixes and refactor --- core/bifrost.go | 847 +++++++++++++++++----------------------------- transports/go.mod | 2 - 2 files changed, 303 insertions(+), 546 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index 536701191b..3ce37dfd0f 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -53,6 +53,7 @@ type Bifrost struct { channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init + responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init pluginPipelinePool sync.Pool // Pool for PluginPipeline objects logger schemas.Logger // logger instance, default logger is used if not provided backgroundCtx context.Context // Shared background context for nil context handling @@ -117,6 +118,11 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { return make(chan schemas.BifrostError, 1) }, } + bifrost.responseStreamPool = sync.Pool{ + New: func() interface{} { + return make(chan chan *schemas.BifrostStream, 1) + }, + } bifrost.pluginPipelinePool = sync.Pool{ New: func() interface{} { return &PluginPipeline{ @@ -132,6 +138,7 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { bifrost.channelMessagePool.Put(&ChannelMessage{}) bifrost.responseChannelPool.Put(make(chan *schemas.BifrostResponse, 1)) bifrost.errorChannelPool.Put(make(chan schemas.BifrostError, 1)) + bifrost.responseStreamPool.Put(make(chan chan *schemas.BifrostStream, 1)) bifrost.pluginPipelinePool.Put(&PluginPipeline{ preHookErrors: make([]error, 0), postHookErrors: make([]error, 0), @@ -184,516 +191,115 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { // PUBLIC API METHODS // TextCompletionRequest sends a text completion request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, TextCompletionRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(&fallbackReq, ctx, TextCompletionRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.TextCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text not provided for text completion request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleRequest(ctx, req, TextCompletionRequest) } // ChatCompletionRequest sends a chat completion request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, ChatCompletionRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(&fallbackReq, ctx, ChatCompletionRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.ChatCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "chats not provided for chat completion request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleRequest(ctx, req, ChatCompletionRequest) } // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, ChatCompletionStreamRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryStreamRequest(&fallbackReq, ctx, ChatCompletionStreamRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.ChatCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "chats not provided for chat completion request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleStreamRequest(ctx, req, ChatCompletionStreamRequest) } // EmbeddingRequest sends an embedding request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - if req.Input.EmbeddingInput == nil { - return nil, newBifrostErrorFromMsg("embedding_input is required") - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, EmbeddingRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(&fallbackReq, ctx, EmbeddingRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "embedding input not provided for embedding request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleRequest(ctx, req, EmbeddingRequest) } // SpeechRequest sends a speech request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, SpeechRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(&fallbackReq, ctx, SpeechRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.SpeechInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "speech input not provided for speech request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleRequest(ctx, req, SpeechRequest) } // SpeechStreamRequest sends a speech stream request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, SpeechStreamRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryStreamRequest(&fallbackReq, ctx, SpeechStreamRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.SpeechInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "speech input not provided for speech stream request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleStreamRequest(ctx, req, SpeechStreamRequest) } // TranscriptionRequest sends a transcription request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, TranscriptionRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(&fallbackReq, ctx, TranscriptionRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.TranscriptionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "transcription input not provided for transcription request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleRequest(ctx, req, TranscriptionRequest) } // TranscriptionStreamRequest sends a transcription stream request to the specified provider. -// It handles plugin hooks, request validation, response processing, and fallback providers. -// If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := validateRequest(req); err != nil { - err.Provider = req.Provider - return nil, err - } - - // Try the primary provider first - primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, TranscriptionStreamRequest) - if primaryErr == nil { - return primaryResult, nil - } - - if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // Check if this is a short-circuit error that doesn't allow fallbacks - // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) - if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider - return nil, primaryErr - } - - // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil - if len(req.Fallbacks) > 0 { - for _, fallback := range req.Fallbacks { - // Check if we have config for this fallback provider - _, err := bifrost.account.GetConfigForProvider(fallback.Provider) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) - continue - } - - // Create a new request with the fallback provider and model - fallbackReq := *req - fallbackReq.Provider = fallback.Provider - fallbackReq.Model = fallback.Model - - // Try the fallback provider - result, fallbackErr := bifrost.tryStreamRequest(&fallbackReq, ctx, TranscriptionStreamRequest) - if fallbackErr == nil { - bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) - return result, nil - } - if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider - return nil, fallbackErr - } - - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + if req.Input.TranscriptionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "transcription input not provided for transcription stream request", + }, } } - primaryErr.Provider = req.Provider - - // All providers failed, return the original error - return nil, primaryErr + return bifrost.handleStreamRequest(ctx, req, TranscriptionStreamRequest) } // UpdateProviderConcurrency dynamically updates the queue size and concurrency for an existing provider. @@ -741,6 +347,7 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi // Step 2: Transfer any buffered requests from old queue to new queue // This prevents request loss during the transition transferredCount := 0 + var transferWaitGroup sync.WaitGroup for { select { case msg := <-oldQueue: @@ -748,19 +355,27 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi case newQueue <- msg: transferredCount++ default: - // New queue is full, put message back and break + // New queue is full, handle this request in a goroutine // This is unlikely with proper buffer sizing but provides safety + transferWaitGroup.Add(1) go func(m ChannelMessage) { + defer transferWaitGroup.Done() select { case newQueue <- m: + // Message successfully transferred case <-time.After(5 * time.Second): bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout") // Send error response to avoid hanging the client - m.Err <- schemas.BifrostError{ + select { + case m.Err <- schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ Message: "request failed during provider concurrency update", }, + }: + case <-time.After(1 * time.Second): + // If we can't send the error either, just log and continue + bifrost.logger.Warn("Failed to send error response during transfer timeout") } } }(msg) @@ -773,6 +388,8 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi } transferComplete: + // Wait for all transfer goroutines to complete + transferWaitGroup.Wait() if transferredCount > 0 { bifrost.logger.Info(fmt.Sprintf("Transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey)) } @@ -1144,6 +761,159 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha // CORE INTERNAL LOGIC +// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately +func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool { + // If no primary error, we succeeded + if primaryErr == nil { + return false + } + + // Handle request cancellation + if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + primaryErr.Provider = req.Provider + return false + } + + // Check if this is a short-circuit error that doesn't allow fallbacks + // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) + if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + primaryErr.Provider = req.Provider + return false + } + + // If no fallbacks configured, return primary error + if len(req.Fallbacks) == 0 { + primaryErr.Provider = req.Provider + return false + } + + // Should proceed with fallbacks + return true +} + +// prepareFallbackRequest creates a fallback request and validates the provider config +// Returns the fallback request or nil if this fallback should be skipped +func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fallback schemas.Fallback) *schemas.BifrostRequest { + // Check if we have config for this fallback provider + _, err := bifrost.account.GetConfigForProvider(fallback.Provider) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) + return nil + } + + // Create a new request with the fallback provider and model + fallbackReq := *req + fallbackReq.Provider = fallback.Provider + fallbackReq.Model = fallback.Model + return &fallbackReq +} + +// shouldContinueWithFallbacks processes errors from fallback attempts +// Returns true if we should continue with more fallbacks, false if we should stop +func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, fallbackErr *schemas.BifrostError) bool { + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + fallbackErr.Provider = fallback.Provider + return false + } + + // Check if it was a short-circuit error that doesn't allow fallbacks + if fallbackErr.AllowFallbacks != nil && !*fallbackErr.AllowFallbacks { + fallbackErr.Provider = fallback.Provider + return false + } + + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) + return true +} + +// handleRequest handles the request to the provider based on the request type +// It handles plugin hooks, request validation, response processing, and fallback providers. +// If the primary provider fails, it will try each fallback provider in order until one succeeds. +// It is the wrapper for all non-streaming public API methods. +func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest, requestType RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := validateRequest(req); err != nil { + err.Provider = req.Provider + return nil, err + } + + // Try the primary provider first + primaryResult, primaryErr := bifrost.tryRequest(req, ctx, requestType) + + // Check if we should proceed with fallbacks + shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) + if !shouldTryFallbacks { + return primaryResult, primaryErr + } + + // Try fallbacks in order + for _, fallback := range req.Fallbacks { + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) + if fallbackReq == nil { + continue + } + + // Try the fallback provider + result, fallbackErr := bifrost.tryRequest(fallbackReq, ctx, requestType) + if fallbackErr == nil { + bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + return result, nil + } + + // Check if we should continue with more fallbacks + if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { + return nil, fallbackErr + } + } + + primaryErr.Provider = req.Provider + // All providers failed, return the original error + return nil, primaryErr +} + +// handleStreamRequest handles the stream request to the provider based on the request type +// It handles plugin hooks, request validation, response processing, and fallback providers. +// If the primary provider fails, it will try each fallback provider in order until one succeeds. +// It is the wrapper for all streaming public API methods. +func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest, requestType RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := validateRequest(req); err != nil { + err.Provider = req.Provider + return nil, err + } + + // Try the primary provider first + primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, requestType) + + // Check if we should proceed with fallbacks + shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) + if !shouldTryFallbacks { + return primaryResult, primaryErr + } + + // Try fallbacks in order + for _, fallback := range req.Fallbacks { + fallbackReq := bifrost.prepareFallbackRequest(req, fallback) + if fallbackReq == nil { + continue + } + + // Try the fallback provider + result, fallbackErr := bifrost.tryStreamRequest(fallbackReq, ctx, requestType) + if fallbackErr == nil { + bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + return result, nil + } + + // Check if we should continue with more fallbacks + if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { + return nil, fallbackErr + } + } + + primaryErr.Provider = req.Provider + // All providers failed, return the original error + return nil, primaryErr +} + // tryRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context, requestType RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -1152,6 +922,11 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont return nil, newBifrostError(err) } + // Handle nil context early to prevent blocking + if ctx == nil { + ctx = bifrost.backgroundCtx + } + // Add MCP tools to request if MCP is configured and requested if requestType != EmbeddingRequest && requestType != SpeechRequest && bifrost.mcpManager != nil { req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) @@ -1198,9 +973,6 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - if ctx == nil { - ctx = bifrost.backgroundCtx - } select { case queue <- *msg: // Message was sent successfully @@ -1240,6 +1012,11 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex return nil, newBifrostError(err) } + // Handle nil context early to prevent blocking + if ctx == nil { + ctx = bifrost.backgroundCtx + } + // Add MCP tools to request if MCP is configured and requested if requestType != SpeechStreamRequest && requestType != TranscriptionStreamRequest && bifrost.mcpManager != nil { req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) @@ -1286,9 +1063,6 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - if ctx == nil { - ctx = bifrost.backgroundCtx - } select { case queue <- *msg: // Message was sent successfully @@ -1356,6 +1130,21 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan // Track attempts var attempts int + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks + var postHookRunner schemas.PostHookRunner + if isStreamRequestType(req.Type) { + pipeline := bifrost.getPluginPipeline() + defer bifrost.releasePluginPipeline(pipeline) + + postHookRunner = func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostHooks(ctx, result, err, len(bifrost.plugins)) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + } + // Execute request with retries for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { if attempts > 0 { @@ -1375,17 +1164,6 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan // Attempt the request if isStreamRequestType(req.Type) { - pipeline := bifrost.getPluginPipeline() - defer bifrost.releasePluginPipeline(pipeline) - - postHookRunner := func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - resp, bifrostErr := pipeline.RunPostHooks(ctx, result, err, len(bifrost.plugins)) - if bifrostErr != nil { - return nil, bifrostErr - } - return resp, nil - } - stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner, req.Type) if bifrostError != nil && !bifrostError.IsBifrostError { break // Don't retry client errors @@ -1415,12 +1193,42 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan attempts, map[bool]string{true: "retries", false: "retry"}[attempts > 1])) } - req.Err <- *bifrostError + // Send error with context awareness to prevent deadlock + select { + case req.Err <- *bifrostError: + // Error sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending error response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") + } } else { if isStreamRequestType(req.Type) { - req.ResponseStream <- stream + // Send stream with context awareness to prevent deadlock + select { + case req.ResponseStream <- stream: + // Stream sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending stream response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending stream response, client may have disconnected") + } } else { - req.Response <- result + // Send response with context awareness to prevent deadlock + select { + case req.Response <- result: + // Response sent successfully + case <-req.Context.Done(): + // Client no longer listening, log and continue + bifrost.logger.Debug("Client context cancelled while sending response") + case <-time.After(5 * time.Second): + // Timeout to prevent indefinite blocking + bifrost.logger.Warn("Timeout while sending response, client may have disconnected") + } } } } @@ -1432,54 +1240,14 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, reqType RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { switch reqType { case TextCompletionRequest: - if req.Input.TextCompletionInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text not provided for text completion request", - }, - } - } return provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params) case ChatCompletionRequest: - if req.Input.ChatCompletionInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "chats not provided for chat completion request", - }, - } - } return provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params) case EmbeddingRequest: - if req.Input.EmbeddingInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "input not provided for embedding request", - }, - } - } return provider.Embedding(req.Context, req.Model, key, req.Input.EmbeddingInput, req.Params) case SpeechRequest: - if req.Input.SpeechInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "input not provided for speech request", - }, - } - } return provider.Speech(req.Context, req.Model, key, req.Input.SpeechInput, req.Params) case TranscriptionRequest: - if req.Input.TranscriptionInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "input not provided for transcription request", - }, - } - } return provider.Transcription(req.Context, req.Model, key, req.Input.TranscriptionInput, req.Params) default: return nil, &schemas.BifrostError{ @@ -1495,35 +1263,10 @@ func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key s func handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, reqType RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { switch reqType { case ChatCompletionStreamRequest: - if req.Input.ChatCompletionInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "chats not provided for chat completion request", - }, - } - } - return provider.ChatCompletionStream(req.Context, postHookRunner, req.Model, key, *req.Input.ChatCompletionInput, req.Params) case SpeechStreamRequest: - if req.Input.SpeechInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "input not provided for speech request", - }, - } - } return provider.SpeechStream(req.Context, postHookRunner, req.Model, key, req.Input.SpeechInput, req.Params) case TranscriptionStreamRequest: - if req.Input.TranscriptionInput == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "input not provided for transcription request", - }, - } - } return provider.TranscriptionStream(req.Context, postHookRunner, req.Model, key, req.Input.TranscriptionInput, req.Params) default: return nil, &schemas.BifrostError{ @@ -1639,7 +1382,13 @@ func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType Re // Conditionally allocate ResponseStream for streaming requests only if isStreamRequestType(reqType) { - msg.ResponseStream = make(chan chan *schemas.BifrostStream, 1) + responseStreamChan := bifrost.responseStreamPool.Get().(chan chan *schemas.BifrostStream) + // Clear any previous values to avoid leaking between requests + select { + case <-responseStreamChan: + default: + } + msg.ResponseStream = responseStreamChan } return msg @@ -1651,6 +1400,16 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { bifrost.responseChannelPool.Put(msg.Response) bifrost.errorChannelPool.Put(msg.Err) + // Return ResponseStream to pool if it was used + if msg.ResponseStream != nil { + // Drain any remaining channels to prevent memory leaks + select { + case <-msg.ResponseStream: + default: + } + bifrost.responseStreamPool.Put(msg.ResponseStream) + } + // Clear references and return to pool msg.Response = nil msg.ResponseStream = nil diff --git a/transports/go.mod b/transports/go.mod index 8781cee299..47ca754c10 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -14,8 +14,6 @@ require ( google.golang.org/genai v1.4.0 ) -replace github.com/maximhq/bifrost/core => ../core - require ( cloud.google.com/go v0.121.0 // indirect cloud.google.com/go/auth v0.16.0 // indirect