Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ func (bifrost *Bifrost) prepareProvider(providerKey interfaces.SupportedModelPro
return nil
}

// Initializes infinite listening channels for each provider
func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interfaces.Logger) (*Bifrost, error) {
// Init initializes a new Bifrost instance with the given account
func Init(config interfaces.BifrostConfig) (*Bifrost, error) {
if config.Account == nil {
return nil, fmt.Errorf("account is required to initialize Bifrost")
}

bifrost := &Bifrost{
account: account,
plugins: plugins,
account: config.Account,
plugins: config.Plugins,
waitGroups: make(map[interfaces.SupportedModelProvider]*sync.WaitGroup),
requestQueues: make(map[interfaces.SupportedModelProvider]chan ChannelMessage),
}
Expand All @@ -116,7 +120,7 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf
}

// Prewarm pools with multiple objects
for range 2500 {
for range config.InitialPoolSize {
// Create and put new objects directly into pools
bifrost.channelMessagePool.Put(&ChannelMessage{})
bifrost.responseChannelPool.Put(make(chan *interfaces.BifrostResponse, 1))
Expand All @@ -128,10 +132,10 @@ func Init(account interfaces.Account, plugins []interfaces.Plugin, logger interf
return nil, err
}

if logger == nil {
logger = NewDefaultLogger(interfaces.LogLevelInfo)
if config.Logger == nil {
config.Logger = NewDefaultLogger(interfaces.LogLevelInfo)
}
bifrost.logger = logger
bifrost.logger = config.Logger

// Create buffered channels for each provider and start workers
for _, providerKey := range providerKeys {
Expand Down
7 changes: 7 additions & 0 deletions interfaces/bifrost.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package interfaces

type BifrostConfig struct {
Account Account
Plugins []Plugin
Logger Logger
InitialPoolSize int
}

// ModelChatMessageRole represents the role of a chat message
type ModelChatMessageRole string

Expand Down
2 changes: 2 additions & 0 deletions providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,5 +299,7 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
RawResponse: rawResponse,
}

ReleaseBifrostResponse(result)

return result, nil
}
9 changes: 7 additions & 2 deletions tests/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ func getBifrost() (*bifrost.Bifrost, error) {
return nil, err
}

bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, nil)
// Initialize Bifrost
b, err := bifrost.Init(interfaces.BifrostConfig{
Account: &account,
Plugins: []interfaces.Plugin{plugin},
Logger: bifrost.NewDefaultLogger(interfaces.LogLevelInfo),
})
if err != nil {
return nil, err
}

return bifrost, nil
return b, nil
}