diff --git a/core/bifrost.go b/core/bifrost.go index 9d96a04eb7..096c8c94dc 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -53,6 +53,7 @@ type Bifrost struct { plugins []schemas.Plugin // list of plugins requestQueues sync.Map // provider request queues (thread-safe) waitGroups sync.Map // wait groups for each provider (thread-safe) + providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe) channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init @@ -170,7 +171,13 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { continue } - if err := bifrost.prepareProvider(providerKey, config); err != nil { + // Lock the provider mutex during initialization + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.Lock() + err = bifrost.prepareProvider(providerKey, config) + providerMutex.Unlock() + + if err != nil { bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err)) } } @@ -361,6 +368,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro // // Note: This operation will temporarily pause request processing for the specified provider // while the transition occurs. In-flight requests will complete before workers are stopped. +// Buffered requests in the old queue will be transferred to the new queue to prevent loss. func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvider) error { bifrost.logger.Info(fmt.Sprintf("Updating concurrency configuration for provider %s", providerKey)) @@ -370,14 +378,21 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi return fmt.Errorf("failed to get updated config for provider %s: %v", providerKey, err) } + // Lock the provider to prevent concurrent access during update + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.Lock() + defer providerMutex.Unlock() + // Check if provider currently exists - oldQueue, exists := bifrost.requestQueues.Load(providerKey) + oldQueueValue, exists := bifrost.requestQueues.Load(providerKey) if !exists { bifrost.logger.Debug(fmt.Sprintf("Provider %s not currently active, initializing with new configuration", providerKey)) // If provider doesn't exist, just prepare it with new configuration return bifrost.prepareProvider(providerKey, providerConfig) } + oldQueue := oldQueueValue.(chan ChannelMessage) + // Check if the provider has any keys (skip keyless providers) if providerRequiresKey(providerKey) { keys, err := bifrost.account.GetKeysForProvider(providerKey) @@ -388,30 +403,71 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi bifrost.logger.Debug(fmt.Sprintf("Gracefully stopping existing workers for provider %s", providerKey)) - // Step 1: Close the existing queue to signal workers to stop processing new requests - close(oldQueue.(chan ChannelMessage)) + // Step 1: Create new queue with updated buffer size + newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) + + // Step 2: Transfer any buffered requests from old queue to new queue + // This prevents request loss during the transition + transferredCount := 0 + for { + select { + case msg := <-oldQueue: + select { + case newQueue <- msg: + transferredCount++ + default: + // New queue is full, put message back and break + // This is unlikely with proper buffer sizing but provides safety + go func(m ChannelMessage) { + select { + case newQueue <- m: + case <-time.After(5 * time.Second): + bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout") + // Send error response to avoid hanging the client + m.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "request failed during provider concurrency update", + }, + } + } + }(msg) + goto transferComplete + } + default: + // No more buffered messages + goto transferComplete + } + } + +transferComplete: + if transferredCount > 0 { + bifrost.logger.Info(fmt.Sprintf("Transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey)) + } + + // Step 3: Close the old queue to signal workers to stop + close(oldQueue) + + // Step 4: Atomically replace the queue + bifrost.requestQueues.Store(providerKey, newQueue) - // Step 2: Wait for all existing workers to finish processing in-flight requests + // Step 5: Wait for all existing workers to finish processing in-flight requests waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug(fmt.Sprintf("All workers for provider %s have stopped", providerKey)) } - // Step 3: Create new queue with updated buffer size - newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) - bifrost.requestQueues.Store(providerKey, newQueue) - - // Step 4: Create new wait group for the updated workers + // Step 6: Create new wait group for the updated workers bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) - // Step 5: Create provider instance + // Step 7: Create provider instance provider, err := bifrost.createProviderFromProviderKey(providerKey, providerConfig) if err != nil { return fmt.Errorf("failed to create provider instance for %s: %v", providerKey, err) } - // Step 6: Start new workers with updated concurrency + // Step 8: Start new workers with updated concurrency bifrost.logger.Debug(fmt.Sprintf("Starting %d new workers for provider %s with buffer size %d", providerConfig.ConcurrencyAndBufferSize.Concurrency, providerKey, @@ -440,6 +496,12 @@ func (bifrost *Bifrost) UpdateDropExcessRequests(value bool) { bifrost.logger.Info(fmt.Sprintf("DropExcessRequests updated to: %v", value)) } +// getProviderMutex gets or creates a mutex for the given provider +func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *sync.RWMutex { + mutexValue, _ := bifrost.providerMutexes.LoadOrStore(providerKey, &sync.RWMutex{}) + return mutexValue.(*sync.RWMutex) +} + // MCP PUBLIC API // RegisterMCPTool registers a typed tool handler with the MCP integration. @@ -671,6 +733,7 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP // prepareProvider sets up a provider with its configuration, keys, and worker channels. // It initializes the request queue and starts worker goroutines for processing requests. +// Note: This function assumes the caller has already acquired the appropriate mutex for the provider. func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error { providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { @@ -710,27 +773,44 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi // getProviderQueue returns the request queue for a given provider key. // If the queue doesn't exist, it creates one at runtime and initializes the provider, // given the provider config is provided in the account interface implementation. +// This function uses read locks to prevent race conditions during provider updates. func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) { - var queue chan ChannelMessage + // Use read lock to allow concurrent reads but prevent concurrent updates + providerMutex := bifrost.getProviderMutex(providerKey) + providerMutex.RLock() - if queueValue, exists := bifrost.requestQueues.Load(providerKey); !exists { - bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey)) + if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists { + queue := queueValue.(chan ChannelMessage) + providerMutex.RUnlock() + return queue, nil + } - config, err := bifrost.account.GetConfigForProvider(providerKey) - if err != nil { - return nil, fmt.Errorf("failed to get config for provider: %v", err) - } + // Provider doesn't exist, need to create it + // Upgrade to write lock for creation + providerMutex.RUnlock() + providerMutex.Lock() + defer providerMutex.Unlock() - if err := bifrost.prepareProvider(providerKey, config); err != nil { - return nil, err - } + // Double-check after acquiring write lock (another goroutine might have created it) + if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists { + queue := queueValue.(chan ChannelMessage) + return queue, nil + } + + bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey)) + + config, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return nil, fmt.Errorf("failed to get config for provider: %v", err) + } - queueValue, _ = bifrost.requestQueues.Load(providerKey) - queue = queueValue.(chan ChannelMessage) - } else { - queue = queueValue.(chan ChannelMessage) + if err := bifrost.prepareProvider(providerKey, config); err != nil { + return nil, err } + queueValue, _ := bifrost.requestQueues.Load(providerKey) + queue := queueValue.(chan ChannelMessage) + return queue, nil } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 2e4a143821..8589382d7e 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -69,13 +69,9 @@ func (h *ConfigHandler) handleUpdateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.PrometheusLabels = req.PrometheusLabels } - if req.InitialPoolSize != currentConfig.InitialPoolSize { - updatedConfig.InitialPoolSize = req.InitialPoolSize - } + updatedConfig.InitialPoolSize = req.InitialPoolSize - if req.LogQueueSize != currentConfig.LogQueueSize { - updatedConfig.LogQueueSize = req.LogQueueSize - } + updatedConfig.EnableLogging = req.EnableLogging // Update the store with the new config h.store.ClientConfig = updatedConfig diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 31dd113ba5..81cd7553b4 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -12,7 +12,7 @@ type ClientConfig struct { DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics - LogQueueSize int `json:"log_queue_size"` // The size of the log queue, additional requests will be dropped (not saved for ui) if the queue is full + EnableLogging bool `json:"enable_logging"` // Enable logging of requests and responses } // ProviderConfig represents the configuration for a specific AI model provider. diff --git a/transports/bifrost-http/lib/store.go b/transports/bifrost-http/lib/store.go index 4d2b7af6de..aaae7d77d7 100644 --- a/transports/bifrost-http/lib/store.go +++ b/transports/bifrost-http/lib/store.go @@ -53,7 +53,7 @@ var DefaultClientConfig = ClientConfig{ DropExcessRequests: false, PrometheusLabels: []string{}, InitialPoolSize: 300, - LogQueueSize: 1000, + EnableLogging: true, } // NewConfigStore creates a new in-memory configuration store instance. @@ -251,7 +251,7 @@ func (s *ConfigStore) writeConfigToFile(configPath string) error { s.mu.RLock() defer s.mu.RUnlock() - s.logger.Info(fmt.Sprintf("Writing current configuration to: %s", configPath)) + s.logger.Debug(fmt.Sprintf("Writing current configuration to: %s", configPath)) // Create a map for quick lookup of env vars by provider and path envVarsByPath := make(map[string]string) @@ -325,7 +325,7 @@ func (s *ConfigStore) writeConfigToFile(configPath string) error { return fmt.Errorf("failed to write config file: %w", err) } - s.logger.Info(fmt.Sprintf("Successfully wrote configuration to: %s", configPath)) + s.logger.Debug(fmt.Sprintf("Successfully wrote configuration to: %s", configPath)) return nil } diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go index da91b4d23d..d2703b55e6 100644 --- a/transports/bifrost-http/main.go +++ b/transports/bifrost-http/main.go @@ -293,23 +293,27 @@ func main() { promPlugin := telemetry.NewPrometheusPlugin() - // Initialize logging plugin with app-dir based path - loggingConfig := &logging.Config{ - DatabasePath: logDir, - LogQueueSize: store.ClientConfig.LogQueueSize, - MaxCacheMemoryMB: 5, - } + var loggingPlugin *logging.LoggerPlugin + var loggingHandler *handlers.LoggingHandler + var wsHandler *handlers.WebSocketHandler + + if store.ClientConfig.EnableLogging { + // Initialize logging plugin with app-dir based path + loggingConfig := &logging.Config{ + DatabasePath: logDir, + } - loggingPlugin, err := logging.NewLoggerPlugin(loggingConfig, logger) - if err != nil { - log.Fatalf("failed to initialize logging plugin: %v", err) - } + var err error + loggingPlugin, err = logging.NewLoggerPlugin(loggingConfig, logger) + if err != nil { + log.Fatalf("failed to initialize logging plugin: %v", err) + } - if err != nil { - log.Fatalf("failed to initialize mocker plugin: %v", err) - } + loadedPlugins = append(loadedPlugins, promPlugin, loggingPlugin) - loadedPlugins = append(loadedPlugins, promPlugin, loggingPlugin) + loggingHandler = handlers.NewLoggingHandler(loggingPlugin.GetPluginLogManager(), logger) + wsHandler = handlers.NewWebSocketHandler(loggingPlugin.GetPluginLogManager(), logger) + } client, err := bifrost.Init(schemas.BifrostConfig{ Account: account, @@ -331,14 +335,14 @@ func main() { mcpHandler := handlers.NewMCPHandler(client, logger, store) integrationHandler := handlers.NewIntegrationHandler(client) configHandler := handlers.NewConfigHandler(client, logger, store, configPath) - loggingHandler := handlers.NewLoggingHandler(loggingPlugin.GetPluginLogManager(), logger) - wsHandler := handlers.NewWebSocketHandler(loggingPlugin.GetPluginLogManager(), logger) // Set up WebSocket callback for real-time log updates - loggingPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate) + if wsHandler != nil && loggingPlugin != nil { + loggingPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate) - // Start WebSocket heartbeat - wsHandler.StartHeartbeat() + // Start WebSocket heartbeat + wsHandler.StartHeartbeat() + } r := router.New() @@ -348,8 +352,12 @@ func main() { mcpHandler.RegisterRoutes(r) integrationHandler.RegisterRoutes(r) configHandler.RegisterRoutes(r) - loggingHandler.RegisterRoutes(r) - wsHandler.RegisterRoutes(r) + if loggingHandler != nil { + loggingHandler.RegisterRoutes(r) + } + if wsHandler != nil { + wsHandler.RegisterRoutes(r) + } // Add Prometheus /metrics endpoint r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler())) @@ -370,6 +378,9 @@ func main() { log.Fatalf("Error starting server: %v", err) } - wsHandler.Stop() + if wsHandler != nil { + wsHandler.Stop() + } + client.Cleanup() } diff --git a/transports/bifrost-http/ui/404.html b/transports/bifrost-http/ui/404.html index e06483eca3..fb4e4626b4 100644 --- a/transports/bifrost-http/ui/404.html +++ b/transports/bifrost-http/ui/404.html @@ -1,11 +1,11 @@ -