Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,9 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
continue
}

// Create a new request with the fallback model
// Create a new request with the fallback provider and model
fallbackReq := *req
fallbackReq.Provider = fallback.Provider
fallbackReq.Model = fallback.Model

// Try the fallback provider
Expand Down Expand Up @@ -690,12 +691,13 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
// Check if we have config for this fallback provider
_, err := bifrost.account.GetConfigForProvider(fallback.Provider)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Skipping fallback provider %s: %v", fallback.Provider, err))
bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err))
continue
}

// Create a new request with the fallback model
// Create a new request with the fallback provider and model
fallbackReq := *req
fallbackReq.Provider = fallback.Provider
fallbackReq.Model = fallback.Model

// Try the fallback provider
Expand All @@ -704,7 +706,11 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}
bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %v", fallback.Provider, fallbackErr.Error.Message))
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
return nil, fallbackErr
}

bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
}
}

Expand Down
61 changes: 34 additions & 27 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
for range config.ConcurrencyAndBufferSize.Concurrency {
anthropicTextResponsePool.Put(&AnthropicTextResponse{})
anthropicChatResponsePool.Put(&AnthropicChatResponse{})
bifrostResponsePool.Put(&schemas.BifrostResponse{})

}

// Configure proxy if provided
Expand Down Expand Up @@ -261,36 +261,39 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke
response := acquireAnthropicTextResponse()
defer releaseAnthropicTextResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
}

bifrostResponse.ID = response.ID
bifrostResponse.Choices = []schemas.BifrostResponseChoice{
{
Index: 0,
Message: schemas.BifrostMessage{
Role: schemas.ModelChatMessageRoleAssistant,
Content: schemas.MessageContent{
ContentStr: &response.Completion,
// Create final response
bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
Message: schemas.BifrostMessage{
Role: schemas.ModelChatMessageRoleAssistant,
Content: schemas.MessageContent{
ContentStr: &response.Completion,
},
},
},
},
Usage: schemas.LLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
Model: response.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,
},
}
bifrostResponse.Usage = schemas.LLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
}
bifrostResponse.Model = response.Model
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Anthropic,
RawResponse: rawResponse,

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
Expand All @@ -317,15 +320,13 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke
response := acquireAnthropicChatResponse()
defer releaseAnthropicChatResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
}

// Create final response
bifrostResponse := &schemas.BifrostResponse{}
bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse)
if err != nil {
return nil, err
Expand All @@ -336,6 +337,10 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke
RawResponse: rawResponse,
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
}

