Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 57 additions & 34 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package providers

import (
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -279,46 +280,59 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Add system messages if present
var systemMessages []BedrockAnthropicSystemMessage
for _, msg := range messages {
if msg.Role == schemas.RoleSystem {
//TODO handling image inputs here
systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{
Text: *msg.Content,
})
}
}

// Format messages for Anthropic API
var formattedMessages []map[string]interface{}
for _, msg := range messages {
if msg.ImageContent != nil {
var content []map[string]interface{}

imageContent := map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": msg.ImageContent.Type,
},
}

// Handle different image source types
if *msg.ImageContent.Type == "url" {
imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL
if msg.Role != schemas.RoleSystem {
if msg.ImageContent != nil {
var content []map[string]interface{}

imageContent := map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": msg.ImageContent.Type,
},
}

// Handle different image source types
if *msg.ImageContent.Type == "url" {
imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL
} else {
imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType
imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL
}

content = append(content, imageContent)

// Add text content if present
if msg.Content != nil {
content = append(content, map[string]interface{}{
"type": "text",
"text": msg.Content,
})
}

formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
"content": content,
})
} else {
imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType
imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL
}

content = append(content, imageContent)

// Add text content if present
if msg.Content != nil {
content = append(content, map[string]interface{}{
"type": "text",
"text": msg.Content,
formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
})
}

formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
"content": content,
})
} else {
formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
})
}
}

Expand All @@ -344,6 +358,15 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
"messages": formattedMessages,
}, preparedParams)

if len(systemMessages) > 0 {
var messages []string
for _, message := range systemMessages {
messages = append(messages, message.Text)
}

requestBody["system"] = strings.Join(messages, " ")
}

responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key)
if err != nil {
return nil, err
Expand Down
8 changes: 6 additions & 2 deletions core/providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,16 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch

bifrostErr := handleProviderAPIError(resp, &errorResp)

bifrostErr.EventID = &errorResp.EventID
if errorResp.EventID != "" {
bifrostErr.EventID = &errorResp.EventID
}
bifrostErr.Error.Type = &errorResp.Error.Type
bifrostErr.Error.Code = &errorResp.Error.Code
bifrostErr.Error.Message = errorResp.Error.Message
bifrostErr.Error.Param = errorResp.Error.Param
bifrostErr.Error.EventID = &errorResp.Error.EventID
if errorResp.Error.EventID != "" {
bifrostErr.Error.EventID = &errorResp.Error.EventID
}

return nil, bifrostErr
}
Expand Down
23 changes: 12 additions & 11 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ type Fallback struct {
type ModelParameters struct {
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Tools *[]Tool `json:"tools,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output
TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling
TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate
StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens
ParallelToolCalls *bool `json:"parallel_tool_calls"` // Enables parallel tool calls
Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output
TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling
TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate
StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls

// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
Expand All @@ -89,10 +89,11 @@ type ModelParameters struct {

// FunctionParameters represents the parameters for a function definition.
type FunctionParameters struct {
Type string `json:"type,"` // Type of the parameters
Type string `json:"type"` // Type of the parameters
Description *string `json:"description,omitempty"` // Description of the parameters
Required []string `json:"required"` // Required parameter names
Properties map[string]interface{} `json:"properties"` // Parameter properties
Required []string `json:"required,omitempty"` // Required parameter names
Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties
Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters
}

// Function represents a function that can be called by the model.
Expand Down
4 changes: 2 additions & 2 deletions core/tests/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ func TestAnthropic(t *testing.T) {
config := TestConfig{
Provider: schemas.Anthropic,
TextModel: "claude-2.1",
ChatModel: "claude-3-5-sonnet-20240620",
ChatModel: "claude-3-7-sonnet-20250219",
SetupText: true,
SetupToolCalls: false, // available in 3.7 sonnet
SetupToolCalls: true, // available in 3.7 sonnet only
SetupImage: true,
SetupBaseImage: true,
CustomParams: &schemas.ModelParameters{
Expand Down
2 changes: 1 addition & 1 deletion core/tests/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestOpenAI(t *testing.T) {
Provider: schemas.OpenAI,
TextModel: "gpt-4o-mini",
ChatModel: "gpt-4o-mini",
SetupText: true, // OpenAI does not support text completion
SetupText: false, // OpenAI does not support text completion
SetupToolCalls: false,
SetupImage: false,
SetupBaseImage: false,
Expand Down
34 changes: 20 additions & 14 deletions core/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package tests

import (
"context"
"fmt"
"log"
"time"

bifrost "github.com/maximhq/bifrost/core"
Expand Down Expand Up @@ -103,9 +103,9 @@ func setupTextCompletionRequest(bifrost *bifrost.Bifrost, config TestConfig, ctx
Fallbacks: config.Fallbacks,
}, ctx)
if err != nil {
fmt.Printf("\nError in %s text completion: %v\n", config.Provider, err.Error.Message)
log.Println("Error in", config.Provider, "text completion:", err.Error.Message)
} else {
fmt.Printf("\n🐒 %s Text Completion Result: %s\n", config.Provider, *result.Choices[0].Message.Content)
log.Println("🐒", config.Provider, "Text Completion Result:", *result.Choices[0].Message.Content)
}
}()
}
Expand Down Expand Up @@ -147,9 +147,9 @@ func setupChatCompletionRequests(bifrost *bifrost.Bifrost, config TestConfig, ct
Fallbacks: config.Fallbacks,
}, ctx)
if err != nil {
fmt.Printf("\nError in %s request %d: %v\n", config.Provider, index+1, err.Error.Message)
log.Println("Error in", config.Provider, "request", index+1, ":", err.Error.Message)
} else {
fmt.Printf("\n🐒 %s Chat Completion Result %d: %s\n", config.Provider, index+1, *result.Choices[0].Message.Content)
log.Println("🐒", config.Provider, "Chat Completion Result", index+1, ":", *result.Choices[0].Message.Content)
}
}(message, delay, i)
}
Expand Down Expand Up @@ -194,9 +194,9 @@ func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Co
Fallbacks: config.Fallbacks,
}, ctx)
if err != nil {
fmt.Printf("\nError in %s URL image request: %v\n", config.Provider, err.Error.Message)
log.Println("Error in", config.Provider, "URL image request:", err.Error.Message)
} else {
fmt.Printf("\n🐒 %s URL Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content)
log.Println("🐒", config.Provider, "URL Image Result:", *result.Choices[0].Message.Content)
}
}()

Expand Down Expand Up @@ -224,9 +224,9 @@ func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Co
Fallbacks: config.Fallbacks,
}, ctx)
if err != nil {
fmt.Printf("\nError in %s base64 image request: %v\n", config.Provider, err.Error.Message)
log.Println("Error in", config.Provider, "base64 image request:", err.Error.Message)
} else {
fmt.Printf("\n🐒 %s Base64 Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content)
log.Println("🐒", config.Provider, "Base64 Image Result:", *result.Choices[0].Message.Content)
}
}()
}
Expand Down Expand Up @@ -272,15 +272,21 @@ func setupToolCalls(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Con
Fallbacks: config.Fallbacks,
}, ctx)
if err != nil {
fmt.Printf("\nError in %s tool call request %d: %v\n", config.Provider, index+1, err.Error.Message)
log.Println("Error in", config.Provider, "tool call request", index+1, ":", err.Error.Message)
} else {
if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 {
toolCall := *result.Choices[0].Message.ToolCalls
fmt.Printf("\n🐒 %s Tool Call Result %d: %s\n", config.Provider, index+1, toolCall[0].Function.Arguments)
for i, choice := range result.Choices {
if choice.Message.ToolCalls != nil && len(*choice.Message.ToolCalls) > 0 {
toolCall := *choice.Message.ToolCalls
log.Println("🐒", config.Provider, "Tool Call Result", index+1, "(Choice", i+1, "):", toolCall[0].Function.Arguments)
} else {
log.Println("🐒", config.Provider, "No tool calls in response", index+1, "(Choice", i+1, ")")
}
}
} else {
fmt.Printf("\n🐒 %s No tool calls in response %d\n", config.Provider, index+1)
log.Println("🐒", config.Provider, "No tool calls in response", index+1)
if result.ExtraFields.RawResponse != nil {
fmt.Println("\nRaw JSON Response", result.ExtraFields.RawResponse)
log.Println("Raw JSON Response", result.ExtraFields.RawResponse)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion transports/http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func main() {
Handler: r.Handler,
}

fmt.Printf("Starting HTTP server on port %s\n", port)
log.Println("Starting HTTP server on port", port)
if err := server.ListenAndServe(fmt.Sprintf(":%s", port)); err != nil {
log.Fatalf("failed to start server: %v", err)
}
Expand Down