diff --git a/core/bifrost.go b/core/bifrost.go index 8f8f58b7e8..e0d0e73198 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "slices" + "strings" "sync" "time" @@ -229,7 +230,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelP // filter out keys which dont support the model var supportedKeys []schemas.Key for _, key := range keys { - if slices.Contains(key.Models, model) { + if slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "" { supportedKeys = append(supportedKeys, key) } } diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 48e81a0b2a..d39b531728 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -278,7 +278,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params) + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) // Merge additional parameters requestBody := mergeConfig(map[string]interface{}{ @@ -317,12 +317,32 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke return bifrostResponse, nil } -func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { +// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. +func buildAnthropicImageSourceMap(imgContent *schemas.ImageContent) map[string]interface{} { + if imgContent == nil || imgContent.Type == nil { + return nil + } + + sourceMap := map[string]interface{}{ + "type": *imgContent.Type, // "base64" or "url" + } + + if *imgContent.Type == "url" { + sourceMap["url"] = imgContent.URL + } else { + if imgContent.MediaType != nil { + sourceMap["media_type"] = *imgContent.MediaType + } + sourceMap["data"] = imgContent.URL // URL field is used for base64 data string + } + return sourceMap +} + +func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { // Add system messages if present var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { if msg.Role == schemas.ModelChatMessageRoleSystem { - //TODO handling image inputs here if msg.Content != nil { systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ Text: *msg.Content, @@ -335,57 +355,64 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage var formattedMessages []map[string]interface{} for _, msg := range messages { if msg.Role != schemas.ModelChatMessageRoleSystem { - if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { - var messageImageContent schemas.ImageContent - if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { - messageImageContent = *msg.UserMessage.ImageContent - } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { - messageImageContent = *msg.ToolMessage.ImageContent - } - - var content []map[string]interface{} - - imageContent := map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": messageImageContent.Type, - }, - } - - // Handle different image source types - if messageImageContent.Type != nil && *messageImageContent.Type == "url" { - imageContent["source"].(map[string]interface{})["url"] = messageImageContent.URL - } else { - imageContent["source"].(map[string]interface{})["media_type"] = messageImageContent.MediaType - imageContent["source"].(map[string]interface{})["data"] = messageImageContent.URL + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": *msg.ToolCallID, } - content = append(content, imageContent) + var toolCallResultContent []map[string]interface{} - // Add text content if present if msg.Content != nil { - content = append(content, map[string]interface{}{ + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ "type": "text", "text": *msg.Content, }) } - // Add thinking content if present in AssistantMessage - if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { - content = append(content, map[string]interface{}{ - "type": "thinking", - "thinking": *msg.AssistantMessage.Thought, - }) + if msg.UserMessage.ImageContent != nil || msg.ToolMessage.ImageContent != nil { + var messageImageContent schemas.ImageContent + if msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + imageSource := buildAnthropicImageSourceMap(&messageImageContent) + if imageSource != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } } + toolCallResult["content"] = toolCallResultContent + formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, + "role": schemas.ModelChatMessageRoleTool, + "content": toolCallResult, }) } else { - // Handle non-image messages var content []map[string]interface{} + if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + var messageImageContent schemas.ImageContent + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + imageSource := buildAnthropicImageSourceMap(&messageImageContent) + if imageSource != nil { + content = append(content, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } + } + // Add text content if present if msg.Content != nil && *msg.Content != "" { content = append(content, map[string]interface{}{ @@ -429,7 +456,6 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage } } - // Always use content block structure if len(content) > 0 { formattedMessages = append(formattedMessages, map[string]interface{}{ "role": msg.Role, @@ -465,6 +491,54 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage preparedParams["system"] = strings.Join(messages, " ") } + // Post-process formattedMessages for tool call results + processedFormattedMessages := []map[string]interface{}{} // Use a new slice + i := 0 + for i < len(formattedMessages) { + currentMsg := formattedMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk { + // If role is of an unexpected type or missing, treat as non-tool message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + accumulatedToolResults := []interface{}{currentMsg["content"]} + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(formattedMessages) { + nextMsg := formattedMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing + } + + accumulatedToolResults = append(accumulatedToolResults, nextMsg["content"]) + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedFormattedMessages = append(processedFormattedMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + } + } + formattedMessages = processedFormattedMessages // Update with processed messages + return formattedMessages, preparedParams } diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 88e4a7fae2..1eb0827e01 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -415,7 +415,6 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { if msg.Role == schemas.ModelChatMessageRoleSystem { - //TODO handling image inputs here if msg.Content != nil { systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ Text: *msg.Content, @@ -428,42 +427,121 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema var bedrockMessages []map[string]interface{} for _, msg := range messages { if msg.Role != schemas.ModelChatMessageRoleSystem { - var content any - if msg.Content != nil { - content = BedrockAnthropicTextMessage{ - Type: "text", - Text: *msg.Content, - } - } else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { - var messageImageContent schemas.ImageContent - if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { - messageImageContent = *msg.UserMessage.ImageContent - } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { - messageImageContent = *msg.ToolMessage.ImageContent + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "toolUseId": *msg.ToolCallID, } - content = BedrockAnthropicImageMessage{ - Type: "image", - Image: BedrockAnthropicImage{ - Format: func() string { - if messageImageContent.Type != nil { - return *messageImageContent.Type - } - return "" - }(), - Source: BedrockAnthropicImageSource{ - Bytes: messageImageContent.URL, + var toolResultContentBlock map[string]interface{} + if msg.Content != nil { + toolResultContentBlock = map[string]interface{}{} + var parsedJSON interface{} + err := json.Unmarshal([]byte(*msg.Content), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.([]interface{}); ok { + toolResultContentBlock["json"] = map[string]interface{}{"content": arr} + } else { + toolResultContentBlock["json"] = parsedJSON + } + } else { + toolResultContentBlock["text"] = *msg.Content + } + + toolCallResult["content"] = []interface{}{toolResultContentBlock} + + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": schemas.ModelChatMessageRoleTool, + "content": map[string]interface{}{ + "toolResult": toolCallResult, + }, + }) + } + } else { + var content any + if msg.Content != nil { + content = BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content, + } + } else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + var messageImageContent schemas.ImageContent + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + content = BedrockAnthropicImageMessage{ + Type: "image", + Image: BedrockAnthropicImage{ + Format: func() string { + if messageImageContent.Type != nil { + return *messageImageContent.Type + } + return "" + }(), + Source: BedrockAnthropicImageSource{ + Bytes: messageImageContent.URL, + }, }, - }, + } } + + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": []interface{}{content}, + }) } + } + } - bedrockMessages = append(bedrockMessages, map[string]interface{}{ - "role": msg.Role, - "content": []interface{}{content}, - }) + // Post-process bedrockMessages for tool call results + processedBedrockMessages := []map[string]interface{}{} + i := 0 + for i < len(bedrockMessages) { + currentMsg := bedrockMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk { + // If role is of an unexpected type or missing, treat as non-tool message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + accumulatedToolResults := []interface{}{currentMsg["content"]} + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(bedrockMessages) { + nextMsg := bedrockMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing + } + + accumulatedToolResults = append(accumulatedToolResults, nextMsg["content"]) + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedBedrockMessages = append(processedBedrockMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ } } + bedrockMessages = processedBedrockMessages // Update with processed messages body := map[string]interface{}{ "messages": bedrockMessages, diff --git a/core/providers/utils.go b/core/providers/utils.go index 7fa930216b..8f97b7cb40 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -268,6 +268,28 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ return rawResponse, nil } +// getRoleFromMessage extracts and validates the role from a message map. +func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRole, bool) { + roleVal, exists := msg["role"] + if !exists { + return "", false // Role key doesn't exist + } + + // Try direct assertion to ModelChatMessageRole + roleAsModelType, ok := roleVal.(schemas.ModelChatMessageRole) + if ok { + return roleAsModelType, true + } + + // Try assertion to string and then convert + roleAsString, okStr := roleVal.(string) + if okStr { + return schemas.ModelChatMessageRole(roleAsString), true + } + + return "", false // Role is of an unexpected or invalid type +} + // float64Ptr creates a pointer to a float64 value. // This is a helper function for creating pointers to float64 values. func float64Ptr(f float64) *float64 { diff --git a/core/providers/vertex.go b/core/providers/vertex.go index b47d81c066..0334ff9a46 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -95,7 +95,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s var preparedParams map[string]interface{} if strings.Contains(model, "claude") { - formattedMessages, preparedParams = prepareAnthropicChatRequest(model, messages, params) + formattedMessages, preparedParams = prepareAnthropicChatRequest(messages, params) } else { formattedMessages, preparedParams = prepareOpenAIChatRequest(model, messages, params) } diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go index f748e08ca6..7dc7fa8c7c 100644 --- a/core/tests/openai_test.go +++ b/core/tests/openai_test.go @@ -21,7 +21,7 @@ func TestOpenAI(t *testing.T) { TextModel: "gpt-4o-mini", ChatModel: "gpt-4o-mini", SetupText: false, // OpenAI does not support text completion - SetupToolCalls: false, + SetupToolCalls: true, SetupImage: false, SetupBaseImage: false, Fallbacks: []schemas.Fallback{ diff --git a/core/tests/tests.go b/core/tests/tests.go index 3bbe01c40f..19ca319267 100644 --- a/core/tests/tests.go +++ b/core/tests/tests.go @@ -149,6 +149,7 @@ func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConf Content: &msg, }, } + result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ Provider: config.Provider, Model: config.ChatModel, diff --git a/core/tests/vertex_test.go b/core/tests/vertex_test.go index c580284442..89265aae57 100644 --- a/core/tests/vertex_test.go +++ b/core/tests/vertex_test.go @@ -20,7 +20,7 @@ func TestVertex(t *testing.T) { Provider: schemas.Vertex, ChatModel: "google/gemini-2.0-flash-001", SetupText: false, // Vertex does not support text completion - SetupToolCalls: false, + SetupToolCalls: true, SetupImage: false, SetupBaseImage: false, }