diff --git a/core/bifrost.go b/core/bifrost.go index 8f8f58b7e8..e0d0e73198 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "slices" + "strings" "sync" "time" @@ -229,7 +230,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelP // filter out keys which dont support the model var supportedKeys []schemas.Key for _, key := range keys { - if slices.Contains(key.Models, model) { + if slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "" { supportedKeys = append(supportedKeys, key) } } diff --git a/core/go.mod b/core/go.mod index 01e20c64e2..6b12f9568b 100644 --- a/core/go.mod +++ b/core/go.mod @@ -8,6 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/goccy/go-json v0.10.5 + github.com/stretchr/testify v1.10.0 github.com/valyala/fasthttp v1.60.0 golang.org/x/oauth2 v0.30.0 ) @@ -26,8 +27,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/text v0.24.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/core/go.sum b/core/go.sum index 2a37e31a12..b78980b657 100644 --- a/core/go.sum +++ b/core/go.sum @@ -28,12 +28,18 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= @@ -46,3 +52,7 @@ golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 48e81a0b2a..9084255efa 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -278,7 +278,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params) + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) // Merge additional parameters requestBody := mergeConfig(map[string]interface{}{ @@ -317,12 +317,32 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke return bifrostResponse, nil } -func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { +// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. +func buildAnthropicImageSourceMap(imgContent *schemas.ImageContent) map[string]interface{} { + if imgContent == nil || imgContent.Type == nil { + return nil + } + + sourceMap := map[string]interface{}{ + "type": *imgContent.Type, // "base64" or "url" + } + + if *imgContent.Type == "url" { + sourceMap["url"] = imgContent.URL + } else { + if imgContent.MediaType != nil { + sourceMap["media_type"] = *imgContent.MediaType + } + sourceMap["data"] = imgContent.URL // URL field is used for base64 data string + } + return sourceMap +} + +func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { // Add system messages if present var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { if msg.Role == schemas.ModelChatMessageRoleSystem { - //TODO handling image inputs here if msg.Content != nil { systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ Text: *msg.Content, @@ -334,57 +354,61 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { - if msg.Role != schemas.ModelChatMessageRoleSystem { - if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { - var messageImageContent schemas.ImageContent - if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { - messageImageContent = *msg.UserMessage.ImageContent - } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { - messageImageContent = *msg.ToolMessage.ImageContent - } - - var content []map[string]interface{} - - imageContent := map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": messageImageContent.Type, - }, - } + var content []interface{} - // Handle different image source types - if messageImageContent.Type != nil && *messageImageContent.Type == "url" { - imageContent["source"].(map[string]interface{})["url"] = messageImageContent.URL - } else { - imageContent["source"].(map[string]interface{})["media_type"] = messageImageContent.MediaType - imageContent["source"].(map[string]interface{})["data"] = messageImageContent.URL + if msg.Role != schemas.ModelChatMessageRoleSystem { + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": *msg.ToolMessage.ToolCallID, } - content = append(content, imageContent) + var toolCallResultContent []map[string]interface{} - // Add text content if present if msg.Content != nil { - content = append(content, map[string]interface{}{ + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ "type": "text", "text": *msg.Content, }) } - // Add thinking content if present in AssistantMessage - if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { - content = append(content, map[string]interface{}{ - "type": "thinking", - "thinking": *msg.AssistantMessage.Thought, - }) + if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + var messageImageContent schemas.ImageContent + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + imageSource := buildAnthropicImageSourceMap(&messageImageContent) + if imageSource != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } } - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) + toolCallResult["content"] = toolCallResultContent + + content = append(content, toolCallResult) } else { - // Handle non-image messages - var content []map[string]interface{} + if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + var messageImageContent schemas.ImageContent + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + imageSource := buildAnthropicImageSourceMap(&messageImageContent) + if imageSource != nil { + content = append(content, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } + } // Add text content if present if msg.Content != nil && *msg.Content != "" { @@ -428,14 +452,13 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage } } } + } - // Always use content block structure - if len(content) > 0 { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } + if len(content) > 0 { + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) } } } @@ -456,6 +479,21 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage preparedParams["tools"] = tools } + // Transform tool choice if present + if params != nil && params.ToolChoice != nil { + switch toolChoice := params.ToolChoice.Type; toolChoice { + case schemas.ToolChoiceTypeFunction: + preparedParams["tool_choice"] = map[string]interface{}{ + "type": "tool", + "name": params.ToolChoice.Function.Name, + } + default: + preparedParams["tool_choice"] = map[string]interface{}{ + "type": toolChoice, + } + } + } + if len(systemMessages) > 0 { var messages []string for _, message := range systemMessages { @@ -465,6 +503,67 @@ func prepareAnthropicChatRequest(model string, messages []schemas.BifrostMessage preparedParams["system"] = strings.Join(messages, " ") } + // Post-process formattedMessages for tool call results + processedFormattedMessages := []map[string]interface{}{} // Use a new slice + i := 0 + for i < len(formattedMessages) { + currentMsg := formattedMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk || currentRole == "" { + // If role is of an unexpected type, missing, or empty, treat as non-tool message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} + + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(formattedMessages) { + nextMsg := formattedMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole == "" || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing/empty + } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedFormattedMessages = append(processedFormattedMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + } + } + formattedMessages = processedFormattedMessages // Update with processed messages + return formattedMessages, preparedParams } diff --git a/core/providers/azure.go b/core/providers/azure.go index e2786222e1..08039450f5 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -297,22 +297,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - preparedParams := prepareParams(params) - - // Format messages for Azure API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - message := map[string]interface{}{ - "role": msg.Role, - } - - // Only add content if it's not nil - if msg.Content != nil { - message["content"] = *msg.Content - } - - formattedMessages = append(formattedMessages, message) - } + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) // Merge additional parameters requestBody := mergeConfig(map[string]interface{}{ diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 88e4a7fae2..2d98ba69bf 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -50,7 +50,9 @@ type BedrockChatResponse struct { Output struct { Message struct { Content []struct { - Text string `json:"text"` // Message content + Text *string `json:"text"` // Message content + // Bedrock returns a union type where either Text or ToolUse is present (mutually exclusive) + BedrockAnthropicToolUseMessage } `json:"content"` // Array of message content Role string `json:"role"` // Role of the message sender } `json:"message"` // Message structure @@ -95,8 +97,8 @@ type BedrockAnthropicImageMessage struct { // BedrockAnthropicImage represents image data for Anthropic models. type BedrockAnthropicImage struct { - Format string `json:"string"` // Image format - Source BedrockAnthropicImageSource `json:"source"` // Image source + Format string `json:"format,omitempty"` // Image format + Source BedrockAnthropicImageSource `json:"source,omitempty"` // Image source } // BedrockAnthropicImageSource represents the source of an image for Anthropic models. @@ -110,6 +112,16 @@ type BedrockMistralToolCall struct { Function schemas.FunctionCall `json:"function"` // Function to call } +type BedrockAnthropicToolUseMessage struct { + ToolUse *BedrockAnthropicToolUse `json:"toolUse"` +} + +type BedrockAnthropicToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + // BedrockAnthropicToolCall represents a tool call for Anthropic models. type BedrockAnthropicToolCall struct { ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"` // Tool specification @@ -415,7 +427,6 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { if msg.Role == schemas.ModelChatMessageRoleSystem { - //TODO handling image inputs here if msg.Content != nil { systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ Text: *msg.Content, @@ -427,43 +438,157 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema // Format messages for Bedrock API var bedrockMessages []map[string]interface{} for _, msg := range messages { + var content []interface{} if msg.Role != schemas.ModelChatMessageRoleSystem { - var content any - if msg.Content != nil { - content = BedrockAnthropicTextMessage{ - Type: "text", - Text: *msg.Content, - } - } else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { - var messageImageContent schemas.ImageContent - if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { - messageImageContent = *msg.UserMessage.ImageContent - } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { - messageImageContent = *msg.ToolMessage.ImageContent + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "toolUseId": *msg.ToolCallID, } - content = BedrockAnthropicImageMessage{ - Type: "image", - Image: BedrockAnthropicImage{ - Format: func() string { - if messageImageContent.Type != nil { - return *messageImageContent.Type + var toolResultContentBlock map[string]interface{} + if msg.Content != nil { + toolResultContentBlock = map[string]interface{}{} + var parsedJSON interface{} + err := json.Unmarshal([]byte(*msg.Content), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.([]interface{}); ok { + toolResultContentBlock["json"] = map[string]interface{}{"content": arr} + } else { + toolResultContentBlock["json"] = parsedJSON + } + } else { + toolResultContentBlock["text"] = *msg.Content + } + + toolCallResult["content"] = []interface{}{toolResultContentBlock} + + content = append(content, map[string]interface{}{ + "toolResult": toolCallResult, + }) + } + } else { + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{"arguments": toolCall.Function.Arguments} } - return "" - }(), - Source: BedrockAnthropicImageSource{ - Bytes: messageImageContent.URL, + } + + content = append(content, BedrockAnthropicToolUseMessage{ + ToolUse: &BedrockAnthropicToolUse{ + ToolUseID: *toolCall.ID, + Name: *toolCall.Function.Name, + Input: input, + }, + }) + } + } + + if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + var messageImageContent schemas.ImageContent + if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { + messageImageContent = *msg.UserMessage.ImageContent + } else if msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil { + messageImageContent = *msg.ToolMessage.ImageContent + } + + content = append(content, BedrockAnthropicImageMessage{ + Type: "image", + Image: BedrockAnthropicImage{ + Format: func() string { + if messageImageContent.MediaType != nil { + mediaType := *messageImageContent.MediaType + mediaType = strings.TrimPrefix(mediaType, "image/") + return mediaType + } + return "" + }(), + Source: BedrockAnthropicImageSource{ + Bytes: messageImageContent.URL, + }, }, - }, + }) + } + + if msg.Content != nil { + content = append(content, BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content, + }) + } + } + + if len(content) > 0 { + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } + } + } + + // Post-process bedrockMessages for tool call results + processedBedrockMessages := []map[string]interface{}{} + i := 0 + for i < len(bedrockMessages) { + currentMsg := bedrockMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk { + // If role is of an unexpected type or missing, treat as non-tool message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} + + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(bedrockMessages) { + nextMsg := bedrockMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ } - bedrockMessages = append(bedrockMessages, map[string]interface{}{ - "role": msg.Role, - "content": []interface{}{content}, - }) + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedBedrockMessages = append(processedBedrockMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ } } + bedrockMessages = processedBedrockMessages // Update with processed messages body := map[string]interface{}{ "messages": bedrockMessages, @@ -631,6 +756,56 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key, return result, nil } +// extractToolsFromHistory extracts minimal tool definitions from conversation history. +// It analyzes the messages to find tool-related content and returns whether tool content +// was found and a list of unique minimal tool definitions extracted from the conversation. +// This is needed when Bedrock requires toolConfig but no tools are provided in the current request. +func (provider *BedrockProvider) extractToolsFromHistory(messages []schemas.BifrostMessage) (bool, []BedrockAnthropicToolCall) { + hasToolContent := false + var toolsFromHistory []BedrockAnthropicToolCall + seenTools := make(map[string]BedrockAnthropicToolCall) + + for _, msg := range messages { + // Check for tool result messages + if msg.Role == schemas.ModelChatMessageRoleTool { + hasToolContent = true + } + // Check for assistant messages with tool calls + if msg.Role == schemas.ModelChatMessageRoleAssistant && msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + hasToolContent = true + // Extract tool definitions from tool calls for toolConfig + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + if _, exists := seenTools[toolName]; !exists { + // Create a basic tool definition from the tool call + // Note: We can't fully reconstruct the original tool definition, + // but we can provide a minimal one that satisfies Bedrock's requirement + tool := BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: toolName, + Description: fmt.Sprintf("Tool: %s", toolName), + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + } + seenTools[toolName] = tool + toolsFromHistory = append(toolsFromHistory, tool) + } + } + } + } + } + + return hasToolContent, toolsFromHistory +} + // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. @@ -644,7 +819,21 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key // Transform tools if present if params != nil && params.Tools != nil && len(*params.Tools) > 0 { - preparedParams["tools"] = provider.getChatCompletionTools(params, model) + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": provider.getChatCompletionTools(params, model), + } + } else { + // Check if conversation history contains tool use/result blocks + // Bedrock requires toolConfig when such blocks are present + hasToolContent, toolsFromHistory := provider.extractToolsFromHistory(messages) + + // If conversation contains tool content but no tools provided in current request, + // include the extracted tools to satisfy Bedrock's toolConfig requirement + if hasToolContent && len(toolsFromHistory) > 0 { + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": toolsFromHistory, + } + } } requestBody := mergeConfig(messageBody, preparedParams) @@ -680,16 +869,69 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key return nil, bifrostErr } - var choices []schemas.BifrostResponseChoice - for i, choice := range response.Output.Message.Content { - choices = append(choices, schemas.BifrostResponseChoice{ - Index: i, + // Collect all content and tool calls into a single message (similar to Anthropic aggregation) + var content strings.Builder + var toolCalls []schemas.ToolCall + + // Process content and tool calls + for _, choice := range response.Output.Message.Content { + if choice.Text != nil && *choice.Text != "" { + if content.Len() > 0 { + content.WriteString("\n") + } + content.WriteString(*choice.Text) + } + + if choice.ToolUse != nil { + input := choice.ToolUse.Input + if input == nil { + input = map[string]any{} + } + arguments, err := json.Marshal(input) + if err != nil { + arguments = []byte("{}") + } + + idCopy := choice.ToolUse.ToolUseID // copy to avoid unsafe pointer creation + nameCopy := choice.ToolUse.Name // copy to avoid unsafe pointer creation + toolCalls = append(toolCalls, schemas.ToolCall{ + Type: StrPtr("function"), + ID: &idCopy, + Function: schemas.FunctionCall{ + Name: &nameCopy, + Arguments: string(arguments), + }, + }) + } + } + + // Create the assistant message + messageContent := content.String() + var contentPtr *string + if messageContent != "" { + contentPtr = &messageContent + } + + var assistantMessage *schemas.AssistantMessage + + // Create AssistantMessage if we have tool calls + if len(toolCalls) > 0 { + assistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Create a single choice with the aggregated content + choices := []schemas.BifrostResponseChoice{ + { + Index: 0, Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: &choice.Text, + Role: schemas.ModelChatMessageRoleAssistant, + Content: contentPtr, + AssistantMessage: assistantMessage, }, FinishReason: &response.StopReason, - }) + }, } latency := float64(response.Metrics.Latency) diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 0dc995a23b..82749aaafc 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "slices" + "strings" "sync" "time" @@ -153,6 +154,87 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s "role": msg.Role, } + if msg.Role == schemas.ModelChatMessageRoleAssistant { + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + var toolCalls []map[string]interface{} + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var arguments map[string]interface{} + var parsedJSON interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + arguments = arr + } else { + arguments = map[string]interface{}{"content": parsedJSON} + } + } else { + arguments = map[string]interface{}{"content": toolCall.Function.Arguments} + } + + toolCalls = append(toolCalls, map[string]interface{}{ + "name": toolCall.Function.Name, + "parameters": arguments, + }) + } + historyMsg["tool_calls"] = toolCalls + } + } else if msg.Role == schemas.ModelChatMessageRoleTool { + // Find the original tool call parameters from conversation history + var toolCallParameters map[string]interface{} + + // Look back through the chat history to find the assistant message with the matching tool call + for i := len(chatHistory) - 1; i >= 0; i-- { + prevMsg := chatHistory[i] + if prevMsg.Role == schemas.ModelChatMessageRoleAssistant && + prevMsg.AssistantMessage != nil && + prevMsg.AssistantMessage.ToolCalls != nil { + + // Search through tool calls in this assistant message + for _, toolCall := range *prevMsg.AssistantMessage.ToolCalls { + if toolCall.ID != nil && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil && + *toolCall.ID == *msg.ToolMessage.ToolCallID { + + // Found the matching tool call, extract its parameters + var parsedJSON interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + toolCallParameters = arr + } else { + toolCallParameters = map[string]interface{}{"content": parsedJSON} + } + } else { + toolCallParameters = map[string]interface{}{"content": toolCall.Function.Arguments} + } + break + } + } + + // If we found the parameters, stop searching + if toolCallParameters != nil { + break + } + } + } + + // If no parameters found, use empty map as fallback + if toolCallParameters == nil { + toolCallParameters = map[string]interface{}{} + } + + toolResults := []map[string]interface{}{ + { + "call": map[string]interface{}{ + "name": *msg.ToolMessage.ToolCallID, + "parameters": toolCallParameters, + }, + "outputs": *msg.Content, + }, + } + + historyMsg["tool_results"] = toolResults + } + // Only add message content if it's not nil if msg.Content != nil { historyMsg["message"] = *msg.Content @@ -207,6 +289,10 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s } requestBody["tools"] = tools } + // Add tool choice if present + if params != nil && params.ToolChoice != nil { + requestBody["tool_choice"] = strings.ToUpper(string(params.ToolChoice.Type)) + } // Marshal request body jsonBody, err := json.Marshal(requestBody) diff --git a/core/providers/openai.go b/core/providers/openai.go index 8c1fc02b6a..8e8e7eee77 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -115,7 +115,7 @@ func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model, key, // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(model, messages, params) + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ "model": model, @@ -205,11 +205,20 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key s return result, nil } -func prepareOpenAIChatRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { +func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { // Format messages for OpenAI API var formattedMessages []map[string]interface{} for _, msg := range messages { - if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { + if msg.Role == schemas.ModelChatMessageRoleAssistant { + assistantMessage := map[string]interface{}{ + "role": msg.Role, + "content": coalesceString(msg.Content), + } + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + assistantMessage["tool_calls"] = *msg.AssistantMessage.ToolCalls + } + formattedMessages = append(formattedMessages, assistantMessage) + } else if (msg.UserMessage != nil && msg.UserMessage.ImageContent != nil) || (msg.ToolMessage != nil && msg.ToolMessage.ImageContent != nil) { var messageImageContent schemas.ImageContent if msg.UserMessage != nil && msg.UserMessage.ImageContent != nil { messageImageContent = *msg.UserMessage.ImageContent @@ -247,7 +256,7 @@ func prepareOpenAIChatRequest(model string, messages []schemas.BifrostMessage, p } else { message := map[string]interface{}{ "role": msg.Role, - "content": msg.Content, + "content": coalesceString(msg.Content), } if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { diff --git a/core/providers/utils.go b/core/providers/utils.go index 7fa930216b..d3ec36ba25 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -268,12 +268,45 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ return rawResponse, nil } +// getRoleFromMessage extracts and validates the role from a message map. +func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRole, bool) { + roleVal, exists := msg["role"] + if !exists { + return "", false // Role key doesn't exist + } + + // Try direct assertion to ModelChatMessageRole + roleAsModelType, ok := roleVal.(schemas.ModelChatMessageRole) + if ok { + return roleAsModelType, true + } + + // Try assertion to string and then convert + roleAsString, okStr := roleVal.(string) + if okStr { + return schemas.ModelChatMessageRole(roleAsString), true + } + + return "", false // Role is of an unexpected or invalid type +} + // float64Ptr creates a pointer to a float64 value. // This is a helper function for creating pointers to float64 values. func float64Ptr(f float64) *float64 { return &f } +// StrPtr creates a pointer to a string value. +// This is a helper function for creating pointers to string values. func StrPtr(s string) *string { return &s } + +// coalesceString returns the string value of a pointer to a string, or an empty string if the pointer is nil. +// This is a helper function for safely handling pointer-to-string values. +func coalesceString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/core/providers/vertex.go b/core/providers/vertex.go index b47d81c066..e5dbb08683 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -95,9 +95,9 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s var preparedParams map[string]interface{} if strings.Contains(model, "claude") { - formattedMessages, preparedParams = prepareAnthropicChatRequest(model, messages, params) + formattedMessages, preparedParams = prepareAnthropicChatRequest(messages, params) } else { - formattedMessages, preparedParams = prepareOpenAIChatRequest(model, messages, params) + formattedMessages, preparedParams = prepareOpenAIChatRequest(messages, params) } requestBody := mergeConfig(map[string]interface{}{ diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index ed852f7ca7..59c485a220 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -116,16 +116,16 @@ type Tool struct { type ToolChoiceType string const ( - // ToolChoiceNone means no tool will be called - ToolChoiceNone ToolChoiceType = "none" - // ToolChoiceAuto means the model can choose whether to call a tool - ToolChoiceAuto ToolChoiceType = "auto" - // ToolChoiceAny means any tool can be called - ToolChoiceAny ToolChoiceType = "any" - // ToolChoiceTool means a specific tool must be called - ToolChoiceTool ToolChoiceType = "tool" - // ToolChoiceRequired means a tool must be called - ToolChoiceRequired ToolChoiceType = "required" + // ToolChoiceTypeNone means no tool will be called + ToolChoiceTypeNone ToolChoiceType = "none" + // ToolChoiceTypeAuto means the model can choose whether to call a tool + ToolChoiceTypeAuto ToolChoiceType = "auto" + // ToolChoiceTypeAny means any tool can be called + ToolChoiceTypeAny ToolChoiceType = "any" + // ToolChoiceTypeFunction means a specific tool must be called (converted to "tool" for Anthropic) + ToolChoiceTypeFunction ToolChoiceType = "function" + // ToolChoiceTypeRequired means a tool must be called + ToolChoiceTypeRequired ToolChoiceType = "required" ) // ToolChoiceFunction represents a specific function to be called. @@ -135,8 +135,8 @@ type ToolChoiceFunction struct { // ToolChoice represents how a tool should be chosen for a request. type ToolChoice struct { - Type ToolChoiceType `json:"type"` // Type of tool choice - Function ToolChoiceFunction `json:"function"` // Function to call if type is ToolChoiceTool + Type ToolChoiceType `json:"type"` // Type of tool choice + Function ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction } // BifrostMessage represents a message in a chat conversation. @@ -145,6 +145,7 @@ type BifrostMessage struct { Content *string `json:"content,omitempty"` // Embedded pointer structs - when non-nil, their exported fields are flattened into the top-level JSON object + // IMPORTANT: Only one of the following can be non-nil at a time, otherwise the JSON marshalling will override the common fields *UserMessage *ToolMessage *AssistantMessage diff --git a/core/tests/bedrock_test.go b/core/tests/bedrock_test.go index 6c9b6126ff..1e7002604f 100644 --- a/core/tests/bedrock_test.go +++ b/core/tests/bedrock_test.go @@ -25,8 +25,8 @@ func TestBedrock(t *testing.T) { ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", SetupText: true, SetupToolCalls: true, - SetupImage: true, - SetupBaseImage: false, + SetupImage: false, // bedrock does not support image URLs + SetupBaseImage: true, CustomParams: &schemas.ModelParameters{ MaxTokens: &maxTokens, }, diff --git a/core/tests/e2e_tool_test.go b/core/tests/e2e_tool_test.go new file mode 100644 index 0000000000..5353538f56 --- /dev/null +++ b/core/tests/e2e_tool_test.go @@ -0,0 +1,128 @@ +package tests + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/joho/godotenv" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToolCallingEndToEnd(t *testing.T) { + // Load environment variables + if err := godotenv.Load(); err != nil && !os.IsNotExist(err) { + t.Fatalf("Error loading .env: %v", err) + } + + // Initialize Bifrost client + client, err := getBifrost() + require.NoError(t, err) + require.NotNil(t, client) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + provider := schemas.Bedrock + model := "anthropic.claude-3-sonnet-20240229-v1:0" + + // Step 1: User asks for weather, LLM should request tool usage + userMessage := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: bifrost.Ptr("What's the weather in London?"), + } + + toolParams := WeatherToolParams + toolParams.ToolChoice = &schemas.ToolChoice{ + Type: schemas.ToolChoiceTypeFunction, + Function: schemas.ToolChoiceFunction{ + Name: "get_weather", + }, + } + toolParams.MaxTokens = bifrost.Ptr(1000) + + firstRequest := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{userMessage}, + }, + Params: &toolParams, + } + + // Execute first request + firstResponse, bifrostErr := client.ChatCompletionRequest(ctx, firstRequest) + require.Nilf(t, bifrostErr, "First request failed: %v", bifrostErr) + require.NotNil(t, firstResponse) + require.NotEmpty(t, firstResponse.Choices) + + // Verify tool call was requested + message := firstResponse.Choices[0].Message + require.NotNil(t, message.AssistantMessage) + require.NotNil(t, message.AssistantMessage.ToolCalls) + require.Len(t, *message.AssistantMessage.ToolCalls, 1) + + toolCall := (*message.AssistantMessage.ToolCalls)[0] + // Only assert on Type if it's populated by the provider + if toolCall.Type != nil { + assert.Equal(t, "function", *toolCall.Type) + } + // Only assert on Function.Name if it's not nil to prevent panic + require.NotNil(t, toolCall.Function.Name, "toolCall.Function.Name should not be nil") + assert.Equal(t, "get_weather", *toolCall.Function.Name) + require.NotNil(t, toolCall.ID) + + // Verify tool arguments contain location + var params map[string]interface{} + err = json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms) + require.NoError(t, err) + assert.Contains(t, params, "location") + + // Step 2: Simulate tool execution and provide result to LLM + toolResult := `{"temperature": "15", "unit": "celsius", "description": "Partly cloudy"}` + + conversationMessages := []schemas.BifrostMessage{ + userMessage, + message, + { + Role: schemas.ModelChatMessageRoleTool, + Content: &toolResult, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: toolCall.ID, + }, + }, + } + + secondRequest := &schemas.BifrostRequest{ + Provider: provider, + Model: model, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationMessages, + }, + Params: &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(1000), + }, + } + + // Execute second request + finalResponse, bifrostErr := client.ChatCompletionRequest(ctx, secondRequest) + require.Nilf(t, bifrostErr, "Second request failed: %v", bifrostErr) + require.NotNil(t, finalResponse) + require.NotEmpty(t, finalResponse.Choices) + + // Verify final response + finalMessage := finalResponse.Choices[0].Message + require.NotNil(t, finalMessage.Content) + + content := *finalMessage.Content + assert.Contains(t, content, "London", "Response should mention London") + assert.Contains(t, content, "15", "Response should mention temperature") + assert.Contains(t, content, "cloudy", "Response should mention weather description") + + t.Logf("Final response: %s", content) +} diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go index f748e08ca6..7dc7fa8c7c 100644 --- a/core/tests/openai_test.go +++ b/core/tests/openai_test.go @@ -21,7 +21,7 @@ func TestOpenAI(t *testing.T) { TextModel: "gpt-4o-mini", ChatModel: "gpt-4o-mini", SetupText: false, // OpenAI does not support text completion - SetupToolCalls: false, + SetupToolCalls: true, SetupImage: false, SetupBaseImage: false, Fallbacks: []schemas.Fallback{ diff --git a/core/tests/tests.go b/core/tests/tests.go index 9830f63fb7..e439cfbea3 100644 --- a/core/tests/tests.go +++ b/core/tests/tests.go @@ -149,6 +149,7 @@ func setupChatCompletionRequests(bifrostClient *bifrost.Bifrost, config TestConf Content: &msg, }, } + result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ Provider: config.Provider, Model: config.ChatModel, @@ -195,24 +196,26 @@ func setupImageTests(bifrostClient *bifrost.Bifrost, config TestConfig, ctx cont }, } - wg.Add(1) - go func() { - defer wg.Done() - result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ - Provider: config.Provider, - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &urlImageMessages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }) - if err != nil { - log.Println("Error in", config.Provider, "URL image request:", err.Error.Message) - } else { - log.Println("🐒", config.Provider, "URL Image Result:", *result.Choices[0].Message.Content) - } - }() + if config.SetupImage { + wg.Add(1) + go func() { + defer wg.Done() + result, err := bifrostClient.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: config.Provider, + Model: config.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &urlImageMessages, + }, + Params: ¶ms, + Fallbacks: config.Fallbacks, + }) + if err != nil { + log.Println("Error in", config.Provider, "URL image request:", err.Error.Message) + } else { + log.Println("🐒", config.Provider, "URL Image Result:", *result.Choices[0].Message.Content) + } + }() + } // Base64 image test (only for providers that support it) if config.SetupBaseImage { @@ -345,7 +348,7 @@ func SetupAllRequests(bifrostClient *bifrost.Bifrost, config TestConfig) { setupChatCompletionRequests(bifrostClient, config, ctx, &wg) - if config.SetupImage { + if config.SetupImage || config.SetupBaseImage { setupImageTests(bifrostClient, config, ctx, &wg) } diff --git a/core/tests/vertex_test.go b/core/tests/vertex_test.go index c580284442..89265aae57 100644 --- a/core/tests/vertex_test.go +++ b/core/tests/vertex_test.go @@ -20,7 +20,7 @@ func TestVertex(t *testing.T) { Provider: schemas.Vertex, ChatModel: "google/gemini-2.0-flash-001", SetupText: false, // Vertex does not support text completion - SetupToolCalls: false, + SetupToolCalls: true, SetupImage: false, SetupBaseImage: false, } diff --git a/transports/bifrost-http/integrations/genai/types.go b/transports/bifrost-http/integrations/genai/types.go index ce89a9b134..924b5badac 100644 --- a/transports/bifrost-http/integrations/genai/types.go +++ b/transports/bifrost-http/integrations/genai/types.go @@ -3,6 +3,7 @@ package genai import ( "encoding/json" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" genai_sdk "google.golang.org/genai" ) @@ -21,13 +22,13 @@ func (r *GeminiChatRequest) ConvertToBifrostRequest(modelStr string) *schemas.Bi Provider: schemas.Vertex, Model: modelStr, Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.Message{}, + ChatCompletionInput: &[]schemas.BifrostMessage{}, }, } // Convert messages (contents) for _, content := range r.Contents { - var bifrostMsg schemas.Message + var bifrostMsg schemas.BifrostMessage bifrostMsg.Role = schemas.ModelChatMessageRole(content.Role) if len(content.Parts) > 0 { @@ -37,15 +38,16 @@ func (r *GeminiChatRequest) ConvertToBifrostRequest(modelStr string) *schemas.Bi bifrostMsg.Content = &part.Text case part.FunctionCall != nil: - toolCalls := []schemas.Tool{ + jsonArgs, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + jsonArgs = []byte("{}") + } + toolCalls := []schemas.ToolCall{ { - Type: "function", - Function: schemas.Function{ - Name: part.FunctionCall.Name, - Parameters: schemas.FunctionParameters{ - Type: "object", - Properties: part.FunctionCall.Args, - }, + Type: bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)), + Function: schemas.FunctionCall{ + Name: &part.FunctionCall.Name, + Arguments: string(jsonArgs), }, }, } diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go index 5b4ba76c1a..8e953ace17 100644 --- a/transports/bifrost-http/main.go +++ b/transports/bifrost-http/main.go @@ -91,7 +91,7 @@ func init() { // It includes all necessary fields for both types of completions. type CompletionRequest struct { Provider schemas.ModelProvider `json:"provider"` // The AI model provider to use - Messages []schemas.Message `json:"messages"` // Chat messages (for chat completion) + Messages []schemas.BifrostMessage `json:"messages"` // Chat messages (for chat completion) Text string `json:"text"` // Text input (for text completion) Model string `json:"model"` // Model to use Params *schemas.ModelParameters `json:"params"` // Additional model parameters @@ -215,7 +215,7 @@ func main() { log.Fatalf("failed to start server: %v", err) } - client.Shutdown() + client.Cleanup() } // handleCompletion processes both text and chat completion requests. diff --git a/transports/go.mod b/transports/go.mod index 1b755b961a..7403c1320e 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -4,7 +4,7 @@ go 1.24.1 require ( github.com/fasthttp/router v1.5.4 - github.com/maximhq/bifrost/core v1.0.7 + github.com/maximhq/bifrost/core v1.1.0 github.com/maximhq/bifrost/plugins/maxim v1.0.2 github.com/prometheus/client_golang v1.22.0 github.com/valyala/fasthttp v1.62.0