Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 106 additions & 26 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type Bifrost struct {
plugins []schemas.Plugin // list of plugins
requestQueues sync.Map // provider request queues (thread-safe)
waitGroups sync.Map // wait groups for each provider (thread-safe)
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
Expand Down Expand Up @@ -170,7 +171,13 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) {
continue
}

if err := bifrost.prepareProvider(providerKey, config); err != nil {
// Lock the provider mutex during initialization
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
err = bifrost.prepareProvider(providerKey, config)
providerMutex.Unlock()

if err != nil {
bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err))
}
}
Expand Down Expand Up @@ -361,6 +368,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro
//
// Note: This operation will temporarily pause request processing for the specified provider
// while the transition occurs. In-flight requests will complete before workers are stopped.
// Buffered requests in the old queue will be transferred to the new queue to prevent loss.
func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvider) error {
bifrost.logger.Info(fmt.Sprintf("Updating concurrency configuration for provider %s", providerKey))

Expand All @@ -370,14 +378,21 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi
return fmt.Errorf("failed to get updated config for provider %s: %v", providerKey, err)
}

// Lock the provider to prevent concurrent access during update
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
defer providerMutex.Unlock()

// Check if provider currently exists
oldQueue, exists := bifrost.requestQueues.Load(providerKey)
oldQueueValue, exists := bifrost.requestQueues.Load(providerKey)
if !exists {
bifrost.logger.Debug(fmt.Sprintf("Provider %s not currently active, initializing with new configuration", providerKey))
// If provider doesn't exist, just prepare it with new configuration
return bifrost.prepareProvider(providerKey, providerConfig)
}

oldQueue := oldQueueValue.(chan ChannelMessage)

// Check if the provider has any keys (skip keyless providers)
if providerRequiresKey(providerKey) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
Expand All @@ -388,30 +403,71 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi

bifrost.logger.Debug(fmt.Sprintf("Gracefully stopping existing workers for provider %s", providerKey))

// Step 1: Close the existing queue to signal workers to stop processing new requests
close(oldQueue.(chan ChannelMessage))
// Step 1: Create new queue with updated buffer size
newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)

// Step 2: Transfer any buffered requests from old queue to new queue
// This prevents request loss during the transition
transferredCount := 0
for {
select {
case msg := <-oldQueue:
select {
case newQueue <- msg:
transferredCount++
default:
// New queue is full, put message back and break
// This is unlikely with proper buffer sizing but provides safety
go func(m ChannelMessage) {
select {
case newQueue <- m:
case <-time.After(5 * time.Second):
bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout")
// Send error response to avoid hanging the client
m.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "request failed during provider concurrency update",
},
}
}
}(msg)
goto transferComplete
}
default:
// No more buffered messages
goto transferComplete
}
}

transferComplete:
if transferredCount > 0 {
bifrost.logger.Info(fmt.Sprintf("Transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey))
}

// Step 3: Close the old queue to signal workers to stop
close(oldQueue)

// Step 4: Atomically replace the queue
bifrost.requestQueues.Store(providerKey, newQueue)

// Step 2: Wait for all existing workers to finish processing in-flight requests
// Step 5: Wait for all existing workers to finish processing in-flight requests
waitGroup, exists := bifrost.waitGroups.Load(providerKey)
if exists {
waitGroup.(*sync.WaitGroup).Wait()
bifrost.logger.Debug(fmt.Sprintf("All workers for provider %s have stopped", providerKey))
}

// Step 3: Create new queue with updated buffer size
newQueue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize)
bifrost.requestQueues.Store(providerKey, newQueue)

// Step 4: Create new wait group for the updated workers
// Step 6: Create new wait group for the updated workers
bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{})

// Step 5: Create provider instance
// Step 7: Create provider instance
provider, err := bifrost.createProviderFromProviderKey(providerKey, providerConfig)
if err != nil {
return fmt.Errorf("failed to create provider instance for %s: %v", providerKey, err)
}

