Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions bifrost.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Package bifrost provides the core implementation of the Bifrost system.
// Bifrost is a unified interface for interacting with various AI model providers,
// managing concurrent requests, and handling provider-specific configurations.
package bifrost

import (
Expand All @@ -15,33 +18,39 @@ import (
"github.com/maximhq/bifrost/providers"
)

// RequestType represents the type of request being made to a provider.
type RequestType string

const (
TextCompletionRequest RequestType = "text_completion"
ChatCompletionRequest RequestType = "chat_completion"
)

// ChannelMessage represents a message passed through the request channel.
// It contains the request, response and error channels, and the request type.
type ChannelMessage struct {
interfaces.BifrostRequest
Response chan *interfaces.BifrostResponse
Err chan interfaces.BifrostError
Type RequestType
}

// Bifrost manages providers and maintains infinite open channels
// Bifrost manages providers and maintains sepcified open channels for concurrent processing.
// It handles request routing, provider management, and response processing.
type Bifrost struct {
account interfaces.Account
providers []interfaces.Provider // list of processed providers
plugins []interfaces.Plugin
account interfaces.Account // account interface
providers []interfaces.Provider // list of processed providers
plugins []interfaces.Plugin // list of plugins
requestQueues map[interfaces.SupportedModelProvider]chan ChannelMessage // provider request queues
waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup
channelMessagePool sync.Pool // Pool for ChannelMessage objects
responseChannelPool sync.Pool // Pool for response channels
errorChannelPool sync.Pool // Pool for error channels
logger interfaces.Logger
waitGroups map[interfaces.SupportedModelProvider]*sync.WaitGroup // wait groups for each provider
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
logger interfaces.Logger // logger instance, default logger is used if not provided
}

// createProviderFromProviderKey creates a new provider instance based on the provider key.
// It returns an error if the provider is not supported.
func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) (interfaces.Provider, error) {
switch providerKey {
case interfaces.OpenAI:
Expand All @@ -59,6 +68,8 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.Sup
}
}

// prepareProvider sets up a provider with its configuration, keys, and worker channels.
// It initializes the request queue and starts worker goroutines for processing requests.
func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelProvider, config *interfaces.ProviderConfig) error {
providerConfig, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
Expand Down Expand Up @@ -91,7 +102,10 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro
return nil
}

// Init initializes a new Bifrost instance with the given account
// Init initializes a new Bifrost instance with the given configuration.
// It sets up the account, plugins, object pools, and initializes providers.
// Returns an error if initialization fails.
// Initial Memory Allocations happens here as per the initial pool size.
func Init(config interfaces.BifrostConfig) (*Bifrost, error) {
if config.Account == nil {
return nil, fmt.Errorf("account is required to initialize Bifrost")
Expand Down Expand Up @@ -155,7 +169,8 @@ func Init(config interfaces.BifrostConfig) (*Bifrost, error) {
return bifrost, nil
}

// getChannelMessage gets a ChannelMessage from the pool
// getChannelMessage gets a ChannelMessage from the pool and configures it with the request.
// It also gets response and error channels from their respective pools.
func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType RequestType) *ChannelMessage {
// Get channels from pool
responseChan := bifrost.responseChannelPool.Get().(chan *interfaces.BifrostResponse)
Expand All @@ -181,7 +196,7 @@ func (bifrost *Bifrost) getChannelMessage(req interfaces.BifrostRequest, reqType
return msg
}

// releaseChannelMessage returns a ChannelMessage and its channels to the pool
// releaseChannelMessage returns a ChannelMessage and its channels to their respective pools.
func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
// Put channels back in pools
bifrost.responseChannelPool.Put(msg.Response)
Expand All @@ -193,6 +208,8 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
bifrost.channelMessagePool.Put(msg)
}

// SelectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.SupportedModelProvider, model string) (string, error) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil {
Expand Down Expand Up @@ -242,7 +259,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey interfaces.Sup
return supportedKeys[0].Value, nil
}

