From 8ef9e3681a1e0bcce9caa312cfda16b4ca63fb29 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Thu, 10 Apr 2025 16:30:16 +0530 Subject: [PATCH] feat: bifrost config added --- bifrost.go | 20 ++++++++++++-------- interfaces/bifrost.go | 7 +++++++ providers/openai.go | 2 ++ tests/setup.go | 9 +++++++-- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/bifrost.go b/bifrost.go index 56a129c342..2d3cd61d5a 100644 --- a/bifrost.go +++ b/bifrost.go @@ -89,11 +89,15 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro return nil } -// Initializes infinite listening channels for each provider -func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interfaces.Logger) (*Bifrost, error) { +// Init initializes a new Bifrost instance with the given account +func Init(config interfaces.BifrostConfig) (*Bifrost, error) { + if config.Account == nil { + return nil, fmt.Errorf("account is required to initialize Bifrost") + } + bifrost := &Bifrost{ - account: account, - plugins: plugins, + account: config.Account, + plugins: config.Plugins, waitGroups: make(map[interfaces.SupportedModelProvider]*sync.WaitGroup), requestQueues: make(map[interfaces.SupportedModelProvider]chan ChannelMessage), } @@ -116,7 +120,7 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf } // Prewarm pools with multiple objects - for range 2500 { + for range config.InitialPoolSize { // Create and put new objects directly into pools bifrost.channelMessagePool.Put(&ChannelMessage{}) bifrost.responseChannelPool.Put(make(chan *interfaces.BifrostResponse, 1)) @@ -128,10 +132,10 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf return nil, err } - if logger == nil { - logger = NewDefaultLogger(interfaces.LogLevelInfo) + if config.Logger == nil { + config.Logger = NewDefaultLogger(interfaces.LogLevelInfo) } - bifrost.logger = logger + bifrost.logger = config.Logger // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { diff --git a/interfaces/bifrost.go b/interfaces/bifrost.go index f298bf32ca..c2e1733e1e 100644 --- a/interfaces/bifrost.go +++ b/interfaces/bifrost.go @@ -1,5 +1,12 @@ package interfaces +type BifrostConfig struct { + Account Account + Plugins []Plugin + Logger Logger + InitialPoolSize int +} + // ModelChatMessageRole represents the role of a chat message type ModelChatMessageRole string diff --git a/providers/openai.go b/providers/openai.go index c9550a1ac0..2d17377f0f 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -299,5 +299,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int RawResponse: rawResponse, } + ReleaseBifrostResponse(result) + return result, nil } diff --git a/tests/setup.go b/tests/setup.go index f7eb058fc6..06e0ee8b28 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -45,10 +45,15 @@ func getBifrost() (*bifrost.Bifrost, error) { return nil, err } - bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, nil) + // Initialize Bifrost + b, err := bifrost.Init(interfaces.BifrostConfig{ + Account: &account, + Plugins: []interfaces.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(interfaces.LogLevelInfo), + }) if err != nil { return nil, err } - return bifrost, nil + return b, nil }