// Step 6: Start new workers with updated concurrency
// Step 8: Start new workers with updated concurrency
bifrost.logger.Debug(fmt.Sprintf("Starting %d new workers for provider %s with buffer size %d",
providerConfig.ConcurrencyAndBufferSize.Concurrency,
providerKey,
Expand Down Expand Up @@ -440,6 +496,12 @@ func (bifrost *Bifrost) UpdateDropExcessRequests(value bool) {
bifrost.logger.Info(fmt.Sprintf("DropExcessRequests updated to: %v", value))
}

// getProviderMutex gets or creates a mutex for the given provider
func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *sync.RWMutex {
mutexValue, _ := bifrost.providerMutexes.LoadOrStore(providerKey, &sync.RWMutex{})
return mutexValue.(*sync.RWMutex)
}

// MCP PUBLIC API

// RegisterMCPTool registers a typed tool handler with the MCP integration.
Expand Down Expand Up @@ -671,6 +733,7 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP

// prepareProvider sets up a provider with its configuration, keys, and worker channels.
// It initializes the request queue and starts worker goroutines for processing requests.
// Note: This function assumes the caller has already acquired the appropriate mutex for the provider.
func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error {
providerConfig, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
Expand Down Expand Up @@ -710,27 +773,44 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
// getProviderQueue returns the request queue for a given provider key.
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
// given the provider config is provided in the account interface implementation.
// This function uses read locks to prevent race conditions during provider updates.
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
var queue chan ChannelMessage
// Use read lock to allow concurrent reads but prevent concurrent updates
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.RLock()

if queueValue, exists := bifrost.requestQueues.Load(providerKey); !exists {
bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey))
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
queue := queueValue.(chan ChannelMessage)
providerMutex.RUnlock()
return queue, nil
}

config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
return nil, fmt.Errorf("failed to get config for provider: %v", err)
}
// Provider doesn't exist, need to create it
// Upgrade to write lock for creation
providerMutex.RUnlock()
providerMutex.Lock()
defer providerMutex.Unlock()

if err := bifrost.prepareProvider(providerKey, config); err != nil {
return nil, err
}
// Double-check after acquiring write lock (another goroutine might have created it)
if queueValue, exists := bifrost.requestQueues.Load(providerKey); exists {
queue := queueValue.(chan ChannelMessage)
return queue, nil
}

bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey))

config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
return nil, fmt.Errorf("failed to get config for provider: %v", err)
}

queueValue, _ = bifrost.requestQueues.Load(providerKey)
queue = queueValue.(chan ChannelMessage)
} else {
queue = queueValue.(chan ChannelMessage)
if err := bifrost.prepareProvider(providerKey, config); err != nil {
return nil, err
}

queueValue, _ := bifrost.requestQueues.Load(providerKey)
queue := queueValue.(chan ChannelMessage)

return queue, nil
}

Expand Down
8 changes: 2 additions & 6 deletions transports/bifrost-http/handlers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,9 @@ func (h *ConfigHandler) handleUpdateConfig(ctx *fasthttp.RequestCtx) {
updatedConfig.PrometheusLabels = req.PrometheusLabels
}

if req.InitialPoolSize != currentConfig.InitialPoolSize {
updatedConfig.InitialPoolSize = req.InitialPoolSize
}
updatedConfig.InitialPoolSize = req.InitialPoolSize

if req.LogQueueSize != currentConfig.LogQueueSize {
updatedConfig.LogQueueSize = req.LogQueueSize
}
updatedConfig.EnableLogging = req.EnableLogging

// Update the store with the new config
h.store.ClientConfig = updatedConfig
Expand Down
2 changes: 1 addition & 1 deletion transports/bifrost-http/lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type ClientConfig struct {
DropExcessRequests bool `json:"drop_excess_requests"` // Drop excess requests if the provider queue is full
InitialPoolSize int `json:"initial_pool_size"` // The initial pool size for the bifrost client
PrometheusLabels []string `json:"prometheus_labels"` // The labels to be used for prometheus metrics
LogQueueSize int `json:"log_queue_size"` // The size of the log queue, additional requests will be dropped (not saved for ui) if the queue is full
EnableLogging bool `json:"enable_logging"` // Enable logging of requests and responses
}

// ProviderConfig represents the configuration for a specific AI model provider.
Expand Down
6 changes: 3 additions & 3 deletions transports/bifrost-http/lib/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ var DefaultClientConfig = ClientConfig{
DropExcessRequests: false,
PrometheusLabels: []string{},
InitialPoolSize: 300,
LogQueueSize: 1000,
EnableLogging: true,
}

