diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 881c0aceda..f9c4fb601b 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -4,6 +4,7 @@ package providers import ( "fmt" + "strings" "sync" "time" @@ -279,46 +280,59 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Add system messages if present + var systemMessages []BedrockAnthropicSystemMessage + for _, msg := range messages { + if msg.Role == schemas.RoleSystem { + //TODO handling image inputs here + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content, + }) + } + } + // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} - - imageContent := map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": msg.ImageContent.Type, - }, - } - - // Handle different image source types - if *msg.ImageContent.Type == "url" { - imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL + if msg.Role != schemas.RoleSystem { + if msg.ImageContent != nil { + var content []map[string]interface{} + + imageContent := map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": msg.ImageContent.Type, + }, + } + + // Handle different image source types + if *msg.ImageContent.Type == "url" { + imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL + } else { + imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType + imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL + } + + content = append(content, imageContent) + + // Add text content if present + if msg.Content != nil { + content = append(content, map[string]interface{}{ + "type": "text", + "text": msg.Content, + }) + } + + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) } else { - imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType - imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL - } - - content = append(content, imageContent) - - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, }) } - - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) } } @@ -344,6 +358,15 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] "messages": formattedMessages, }, preparedParams) + if len(systemMessages) > 0 { + var messages []string + for _, message := range systemMessages { + messages = append(messages, message.Text) + } + + requestBody["system"] = strings.Join(messages, " ") + } + responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) if err != nil { return nil, err diff --git a/core/providers/openai.go b/core/providers/openai.go index ff96a69b77..c4be162bcb 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -199,12 +199,16 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.EventID = &errorResp.EventID + if errorResp.EventID != "" { + bifrostErr.EventID = &errorResp.EventID + } bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Code = &errorResp.Error.Code bifrostErr.Error.Message = errorResp.Error.Message bifrostErr.Error.Param = errorResp.Error.Param - bifrostErr.Error.EventID = &errorResp.Error.EventID + if errorResp.Error.EventID != "" { + bifrostErr.Error.EventID = &errorResp.Error.EventID + } return nil, bifrostErr } diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 4e3f06041e..5b95b5d796 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -73,14 +73,14 @@ type Fallback struct { type ModelParameters struct { ToolChoice *ToolChoice `json:"tool_choice,omitempty"` Tools *[]Tool `json:"tools,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output - TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling - TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens - ParallelToolCalls *bool `json:"parallel_tool_calls"` // Enables parallel tool calls + Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. @@ -89,10 +89,11 @@ type ModelParameters struct { // FunctionParameters represents the parameters for a function definition. type FunctionParameters struct { - Type string `json:"type,"` // Type of the parameters + Type string `json:"type"` // Type of the parameters Description *string `json:"description,omitempty"` // Description of the parameters - Required []string `json:"required"` // Required parameter names - Properties map[string]interface{} `json:"properties"` // Parameter properties + Required []string `json:"required,omitempty"` // Required parameter names + Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties + Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters } // Function represents a function that can be called by the model. diff --git a/core/tests/anthropic_test.go b/core/tests/anthropic_test.go index 5df5170b45..9522f76ba7 100644 --- a/core/tests/anthropic_test.go +++ b/core/tests/anthropic_test.go @@ -21,9 +21,9 @@ func TestAnthropic(t *testing.T) { config := TestConfig{ Provider: schemas.Anthropic, TextModel: "claude-2.1", - ChatModel: "claude-3-5-sonnet-20240620", + ChatModel: "claude-3-7-sonnet-20250219", SetupText: true, - SetupToolCalls: false, // available in 3.7 sonnet + SetupToolCalls: true, // available in 3.7 sonnet only SetupImage: true, SetupBaseImage: true, CustomParams: &schemas.ModelParameters{ diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go index cae22a1b79..6cf5e67086 100644 --- a/core/tests/openai_test.go +++ b/core/tests/openai_test.go @@ -20,7 +20,7 @@ func TestOpenAI(t *testing.T) { Provider: schemas.OpenAI, TextModel: "gpt-4o-mini", ChatModel: "gpt-4o-mini", - SetupText: true, // OpenAI does not support text completion + SetupText: false, // OpenAI does not support text completion SetupToolCalls: false, SetupImage: false, SetupBaseImage: false, diff --git a/core/tests/tests.go b/core/tests/tests.go index f8a70b8d53..69b62c7f97 100644 --- a/core/tests/tests.go +++ b/core/tests/tests.go @@ -5,7 +5,7 @@ package tests import ( "context" - "fmt" + "log" "time" bifrost "github.com/maximhq/bifrost/core" @@ -103,9 +103,9 @@ func setupTextCompletionRequest(bifrost *bifrost.Bifrost, config TestConfig, ctx Fallbacks: config.Fallbacks, }, ctx) if err != nil { - fmt.Printf("\nError in %s text completion: %v\n", config.Provider, err.Error.Message) + log.Println("Error in", config.Provider, "text completion:", err.Error.Message) } else { - fmt.Printf("\nšŸ’ %s Text Completion Result: %s\n", config.Provider, *result.Choices[0].Message.Content) + log.Println("šŸ’", config.Provider, "Text Completion Result:", *result.Choices[0].Message.Content) } }() } @@ -147,9 +147,9 @@ func setupChatCompletionRequests(bifrost *bifrost.Bifrost, config TestConfig, ct Fallbacks: config.Fallbacks, }, ctx) if err != nil { - fmt.Printf("\nError in %s request %d: %v\n", config.Provider, index+1, err.Error.Message) + log.Println("Error in", config.Provider, "request", index+1, ":", err.Error.Message) } else { - fmt.Printf("\nšŸ’ %s Chat Completion Result %d: %s\n", config.Provider, index+1, *result.Choices[0].Message.Content) + log.Println("šŸ’", config.Provider, "Chat Completion Result", index+1, ":", *result.Choices[0].Message.Content) } }(message, delay, i) } @@ -194,9 +194,9 @@ func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Co Fallbacks: config.Fallbacks, }, ctx) if err != nil { - fmt.Printf("\nError in %s URL image request: %v\n", config.Provider, err.Error.Message) + log.Println("Error in", config.Provider, "URL image request:", err.Error.Message) } else { - fmt.Printf("\nšŸ’ %s URL Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) + log.Println("šŸ’", config.Provider, "URL Image Result:", *result.Choices[0].Message.Content) } }() @@ -224,9 +224,9 @@ func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Co Fallbacks: config.Fallbacks, }, ctx) if err != nil { - fmt.Printf("\nError in %s base64 image request: %v\n", config.Provider, err.Error.Message) + log.Println("Error in", config.Provider, "base64 image request:", err.Error.Message) } else { - fmt.Printf("\nšŸ’ %s Base64 Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) + log.Println("šŸ’", config.Provider, "Base64 Image Result:", *result.Choices[0].Message.Content) } }() } @@ -272,15 +272,21 @@ func setupToolCalls(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Con Fallbacks: config.Fallbacks, }, ctx) if err != nil { - fmt.Printf("\nError in %s tool call request %d: %v\n", config.Provider, index+1, err.Error.Message) + log.Println("Error in", config.Provider, "tool call request", index+1, ":", err.Error.Message) } else { if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { - toolCall := *result.Choices[0].Message.ToolCalls - fmt.Printf("\nšŸ’ %s Tool Call Result %d: %s\n", config.Provider, index+1, toolCall[0].Function.Arguments) + for i, choice := range result.Choices { + if choice.Message.ToolCalls != nil && len(*choice.Message.ToolCalls) > 0 { + toolCall := *choice.Message.ToolCalls + log.Println("šŸ’", config.Provider, "Tool Call Result", index+1, "(Choice", i+1, "):", toolCall[0].Function.Arguments) + } else { + log.Println("šŸ’", config.Provider, "No tool calls in response", index+1, "(Choice", i+1, ")") + } + } } else { - fmt.Printf("\nšŸ’ %s No tool calls in response %d\n", config.Provider, index+1) + log.Println("šŸ’", config.Provider, "No tool calls in response", index+1) if result.ExtraFields.RawResponse != nil { - fmt.Println("\nRaw JSON Response", result.ExtraFields.RawResponse) + log.Println("Raw JSON Response", result.ExtraFields.RawResponse) } } } diff --git a/transports/http/main.go b/transports/http/main.go index 8af6fb3171..aa3e4b710a 100644 --- a/transports/http/main.go +++ b/transports/http/main.go @@ -434,7 +434,7 @@ func main() { Handler: r.Handler, } - fmt.Printf("Starting HTTP server on port %s\n", port) + log.Println("Starting HTTP server on port", port) if err := server.ListenAndServe(fmt.Sprintf(":%s", port)); err != nil { log.Fatalf("failed to start server: %v", err) }