diff --git a/internal/extproc/translator/gemini_helper.go b/internal/extproc/translator/gemini_helper.go index 18c3a7dfb3..f171ad3979 100644 --- a/internal/extproc/translator/gemini_helper.go +++ b/internal/extproc/translator/gemini_helper.go @@ -6,11 +6,19 @@ package translator import ( + "encoding/json" "fmt" + "mime" + "net/url" + "path" "strconv" "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/uuid" + "google.golang.org/genai" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" ) const ( @@ -20,6 +28,455 @@ const ( HTTPHeaderKeyContentLength = "Content-Length" ) +// ------------------------------------------------------------- +// Request Conversion Helper for OpenAI to GCP Gemini Translator +// -------------------------------------------------------------. + +// openAIMessagesToGeminiContents converts OpenAI messages to Gemini Contents and SystemInstruction. +func openAIMessagesToGeminiContents(messages []openai.ChatCompletionMessageParamUnion) ([]genai.Content, *genai.Content, error) { + var gcpContents []genai.Content + var systemInstruction *genai.Content + knownToolCalls := make(map[string]string) + var gcpParts []*genai.Part + + for _, msgUnion := range messages { + switch msgUnion.Type { + case openai.ChatMessageRoleDeveloper: + msg := msgUnion.Value.(openai.ChatCompletionDeveloperMessageParam) + inst, err := developerMsgToGeminiParts(msg) + if err != nil { + return nil, nil, fmt.Errorf("error converting developer message: %w", err) + } + if len(inst) != 0 { + if systemInstruction == nil { + systemInstruction = &genai.Content{} + } + systemInstruction.Parts = append(systemInstruction.Parts, inst...) + } + case openai.ChatMessageRoleSystem: + msg := msgUnion.Value.(openai.ChatCompletionSystemMessageParam) + devMsg := systemMsgToDeveloperMsg(msg) + inst, err := developerMsgToGeminiParts(devMsg) + if err != nil { + return nil, nil, fmt.Errorf("error converting developer message: %w", err) + } + if len(inst) != 0 { + if systemInstruction == nil { + systemInstruction = &genai.Content{} + } + systemInstruction.Parts = append(systemInstruction.Parts, inst...) + } + case openai.ChatMessageRoleUser: + msg := msgUnion.Value.(openai.ChatCompletionUserMessageParam) + parts, err := userMsgToGeminiParts(msg) + if err != nil { + return nil, nil, fmt.Errorf("error converting user message: %w", err) + } + gcpParts = append(gcpParts, parts...) + case openai.ChatMessageRoleTool: + msg := msgUnion.Value.(openai.ChatCompletionToolMessageParam) + part, err := toolMsgToGeminiParts(msg, knownToolCalls) + if err != nil { + return nil, nil, fmt.Errorf("error converting tool message: %w", err) + } + gcpParts = append(gcpParts, part) + case openai.ChatMessageRoleAssistant: + // Flush any accumulated user/tool parts before assistant. + if len(gcpParts) > 0 { + gcpContents = append(gcpContents, genai.Content{Role: genai.RoleUser, Parts: gcpParts}) + gcpParts = nil + } + msg := msgUnion.Value.(openai.ChatCompletionAssistantMessageParam) + assistantParts, toolCalls, err := assistantMsgToGeminiParts(msg) + if err != nil { + return nil, nil, fmt.Errorf("error converting assistant message: %w", err) + } + for k, v := range toolCalls { + knownToolCalls[k] = v + } + gcpContents = append(gcpContents, genai.Content{Role: genai.RoleModel, Parts: assistantParts}) + default: + return nil, nil, fmt.Errorf("invalid role in message: %s", msgUnion.Type) + } + } + + // If there are any remaining parts after processing all messages, add them as user content. + if len(gcpParts) > 0 { + gcpContents = append(gcpContents, genai.Content{Role: genai.RoleUser, Parts: gcpParts}) + } + return gcpContents, systemInstruction, nil +} + +// systemMsgToDeveloperMsg converts OpenAI system message to developer message. +// Since systemMsg is deprecated, this function is provided to maintain backward compatibility. +func systemMsgToDeveloperMsg(msg openai.ChatCompletionSystemMessageParam) openai.ChatCompletionDeveloperMessageParam { + // Convert OpenAI system message to developer message. + return openai.ChatCompletionDeveloperMessageParam{ + Name: msg.Name, + Role: openai.ChatMessageRoleDeveloper, + Content: msg.Content, + } +} + +// developerMsgToGeminiParts converts OpenAI developer message to Gemini Content. +func developerMsgToGeminiParts(msg openai.ChatCompletionDeveloperMessageParam) ([]*genai.Part, error) { + var parts []*genai.Part + + switch contentValue := msg.Content.Value.(type) { + case string: + if contentValue != "" { + parts = append(parts, genai.NewPartFromText(contentValue)) + } + case []openai.ChatCompletionContentPartTextParam: + if len(contentValue) > 0 { + for _, textParam := range contentValue { + if textParam.Text != "" { + parts = append(parts, genai.NewPartFromText(textParam.Text)) + } + } + } + default: + return nil, fmt.Errorf("unsupported content type in developer message: %T", contentValue) + + } + return parts, nil +} + +// userMsgToGeminiParts converts OpenAI user message to Gemini Parts. +func userMsgToGeminiParts(msg openai.ChatCompletionUserMessageParam) ([]*genai.Part, error) { + var parts []*genai.Part + switch contentValue := msg.Content.Value.(type) { + case string: + if contentValue != "" { + parts = append(parts, genai.NewPartFromText(contentValue)) + } + case []openai.ChatCompletionContentPartUserUnionParam: + for _, content := range contentValue { + switch { + case content.TextContent != nil: + parts = append(parts, genai.NewPartFromText(content.TextContent.Text)) + case content.ImageContent != nil: + imgURL := content.ImageContent.ImageURL.URL + if imgURL == "" { + // If image URL is empty, we skip it. + continue + } + + parsedURL, err := url.Parse(imgURL) + if err != nil { + return nil, fmt.Errorf("invalid image URL: %w", err) + } + + if parsedURL.Scheme == "data" { + mimeType, imgBytes, err := parseDataURI(imgURL) + if err != nil { + return nil, fmt.Errorf("failed to parse data URI: %w", err) + } + parts = append(parts, genai.NewPartFromBytes(imgBytes, mimeType)) + } else { + // Identify mimeType based in image url. + mimeType := mimeTypeImageJPEG // Default to jpeg if unknown. + if mt := mime.TypeByExtension(path.Ext(imgURL)); mt != "" { + mimeType = mt + } + + parts = append(parts, genai.NewPartFromURI(imgURL, mimeType)) + } + case content.InputAudioContent != nil: + // Audio content is currently not supported in this implementation. + return nil, fmt.Errorf("audio content not supported yet") + } + } + default: + return nil, fmt.Errorf("unsupported content type in user message: %T", contentValue) + } + return parts, nil +} + +// toolMsgToGeminiParts converts OpenAI tool message to Gemini Parts. +func toolMsgToGeminiParts(msg openai.ChatCompletionToolMessageParam, knownToolCalls map[string]string) (*genai.Part, error) { + var part *genai.Part + name := knownToolCalls[msg.ToolCallID] + funcResponse := "" + switch contentValue := msg.Content.Value.(type) { + case string: + funcResponse = contentValue + case []openai.ChatCompletionContentPartTextParam: + for _, textParam := range contentValue { + if textParam.Text != "" { + funcResponse += textParam.Text + } + } + default: + return nil, fmt.Errorf("unsupported content type in tool message: %T", contentValue) + } + + part = genai.NewPartFromFunctionResponse(name, map[string]any{"output": funcResponse}) + return part, nil +} + +// assistantMsgToGeminiParts converts OpenAI assistant message to Gemini Parts and known tool calls. +func assistantMsgToGeminiParts(msg openai.ChatCompletionAssistantMessageParam) ([]*genai.Part, map[string]string, error) { + var parts []*genai.Part + + // Handle tool calls in the assistant message. + knownToolCalls := make(map[string]string) + for _, toolCall := range msg.ToolCalls { + knownToolCalls[toolCall.ID] = toolCall.Function.Name + var parsedArgs map[string]any + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedArgs); err != nil { + return nil, nil, fmt.Errorf("function arguments should be valid json string. failed to parse function arguments: %w", err) + } + parts = append(parts, genai.NewPartFromFunctionCall(toolCall.Function.Name, parsedArgs)) + } + + // Handle content in the assistant message. + switch v := msg.Content.Value.(type) { + case string: + if v != "" { + parts = append(parts, genai.NewPartFromText(v)) + } + case []openai.ChatCompletionAssistantMessageParamContent: + for _, contPart := range v { + switch contPart.Type { + case openai.ChatCompletionAssistantMessageParamContentTypeText: + if contPart.Text != nil && *contPart.Text != "" { + parts = append(parts, genai.NewPartFromText(*contPart.Text)) + } + case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: + // Refusal messages are currently ignored in this implementation. + default: + return nil, nil, fmt.Errorf("unsupported content type in assistant message: %s", contPart.Type) + } + } + case nil: + // No content provided, this is valid. + default: + return nil, nil, fmt.Errorf("unsupported content type in assistant message: %T", v) + } + + return parts, knownToolCalls, nil +} + +// openAIReqToGeminiGenerationConfig converts OpenAI request to Gemini GenerationConfig. +func openAIReqToGeminiGenerationConfig(openAIReq *openai.ChatCompletionRequest) (*genai.GenerationConfig, error) { + gc := &genai.GenerationConfig{} + if openAIReq.Temperature != nil { + f := float32(*openAIReq.Temperature) + gc.Temperature = &f + } + if openAIReq.TopP != nil { + f := float32(*openAIReq.TopP) + gc.TopP = &f + } + + if openAIReq.Seed != nil { + seed := int32(*openAIReq.Seed) // nolint:gosec + gc.Seed = &seed + } + + if openAIReq.TopLogProbs != nil { + logProbs := int32(*openAIReq.TopLogProbs) // nolint:gosec + gc.Logprobs = &logProbs + } + + if openAIReq.LogProbs != nil { + gc.ResponseLogprobs = *openAIReq.LogProbs + } + + if openAIReq.N != nil { + gc.CandidateCount = int32(*openAIReq.N) // nolint:gosec + } + if openAIReq.MaxTokens != nil { + gc.MaxOutputTokens = int32(*openAIReq.MaxTokens) // nolint:gosec + } + if openAIReq.PresencePenalty != nil { + gc.PresencePenalty = openAIReq.PresencePenalty + } + if openAIReq.FrequencyPenalty != nil { + gc.FrequencyPenalty = openAIReq.FrequencyPenalty + } + if len(openAIReq.Stop) > 0 { + var stops []string + for _, s := range openAIReq.Stop { + if s != nil { + stops = append(stops, *s) + } + } + gc.StopSequences = stops + } + return gc, nil +} + +// -------------------------------------------------------------- +// Response Conversion Helper for GCP Gemini to OpenAI Translator +// --------------------------------------------------------------. + +// geminiCandidatesToOpenAIChoices converts Gemini candidates to OpenAI choices. +func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate) ([]openai.ChatCompletionResponseChoice, error) { + choices := make([]openai.ChatCompletionResponseChoice, 0, len(candidates)) + + for idx, candidate := range candidates { + if candidate == nil { + continue + } + + // Create the choice. + choice := openai.ChatCompletionResponseChoice{ + Index: int64(idx), + FinishReason: geminiFinishReasonToOpenAI(candidate.FinishReason), + } + + if candidate.Content != nil { + message := openai.ChatCompletionResponseChoiceMessage{ + Role: openai.ChatMessageRoleAssistant, + } + // Extract text from parts. + content := extractTextFromGeminiParts(candidate.Content.Parts) + message.Content = &content + + // Extract tool calls if any. + toolCalls, err := extractToolCallsFromGeminiParts(candidate.Content.Parts) + if err != nil { + return nil, fmt.Errorf("error extracting tool calls: %w", err) + } + message.ToolCalls = toolCalls + + // If there's no content but there are tool calls, set content to nil. + if content == "" && len(toolCalls) > 0 { + message.Content = nil + } + + choice.Message = message + } + + // Handle logprobs if available. + if candidate.LogprobsResult != nil { + choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult) + } + + choices = append(choices, choice) + } + + return choices, nil +} + +// geminiFinishReasonToOpenAI converts Gemini finish reason to OpenAI finish reason. +func geminiFinishReasonToOpenAI(reason genai.FinishReason) openai.ChatCompletionChoicesFinishReason { + switch reason { + case genai.FinishReasonStop: + return openai.ChatCompletionChoicesFinishReasonStop + case genai.FinishReasonMaxTokens: + return openai.ChatCompletionChoicesFinishReasonLength + default: + return openai.ChatCompletionChoicesFinishReasonContentFilter + } +} + +// extractTextFromGeminiParts extracts text from Gemini parts. +func extractTextFromGeminiParts(parts []*genai.Part) string { + var text string + for _, part := range parts { + if part != nil && part.Text != "" { + text += part.Text + } + } + return text +} + +// extractToolCallsFromGeminiParts extracts tool calls from Gemini parts. +func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) { + var toolCalls []openai.ChatCompletionMessageToolCallParam + + for _, part := range parts { + if part == nil || part.FunctionCall == nil { + continue + } + + // Convert function call arguments to JSON string. + args, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + return nil, fmt.Errorf("failed to marshal function arguments: %w", err) + } + + // Generate a random ID for the tool call. + toolCallID := uuid.New().String() + + toolCall := openai.ChatCompletionMessageToolCallParam{ + ID: toolCallID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: part.FunctionCall.Name, + Arguments: string(args), + }, + } + + toolCalls = append(toolCalls, toolCall) + } + + if len(toolCalls) == 0 { + return nil, nil + } + + return toolCalls, nil +} + +// geminiUsageToOpenAIUsage converts Gemini usage metadata to OpenAI usage. +func geminiUsageToOpenAIUsage(metadata *genai.GenerateContentResponseUsageMetadata) openai.ChatCompletionResponseUsage { + if metadata == nil { + return openai.ChatCompletionResponseUsage{} + } + + return openai.ChatCompletionResponseUsage{ + CompletionTokens: int(metadata.CandidatesTokenCount), + PromptTokens: int(metadata.PromptTokenCount), + TotalTokens: int(metadata.TotalTokenCount), + } +} + +// geminiLogprobsToOpenAILogprobs converts Gemini logprobs to OpenAI logprobs. +func geminiLogprobsToOpenAILogprobs(logprobsResult genai.LogprobsResult) openai.ChatCompletionChoicesLogprobs { + if len(logprobsResult.ChosenCandidates) == 0 { + return openai.ChatCompletionChoicesLogprobs{} + } + + content := make([]openai.ChatCompletionTokenLogprob, 0, len(logprobsResult.ChosenCandidates)) + + for i := 0; i < len(logprobsResult.ChosenCandidates); i++ { + chosen := logprobsResult.ChosenCandidates[i] + + var topLogprobs []openai.ChatCompletionTokenLogprobTopLogprob + + // Process top candidates if available. + if i < len(logprobsResult.TopCandidates) && logprobsResult.TopCandidates[i] != nil { + topCandidates := logprobsResult.TopCandidates[i].Candidates + if len(topCandidates) > 0 { + topLogprobs = make([]openai.ChatCompletionTokenLogprobTopLogprob, 0, len(topCandidates)) + for _, tc := range topCandidates { + topLogprobs = append(topLogprobs, openai.ChatCompletionTokenLogprobTopLogprob{ + Token: tc.Token, + Logprob: float64(tc.LogProbability), + }) + } + } + } + + // Create token logprob. + tokenLogprob := openai.ChatCompletionTokenLogprob{ + Token: chosen.Token, + Logprob: float64(chosen.LogProbability), + TopLogprobs: topLogprobs, + } + + content = append(content, tokenLogprob) + } + + // Return the logprobs. + return openai.ChatCompletionChoicesLogprobs{ + Content: content, + } +} + func buildGCPModelPathSuffix(publisher, model, gcpMethod string) string { pathSuffix := fmt.Sprintf("publishers/%s/models/%s:%s", publisher, model, gcpMethod) return pathSuffix @@ -29,21 +486,27 @@ func buildGCPModelPathSuffix(publisher, model, gcpMethod string) string { // It sets the ":path" header, the "content-length" header and the request body. func buildGCPRequestMutations(path string, reqBody []byte) (*ext_procv3.HeaderMutation, *ext_procv3.BodyMutation) { var bodyMutation *ext_procv3.BodyMutation + var headerMutation *ext_procv3.HeaderMutation // Create header mutation. - headerMutation := &ext_procv3.HeaderMutation{ - SetHeaders: []*corev3.HeaderValueOption{ - { - Header: &corev3.HeaderValue{ - Key: ":path", - RawValue: []byte(path), + if len(path) != 0 { + headerMutation = &ext_procv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte(path), + }, }, }, - }, + } } // If the request body is not empty, we set the content-length header and create a body mutation. if len(reqBody) != 0 { + if headerMutation == nil { + headerMutation = &ext_procv3.HeaderMutation{} + } // Set the "content-length" header. headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ Header: &corev3.HeaderValue{ @@ -56,7 +519,6 @@ func buildGCPRequestMutations(path string, reqBody []byte) (*ext_procv3.HeaderMu bodyMutation = &ext_procv3.BodyMutation{ Mutation: &ext_procv3.BodyMutation_Body{Body: reqBody}, } - } return headerMutation, bodyMutation diff --git a/internal/extproc/translator/gemini_helper_test.go b/internal/extproc/translator/gemini_helper_test.go new file mode 100644 index 0000000000..6b89e9d5e1 --- /dev/null +++ b/internal/extproc/translator/gemini_helper_test.go @@ -0,0 +1,786 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestOpenAIMessagesToGeminiContents(t *testing.T) { + tests := []struct { + name string + messages []openai.ChatCompletionMessageParamUnion + expectedErrorMsg string + expectedContents []genai.Content + expectedSystemInstruction *genai.Content + }{ + { + name: "happy-path", + messages: []openai.ChatCompletionMessageParamUnion{ + { + Type: openai.ChatMessageRoleDeveloper, + Value: openai.ChatCompletionDeveloperMessageParam{ + Role: openai.ChatMessageRoleDeveloper, + Content: openai.StringOrArray{Value: "This is a developer message"}, + }, + }, + { + Type: openai.ChatMessageRoleSystem, + Value: openai.ChatCompletionSystemMessageParam{ + Role: openai.ChatMessageRoleSystem, + Content: openai.StringOrArray{Value: "This is a system message"}, + }, + }, + { + Type: openai.ChatMessageRoleUser, + Value: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{Value: "This is a user message"}, + }, + }, + { + Type: openai.ChatMessageRoleAssistant, + Value: openai.ChatCompletionAssistantMessageParam{ + Role: openai.ChatMessageRoleAssistant, + Audio: openai.ChatCompletionAssistantMessageParamAudio{}, + Content: openai.StringOrAssistantRoleContentUnion{Value: "This is a assistant message"}, + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: "tool_call_1", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "example_tool", + Arguments: "{\"param1\":\"value1\"}", + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + }, + }, + }, + { + Type: openai.ChatMessageRoleTool, + Value: openai.ChatCompletionToolMessageParam{ + ToolCallID: "tool_call_1", + Content: openai.StringOrArray{Value: "This is a message from the example_tool"}, + }, + }, + }, + expectedContents: []genai.Content{ + { + Parts: []*genai.Part{ + {Text: "This is a user message"}, + }, + Role: genai.RoleUser, + }, + { + Role: genai.RoleModel, + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + Name: "example_tool", + Args: map[string]any{ + "param1": "value1", + }, + }, + }, + {Text: "This is a assistant message"}, + }, + }, + { + Role: genai.RoleUser, + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + Name: "example_tool", + Response: map[string]any{ + "output": "This is a message from the example_tool", + }, + }, + }, + }, + }, + }, + expectedSystemInstruction: &genai.Content{ + Parts: []*genai.Part{ + {Text: "This is a developer message"}, + {Text: "This is a system message"}, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + contents, systemInstruction, err := openAIMessagesToGeminiContents(tc.messages) + + if tc.expectedErrorMsg != "" || err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrorMsg) + } else { + if d := cmp.Diff(tc.expectedContents, contents); d != "" { + t.Errorf("Gemini Contents mismatch (-want +got):\n%s", d) + } + if d := cmp.Diff(tc.expectedSystemInstruction, systemInstruction); d != "" { + t.Errorf("SystemInstruction mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +// TestAssistantMsgToGeminiParts tests the assistantMsgToGeminiParts function. +func TestAssistantMsgToGeminiParts(t *testing.T) { + tests := []struct { + name string + msg openai.ChatCompletionAssistantMessageParam + expectedParts []*genai.Part + expectedToolCalls map[string]string + expectedErrorMsg string + }{ + { + name: "empty text content", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: "", + }, + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: nil, + expectedToolCalls: map[string]string{}, + }, + { + name: "invalid content type", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: 10, // Invalid type. + }, + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: nil, + expectedToolCalls: map[string]string{}, + expectedErrorMsg: "unsupported content type in assistant message: int", + }, + { + name: "simple text content", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: "Hello, I'm an AI assistant", + }, + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: []*genai.Part{ + genai.NewPartFromText("Hello, I'm an AI assistant"), + }, + expectedToolCalls: map[string]string{}, + }, + // Currently noting is returned for refusal messages. + { + name: "text content with refusal message", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeRefusal, + Refusal: ptr.To("Response was refused"), + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: nil, + expectedToolCalls: map[string]string{}, + }, + { + name: "content with an array of texts", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: []openai.ChatCompletionAssistantMessageParamContent{ + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeText, + Text: ptr.To("Hello, I'm an AI assistant"), + }, + { + Type: openai.ChatCompletionAssistantMessageParamContentTypeText, + Text: ptr.To("How can I assist you today?"), + }, + }, + }, + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: []*genai.Part{ + genai.NewPartFromText("Hello, I'm an AI assistant"), + genai.NewPartFromText("How can I assist you today?"), + }, + expectedToolCalls: map[string]string{}, + }, + { + name: "tool calls without content", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: "", + }, + Role: openai.ChatMessageRoleAssistant, + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: "call_123", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_weather", + Arguments: `{"location":"New York","unit":"celsius"}`, + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + }, + }, + expectedParts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + Args: map[string]any{"location": "New York", "unit": "celsius"}, + Name: "get_weather", + }, + }, + }, + expectedToolCalls: map[string]string{ + "call_123": "get_weather", + }, + }, + { + name: "multiple tool calls with content", + msg: openai.ChatCompletionAssistantMessageParam{ + Content: openai.StringOrAssistantRoleContentUnion{ + Value: "I'll help you with that", + }, + Role: openai.ChatMessageRoleAssistant, + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: "call_789", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_weather", + Arguments: `{"location":"New York","unit":"celsius"}`, + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + { + ID: "call_abc", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_time", + Arguments: `{"timezone":"EST"}`, + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + }, + }, + expectedParts: []*genai.Part{ + genai.NewPartFromFunctionCall("get_weather", map[string]any{ + "location": "New York", + "unit": "celsius", + }), + genai.NewPartFromFunctionCall("get_time", map[string]any{ + "timezone": "EST", + }), + genai.NewPartFromText("I'll help you with that"), + }, + expectedToolCalls: map[string]string{ + "call_789": "get_weather", + "call_abc": "get_time", + }, + }, + { + name: "invalid tool call arguments", + msg: openai.ChatCompletionAssistantMessageParam{ + Role: openai.ChatMessageRoleAssistant, + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: "call_def", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_weather", + Arguments: `{"location":"New York"`, // Invalid JSON. + }, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + }, + }, + }, + expectedErrorMsg: "function arguments should be valid json string", + }, + { + name: "nil content", + msg: openai.ChatCompletionAssistantMessageParam{ + Role: openai.ChatMessageRoleAssistant, + }, + expectedParts: nil, + expectedToolCalls: map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + parts, toolCalls, err := assistantMsgToGeminiParts(tc.msg) + + if tc.expectedErrorMsg != "" || err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrorMsg) + } else { + require.NoError(t, err) + if d := cmp.Diff(tc.expectedParts, parts); d != "" { + t.Errorf("Parts mismatch (-want +got):\n%s", d) + } + if d := cmp.Diff(tc.expectedToolCalls, toolCalls); d != "" { + t.Errorf("Tools mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +func TestDeveloperMsgToGeminiParts(t *testing.T) { + tests := []struct { + name string + msg openai.ChatCompletionDeveloperMessageParam + expectedParts []*genai.Part + expectedErrorMsg string + }{ + { + name: "string content", + msg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.StringOrArray{ + Value: "This is a system message", + }, + Role: openai.ChatMessageRoleSystem, + }, + expectedParts: []*genai.Part{ + {Text: "This is a system message"}, + }, + }, + { + name: "content as string array", + msg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.StringOrArray{ + Value: []openai.ChatCompletionContentPartTextParam{ + {Text: "This is a system message"}, + {Text: "It can be multiline"}, + }, + }, + Role: openai.ChatMessageRoleSystem, + }, + expectedParts: []*genai.Part{ + {Text: "This is a system message"}, + {Text: "It can be multiline"}, + }, + }, + { + name: "invalid content type", + msg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.StringOrArray{ + Value: 10, // Invalid type. + }, + Role: openai.ChatMessageRoleSystem, + }, + expectedParts: []*genai.Part{ + {Text: "This is a system message"}, + }, + expectedErrorMsg: "unsupported content type in developer message: int", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + content, err := developerMsgToGeminiParts(tc.msg) + + if tc.expectedErrorMsg != "" || err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrorMsg) + } else { + require.NoError(t, err) + if d := cmp.Diff(tc.expectedParts, content); d != "" { + t.Errorf("Content mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +func TestToolMsgToGeminiParts(t *testing.T) { + tests := []struct { + name string + msg openai.ChatCompletionToolMessageParam + knownToolCalls map[string]string + expectedPart *genai.Part + expectedErrorMsg string + }{ + { + name: "Tool message with invalid content", + msg: openai.ChatCompletionToolMessageParam{ + Content: openai.StringOrArray{ + Value: 10, // Invalid type. + }, + Role: openai.ChatMessageRoleTool, + ToolCallID: "tool_123", + }, + knownToolCalls: map[string]string{"tool_123": "get_weather"}, + expectedErrorMsg: "unsupported content type in tool message: int", + }, + { + name: "Tool message with string content", + msg: openai.ChatCompletionToolMessageParam{ + Content: openai.StringOrArray{ + Value: "This is a tool message", + }, + Role: openai.ChatMessageRoleTool, + ToolCallID: "tool_123", + }, + knownToolCalls: map[string]string{"tool_123": "get_weather"}, + expectedPart: &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + Name: "get_weather", + Response: map[string]interface{}{"output": "This is a tool message"}, + }, + }, + }, + { + name: "Tool message with string array content", + msg: openai.ChatCompletionToolMessageParam{ + Content: openai.StringOrArray{ + Value: []openai.ChatCompletionContentPartTextParam{ + { + Type: string(openai.ChatCompletionContentPartTextTypeText), + Text: "This is a tool message. ", + }, + { + Type: string(openai.ChatCompletionContentPartTextTypeText), + Text: "And this is another part", + }, + }, + }, + Role: openai.ChatMessageRoleTool, + ToolCallID: "tool_123", + }, + knownToolCalls: map[string]string{"tool_123": "get_weather"}, + expectedPart: &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + Name: "get_weather", + Response: map[string]interface{}{"output": "This is a tool message. And this is another part"}, + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + parts, err := toolMsgToGeminiParts(tc.msg, tc.knownToolCalls) + + if tc.expectedErrorMsg != "" || err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrorMsg) + } else { + require.NoError(t, err) + if d := cmp.Diff(tc.expectedPart, parts); d != "" { + t.Errorf("Parts mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +// TestUserMsgToGeminiParts tests the gcpPartsFromUserMsgToGeminiParts function with different inputs. +func TestUserMsgToGeminiParts(t *testing.T) { + tests := []struct { + name string + msg openai.ChatCompletionUserMessageParam + expectedParts []*genai.Part + expectedErrMsg string + }{ + { + name: "simple string content", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: "Hello, how are you?", + }, + }, + expectedParts: []*genai.Part{ + {Text: "Hello, how are you?"}, + }, + }, + { + name: "empty string content", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: "", + }, + }, + expectedParts: nil, + }, + { + name: "array with multiple text contents", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + TextContent: &openai.ChatCompletionContentPartTextParam{ + Type: string(openai.ChatCompletionContentPartTextTypeText), + Text: "First message", + }, + }, + { + TextContent: &openai.ChatCompletionContentPartTextParam{ + Type: string(openai.ChatCompletionContentPartTextTypeText), + Text: "Second message", + }, + }, + }, + }, + }, + expectedParts: []*genai.Part{ + {Text: "First message"}, + {Text: "Second message"}, + }, + }, + { + name: "image content with URL", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image.jpg", + }, + }, + }, + }, + }, + }, + expectedParts: []*genai.Part{ + {FileData: &genai.FileData{FileURI: "https://example.com/image.jpg", MIMEType: "image/jpeg"}}, + }, + }, + { + name: "empty image URL", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "", + }, + }, + }, + }, + }, + }, + expectedParts: nil, + }, + { + name: "invalid image URL", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: ":%invalid-url%:", + }, + }, + }, + }, + }, + }, + expectedErrMsg: "invalid image URL", + }, + { + name: "mixed content - text and image", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + TextContent: &openai.ChatCompletionContentPartTextParam{ + Type: string(openai.ChatCompletionContentPartTextTypeText), + Text: "Check this image:", + }, + }, + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image.jpg", + }, + }, + }, + }, + }, + }, + expectedParts: []*genai.Part{ + {Text: "Check this image:"}, + {FileData: &genai.FileData{FileURI: "https://example.com/image.jpg", MIMEType: "image/jpeg"}}, + }, + }, + { + name: "data URI image content", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==", + }, + }, + }, + }, + }, + }, + expectedParts: []*genai.Part{ + { + InlineData: &genai.Blob{ + Data: []byte("This field is ignored during testcase comparison"), + MIMEType: "image/jpeg", + }, + }, + }, + }, + { + name: "invalid data URI format", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + ImageContent: &openai.ChatCompletionContentPartImageParam{ + Type: openai.ChatCompletionContentPartImageTypeImageURL, + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "data:invalid-format", + }, + }, + }, + }, + }, + }, + expectedErrMsg: "data uri does not have a valid format", + }, + { + name: "audio content - not supported", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: []openai.ChatCompletionContentPartUserUnionParam{ + { + InputAudioContent: &openai.ChatCompletionContentPartInputAudioParam{ + Type: "audio", + }, + }, + }, + }, + }, + expectedErrMsg: "audio content not supported yet", + }, + { + name: "unsupported content type", + msg: openai.ChatCompletionUserMessageParam{ + Role: openai.ChatMessageRoleUser, + Content: openai.StringOrUserRoleContentUnion{ + Value: 42, // not a string or array. + }, + }, + expectedErrMsg: "unsupported content type in user message: int", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + parts, err := userMsgToGeminiParts(tc.msg) + + if tc.expectedErrMsg != "" || err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + if d := cmp.Diff(tc.expectedParts, parts, cmpopts.IgnoreFields(genai.Blob{}, "Data")); d != "" { + t.Errorf("Parts mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +func TestOpenAIReqToGeminiGenerationConfig(t *testing.T) { + tests := []struct { + name string + input *openai.ChatCompletionRequest + expects *genai.GenerationConfig + wantErr bool + }{ + { + name: "all fields set", + input: &openai.ChatCompletionRequest{ + Temperature: ptr.To(0.7), + TopP: ptr.To(0.9), + Seed: ptr.To(42), + TopLogProbs: ptr.To(3), + LogProbs: ptr.To(true), + N: ptr.To(2), + MaxTokens: ptr.To(int64(256)), + PresencePenalty: ptr.To(float32(1.1)), + FrequencyPenalty: ptr.To(float32(0.5)), + Stop: []*string{ptr.To("stop1"), ptr.To("stop2")}, + }, + expects: &genai.GenerationConfig{ + Temperature: ptr.To(float32(0.7)), + TopP: ptr.To(float32(0.9)), + Seed: ptr.To(int32(42)), + Logprobs: ptr.To(int32(3)), + ResponseLogprobs: true, + CandidateCount: 2, + MaxOutputTokens: 256, + PresencePenalty: ptr.To(float32(1.1)), + FrequencyPenalty: ptr.To(float32(0.5)), + StopSequences: []string{"stop1", "stop2"}, + }, + wantErr: false, + }, + { + name: "minimal fields", + input: &openai.ChatCompletionRequest{}, + expects: &genai.GenerationConfig{}, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := openAIReqToGeminiGenerationConfig(tc.input) + if tc.wantErr { + if err == nil { + t.Errorf("expected error but got nil") + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if diff := cmp.Diff(tc.expects, got, cmpopts.IgnoreUnexported(genai.GenerationConfig{})); diff != "" { + t.Errorf("GenerationConfig mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/internal/extproc/translator/openai_awsbedrock.go b/internal/extproc/translator/openai_awsbedrock.go index ad170d2c0e..2f57c6d685 100644 --- a/internal/extproc/translator/openai_awsbedrock.go +++ b/internal/extproc/translator/openai_awsbedrock.go @@ -7,11 +7,9 @@ package translator import ( "bytes" - "encoding/base64" "encoding/json" "fmt" "io" - "regexp" "strconv" "strings" @@ -160,25 +158,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIToolsToBedrockToolC return nil } -// regDataURI follows the web uri regex definition. -// https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data#syntax -var regDataURI = regexp.MustCompile(`\Adata:(.+?)?(;base64)?,`) - -// parseDataURI parse data uri example: data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAZABkAAD. -func parseDataURI(uri string) (string, []byte, error) { - matches := regDataURI.FindStringSubmatch(uri) - if len(matches) != 3 { - return "", nil, fmt.Errorf("data uri does not have a valid format") - } - l := len(matches[0]) - contentType := matches[1] - bin, err := base64.StdEncoding.DecodeString(uri[l:]) - if err != nil { - return "", nil, err - } - return contentType, bin, nil -} - // openAIMessageToBedrockMessageRoleUser converts openai user role message. func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMessageRoleUser( openAiMessage *openai.ChatCompletionUserMessageParam, role string, @@ -208,13 +187,13 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes } var format string switch contentType { - case "image/png": + case mimeTypeImagePNG: format = "png" - case "image/jpeg": + case mimeTypeImageJPEG: format = "jpeg" - case "image/gif": + case mimeTypeImageGIF: format = "gif" - case "image/webp": + case mimeTypeImageWEBP: format = "webp" default: return nil, fmt.Errorf("unsupported image type: %s please use one of [png, jpeg, gif, webp]", diff --git a/internal/extproc/translator/openai_gcpvertexai.go b/internal/extproc/translator/openai_gcpvertexai.go index 657cad446b..029f100e41 100644 --- a/internal/extproc/translator/openai_gcpvertexai.go +++ b/internal/extproc/translator/openai_gcpvertexai.go @@ -6,10 +6,15 @@ package translator import ( + "encoding/json" + "fmt" "io" + "strconv" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "google.golang.org/genai" + "github.com/envoyproxy/ai-gateway/internal/apischema/gcp" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" ) @@ -23,15 +28,21 @@ type openAIToGCPVertexAITranslatorV1ChatCompletion struct{} // RequestBody implements [Translator.RequestBody] for GCP Gemini. // This method translates an OpenAI ChatCompletion request to a GCP Gemini API request. -func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, onRetry bool) ( +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, ) { - _, _ = openAIReq, onRetry pathSuffix := buildGCPModelPathSuffix(GCPModelPublisherGoogle, openAIReq.Model, GCPMethodGenerateContent) - // TODO: Implement actual translation from OpenAI to Gemini request. + gcpReq, err := o.openAIMessageToGeminiMessage(openAIReq) + if err != nil { + return nil, nil, fmt.Errorf("error converting OpenAI request to Gemini request: %w", err) + } + gcpReqBody, err := json.Marshal(gcpReq) + if err != nil { + return nil, nil, fmt.Errorf("error marshaling Gemini request: %w", err) + } - headerMutation, bodyMutation = buildGCPRequestMutations(pathSuffix, nil) + headerMutation, bodyMutation = buildGCPRequestMutations(pathSuffix, gcpReqBody) return headerMutation, bodyMutation, nil } @@ -46,10 +57,92 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseHeaders(headers // ResponseBody implements [Translator.ResponseBody] for GCP Gemini. // This method translates a GCP Gemini API response to the OpenAI ChatCompletion format. -func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) ( +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, _ bool) ( headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error, ) { - // TODO: Implement response body translation from GCP Gemini to OpenAI format. - _, _, _ = respHeaders, body, endOfStream - return nil, nil, LLMTokenUsage{}, nil + if statusStr, ok := respHeaders[statusHeaderName]; ok { + var status int + if status, err = strconv.Atoi(statusStr); err == nil { + if !isGoodStatusCode(status) { + // TODO: Parse GCP error response and convert to OpenAI error format. + // For now, just return error response as-is. + return nil, nil, LLMTokenUsage{}, err + } + } + } + + // Parse the GCP response. + var gcpResp genai.GenerateContentResponse + if err = json.NewDecoder(body).Decode(&gcpResp); err != nil { + return nil, nil, LLMTokenUsage{}, fmt.Errorf("error decoding GCP response: %w", err) + } + + var openAIRespBytes []byte + // Convert to OpenAI format. + openAIResp, err := o.geminiResponseToOpenAIMessage(gcpResp) + if err != nil { + return nil, nil, LLMTokenUsage{}, fmt.Errorf("error converting GCP response to OpenAI format: %w", err) + } + + // Marshal the OpenAI response. + openAIRespBytes, err = json.Marshal(openAIResp) + if err != nil { + return nil, nil, LLMTokenUsage{}, fmt.Errorf("error marshaling OpenAI response: %w", err) + } + + // Update token usage if available. + var usage LLMTokenUsage + if gcpResp.UsageMetadata != nil { + usage = LLMTokenUsage{ + InputTokens: uint32(gcpResp.UsageMetadata.PromptTokenCount), // nolint:gosec + OutputTokens: uint32(gcpResp.UsageMetadata.CandidatesTokenCount), // nolint:gosec + TotalTokens: uint32(gcpResp.UsageMetadata.TotalTokenCount), // nolint:gosec + } + } + + headerMutation, bodyMutation = buildGCPRequestMutations("", openAIRespBytes) + + return headerMutation, bodyMutation, usage, nil +} + +// openAIMessageToGeminiMessage converts an OpenAI ChatCompletionRequest to a GCP Gemini GenerateContentRequest. +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) openAIMessageToGeminiMessage(openAIReq *openai.ChatCompletionRequest) (gcp.GenerateContentRequest, error) { + // Convert OpenAI messages to Gemini Contents and SystemInstruction. + contents, systemInstruction, err := openAIMessagesToGeminiContents(openAIReq.Messages) + if err != nil { + return gcp.GenerateContentRequest{}, err + } + + // Convert generation config. + generationConfig, err := openAIReqToGeminiGenerationConfig(openAIReq) + if err != nil { + return gcp.GenerateContentRequest{}, fmt.Errorf("error converting generation config: %w", err) + } + + gcr := gcp.GenerateContentRequest{ + Contents: contents, + Tools: nil, + ToolConfig: nil, + GenerationConfig: generationConfig, + SystemInstruction: systemInstruction, + } + + return gcr, nil +} + +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiResponseToOpenAIMessage(gcr genai.GenerateContentResponse) (openai.ChatCompletionResponse, error) { + // Convert candidates to OpenAI choices. + choices, err := geminiCandidatesToOpenAIChoices(gcr.Candidates) + if err != nil { + return openai.ChatCompletionResponse{}, fmt.Errorf("error converting choices: %w", err) + } + + // Set up the OpenAI response. + openaiResp := openai.ChatCompletionResponse{ + Choices: choices, + Object: "chat.completion", + Usage: geminiUsageToOpenAIUsage(gcr.UsageMetadata), + } + + return openaiResp, nil } diff --git a/internal/extproc/translator/openai_gcpvertexai_test.go b/internal/extproc/translator/openai_gcpvertexai_test.go index cbb5b4de06..e398e68115 100644 --- a/internal/extproc/translator/openai_gcpvertexai_test.go +++ b/internal/extproc/translator/openai_gcpvertexai_test.go @@ -7,6 +7,7 @@ package translator import ( "bytes" + "encoding/json" "testing" corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -20,13 +21,36 @@ import ( ) func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) { + wantBdy := []byte(`{ + "contents": [ + { + "parts": [ + { + "text": "Tell me about AI Gateways" + } + ], + "role": "user" + } + ], + "tools": null, + "generation_config": {}, + "system_instruction": { + "parts": [ + { + "text": "You are a helpful assistant" + } + ] + } +} +`) + tests := []struct { name string input openai.ChatCompletionRequest onRetry bool wantError bool wantHeaderMut *extprocv3.HeaderMutation - wantBodyMut *extprocv3.BodyMutation + wantBody *extprocv3.BodyMutation }{ { name: "basic request", @@ -63,9 +87,19 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) RawValue: []byte("publishers/google/models/gemini-pro:generateContent"), }, }, + { + Header: &corev3.HeaderValue{ + Key: "Content-Length", + RawValue: []byte("185"), + }, + }, + }, + }, + wantBody: &extprocv3.BodyMutation{ + Mutation: &extprocv3.BodyMutation_Body{ + Body: wantBdy, }, }, - wantBodyMut: nil, }, } @@ -83,7 +117,7 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + if diff := cmp.Diff(tc.wantBody, bodyMut, bodyMutTransformer(t)); diff != "" { t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) } }) @@ -118,7 +152,7 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseHeaders(t *testin } require.NoError(t, err) - if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + if diff := cmp.Diff(tc.wantHeaderMut, headerMut, cmpopts.IgnoreUnexported(extprocv3.HeaderMutation{}, corev3.HeaderValueOption{}, corev3.HeaderValue{})); diff != "" { t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) } }) @@ -158,56 +192,68 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseBody(t *testing.T "promptFeedback": { "safetyRatings": [] }, - "usage": { - "promptTokens": 10, - "candidatesTokens": 15, - "totalTokens": 25 + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 15, + "totalTokenCount": 25 } }`, - endOfStream: true, - wantError: false, - wantHeaderMut: nil, - wantBodyMut: nil, + endOfStream: true, + wantError: false, + wantHeaderMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{{ + Header: &corev3.HeaderValue{Key: "Content-Length", RawValue: []byte("270")}, + }}, + }, + wantBodyMut: &extprocv3.BodyMutation{ + Mutation: &extprocv3.BodyMutation_Body{ + Body: []byte(`{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": {}, + "message": { + "content": "AI Gateways act as intermediaries between clients and LLM services.", + "role": "assistant" + } + } + ], + "object": "chat.completion", + "usage": { + "completion_tokens": 15, + "prompt_tokens": 10, + "total_tokens": 25 + } +}`), + }, + }, wantTokenUsage: LLMTokenUsage{ - InputTokens: 0, - OutputTokens: 0, - TotalTokens: 0, + InputTokens: 10, + OutputTokens: 15, + TotalTokens: 25, }, }, { - name: "streaming chunk", + name: "empty response", respHeaders: map[string]string{ "content-type": "application/json", }, - body: `{ - "candidates": [ + body: `{}`, + endOfStream: true, + wantError: false, + wantHeaderMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ { - "content": { - "parts": [ - { - "text": "AI" - } - ] - } - } - ] - }`, - endOfStream: false, - wantError: false, - wantHeaderMut: nil, - wantBodyMut: nil, - wantTokenUsage: LLMTokenUsage{}, - }, - { - name: "empty response", - respHeaders: map[string]string{ - "content-type": "application/json", + Header: &corev3.HeaderValue{Key: "Content-Length", RawValue: []byte("39")}, + }, + }, + }, + wantBodyMut: &extprocv3.BodyMutation{ + Mutation: &extprocv3.BodyMutation_Body{ + Body: []byte(`{"object":"chat.completion","usage":{}}`), + }, }, - body: `{}`, - endOfStream: true, - wantError: false, - wantHeaderMut: nil, - wantBodyMut: nil, wantTokenUsage: LLMTokenUsage{}, }, } @@ -223,11 +269,11 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseBody(t *testing.T } require.NoError(t, err) - if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + if diff := cmp.Diff(tc.wantHeaderMut, headerMut, cmpopts.IgnoreUnexported(extprocv3.HeaderMutation{}, corev3.HeaderValueOption{}, corev3.HeaderValue{})); diff != "" { t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + if diff := cmp.Diff(tc.wantBodyMut, bodyMut, bodyMutTransformer(t)); diff != "" { t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) } @@ -237,3 +283,21 @@ func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseBody(t *testing.T }) } } + +func bodyMutTransformer(t *testing.T) cmp.Option { + return cmp.Transformer("BodyMutationsToBodyBytes", func(bm *extprocv3.BodyMutation) map[string]interface{} { + if bm == nil { + return nil + } + + var bdy map[string]interface{} + if body, ok := bm.Mutation.(*extprocv3.BodyMutation_Body); ok { + if err := json.Unmarshal(body.Body, &bdy); err != nil { + t.Errorf("error unmarshaling body: %v", err) + return nil + } + return bdy + } + return nil + }) +} diff --git a/internal/extproc/translator/util.go b/internal/extproc/translator/util.go new file mode 100644 index 0000000000..4605009ff2 --- /dev/null +++ b/internal/extproc/translator/util.go @@ -0,0 +1,38 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "encoding/base64" + "fmt" + "regexp" +) + +const ( + mimeTypeImageJPEG = "image/jpeg" + mimeTypeImagePNG = "image/png" + mimeTypeImageGIF = "image/gif" + mimeTypeImageWEBP = "image/webp" +) + +// regDataURI follows the web uri regex definition. +// https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data#syntax +var regDataURI = regexp.MustCompile(`\Adata:(.+?)?(;base64)?,`) + +// parseDataURI parse data uri example: data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAZABkAAD. +func parseDataURI(uri string) (string, []byte, error) { + matches := regDataURI.FindStringSubmatch(uri) + if len(matches) != 3 { + return "", nil, fmt.Errorf("data uri does not have a valid format") + } + l := len(matches[0]) + contentType := matches[1] + bin, err := base64.StdEncoding.DecodeString(uri[l:]) + if err != nil { + return "", nil, err + } + return contentType, bin, nil +} diff --git a/tests/extproc/envoy.yaml b/tests/extproc/envoy.yaml index 8e7a8f8838..63ec747d09 100644 --- a/tests/extproc/envoy.yaml +++ b/tests/extproc/envoy.yaml @@ -188,6 +188,14 @@ static_resources: exact: azure-openai route: cluster: testupstream-azure + - match: + prefix: "/" + headers: + - name: x-test-backend + string_match: + exact: gcp-vertexai + route: + cluster: testupstream-gcp-vertexai http_filters: - name: envoy.filters.http.ext_proc typed_config: @@ -556,6 +564,66 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext auto_host_sni: true + - name: testupstream-gcp-vertexai + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + outlier_detection: + consecutive_5xx: 1 + interval: 1s + base_ejection_time: 2s # Must be smaller than the require.Eventually's interval. Otherwise, the tests may pass without going through the fallback since the always-failing backend could be ejected by the time when require.Eventually retries due to the previous request IF the retry is not configured. + max_ejection_percent: 100 + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http_protocol_options: {} + http_filters: + - name: upstream_extproc + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.ext_proc.v3.ExternalProcessor + allow_mode_override: true + request_attributes: + - xds.upstream_host_metadata + processing_mode: + request_header_mode: "SEND" + request_body_mode: "NONE" + response_header_mode: "SKIP" + response_body_mode: "NONE" + grpc_service: + envoy_grpc: + cluster_name: extproc_cluster + metadataOptions: + receivingNamespaces: + untyped: + - ai_gateway_llm_ns + - name: envoy.filters.http.header_mutation + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.header_mutation.v3.HeaderMutation + mutations: + request_mutations: + - append: + append_action: ADD_IF_ABSENT + header: + key: content-length + value: "%DYNAMIC_METADATA(ai_gateway_llm_ns:content_length)%" + - name: envoy.filters.http.upstream_codec + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.upstream_codec.v3.UpstreamCodec + load_assignment: + cluster_name: testupstream-gcp-vertexai + endpoints: + - priority: 0 + lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 8080 + metadata: + filter_metadata: + aigateway.envoy.io: + per_route_rule_backend_name: "testupstream-gcp-vertexai" - name: openai connect_timeout: 30s type: STRICT_DNS diff --git a/tests/extproc/extproc_test.go b/tests/extproc/extproc_test.go index 893c6578e6..dcfff31f87 100644 --- a/tests/extproc/extproc_test.go +++ b/tests/extproc/extproc_test.go @@ -30,6 +30,7 @@ const ( listenerAddress = "http://localhost:1062" eventuallyTimeout = 60 * time.Second eventuallyInterval = 4 * time.Second + fakeGCPAuthToken = "fake-gcp-auth-token" //nolint:gosec ) func TestMain(m *testing.M) { @@ -67,16 +68,22 @@ var ( openAISchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1"} awsBedrockSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock} azureOpenAISchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaAzureOpenAI, Version: "2025-01-01-preview"} + gcpVertexAISchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaGCPVertexAI} geminiSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1beta/openai"} groqSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "openai/v1"} grokSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1"} sambaNovaSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1"} deepInfraSchema = filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI, Version: "v1/openai"} - testUpstreamOpenAIBackend = filterapi.Backend{Name: "testupstream-openai", Schema: openAISchema} - testUpstreamModelNameOverride = filterapi.Backend{Name: "testupstream-modelname-override", ModelNameOverride: "override-model", Schema: openAISchema} - testUpstreamAAWSBackend = filterapi.Backend{Name: "testupstream-aws", Schema: awsBedrockSchema} - testUpstreamAzureBackend = filterapi.Backend{Name: "testupstream-azure", Schema: azureOpenAISchema} + testUpstreamOpenAIBackend = filterapi.Backend{Name: "testupstream-openai", Schema: openAISchema} + testUpstreamModelNameOverride = filterapi.Backend{Name: "testupstream-modelname-override", ModelNameOverride: "override-model", Schema: openAISchema} + testUpstreamAAWSBackend = filterapi.Backend{Name: "testupstream-aws", Schema: awsBedrockSchema} + testUpstreamAzureBackend = filterapi.Backend{Name: "testupstream-azure", Schema: azureOpenAISchema} + testUpstreamGCPVertexAIBackend = filterapi.Backend{Name: "testupstream-gcp-vertexai", Schema: gcpVertexAISchema, Auth: &filterapi.BackendAuth{GCPAuth: &filterapi.GCPAuth{ + AccessToken: fakeGCPAuthToken, + Region: "gcp-region", + ProjectName: "gcp-project-name", + }}} // This always failing backend is configured to have AWS Bedrock schema so that // we can test that the extproc can fallback to the different schema. E.g. Primary AWS and then OpenAI. alwaysFailingBackend = filterapi.Backend{Name: "always-failing-backend", Schema: awsBedrockSchema} diff --git a/tests/extproc/testupstream_test.go b/tests/extproc/testupstream_test.go index 5a27d5a626..464702ee65 100644 --- a/tests/extproc/testupstream_test.go +++ b/tests/extproc/testupstream_test.go @@ -14,6 +14,7 @@ import ( "io" "net/http" "os" + "strconv" "strings" "testing" "time" @@ -52,6 +53,7 @@ func TestWithTestUpstream(t *testing.T) { testUpstreamModelNameOverride, testUpstreamAAWSBackend, testUpstreamAzureBackend, + testUpstreamGCPVertexAIBackend, }, Models: []filterapi.Model{ {Name: "some-model1", OwnedBy: "Envoy AI Gateway", CreatedAt: now}, @@ -94,6 +96,12 @@ func TestWithTestUpstream(t *testing.T) { responseHeaders, // expPath is the expected path to be sent to the test upstream. expPath string + // expHost is the expected host to be sent to the test upstream. + expHost string + // expHeaders are the expected headers to be sent to the test upstream. + // The value is a base64 encoded string of comma separated key-value pairs. + // E.g. "key1:value1,key2:value2". + expHeaders map[string]string // expRequestBody is the expected body to be sent to the test upstream. // This can be used to test the request body translation. expRequestBody string @@ -156,6 +164,21 @@ func TestWithTestUpstream(t *testing.T) { expStatus: http.StatusOK, expResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`, }, + { + name: "gcp-vertexai - /v1/chat/completions", + backend: "gcp-vertexai", + path: "/v1/chat/completions", + method: http.MethodPost, + requestBody: `{"model":"gemini-1.5-pro","messages":[{"role":"system","content":"You are a helpful assistant."}]}`, + expRequestBody: `{"contents":null,"tools":null,"generation_config":{},"system_instruction":{"parts":[{"text":"You are a helpful assistant."}]}}`, + expHost: "gcp-region-aiplatform.googleapis.com", + expPath: "/v1/projects/gcp-project-name/locations/gcp-region/publishers/google/models/gemini-1.5-pro:generateContent", + expHeaders: map[string]string{"Authorization": "Bearer " + fakeGCPAuthToken}, + responseStatus: strconv.Itoa(http.StatusOK), + responseBody: `{"candidates":[{"content":{"parts":[{"text":"This is a test response from Gemini."}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":15,"candidatesTokenCount":10,"totalTokenCount":25}}`, + expStatus: http.StatusOK, + expResponseBody: `{"choices":[{"finish_reason":"stop","index":0,"logprobs":{},"message":{"content":"This is a test response from Gemini.","role":"assistant"}}],"object":"chat.completion","usage":{"completion_tokens":10,"prompt_tokens":15,"total_tokens":25}}`, + }, { name: "modelname-override - /v1/chat/completions", backend: "modelname-override", @@ -252,6 +275,19 @@ data: [DONE] responseBody: `{"message": "aws bedrock rate limit exceeded"}`, expResponseBody: `{"type":"error","error":{"type":"ThrottledException","code":"429","message":"aws bedrock rate limit exceeded"}}`, }, + { + name: "gcp-vertexai - /v1/chat/completions - error response", + backend: "gcp-vertexai", + path: "/v1/chat/completions", + responseType: "", + method: http.MethodPost, + requestBody: `{"model":"gemini-1.5-pro","messages":[{"role":"system","content":"You are a helpful assistant."}]}`, + expPath: "/v1/projects/gcp-project-name/locations/gcp-region/publishers/google/models/gemini-1.5-pro:generateContent", + responseStatus: "400", + expStatus: http.StatusBadRequest, + responseBody: `{"error":{"code":400,"message":"Invalid request: missing required field","status":"INVALID_ARGUMENT"}}`, + expResponseBody: `{"error":{"code":400,"message":"Invalid request: missing required field","status":"INVALID_ARGUMENT"}}`, + }, { name: "openai - /v1/embeddings", backend: "openai", @@ -305,6 +341,22 @@ data: [DONE] req.Header.Set(testupstreamlib.ResponseBodyHeaderKey, base64.StdEncoding.EncodeToString([]byte(tc.responseBody))) req.Header.Set(testupstreamlib.ExpectedPathHeaderKey, base64.StdEncoding.EncodeToString([]byte(tc.expPath))) req.Header.Set(testupstreamlib.ResponseStatusKey, tc.responseStatus) + + var expHeaders []string + for k, v := range tc.expHeaders { + expHeaders = append(expHeaders, fmt.Sprintf("%s:%s", k, v)) + } + if len(expHeaders) > 0 { + req.Header.Set( + testupstreamlib.ExpectedHeadersKey, + base64.StdEncoding.EncodeToString( + []byte(strings.Join(expHeaders, ","))), + ) + } + + if tc.expHost != "" { + req.Header.Set(testupstreamlib.ExpectedHostKey, tc.expHost) + } if tc.responseType != "" { req.Header.Set(testupstreamlib.ResponseTypeKey, tc.responseType) }