Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"math/rand"
"slices"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -229,7 +230,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelP
// filter out keys which dont support the model
var supportedKeys []schemas.Key
for _, key := range keys {
if slices.Contains(key.Models, model) {
if slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "" {
supportedKeys = append(supportedKeys, key)
}
}
Expand Down
152 changes: 113 additions & 39 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params)
formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params)

// Merge additional parameters
requestBody := mergeConfig(map[string]interface{}{
Expand Down Expand Up @@ -317,12 +317,32 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke
return bifrostResponse, nil
}

func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) {
// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part.
func buildAnthropicImageSourceMap(imgContent *schemas.ImageContent) map[string]interface{} {
if imgContent == nil || imgContent.Type == nil {
return nil
}

sourceMap := map[string]interface{}{
"type": *imgContent.Type, // "base64" or "url"
}

if *imgContent.Type == "url" {
sourceMap["url"] = imgContent.URL
} else {
if imgContent.MediaType != nil {
sourceMap["media_type"] = *imgContent.MediaType
}
sourceMap["data"] = imgContent.URL // URL field is used for base64 data string
}
return sourceMap
}

func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) {
// Add system messages if present
var systemMessages []BedrockAnthropicSystemMessage
for _, msg := range messages {
if msg.Role == schemas.ModelChatMessageRoleSystem {
//TODO handling image inputs here
if msg.Content != nil {
systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{
Text: *msg.Content,
Expand All @@ -335,57 +355,64 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage
var formattedMessages []map[string]interface{}
for _, msg := range messages {
if msg.Role != schemas.ModelChatMessageRoleSystem {
if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
messageImageContent = *msg.ToolMessage.ImageContent
}

var content []map[string]interface{}

imageContent := map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": messageImageContent.Type,
},
}

// Handle different image source types
if messageImageContent.Type != nil && *messageImageContent.Type == "url" {
imageContent["source"].(map[string]interface{})["url"] = messageImageContent.URL
} else {
imageContent["source"].(map[string]interface{})["media_type"] = messageImageContent.MediaType
imageContent["source"].(map[string]interface{})["data"] = messageImageContent.URL
if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil {
toolCallResult := map[string]interface{}{
"type": "tool_result",
"tool_use_id": *msg.ToolCallID,
}

content = append(content, imageContent)
var toolCallResultContent []map[string]interface{}

// Add text content if present
if msg.Content != nil {
content = append(content, map[string]interface{}{
toolCallResultContent = append(toolCallResultContent, map[string]interface{}{
"type": "text",
"text": *msg.Content,
})
}

// Add thinking content if present in AssistantMessage
if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil {
content = append(content, map[string]interface{}{
"type": "thinking",
"thinking": *msg.AssistantMessage.Thought,
})
if msg.UserMessage.ImageContent != nil || msg.ToolMessage.ImageContent != nil {
var messageImageContent schemas.ImageContent
if msg.UserMessage.ImageContent != nil {
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage.ImageContent != nil {
messageImageContent = *msg.ToolMessage.ImageContent
}

imageSource := buildAnthropicImageSourceMap(&messageImageContent)
if imageSource != nil {
toolCallResultContent = append(toolCallResultContent, map[string]interface{}{
"type": "image",
"source": imageSource,
})
}
}

toolCallResult["content"] = toolCallResultContent

formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
"content": content,
"role": schemas.ModelChatMessageRoleTool,
"content": toolCallResult,
})
} else {
// Handle non-image messages
var content []map[string]interface{}

if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
messageImageContent = *msg.ToolMessage.ImageContent
}

imageSource := buildAnthropicImageSourceMap(&messageImageContent)
if imageSource != nil {
content = append(content, map[string]interface{}{
"type": "image",
"source": imageSource,
})
}
}

// Add text content if present
if msg.Content != nil && *msg.Content != "" {
content = append(content, map[string]interface{}{
Expand Down Expand Up @@ -429,7 +456,6 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage
}
}

// Always use content block structure
if len(content) > 0 {
formattedMessages = append(formattedMessages, map[string]interface{}{
"role": msg.Role,
Expand Down Expand Up @@ -465,6 +491,54 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage
preparedParams["system"] = strings.Join(messages, " ")
}

// Post-process formattedMessages for tool call results
processedFormattedMessages := []map[string]interface{}{} // Use a new slice
i := 0
for i < len(formattedMessages) {
currentMsg := formattedMessages[i]
currentRole, roleOk := getRoleFromMessage(currentMsg)

if !roleOk {
// If role is of an unexpected type or missing, treat as non-tool message
processedFormattedMessages = append(processedFormattedMessages, currentMsg)
i++
continue
}

if currentRole == schemas.ModelChatMessageRoleTool {
// Content of a tool message is the toolCallResult map
// Initialize accumulatedToolResults with the content of the current tool message.
accumulatedToolResults := []interface{}{currentMsg["content"]}

// Look ahead for more sequential tool messages
j := i + 1
for j < len(formattedMessages) {
nextMsg := formattedMessages[j]
nextRole, nextRoleOk := getRoleFromMessage(nextMsg)

if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool {
break // Not a sequential tool message or role is invalid/missing
}

accumulatedToolResults = append(accumulatedToolResults, nextMsg["content"])
j++
}

// Create a new message with role User and accumulated content
mergedMsg := map[string]interface{}{
"role": schemas.ModelChatMessageRoleUser, // Final role is User
"content": accumulatedToolResults,
}
processedFormattedMessages = append(processedFormattedMessages, mergedMsg)
i = j // Advance main loop index past all merged messages
} else {
// Not a tool message, add it as is
processedFormattedMessages = append(processedFormattedMessages, currentMsg)
i++
}
}
formattedMessages = processedFormattedMessages // Update with processed messages

return formattedMessages, preparedParams
}

Expand Down
136 changes: 107 additions & 29 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema
var systemMessages []BedrockAnthropicSystemMessage
for _, msg := range messages {
if msg.Role == schemas.ModelChatMessageRoleSystem {
//TODO handling image inputs here
if msg.Content != nil {
systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{
Text: *msg.Content,
Expand All @@ -428,42 +427,121 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema
var bedrockMessages []map[string]interface{}
for _, msg := range messages {
if msg.Role != schemas.ModelChatMessageRoleSystem {
var content any
if msg.Content != nil {
content = BedrockAnthropicTextMessage{
Type: "text",
Text: *msg.Content,
}
} else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
messageImageContent = *msg.ToolMessage.ImageContent
if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil {
toolCallResult := map[string]interface{}{
"toolUseId": *msg.ToolCallID,
}

content = BedrockAnthropicImageMessage{
Type: "image",
Image: BedrockAnthropicImage{
Format: func() string {
if messageImageContent.Type != nil {
return *messageImageContent.Type
}
return ""
}(),
Source: BedrockAnthropicImageSource{
Bytes: messageImageContent.URL,
var toolResultContentBlock map[string]interface{}
if msg.Content != nil {
toolResultContentBlock = map[string]interface{}{}
var parsedJSON interface{}
err := json.Unmarshal([]byte(*msg.Content), &parsedJSON)
if err == nil {
if arr, ok := parsedJSON.([]interface{}); ok {
toolResultContentBlock["json"] = map[string]interface{}{"content": arr}
} else {
toolResultContentBlock["json"] = parsedJSON
}
} else {
toolResultContentBlock["text"] = *msg.Content
}

toolCallResult["content"] = []interface{}{toolResultContentBlock}

bedrockMessages = append(bedrockMessages, map[string]interface{}{
"role": schemas.ModelChatMessageRoleTool,
"content": map[string]interface{}{
"toolResult": toolCallResult,
},
})
}
} else {
var content any
if msg.Content != nil {
content = BedrockAnthropicTextMessage{
Type: "text",
Text: *msg.Content,
}
} else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
messageImageContent = *msg.ToolMessage.ImageContent
}

content = BedrockAnthropicImageMessage{
Type: "image",
Image: BedrockAnthropicImage{
Format: func() string {
if messageImageContent.Type != nil {
return *messageImageContent.Type
}
return ""
}(),
Source: BedrockAnthropicImageSource{
Bytes: messageImageContent.URL,
},
},
},
}
}

bedrockMessages = append(bedrockMessages, map[string]interface{}{
"role": msg.Role,
"content": []interface{}{content},
})
}
}
}

bedrockMessages = append(bedrockMessages, map[string]interface{}{
"role": msg.Role,
"content": []interface{}{content},
})
// Post-process bedrockMessages for tool call results
processedBedrockMessages := []map[string]interface{}{}
i := 0
for i < len(bedrockMessages) {
currentMsg := bedrockMessages[i]
currentRole, roleOk := getRoleFromMessage(currentMsg)

if !roleOk {
// If role is of an unexpected type or missing, treat as non-tool message
processedBedrockMessages = append(processedBedrockMessages, currentMsg)
i++
continue
}

if currentRole == schemas.ModelChatMessageRoleTool {
// Content of a tool message is the toolCallResult map
// Initialize accumulatedToolResults with the content of the current tool message.
accumulatedToolResults := []interface{}{currentMsg["content"]}

// Look ahead for more sequential tool messages
j := i + 1
for j < len(bedrockMessages) {
nextMsg := bedrockMessages[j]
nextRole, nextRoleOk := getRoleFromMessage(nextMsg)

if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool {
break // Not a sequential tool message or role is invalid/missing
}

accumulatedToolResults = append(accumulatedToolResults, nextMsg["content"])
j++
}

// Create a new message with role User and accumulated content
mergedMsg := map[string]interface{}{
"role": schemas.ModelChatMessageRoleUser, // Final role is User
"content": accumulatedToolResults,
}
processedBedrockMessages = append(processedBedrockMessages, mergedMsg)
i = j // Advance main loop index past all merged messages
} else {
// Not a tool message, add it as is
processedBedrockMessages = append(processedBedrockMessages, currentMsg)
i++
}
}
bedrockMessages = processedBedrockMessages // Update with processed messages

body := map[string]interface{}{
"messages": bedrockMessages,
Expand Down
Loading