diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index c2afcb7f06..d1dd9c3a7b 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -1444,7 +1444,7 @@ type ChatCompletionResponseChunkChoiceDelta struct { Role string `json:"role,omitempty"` ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"` Annotations *[]Annotation `json:"annotations,omitempty"` - ReasoningContent *AWSBedrockStreamReasoningContent `json:"reasoning_content,omitempty"` + ReasoningContent *StreamReasoningContent `json:"reasoning_content,omitempty"` } // Error is described in the OpenAI API documentation @@ -1662,7 +1662,7 @@ func (r *ReasoningContentUnion) UnmarshalJSON(data []byte) error { return nil } - var content *AWSBedrockReasoningContent + var content *ReasoningContent err = json.Unmarshal(data, &content) if err == nil { r.Value = content @@ -1675,19 +1675,20 @@ func (r ReasoningContentUnion) MarshalJSON() ([]byte, error) { if stringContent, ok := r.Value.(string); ok { return json.Marshal(stringContent) } - if reasoningContent, ok := r.Value.(*AWSBedrockReasoningContent); ok { + if reasoningContent, ok := r.Value.(*ReasoningContent); ok { return json.Marshal(reasoningContent) } return nil, errors.New("no reasoning content to marshal") } -type AWSBedrockReasoningContent struct { +// ReasoningContent is used on both aws bedrock and gemini's reasoning +type ReasoningContent struct { // See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html for more information. ReasoningContent *awsbedrock.ReasoningContentBlock `json:"reasoningContent,omitzero"` } -type AWSBedrockStreamReasoningContent struct { +type StreamReasoningContent struct { Text string `json:"text,omitzero"` Signature string `json:"signature,omitzero"` RedactedContent []byte `json:"redactedContent,omitzero"` diff --git a/internal/translator/gemini_helper.go b/internal/translator/gemini_helper.go index 94482d29ae..c79d157b9e 100644 --- a/internal/translator/gemini_helper.go +++ b/internal/translator/gemini_helper.go @@ -18,6 +18,7 @@ import ( openaisdk "github.com/openai/openai-go/v2" "google.golang.org/genai" + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" ) @@ -640,9 +641,22 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode message := openai.ChatCompletionResponseChoiceMessage{ Role: openai.ChatMessageRoleAssistant, } - // Extract text from parts. - content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode) - message.Content = &content + // Extract thought summary and text from parts. + thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode) + if thoughtSummary != "" { + message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: thoughtSummary, + }, + }, + }, + } + } + if content != "" { + message.Content = &content + } // Extract tool calls if any. toolCalls, err = extractToolCallsFromGeminiParts(toolCalls, candidate.Content.Parts) @@ -657,6 +671,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode } choice.Message = message + } if candidate.SafetyRatings != nil { @@ -704,23 +719,28 @@ func geminiFinishReasonToOpenAI[T toolCallSlice](reason genai.FinishReason, tool } } -// extractTextFromGeminiParts extracts text from Gemini parts. -func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) string { - var text string +// extractTextAndThoughtSummaryFromGeminiParts extracts thought summary and text from Gemini parts. +func extractTextAndThoughtSummaryFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) (string, string) { + text := "" + thoughtSummary := "" for _, part := range parts { if part != nil && part.Text != "" { - if responseMode == responseModeRegex { - // GCP doesn't natively support REGEX response modes, so we instead express them as json schema. - // This causes the response to be wrapped in double-quotes. - // E.g. `"positive"` (the double-quotes at the start and end are unwanted) - // Here we remove the wrapping double-quotes. - part.Text = strings.TrimPrefix(part.Text, "\"") - part.Text = strings.TrimSuffix(part.Text, "\"") + if part.Thought { + thoughtSummary += part.Text + } else { + if responseMode == responseModeRegex { + // GCP doesn't natively support REGEX response modes, so we instead express them as json schema. + // This causes the response to be wrapped in double-quotes. + // E.g. `"positive"` (the double-quotes at the start and end are unwanted) + // Here we remove the wrapping double-quotes. + part.Text = strings.TrimPrefix(part.Text, "\"") + part.Text = strings.TrimSuffix(part.Text, "\"") + } + text += part.Text } - text += part.Text } } - return text + return thoughtSummary, text } // extractToolCallsFromGeminiParts extracts tool calls from Gemini parts. diff --git a/internal/translator/gemini_helper_test.go b/internal/translator/gemini_helper_test.go index 1262e9664a..58e410effd 100644 --- a/internal/translator/gemini_helper_test.go +++ b/internal/translator/gemini_helper_test.go @@ -1829,24 +1829,27 @@ func TestGeminiFinishReasonToOpenAI(t *testing.T) { } } -func TestExtractTextFromGeminiParts(t *testing.T) { +func TestExtractTextAndThoughtSummaryFromGeminiParts(t *testing.T) { tests := []struct { - name string - parts []*genai.Part - responseMode geminiResponseMode - expected string + name string + parts []*genai.Part + responseMode geminiResponseMode + expectedThoughtSummary string + expectedText string }{ { - name: "nil parts", - parts: nil, - responseMode: responseModeNone, - expected: "", + name: "nil parts", + parts: nil, + responseMode: responseModeNone, + expectedThoughtSummary: "", + expectedText: "", }, { - name: "empty parts", - parts: []*genai.Part{}, - responseMode: responseModeNone, - expected: "", + name: "empty parts", + parts: []*genai.Part{}, + responseMode: responseModeNone, + expectedThoughtSummary: "", + expectedText: "", }, { name: "multiple text parts without regex mode", @@ -1854,8 +1857,9 @@ func TestExtractTextFromGeminiParts(t *testing.T) { {Text: "Hello, "}, {Text: "world!"}, }, - responseMode: responseModeJSON, - expected: "Hello, world!", + responseMode: responseModeJSON, + expectedThoughtSummary: "", + expectedText: "Hello, world!", }, { name: "regex mode with mixed quoted and unquoted text", @@ -1864,40 +1868,56 @@ func TestExtractTextFromGeminiParts(t *testing.T) { {Text: `unquoted`}, {Text: `"negative"`}, }, - responseMode: responseModeRegex, - expected: "positiveunquotednegative", + responseMode: responseModeRegex, + expectedThoughtSummary: "", + expectedText: "positiveunquotednegative", }, { name: "regex mode with only double-quoted first and last words", parts: []*genai.Part{ {Text: "\"\"ERROR\" Unable to connect to database \"DatabaseModule\"\""}, }, - responseMode: responseModeRegex, - expected: "\"ERROR\" Unable to connect to database \"DatabaseModule\"", + responseMode: responseModeRegex, + expectedThoughtSummary: "", + expectedText: "\"ERROR\" Unable to connect to database \"DatabaseModule\"", }, { name: "non-regex mode with double-quoted text (should not remove quotes)", parts: []*genai.Part{ {Text: `"positive"`}, }, - responseMode: responseModeJSON, - expected: `"positive"`, + responseMode: responseModeJSON, + expectedThoughtSummary: "", + expectedText: `"positive"`, }, { name: "regex mode with text containing internal quotes", parts: []*genai.Part{ {Text: `"He said \"hello\" to me"`}, }, - responseMode: responseModeRegex, - expected: `He said \"hello\" to me`, + responseMode: responseModeRegex, + expectedThoughtSummary: "", + expectedText: `He said \"hello\" to me`, + }, + { + name: "test thought summary", + parts: []*genai.Part{ + {Text: "Let me think step by step", Thought: true}, + {Text: "Here is the conclusion"}, + }, + expectedThoughtSummary: "Let me think step by step", + expectedText: "Here is the conclusion", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result := extractTextFromGeminiParts(tc.parts, tc.responseMode) - if result != tc.expected { - t.Errorf("extractTextFromGeminiParts() = %q, want %q", result, tc.expected) + thoughtSummary, text := extractTextAndThoughtSummaryFromGeminiParts(tc.parts, tc.responseMode) + if thoughtSummary != tc.expectedThoughtSummary { + t.Errorf("thought summary result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", thoughtSummary, tc.expectedText) + } + if text != tc.expectedText { + t.Errorf("text result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", text, tc.expectedText) } }) } diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index df1c53e947..2d73b00d9c 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -702,7 +702,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string } case output.ReasoningContent != nil: choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ - Value: &openai.AWSBedrockReasoningContent{ + Value: &openai.ReasoningContent{ ReasoningContent: output.ReasoningContent, }, } @@ -819,7 +819,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe }, }) case event.Delta.ReasoningContent != nil: - reasoningDelta := &openai.AWSBedrockStreamReasoningContent{} + reasoningDelta := &openai.StreamReasoningContent{} // Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct. if event.Delta.ReasoningContent != nil { diff --git a/internal/translator/openai_awsbedrock_test.go b/internal/translator/openai_awsbedrock_test.go index a70ef23bff..01d4797754 100644 --- a/internal/translator/openai_awsbedrock_test.go +++ b/internal/translator/openai_awsbedrock_test.go @@ -1665,7 +1665,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) Role: awsbedrock.ConversationRoleAssistant, Content: ptr.To("This is the final answer."), ReasoningContent: &openai.ReasoningContentUnion{ - Value: &openai.AWSBedrockReasoningContent{ + Value: &openai.ReasoningContent{ ReasoningContent: &awsbedrock.ReasoningContentBlock{ ReasoningText: &awsbedrock.ReasoningTextBlock{ Text: "This is the model's thought process.", @@ -1990,7 +1990,7 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) { Choices: []openai.ChatCompletionResponseChunkChoice{ { Delta: &openai.ChatCompletionResponseChunkChoiceDelta{ - ReasoningContent: &openai.AWSBedrockStreamReasoningContent{ + ReasoningContent: &openai.StreamReasoningContent{ Text: "thinking...", }, }, @@ -2171,7 +2171,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody_WithReasoning require.Equal(t, "9.11 is greater than 9.8.", *message.Content) require.NotNil(t, message.ReasoningContent, "Reasoning content should not be nil") - reasoningBlock, _ := message.ReasoningContent.Value.(*openai.AWSBedrockReasoningContent) + reasoningBlock, _ := message.ReasoningContent.Value.(*openai.ReasoningContent) require.NotNil(t, reasoningBlock, "The nested reasoning content block should not be nil") require.NotEmpty(t, reasoningBlock.ReasoningContent.ReasoningText.Text, "The reasoning text itself should not be empty") diff --git a/internal/translator/openai_gcpvertexai.go b/internal/translator/openai_gcpvertexai.go index c42def23bb..3e766bb077 100644 --- a/internal/translator/openai_gcpvertexai.go +++ b/internal/translator/openai_gcpvertexai.go @@ -344,8 +344,14 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiCandidatesToOpenAI Role: openai.ChatMessageRoleAssistant, } - // Extract text from parts for streaming (delta). - content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode) + // Extract thought summary and text from parts for streaming (delta). + thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode) + if thoughtSummary != "" { + delta.ReasoningContent = &openai.StreamReasoningContent{ + Text: thoughtSummary, + } + } + if content != "" { delta.Content = &content } diff --git a/internal/translator/openai_gcpvertexai_test.go b/internal/translator/openai_gcpvertexai_test.go index 7a16024c6d..a60b1bcf97 100644 --- a/internal/translator/openai_gcpvertexai_test.go +++ b/internal/translator/openai_gcpvertexai_test.go @@ -1099,6 +1099,91 @@ data: [DONE] }`), wantTokenUsage: tokenUsageFrom(8, 0, 12, 20), }, + { + name: "response with thought summary", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Let me think step by step.", + "thought": true + }, + { + "text": "AI Gateways act as intermediaries between clients and LLM services." + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [] + } + ], + "promptFeedback": { + "safetyRatings": [] + }, + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 15, + "totalTokenCount": 25, + "cachedContentTokenCount": 10, + "thoughtsTokenCount": 10 + } + }`, + endOfStream: true, + wantError: false, + wantHeaderMut: []internalapi.Header{{contentLengthHeaderName, "450"}}, + wantBodyMut: []byte(`{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "AI Gateways act as intermediaries between clients and LLM services.", + "reasoning_content": {"reasoningContent": {"reasoningText": {"text": "Let me think step by step."}}}, + "role": "assistant" + } + } + ], + "object": "chat.completion", + "usage": { + "completion_tokens": 25, + "completion_tokens_details": { + "reasoning_tokens": 10 + }, + "prompt_tokens": 10, + "prompt_tokens_details": { + "cached_tokens": 10 + }, + "total_tokens": 25 + } +}`), + + wantTokenUsage: tokenUsageFrom(10, 10, 15, 25), + }, + { + name: "stream chunks with thought summary", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `data: {"candidates":[{"content":{"parts":[{"text":"let me think step by step and reply you.", "thought": true}]}}]} + +data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}`, + stream: true, + endOfStream: true, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: []byte(`data: {"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":{"text":"let me think step by step and reply you."}}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}],"object":"chat.completion.chunk","usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8,"completion_tokens_details":{},"prompt_tokens_details":{}}} + +data: [DONE] +`), + wantTokenUsage: tokenUsageFrom(5, 0, 3, 8), + }, } for _, tc := range tests {