// NewConfigStore creates a new in-memory configuration store instance.
Expand Down Expand Up @@ -251,7 +251,7 @@ func (s *ConfigStore) writeConfigToFile(configPath string) error {
s.mu.RLock()
defer s.mu.RUnlock()

s.logger.Info(fmt.Sprintf("Writing current configuration to: %s", configPath))
s.logger.Debug(fmt.Sprintf("Writing current configuration to: %s", configPath))

// Create a map for quick lookup of env vars by provider and path
envVarsByPath := make(map[string]string)
Expand Down Expand Up @@ -325,7 +325,7 @@ func (s *ConfigStore) writeConfigToFile(configPath string) error {
return fmt.Errorf("failed to write config file: %w", err)
}

s.logger.Info(fmt.Sprintf("Successfully wrote configuration to: %s", configPath))
s.logger.Debug(fmt.Sprintf("Successfully wrote configuration to: %s", configPath))
return nil
}

Expand Down
55 changes: 33 additions & 22 deletions transports/bifrost-http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,23 +293,27 @@ func main() {

promPlugin := telemetry.NewPrometheusPlugin()

// Initialize logging plugin with app-dir based path
loggingConfig := &logging.Config{
DatabasePath: logDir,
LogQueueSize: store.ClientConfig.LogQueueSize,
MaxCacheMemoryMB: 5,
}
var loggingPlugin *logging.LoggerPlugin
var loggingHandler *handlers.LoggingHandler
var wsHandler *handlers.WebSocketHandler

if store.ClientConfig.EnableLogging {
// Initialize logging plugin with app-dir based path
loggingConfig := &logging.Config{
DatabasePath: logDir,
}

loggingPlugin, err := logging.NewLoggerPlugin(loggingConfig, logger)
if err != nil {
log.Fatalf("failed to initialize logging plugin: %v", err)
}
var err error
loggingPlugin, err = logging.NewLoggerPlugin(loggingConfig, logger)
if err != nil {
log.Fatalf("failed to initialize logging plugin: %v", err)
}

if err != nil {
log.Fatalf("failed to initialize mocker plugin: %v", err)
}
loadedPlugins = append(loadedPlugins, promPlugin, loggingPlugin)

loadedPlugins = append(loadedPlugins, promPlugin, loggingPlugin)
loggingHandler = handlers.NewLoggingHandler(loggingPlugin.GetPluginLogManager(), logger)
wsHandler = handlers.NewWebSocketHandler(loggingPlugin.GetPluginLogManager(), logger)
}

client, err := bifrost.Init(schemas.BifrostConfig{
Account: account,
Expand All @@ -331,14 +335,14 @@ func main() {
mcpHandler := handlers.NewMCPHandler(client, logger, store)
integrationHandler := handlers.NewIntegrationHandler(client)
configHandler := handlers.NewConfigHandler(client, logger, store, configPath)
loggingHandler := handlers.NewLoggingHandler(loggingPlugin.GetPluginLogManager(), logger)
wsHandler := handlers.NewWebSocketHandler(loggingPlugin.GetPluginLogManager(), logger)

// Set up WebSocket callback for real-time log updates
loggingPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate)
if wsHandler != nil && loggingPlugin != nil {
loggingPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate)

// Start WebSocket heartbeat
wsHandler.StartHeartbeat()
// Start WebSocket heartbeat
wsHandler.StartHeartbeat()
}

r := router.New()

Expand All @@ -348,8 +352,12 @@ func main() {
mcpHandler.RegisterRoutes(r)
integrationHandler.RegisterRoutes(r)
configHandler.RegisterRoutes(r)
loggingHandler.RegisterRoutes(r)
wsHandler.RegisterRoutes(r)
if loggingHandler != nil {
loggingHandler.RegisterRoutes(r)
}
if wsHandler != nil {
wsHandler.RegisterRoutes(r)
}

// Add Prometheus /metrics endpoint
r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler()))
Expand All @@ -370,6 +378,9 @@ func main() {
log.Fatalf("Error starting server: %v", err)
}

wsHandler.Stop()
if wsHandler != nil {
wsHandler.Stop()
}

client.Cleanup()
}
8 changes: 4 additions & 4 deletions transports/bifrost-http/ui/404.html

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions transports/bifrost-http/ui/404/index.html

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

This file was deleted.

Loading