Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
bifrost.channelMessagePool.Put(msg)
}

// SelectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) {
func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil {
return "", err
Expand Down Expand Up @@ -298,7 +298,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan

key := ""
if provider.GetProviderKey() != schemas.Vertex {
key, err = bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err))
req.Err <- schemas.BifrostError{
Expand Down Expand Up @@ -411,10 +411,10 @@ func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key schemas.ModelPr
return nil, fmt.Errorf("no provider found for key: %s", key)
}

// GetProviderQueue returns the request queue for a given provider key.
// getProviderQueue returns the request queue for a given provider key.
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
// given the provider config is provided in the account interface implementation.
func (bifrost *Bifrost) GetProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) {
var queue chan ChannelMessage
var exists bool

Expand Down Expand Up @@ -512,7 +512,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.
// tryTextCompletion attempts a text completion request with a single provider.
// This is a helper function used by TextCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(req.Provider)
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down Expand Up @@ -686,7 +686,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.
// tryChatCompletion attempts a chat completion request with a single provider.
// This is a helper function used by ChatCompletionRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.GetProviderQueue(req.Provider)
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down
20 changes: 13 additions & 7 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,21 +319,23 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke

// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part.
func buildAnthropicImageSourceMap(imgContent *schemas.ImageContent) map[string]interface{} {
if imgContent == nil || imgContent.Type == nil {
if imgContent == nil {
return nil
}

formattedImgContent := *FormatImageContent(imgContent, false)

sourceMap := map[string]interface{}{
"type": *imgContent.Type, // "base64" or "url"
"type": string(formattedImgContent.Type), // "base64" or "url"
}

if *imgContent.Type == "url" {
sourceMap["url"] = imgContent.URL
if formattedImgContent.Type == schemas.ImageContentTypeURL {
sourceMap["url"] = formattedImgContent.URL
} else {
if imgContent.MediaType != nil {
sourceMap["media_type"] = *imgContent.MediaType
if formattedImgContent.MediaType != nil {
sourceMap["media_type"] = *formattedImgContent.MediaType
}
sourceMap["data"] = imgContent.URL // URL field is used for base64 data string
sourceMap["data"] = formattedImgContent.URL // URL field contains base64 data string
}
return sourceMap
}
Expand Down Expand Up @@ -375,8 +377,10 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche
if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
// Create a copy to avoid modifying the original
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
// Create a copy to avoid modifying the original
messageImageContent = *msg.ToolMessage.ImageContent
}

Expand All @@ -396,8 +400,10 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche
if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) {
var messageImageContent schemas.ImageContent
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
// Create a copy to avoid modifying the original
messageImageContent = *msg.UserMessage.ImageContent
} else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil {
// Create a copy to avoid modifying the original
messageImageContent = *msg.ToolMessage.ImageContent
}

Expand Down
9 changes: 6 additions & 3 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,19 +494,22 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema
messageImageContent = *msg.ToolMessage.ImageContent
}

formattedImgContent := *FormatImageContent(&messageImageContent, false)

content = append(content, BedrockAnthropicImageMessage{
Type: "image",
Image: BedrockAnthropicImage{
Format: func() string {
if messageImageContent.MediaType != nil {
mediaType := *messageImageContent.MediaType
if formattedImgContent.MediaType != nil {
mediaType := *formattedImgContent.MediaType
// Remove "image/" prefix if present, since normalizeMediaType ensures full format
mediaType = strings.TrimPrefix(mediaType, "image/")
return mediaType
}
return ""
}(),
Source: BedrockAnthropicImageSource{
Bytes: messageImageContent.URL,
Bytes: formattedImgContent.URL,
},
},
})
Expand Down
74 changes: 68 additions & 6 deletions core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,31 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s
historyMsg["tool_results"] = toolResults
}

// Only add message content if it's not nil
if msg.Content != nil {
historyMsg["message"] = *msg.Content
// Handle message content based on whether it supports vision
if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil {
// Create content array with text and image
contentArray := []map[string]interface{}{}

// Add text content if present
if msg.Content != nil {
contentArray = append(contentArray, map[string]interface{}{
"type": "text",
"text": *msg.Content,
})
}

// Add image content using our helper function
// NOTE: Cohere v1 does not support image content
// if processedImageContent := processImageContent(msg.UserMessage.ImageContent); processedImageContent != nil {
// contentArray = append(contentArray, processedImageContent)
// }

historyMsg["content"] = contentArray
} else {
// For non-vision models or text-only messages, use simple message field
if msg.Content != nil {
historyMsg["message"] = *msg.Content
}
}
Comment thread
akshaydeo marked this conversation as resolved.

cohereHistory = append(cohereHistory, historyMsg)
Expand All @@ -251,9 +273,31 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s
"model": model,
}, preparedParams)

// Only add last message content if it's not nil
if lastMessage.Content != nil {
requestBody["message"] = *lastMessage.Content
// Handle the last message content based on whether it supports vision
if lastMessage.UserMessage != nil && lastMessage.UserMessage.ImageContent != nil {
// Create content array with text and image
contentArray := []map[string]interface{}{}

// Add text content if present
if lastMessage.Content != nil {
contentArray = append(contentArray, map[string]interface{}{
"type": "text",
"text": *lastMessage.Content,
})
}

// Add image content using our helper function
// NOTE: Cohere v1 does not support image content
// if processedImageContent := processImageContent(lastMessage.UserMessage.ImageContent); processedImageContent != nil {
// contentArray = append(contentArray, processedImageContent)
// }

requestBody["content"] = contentArray
} else {
// For non-vision models or text-only messages, use simple message field
if lastMessage.Content != nil {
requestBody["message"] = *lastMessage.Content
}
}

// Add tools if present
Expand Down Expand Up @@ -418,6 +462,24 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s
return bifrostResponse, nil
}

// processImageContent processes image content for Cohere API format.
// It creates a copy of the image content, normalizes and formats it, then returns the properly formatted map.
// This prevents unintended mutations to the original image content.
func processImageContent(imageContent *schemas.ImageContent) map[string]interface{} {
if imageContent == nil {
return nil
}

formattedImgContent := *FormatImageContent(imageContent, true)

return map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": formattedImgContent.URL,
},
}
}

// convertChatHistory converts Cohere's chat history format to Bifrost's format for standardization.
// It transforms the chat history messages and their tool calls.
func convertChatHistory(history []struct {
Expand Down
8 changes: 5 additions & 3 deletions core/providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas
messageImageContent = *msg.ToolMessage.ImageContent
}

formattedImgContent := *FormatImageContent(&messageImageContent, true)

var content []map[string]interface{}

// Add text content if present
Expand All @@ -239,12 +241,12 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas
imageContent := map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": messageImageContent.URL,
"url": formattedImgContent.URL,
},
}

if messageImageContent.Detail != nil {
imageContent["image_url"].(map[string]interface{})["detail"] = messageImageContent.Detail
if formattedImgContent.Detail != nil {
imageContent["image_url"].(map[string]interface{})["detail"] = formattedImgContent.Detail
}

content = append(content, imageContent)
Expand Down
Loading