Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.env
.vscode
.DS_Store
*_creds*
44 changes: 26 additions & 18 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP
case schemas.Anthropic:
return providers.NewAnthropicProvider(config, bifrost.logger), nil
case schemas.Bedrock:
return providers.NewBedrockProvider(config, bifrost.logger), nil
return providers.NewBedrockProvider(config, bifrost.logger)
case schemas.Cohere:
return providers.NewCohereProvider(config, bifrost.logger), nil
case schemas.Azure:
return providers.NewAzureProvider(config, bifrost.logger), nil
return providers.NewAzureProvider(config, bifrost.logger)
case schemas.Vertex:
return providers.NewVertexProvider(config, bifrost.logger)
default:
return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
Expand All @@ -78,10 +80,12 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
return fmt.Errorf("failed to get config for provider: %v", err)
}

// Check if the provider has any keys
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil || len(keys) == 0 {
return fmt.Errorf("failed to get keys for provider: %v", err)
// Check if the provider has any keys (skip vertex)
if providerKey != schemas.Vertex {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil || len(keys) == 0 {
return fmt.Errorf("failed to get keys for provider: %v", err)
}
}

queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider
Expand All @@ -93,7 +97,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi

provider, err := bifrost.createProviderFromProviderKey(providerKey, config)
if err != nil {
return fmt.Errorf("failed to get provider for the given key: %v", err)
return fmt.Errorf("failed to create provider for the given key: %v", err)
}

for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
Expand Down Expand Up @@ -166,7 +170,7 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) {
}

if err := bifrost.prepareProvider(providerKey, config); err != nil {
bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider: %v", err))
bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err))
}
}

Expand Down Expand Up @@ -291,18 +295,22 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
for req := range queue {
var result *schemas.BifrostResponse
var bifrostError *schemas.BifrostError
var err error

key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err))
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
Error: err,
},
key := ""
if provider.GetProviderKey() != schemas.Vertex {
key, err = bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err))
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: err.Error(),
Error: err,
},
}
continue
}
continue
}

config, err := bifrost.account.GetConfigForProvider(provider.GetProviderKey())
Expand Down
2 changes: 2 additions & 0 deletions core/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ require (
github.com/goccy/go-json v0.10.5
github.com/maximhq/bifrost/plugins v1.0.0
github.com/valyala/fasthttp v1.60.0
golang.org/x/oauth2 v0.30.0
)

require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
Expand Down
4 changes: 4 additions & 0 deletions core/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
Expand Down Expand Up @@ -44,5 +46,7 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
76 changes: 47 additions & 29 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func releaseAnthropicTextResponse(resp *AnthropicTextResponse) {
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AnthropicProvider {
setConfigDefaults(config)
config.CheckAndSetDefaults()

client := &fasthttp.Client{
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
Expand Down Expand Up @@ -207,6 +207,8 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug(fmt.Sprintf("error from anthropic provider: %s", string(resp.Body())))

var errorResp AnthropicError

bifrostErr := handleProviderAPIError(resp, &errorResp)
Expand Down Expand Up @@ -280,6 +282,46 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params)

// Merge additional parameters
requestBody := mergeConfig(map[string]interface{}{
"model": model,
"messages": formattedMessages,
}, preparedParams)

responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key)
if err != nil {
return nil, err
}

// Create response object from pool
response := acquireAnthropicChatResponse()
defer releaseAnthropicChatResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
}

bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse)
if err != nil {
return nil, err
}

bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,
}

return bifrostResponse, nil
}

func prepareAnthropicChatRequest(model string, messages []schemas.Message, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) {
// Add system messages if present
var systemMessages []BedrockAnthropicSystemMessage
for _, msg := range messages {
Expand Down Expand Up @@ -352,39 +394,19 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
preparedParams["tools"] = tools
}

// Merge additional parameters
requestBody := mergeConfig(map[string]interface{}{
"model": model,
"messages": formattedMessages,
}, preparedParams)

if len(systemMessages) > 0 {
var messages []string
for _, message := range systemMessages {
messages = append(messages, message.Text)
}

requestBody["system"] = strings.Join(messages, " ")
}

responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key)
if err != nil {
return nil, err
preparedParams["system"] = strings.Join(messages, " ")
}

// Create response object from pool
response := acquireAnthropicChatResponse()
defer releaseAnthropicChatResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
}
return formattedMessages, preparedParams
}

func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Process the response into our BifrostResponse format
var choices []schemas.BifrostResponseChoice

Expand Down Expand Up @@ -437,10 +459,6 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
}
bifrostResponse.Model = response.Model
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,
}

return bifrostResponse, nil
}
12 changes: 9 additions & 3 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ type AzureProvider struct {
// NewAzureProvider creates a new Azure provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AzureProvider {
setConfigDefaults(config)
func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) {
config.CheckAndSetDefaults()

if config.MetaConfig == nil {
return nil, fmt.Errorf("meta config is not set")
}

client := &fasthttp.Client{
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
Expand All @@ -126,7 +130,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *Az
logger: logger,
client: client,
meta: config.MetaConfig,
}
}, nil
}

// GetProviderKey returns the provider identifier for Azure.
Expand Down Expand Up @@ -212,6 +216,8 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body())))

var errorResp AzureError

bifrostErr := handleProviderAPIError(resp, &errorResp)
Expand Down
11 changes: 8 additions & 3 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ func releaseBedrockChatResponse(resp *BedrockChatResponse) {
// NewBedrockProvider creates a new Bedrock provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts and AWS-specific settings.
func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) *BedrockProvider {
setConfigDefaults(config)
func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*BedrockProvider, error) {
config.CheckAndSetDefaults()

if config.MetaConfig == nil {
return nil, fmt.Errorf("meta config is not set")
}

client := &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}

Expand All @@ -174,7 +178,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) *
logger: logger,
client: client,
meta: config.MetaConfig,
}
}, nil
}

// GetProviderKey returns the provider identifier for Bedrock.
Expand Down Expand Up @@ -258,6 +262,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &resp.StatusCode,
Error: schemas.ErrorField{
Message: schemas.ErrProviderResponseUnmarshal,
Error: err,
Expand Down
4 changes: 3 additions & 1 deletion core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type CohereProvider struct {
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts and connection limits.
func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *CohereProvider {
setConfigDefaults(config)
config.CheckAndSetDefaults()

client := &fasthttp.Client{
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
Expand Down Expand Up @@ -234,6 +234,8 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug(fmt.Sprintf("error from cohere provider: %s", string(resp.Body())))

var errorResp CohereError

bifrostErr := handleProviderAPIError(resp, &errorResp)
Expand Down
Loading