Expand Down Expand Up @@ -521,7 +526,9 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche
// Transform tool choice if present
if params != nil && params.ToolChoice != nil {
if params.ToolChoice.ToolChoiceStr != nil {
preparedParams["tool_choice"] = *params.ToolChoice.ToolChoiceStr
preparedParams["tool_choice"] = map[string]interface{}{
"type": *params.ToolChoice.ToolChoiceStr,
}
} else if params.ToolChoice.ToolChoiceStruct != nil {
switch toolChoice := params.ToolChoice.ToolChoiceStruct.Type; toolChoice {
case schemas.ToolChoiceTypeFunction:
Expand Down
65 changes: 34 additions & 31 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
for range config.ConcurrencyAndBufferSize.Concurrency {
azureChatResponsePool.Put(&AzureChatResponse{})
azureTextCompletionResponsePool.Put(&AzureTextResponse{})
bifrostResponsePool.Put(&schemas.BifrostResponse{})

}

// Configure proxy if provided
Expand Down Expand Up @@ -256,10 +256,6 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t
response := acquireAzureTextResponse()
defer releaseAzureTextResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
Expand All @@ -269,15 +265,12 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t

// Create the completion result
if len(response.Choices) > 0 {
// Create a copy of the text to avoid dangling pointer to pooled object
textCopy := response.Choices[0].Text

choices = append(choices, schemas.BifrostResponseChoice{
Index: 0,
Message: schemas.BifrostMessage{
Role: schemas.ModelChatMessageRoleAssistant,
Content: schemas.MessageContent{
ContentStr: &textCopy,
ContentStr: &response.Choices[0].Text,
},
},
FinishReason: response.Choices[0].FinishReason,
Expand All @@ -287,15 +280,22 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t
})
}

bifrostResponse.ID = response.ID
bifrostResponse.Choices = choices
bifrostResponse.Model = response.Model
bifrostResponse.Created = response.Created
bifrostResponse.SystemFingerprint = response.SystemFingerprint
bifrostResponse.Usage = response.Usage
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
// Create final response
bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Choices: choices,
Model: response.Model,
Created: response.Created,
SystemFingerprint: response.SystemFingerprint,
Usage: response.Usage,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
},
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
Expand All @@ -322,24 +322,27 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key st
response := acquireAzureChatResponse()
defer releaseAzureChatResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
}

bifrostResponse.ID = response.ID
bifrostResponse.Choices = response.Choices
bifrostResponse.Model = response.Model
bifrostResponse.Created = response.Created
bifrostResponse.SystemFingerprint = response.SystemFingerprint
bifrostResponse.Usage = response.Usage
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
// Create final response
bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Choices: response.Choices,
Model: response.Model,
Created: response.Created,
SystemFingerprint: response.SystemFingerprint,
Usage: response.Usage,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: rawResponse,
},
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
Expand Down
55 changes: 33 additions & 22 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (
// Pre-warm response pools
for range config.ConcurrencyAndBufferSize.Concurrency {
bedrockChatResponsePool.Put(&BedrockChatResponse{})
bifrostResponsePool.Put(&schemas.BifrostResponse{})

}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

return &BedrockProvider{
Expand Down Expand Up @@ -522,6 +522,12 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema
})
} else if msg.Content.ContentBlocks != nil {
for _, block := range *msg.Content.ContentBlocks {
if block.Text != nil {
content = append(content, BedrockAnthropicTextMessage{
Type: "text",
Text: *block.Text,
})
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if block.ImageURL != nil {
sanitizedURL, _ := SanitizeImageURL(block.ImageURL.URL)
urlTypeInfo := ExtractURLTypeInfo(sanitizedURL)
Expand Down Expand Up @@ -781,7 +787,7 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key,
return nil, err
}

result, err := provider.getTextCompletionResult(body, model)
bifrostResponse, err := provider.getTextCompletionResult(body, model)
if err != nil {
return nil, err
}
Expand All @@ -798,9 +804,13 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key,
}
}

result.ExtraFields.RawResponse = rawResponse
bifrostResponse.ExtraFields.RawResponse = rawResponse

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return result, nil
return bifrostResponse, nil
}

// extractToolsFromHistory extracts minimal tool definitions from conversation history.
Expand Down Expand Up @@ -907,10 +917,6 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key
response := acquireBedrockChatResponse()
defer releaseBedrockChatResponse(response)

// Create Bifrost response from pool
bifrostResponse := acquireBifrostResponse()
defer releaseBifrostResponse(bifrostResponse)

rawResponse, bifrostErr := handleProviderResponse(responseBody, response)
if bifrostErr != nil {
return nil, bifrostErr
Expand Down Expand Up @@ -939,13 +945,11 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key
arguments = []byte("{}")
}

idCopy := choice.ToolUse.ToolUseID // copy to avoid unsafe pointer creation
nameCopy := choice.ToolUse.Name // copy to avoid unsafe pointer creation
toolCalls = append(toolCalls, schemas.ToolCall{
Type: StrPtr("function"),
ID: &idCopy,
ID: &choice.ToolUse.ToolUseID,
Function: schemas.FunctionCall{
Name: &nameCopy,
Name: &choice.ToolUse.Name,
Arguments: string(arguments),
},
})
Expand Down Expand Up @@ -979,17 +983,24 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key

latency := float64(response.Metrics.Latency)

bifrostResponse.Choices = choices
bifrostResponse.Usage = schemas.LLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
// Create final response
bifrostResponse := &schemas.BifrostResponse{
Choices: choices,
Usage: schemas.LLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
},
Model: model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: &latency,
Provider: schemas.Bedrock,
RawResponse: rawResponse,
},
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
bifrostResponse.Model = model
bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{
Latency: &latency,
Provider: schemas.Bedrock,
RawResponse: rawResponse,

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
Expand Down
Loading