// calculateBackoff implements exponential backoff with jitter
// calculateBackoff implements exponential backoff with jitter for retry attempts.
func (bifrost *Bifrost) calculateBackoff(attempt int, config *interfaces.ProviderConfig) time.Duration {
// Calculate an exponential backoff: initial * 2^attempt
backoff := config.NetworkConfig.RetryBackoffInitial * time.Duration(1<<uint(attempt))
Expand All @@ -256,6 +273,8 @@ func (bifrost *Bifrost) calculateBackoff(attempt int, config *interfaces.Provide
return time.Duration(jitter)
}

// processRequests handles incoming requests from the queue for a specific provider.
// It manages retries, error handling, and response processing.
func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan ChannelMessage) {
defer bifrost.waitGroups[provider.GetProviderKey()].Done()

Expand Down Expand Up @@ -357,6 +376,8 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan
bifrost.logger.Debug(fmt.Sprintf("Worker for provider %s exiting...", provider.GetProviderKey()))
}

// GetConfiguredProviderFromProviderKey returns the provider instance for a given provider key.
// Uses the GetProviderKey method of the provider interface to find the provider.
func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.SupportedModelProvider) (interfaces.Provider, error) {
for _, provider := range bifrost.providers {
if provider.GetProviderKey() == key {
Expand All @@ -367,6 +388,9 @@ func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key interfaces.Supp
return nil, fmt.Errorf("no provider found for key: %s", key)
}

// GetProviderQueue returns the request queue for a given provider key.
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
// given the provider config is provided in the account interface implementation.
func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelProvider) (chan ChannelMessage, error) {
var queue chan ChannelMessage
var exists bool
Expand All @@ -387,6 +411,8 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr
return queue, nil
}

// TextCompletionRequest sends a text completion request to the specified provider.
// It handles plugin hooks, request validation, and response processing.
func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) {
if req == nil {
return nil, &interfaces.BifrostError{
Expand Down Expand Up @@ -459,6 +485,8 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
}

// ChatCompletionRequest sends a chat completion request to the specified provider.
// It handles plugin hooks, request validation, and response processing.
func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, *interfaces.BifrostError) {
if req == nil {
return nil, &interfaces.BifrostError{
Expand Down Expand Up @@ -532,7 +560,8 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedMo
return result, nil
}

// Shutdown gracefully stops all workers when triggered
// Shutdown gracefully stops all workers when triggered.
// It closes all request channels and waits for workers to exit.
func (bifrost *Bifrost) Shutdown() {
bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...")

Expand All @@ -547,7 +576,8 @@ func (bifrost *Bifrost) Shutdown() {
}
}

// Cleanup handles SIGINT (Ctrl+C) to exit cleanly
// Cleanup handles SIGINT (Ctrl+C) to exit cleanly.
// It sets up signal handling and calls Shutdown when interrupted.
func (bifrost *Bifrost) Cleanup() {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
Expand Down
21 changes: 17 additions & 4 deletions interfaces/account.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
// Package interfaces defines the core interfaces and types used by the Bifrost system.
package interfaces

// Key represents an API key and its associated configuration for a provider.
// It contains the key value, supported models, and a weight for load balancing.
type Key struct {
Value string `json:"value"`
Models []string `json:"models"`
Weight float64 `json:"weight"`
Value string `json:"value"` // The actual API key value
Models []string `json:"models"` // List of models this key can access
Weight float64 `json:"weight"` // Weight for load balancing between multiple keys
}

// TODO one get config method
// Account defines the interface for managing provider accounts and their configurations.
// It provides methods to access provider-specific settings, API keys, and configurations.
type Account interface {
// GetInitiallyConfiguredProviders returns a list of providers that are configured
// in the account. This is used to determine which providers are available for use.
GetInitiallyConfiguredProviders() ([]SupportedModelProvider, error)

// GetKeysForProvider returns the API keys configured for a specific provider.
// The keys include their values, supported models, and weights for load balancing.
GetKeysForProvider(providerKey SupportedModelProvider) ([]Key, error)

// GetConfigForProvider returns the configuration for a specific provider.
// This includes network settings, authentication details, and other provider-specific
// configurations.
GetConfigForProvider(providerKey SupportedModelProvider) (*ProviderConfig, error)
}
Loading