From 43c207efdfd885edde1cde8ebfcba9bd375ef853 Mon Sep 17 00:00:00 2001 From: Pratham Mishra <99235987+Pratham-Mishra04@users.noreply.github.com> Date: Mon, 9 Jun 2025 01:54:12 +0530 Subject: [PATCH] feat: image input normalisation added across all providers --- core/bifrost.go | 14 ++-- core/providers/anthropic.go | 20 ++++-- core/providers/bedrock.go | 9 ++- core/providers/cohere.go | 74 ++++++++++++++++++-- core/providers/openai.go | 8 ++- core/providers/utils.go | 133 ++++++++++++++++++++++++++++++++++++ core/schemas/bifrost.go | 15 ++-- core/tests/account.go | 2 +- core/tests/azure_test.go | 2 +- core/tests/openai_test.go | 4 +- core/tests/tests.go | 4 +- core/tests/vertex_test.go | 4 +- 12 files changed, 251 insertions(+), 38 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index e0d0e73198..f77e6cf07e 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -215,9 +215,9 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { bifrost.channelMessagePool.Put(msg) } -// SelectKeyFromProviderForModel selects an appropriate API key for a given provider and model. +// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { +func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil { return "", err @@ -298,7 +298,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan key := "" if provider.GetProviderKey() != schemas.Vertex { - key, err = bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) + key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err)) req.Err <- schemas.BifrostError{ @@ -411,10 +411,10 @@ func (bifrost *Bifrost) GetConfiguredProviderFromProviderKey(key schemas.ModelPr return nil, fmt.Errorf("no provider found for key: %s", key) } -// GetProviderQueue returns the request queue for a given provider key. +// getProviderQueue returns the request queue for a given provider key. // If the queue doesn't exist, it creates one at runtime and initializes the provider, // given the provider config is provided in the account interface implementation. -func (bifrost *Bifrost) GetProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) { +func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (chan ChannelMessage, error) { var queue chan ChannelMessage var exists bool @@ -512,7 +512,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. // tryTextCompletion attempts a text completion request with a single provider. // This is a helper function used by TextCompletionRequest to handle individual provider attempts. func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(req.Provider) + queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -686,7 +686,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. // tryChatCompletion attempts a chat completion request with a single provider. // This is a helper function used by ChatCompletionRequest to handle individual provider attempts. func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(req.Provider) + queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 9084255efa..f78b8fcbbd 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -319,21 +319,23 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke // buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. func buildAnthropicImageSourceMap(imgContent *schemas.ImageContent) map[string]interface{} { - if imgContent == nil || imgContent.Type == nil { + if imgContent == nil { return nil } + formattedImgContent := *FormatImageContent(imgContent, false) + sourceMap := map[string]interface{}{ - "type": *imgContent.Type, // "base64" or "url" + "type": string(formattedImgContent.Type), // "base64" or "url" } - if *imgContent.Type == "url" { - sourceMap["url"] = imgContent.URL + if formattedImgContent.Type == schemas.ImageContentTypeURL { + sourceMap["url"] = formattedImgContent.URL } else { - if imgContent.MediaType != nil { - sourceMap["media_type"] = *imgContent.MediaType + if formattedImgContent.MediaType != nil { + sourceMap["media_type"] = *formattedImgContent.MediaType } - sourceMap["data"] = imgContent.URL // URL field is used for base64 data string + sourceMap["data"] = formattedImgContent.URL // URL field contains base64 data string } return sourceMap } @@ -375,8 +377,10 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { var messageImageContent schemas.ImageContent if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + // Create a copy to avoid modifying the original messageImageContent = *msg.UserMessage.ImageContent } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + // Create a copy to avoid modifying the original messageImageContent = *msg.ToolMessage.ImageContent } @@ -396,8 +400,10 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { var messageImageContent schemas.ImageContent if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + // Create a copy to avoid modifying the original messageImageContent = *msg.UserMessage.ImageContent } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + // Create a copy to avoid modifying the original messageImageContent = *msg.ToolMessage.ImageContent } diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 2d98ba69bf..e4718237d1 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -494,19 +494,22 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema messageImageContent = *msg.ToolMessage.ImageContent } + formattedImgContent := *FormatImageContent(&messageImageContent, false) + content = append(content, BedrockAnthropicImageMessage{ Type: "image", Image: BedrockAnthropicImage{ Format: func() string { - if messageImageContent.MediaType != nil { - mediaType := *messageImageContent.MediaType + if formattedImgContent.MediaType != nil { + mediaType := *formattedImgContent.MediaType + // Remove "image/" prefix if present, since normalizeMediaType ensures full format mediaType = strings.TrimPrefix(mediaType, "image/") return mediaType } return "" }(), Source: BedrockAnthropicImageSource{ - Bytes: messageImageContent.URL, + Bytes: formattedImgContent.URL, }, }, }) diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 82749aaafc..9e60809280 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -235,9 +235,31 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s historyMsg["tool_results"] = toolResults } - // Only add message content if it's not nil - if msg.Content != nil { - historyMsg["message"] = *msg.Content + // Handle message content based on whether it supports vision + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + // Create content array with text and image + contentArray := []map[string]interface{}{} + + // Add text content if present + if msg.Content != nil { + contentArray = append(contentArray, map[string]interface{}{ + "type": "text", + "text": *msg.Content, + }) + } + + // Add image content using our helper function + // NOTE: Cohere v1 does not support image content + // if processedImageContent := processImageContent(msg.UserMessage.ImageContent); processedImageContent != nil { + // contentArray = append(contentArray, processedImageContent) + // } + + historyMsg["content"] = contentArray + } else { + // For non-vision models or text-only messages, use simple message field + if msg.Content != nil { + historyMsg["message"] = *msg.Content + } } cohereHistory = append(cohereHistory, historyMsg) @@ -251,9 +273,31 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s "model": model, }, preparedParams) - // Only add last message content if it's not nil - if lastMessage.Content != nil { - requestBody["message"] = *lastMessage.Content + // Handle the last message content based on whether it supports vision + if lastMessage.UserMessage != nil && lastMessage.UserMessage.ImageContent != nil { + // Create content array with text and image + contentArray := []map[string]interface{}{} + + // Add text content if present + if lastMessage.Content != nil { + contentArray = append(contentArray, map[string]interface{}{ + "type": "text", + "text": *lastMessage.Content, + }) + } + + // Add image content using our helper function + // NOTE: Cohere v1 does not support image content + // if processedImageContent := processImageContent(lastMessage.UserMessage.ImageContent); processedImageContent != nil { + // contentArray = append(contentArray, processedImageContent) + // } + + requestBody["content"] = contentArray + } else { + // For non-vision models or text-only messages, use simple message field + if lastMessage.Content != nil { + requestBody["message"] = *lastMessage.Content + } } // Add tools if present @@ -418,6 +462,24 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s return bifrostResponse, nil } +// processImageContent processes image content for Cohere API format. +// It creates a copy of the image content, normalizes and formats it, then returns the properly formatted map. +// This prevents unintended mutations to the original image content. +func processImageContent(imageContent *schemas.ImageContent) map[string]interface{} { + if imageContent == nil { + return nil + } + + formattedImgContent := *FormatImageContent(imageContent, true) + + return map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": formattedImgContent.URL, + }, + } +} + // convertChatHistory converts Cohere's chat history format to Bifrost's format for standardization. // It transforms the chat history messages and their tool calls. func convertChatHistory(history []struct { diff --git a/core/providers/openai.go b/core/providers/openai.go index 8e8e7eee77..a5a972649f 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -226,6 +226,8 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas messageImageContent = *msg.ToolMessage.ImageContent } + formattedImgContent := *FormatImageContent(&messageImageContent, true) + var content []map[string]interface{} // Add text content if present @@ -239,12 +241,12 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas imageContent := map[string]interface{}{ "type": "image_url", "image_url": map[string]interface{}{ - "url": messageImageContent.URL, + "url": formattedImgContent.URL, }, } - if messageImageContent.Detail != nil { - imageContent["image_url"].(map[string]interface{})["detail"] = messageImageContent.Detail + if formattedImgContent.Detail != nil { + imageContent["image_url"].(map[string]interface{})["detail"] = formattedImgContent.Detail } content = append(content, imageContent) diff --git a/core/providers/utils.go b/core/providers/utils.go index d3ec36ba25..42dca5c1ad 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -7,6 +7,7 @@ import ( "fmt" "net/url" "reflect" + "regexp" "strings" "sync" @@ -25,6 +26,10 @@ var bifrostResponsePool = sync.Pool{ }, } +// dataURIRegex is a precompiled regex for matching data URI format patterns. +// It matches patterns like: data:image/png;base64,iVBORw0KGgo... +var dataURIRegex = regexp.MustCompile(`^data:([^;]+);base64,(.*)$`) + // acquireBifrostResponse gets a Bifrost response from the pool and resets it. func acquireBifrostResponse() *schemas.BifrostResponse { resp := bifrostResponsePool.Get().(*schemas.BifrostResponse) @@ -310,3 +315,131 @@ func coalesceString(s *string) string { } return *s } + +// normalizeMediaType converts short media types to full media types +// e.g., "jpeg" -> "image/jpeg", "png" -> "image/png" +func normalizeMediaType(mediaType string) string { + if mediaType == "" { + return "image/jpeg" // default + } + + // If it already has the image/ prefix, return as is + if strings.HasPrefix(mediaType, "image/") { + return mediaType + } + + // Add image/ prefix for common formats + switch strings.ToLower(mediaType) { + case "jpeg", "jpg": + return "image/jpeg" + case "png": + return "image/png" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + case "bmp": + return "image/bmp" + case "svg": + return "image/svg+xml" + default: + return "image/" + mediaType + } +} + +// Normalize handles type inference and media type normalization for image content. +// It automatically detects content type from URL patterns and normalizes media types. +// +// NOTE: This function is called internally by the Bifrost system - you do not need to call it yourself. +// It is automatically invoked when processing image content in requests. +func normalizeImageContent(ic *schemas.ImageContent) { + if ic == nil { + return + } + + // Handle unknown/empty type - try to infer from URL + if ic.Type == "" && ic.URL != "" { + if dataURIRegex.MatchString(ic.URL) { + // Looks like base64 data URI + ic.Type = schemas.ImageContentTypeBase64 + } else if strings.HasPrefix(ic.URL, "http://") || strings.HasPrefix(ic.URL, "https://") { + // Looks like a regular URL + ic.Type = schemas.ImageContentTypeURL + } else { + // Assume it's raw base64 data + ic.Type = schemas.ImageContentTypeBase64 + } + } + + // Normalize MediaType if provided + if ic.MediaType != nil && *ic.MediaType != "" { + normalizedMediaType := normalizeMediaType(*ic.MediaType) + ic.MediaType = &normalizedMediaType + } + +} + +// FormatDataURL modifies the image content struct in place to format data URL for base64 image content. +// +// NOTE: This function is called internally by the Bifrost system - you do not need to call it yourself. +// It is automatically invoked when processing image content for different providers. +// +// Parameters: +// - includePrefix: Whether to include the "data:mediatype;base64," prefix +// - true: URL will be in full data URI format (data:image/png;base64,iVBORw0KGgo...) +// - false: URL will contain only the base64 data (iVBORw0KGgo...) +func FormatImageContent(imageContent *schemas.ImageContent, includePrefix bool) *schemas.ImageContent { + if imageContent == nil { + return nil + } + + newImageContent := *imageContent + + normalizeImageContent(&newImageContent) + + if newImageContent.Type != schemas.ImageContentTypeBase64 { + return &newImageContent + } + + var finalMediaType string + var base64Data string + + // Extract base64 data and media type from URL using precompiled regex + if matches := dataURIRegex.FindStringSubmatch(newImageContent.URL); matches != nil { + // URL already has data URI format + existingMediaType := matches[1] + base64Data = matches[2] + + // Determine final media type (prefer explicit MediaType field) + if newImageContent.MediaType != nil && *newImageContent.MediaType != "" { + finalMediaType = normalizeMediaType(*newImageContent.MediaType) + } else { + finalMediaType = normalizeMediaType(existingMediaType) + } + } else { + // URL contains raw base64 data (no data URI prefix) + base64Data = newImageContent.URL + + // Determine media type + if newImageContent.MediaType != nil && *newImageContent.MediaType != "" { + finalMediaType = normalizeMediaType(*newImageContent.MediaType) + } else { + finalMediaType = "image/jpeg" // default when no media type provided + } + } + + // Ensure MediaType field is always set with normalized value + normalizedMediaType := finalMediaType + newImageContent.MediaType = &normalizedMediaType + + // Set URL based on includePrefix preference + if includePrefix { + // Full data URI format + newImageContent.URL = fmt.Sprintf("data:%s;base64,%s", finalMediaType, base64Data) + } else { + // Raw base64 data only + newImageContent.URL = base64Data + } + + return &newImageContent +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 59c485a220..6d9dcd2cd8 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -170,12 +170,19 @@ type AssistantMessage struct { Thought *string `json:"thought,omitempty"` } +type ImageContentType string + +const ( + ImageContentTypeBase64 ImageContentType = "base64" + ImageContentTypeURL ImageContentType = "url" +) + // ImageContent represents image data in a message. type ImageContent struct { - Type *string `json:"type"` - URL string `json:"url"` - MediaType *string `json:"media_type"` - Detail *string `json:"detail"` + Type ImageContentType `json:"type"` + URL string `json:"url"` + MediaType *string `json:"media_type,omitempty"` + Detail *string `json:"detail,omitempty"` } //* Response Structs diff --git a/core/tests/account.go b/core/tests/account.go index 954cab8480..d5652a3c14 100644 --- a/core/tests/account.go +++ b/core/tests/account.go @@ -77,7 +77,7 @@ func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProv return []schemas.Key{ { Value: os.Getenv("COHERE_API_KEY"), - Models: []string{"command-a-03-2025"}, + Models: []string{"command-a-03-2025", "c4ai-aya-vision-8b"}, Weight: 1.0, }, }, nil diff --git a/core/tests/azure_test.go b/core/tests/azure_test.go index 5c2194364c..2d78aefe05 100644 --- a/core/tests/azure_test.go +++ b/core/tests/azure_test.go @@ -22,7 +22,7 @@ func TestAzure(t *testing.T) { SetupText: false, // gpt-4o does not support text completion SetupToolCalls: true, SetupImage: true, - SetupBaseImage: false, + SetupBaseImage: true, } SetupAllRequests(bifrost, config) diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go index 7dc7fa8c7c..9c3d137dbb 100644 --- a/core/tests/openai_test.go +++ b/core/tests/openai_test.go @@ -22,8 +22,8 @@ func TestOpenAI(t *testing.T) { ChatModel: "gpt-4o-mini", SetupText: false, // OpenAI does not support text completion SetupToolCalls: true, - SetupImage: false, - SetupBaseImage: false, + SetupImage: true, + SetupBaseImage: true, Fallbacks: []schemas.Fallback{ { Provider: schemas.Anthropic, diff --git a/core/tests/tests.go b/core/tests/tests.go index e439cfbea3..d1acc53e95 100644 --- a/core/tests/tests.go +++ b/core/tests/tests.go @@ -189,7 +189,7 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont Content: bifrost.Ptr("What is Happening in this picture?"), UserMessage: &schemas.UserMessage{ ImageContent: &schemas.ImageContent{ - Type: bifrost.Ptr("url"), + Type: schemas.ImageContentTypeURL, URL: "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg", }, }, @@ -225,7 +225,7 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont Content: bifrost.Ptr("What is this image about?"), UserMessage: &schemas.UserMessage{ ImageContent: &schemas.ImageContent{ - Type: bifrost.Ptr("base64"), + Type: schemas.ImageContentTypeBase64, URL: "/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=", MediaType: bifrost.Ptr("image/jpeg"), }, diff --git a/core/tests/vertex_test.go b/core/tests/vertex_test.go index 89265aae57..93f445e36a 100644 --- a/core/tests/vertex_test.go +++ b/core/tests/vertex_test.go @@ -21,8 +21,8 @@ func TestVertex(t *testing.T) { ChatModel: "google/gemini-2.0-flash-001", SetupText: false, // Vertex does not support text completion SetupToolCalls: true, - SetupImage: false, - SetupBaseImage: false, + SetupImage: true, + SetupBaseImage: true, } SetupAllRequests(bifrostClient, config)