Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,7 @@ type ChatCompletionResponseChunkChoiceDelta struct {
Role string `json:"role,omitempty"`
ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls,omitempty"`
Annotations *[]Annotation `json:"annotations,omitempty"`
ReasoningContent *AWSBedrockStreamReasoningContent `json:"reasoning_content,omitempty"`
ReasoningContent *StreamReasoningContent `json:"reasoning_content,omitempty"`
}

// Error is described in the OpenAI API documentation
Expand Down Expand Up @@ -1662,7 +1662,7 @@ func (r *ReasoningContentUnion) UnmarshalJSON(data []byte) error {
return nil
}

var content *AWSBedrockReasoningContent
var content *ReasoningContent
err = json.Unmarshal(data, &content)
if err == nil {
r.Value = content
Expand All @@ -1675,19 +1675,20 @@ func (r ReasoningContentUnion) MarshalJSON() ([]byte, error) {
if stringContent, ok := r.Value.(string); ok {
return json.Marshal(stringContent)
}
if reasoningContent, ok := r.Value.(*AWSBedrockReasoningContent); ok {
if reasoningContent, ok := r.Value.(*ReasoningContent); ok {
return json.Marshal(reasoningContent)
}

return nil, errors.New("no reasoning content to marshal")
}

type AWSBedrockReasoningContent struct {
// ReasoningContent is used on both aws bedrock and gemini's reasoning
type ReasoningContent struct {
// See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html for more information.
ReasoningContent *awsbedrock.ReasoningContentBlock `json:"reasoningContent,omitzero"`
}

type AWSBedrockStreamReasoningContent struct {
type StreamReasoningContent struct {
Text string `json:"text,omitzero"`
Signature string `json:"signature,omitzero"`
RedactedContent []byte `json:"redactedContent,omitzero"`
Expand Down
50 changes: 35 additions & 15 deletions internal/translator/gemini_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
openaisdk "github.com/openai/openai-go/v2"
"google.golang.org/genai"

"github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock"
"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
)
Expand Down Expand Up @@ -640,9 +641,22 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
message := openai.ChatCompletionResponseChoiceMessage{
Role: openai.ChatMessageRoleAssistant,
}
// Extract text from parts.
content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode)
message.Content = &content
// Extract thought summary and text from parts.
thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode)
if thoughtSummary != "" {
message.ReasoningContent = &openai.ReasoningContentUnion{
Value: &openai.ReasoningContent{
ReasoningContent: &awsbedrock.ReasoningContentBlock{
ReasoningText: &awsbedrock.ReasoningTextBlock{
Text: thoughtSummary,
},
},
},
}
}
if content != "" {
message.Content = &content
}

// Extract tool calls if any.
toolCalls, err = extractToolCallsFromGeminiParts(toolCalls, candidate.Content.Parts)
Expand All @@ -657,6 +671,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
}

choice.Message = message

}

if candidate.SafetyRatings != nil {
Expand Down Expand Up @@ -704,23 +719,28 @@ func geminiFinishReasonToOpenAI[T toolCallSlice](reason genai.FinishReason, tool
}
}

// extractTextFromGeminiParts extracts text from Gemini parts.
func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) string {
var text string
// extractTextAndThoughtSummaryFromGeminiParts extracts thought summary and text from Gemini parts.
func extractTextAndThoughtSummaryFromGeminiParts(parts []*genai.Part, responseMode geminiResponseMode) (string, string) {
text := ""
thoughtSummary := ""
for _, part := range parts {
if part != nil && part.Text != "" {
if responseMode == responseModeRegex {
// GCP doesn't natively support REGEX response modes, so we instead express them as json schema.
// This causes the response to be wrapped in double-quotes.
// E.g. `"positive"` (the double-quotes at the start and end are unwanted)
// Here we remove the wrapping double-quotes.
part.Text = strings.TrimPrefix(part.Text, "\"")
part.Text = strings.TrimSuffix(part.Text, "\"")
if part.Thought {
thoughtSummary += part.Text
} else {
if responseMode == responseModeRegex {
// GCP doesn't natively support REGEX response modes, so we instead express them as json schema.
// This causes the response to be wrapped in double-quotes.
// E.g. `"positive"` (the double-quotes at the start and end are unwanted)
// Here we remove the wrapping double-quotes.
part.Text = strings.TrimPrefix(part.Text, "\"")
part.Text = strings.TrimSuffix(part.Text, "\"")
}
text += part.Text
}
text += part.Text
}
}
return text
return thoughtSummary, text
}

// extractToolCallsFromGeminiParts extracts tool calls from Gemini parts.
Expand Down
72 changes: 46 additions & 26 deletions internal/translator/gemini_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,33 +1829,37 @@ func TestGeminiFinishReasonToOpenAI(t *testing.T) {
}
}

func TestExtractTextFromGeminiParts(t *testing.T) {
func TestExtractTextAndThoughtSummaryFromGeminiParts(t *testing.T) {
tests := []struct {
name string
parts []*genai.Part
responseMode geminiResponseMode
expected string
name string
parts []*genai.Part
responseMode geminiResponseMode
expectedThoughtSummary string
expectedText string
}{
{
name: "nil parts",
parts: nil,
responseMode: responseModeNone,
expected: "",
name: "nil parts",
parts: nil,
responseMode: responseModeNone,
expectedThoughtSummary: "",
expectedText: "",
},
{
name: "empty parts",
parts: []*genai.Part{},
responseMode: responseModeNone,
expected: "",
name: "empty parts",
parts: []*genai.Part{},
responseMode: responseModeNone,
expectedThoughtSummary: "",
expectedText: "",
},
{
name: "multiple text parts without regex mode",
parts: []*genai.Part{
{Text: "Hello, "},
{Text: "world!"},
},
responseMode: responseModeJSON,
expected: "Hello, world!",
responseMode: responseModeJSON,
expectedThoughtSummary: "",
expectedText: "Hello, world!",
},
{
name: "regex mode with mixed quoted and unquoted text",
Expand All @@ -1864,40 +1868,56 @@ func TestExtractTextFromGeminiParts(t *testing.T) {
{Text: `unquoted`},
{Text: `"negative"`},
},
responseMode: responseModeRegex,
expected: "positiveunquotednegative",
responseMode: responseModeRegex,
expectedThoughtSummary: "",
expectedText: "positiveunquotednegative",
},
{
name: "regex mode with only double-quoted first and last words",
parts: []*genai.Part{
{Text: "\"\"ERROR\" Unable to connect to database \"DatabaseModule\"\""},
},
responseMode: responseModeRegex,
expected: "\"ERROR\" Unable to connect to database \"DatabaseModule\"",
responseMode: responseModeRegex,
expectedThoughtSummary: "",
expectedText: "\"ERROR\" Unable to connect to database \"DatabaseModule\"",
},
{
name: "non-regex mode with double-quoted text (should not remove quotes)",
parts: []*genai.Part{
{Text: `"positive"`},
},
responseMode: responseModeJSON,
expected: `"positive"`,
responseMode: responseModeJSON,
expectedThoughtSummary: "",
expectedText: `"positive"`,
},
{
name: "regex mode with text containing internal quotes",
parts: []*genai.Part{
{Text: `"He said \"hello\" to me"`},
},
responseMode: responseModeRegex,
expected: `He said \"hello\" to me`,
responseMode: responseModeRegex,
expectedThoughtSummary: "",
expectedText: `He said \"hello\" to me`,
},
{
name: "test thought summary",
parts: []*genai.Part{
{Text: "Let me think step by step", Thought: true},
{Text: "Here is the conclusion"},
},
expectedThoughtSummary: "Let me think step by step",
expectedText: "Here is the conclusion",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := extractTextFromGeminiParts(tc.parts, tc.responseMode)
if result != tc.expected {
t.Errorf("extractTextFromGeminiParts() = %q, want %q", result, tc.expected)
thoughtSummary, text := extractTextAndThoughtSummaryFromGeminiParts(tc.parts, tc.responseMode)
if thoughtSummary != tc.expectedThoughtSummary {
t.Errorf("thought summary result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", thoughtSummary, tc.expectedText)
}
if text != tc.expectedText {
t.Errorf("text result of extractTextAndThoughtSummaryFromGeminiParts() = %q, want %q", text, tc.expectedText)
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions internal/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
}
case output.ReasoningContent != nil:
choice.Message.ReasoningContent = &openai.ReasoningContentUnion{
Value: &openai.AWSBedrockReasoningContent{
Value: &openai.ReasoningContent{
ReasoningContent: output.ReasoningContent,
},
}
Expand Down Expand Up @@ -819,7 +819,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
},
})
case event.Delta.ReasoningContent != nil:
reasoningDelta := &openai.AWSBedrockStreamReasoningContent{}
reasoningDelta := &openai.StreamReasoningContent{}

// Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct.
if event.Delta.ReasoningContent != nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
Role: awsbedrock.ConversationRoleAssistant,
Content: ptr.To("This is the final answer."),
ReasoningContent: &openai.ReasoningContentUnion{
Value: &openai.AWSBedrockReasoningContent{
Value: &openai.ReasoningContent{
ReasoningContent: &awsbedrock.ReasoningContentBlock{
ReasoningText: &awsbedrock.ReasoningTextBlock{
Text: "This is the model's thought process.",
Expand Down Expand Up @@ -1990,7 +1990,7 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
Choices: []openai.ChatCompletionResponseChunkChoice{
{
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
ReasoningContent: &openai.AWSBedrockStreamReasoningContent{
ReasoningContent: &openai.StreamReasoningContent{
Text: "thinking...",
},
},
Expand Down Expand Up @@ -2171,7 +2171,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody_WithReasoning
require.Equal(t, "9.11 is greater than 9.8.", *message.Content)

require.NotNil(t, message.ReasoningContent, "Reasoning content should not be nil")
reasoningBlock, _ := message.ReasoningContent.Value.(*openai.AWSBedrockReasoningContent)
reasoningBlock, _ := message.ReasoningContent.Value.(*openai.ReasoningContent)
require.NotNil(t, reasoningBlock, "The nested reasoning content block should not be nil")
require.NotEmpty(t, reasoningBlock.ReasoningContent.ReasoningText.Text, "The reasoning text itself should not be empty")

Expand Down
10 changes: 8 additions & 2 deletions internal/translator/openai_gcpvertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,14 @@ func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) geminiCandidatesToOpenAI
Role: openai.ChatMessageRoleAssistant,
}

// Extract text from parts for streaming (delta).
content := extractTextFromGeminiParts(candidate.Content.Parts, responseMode)
// Extract thought summary and text from parts for streaming (delta).
thoughtSummary, content := extractTextAndThoughtSummaryFromGeminiParts(candidate.Content.Parts, responseMode)
if thoughtSummary != "" {
delta.ReasoningContent = &openai.StreamReasoningContent{
Text: thoughtSummary,
}
}

if content != "" {
delta.Content = &content
}
Expand Down
85 changes: 85 additions & 0 deletions internal/translator/openai_gcpvertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,91 @@ data: [DONE]
}`),
wantTokenUsage: tokenUsageFrom(8, 0, 12, 20),
},
{
name: "response with thought summary",
respHeaders: map[string]string{
"content-type": "application/json",
},
body: `{
"candidates": [
{
"content": {
"parts": [
{
"text": "Let me think step by step.",
"thought": true
},
{
"text": "AI Gateways act as intermediaries between clients and LLM services."
}
]
},
"finishReason": "STOP",
"safetyRatings": []
}
],
"promptFeedback": {
"safetyRatings": []
},
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 15,
"totalTokenCount": 25,
"cachedContentTokenCount": 10,
"thoughtsTokenCount": 10
}
}`,
endOfStream: true,
wantError: false,
wantHeaderMut: []internalapi.Header{{contentLengthHeaderName, "450"}},
wantBodyMut: []byte(`{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "AI Gateways act as intermediaries between clients and LLM services.",
"reasoning_content": {"reasoningContent": {"reasoningText": {"text": "Let me think step by step."}}},
"role": "assistant"
}
}
],
"object": "chat.completion",
"usage": {
"completion_tokens": 25,
"completion_tokens_details": {
"reasoning_tokens": 10
},
"prompt_tokens": 10,
"prompt_tokens_details": {
"cached_tokens": 10
},
"total_tokens": 25
}
}`),

wantTokenUsage: tokenUsageFrom(10, 10, 15, 25),
},
{
name: "stream chunks with thought summary",
respHeaders: map[string]string{
"content-type": "application/json",
},
body: `data: {"candidates":[{"content":{"parts":[{"text":"let me think step by step and reply you.", "thought": true}]}}]}

data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3,"totalTokenCount":8}}`,
stream: true,
endOfStream: true,
wantError: false,
wantHeaderMut: nil,
wantBodyMut: []byte(`data: {"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":{"text":"let me think step by step and reply you."}}}],"object":"chat.completion.chunk"}

data: {"choices":[{"index":0,"delta":{"content":"Hello","role":"assistant"}}],"object":"chat.completion.chunk","usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8,"completion_tokens_details":{},"prompt_tokens_details":{}}}

data: [DONE]
`),
wantTokenUsage: tokenUsageFrom(5, 0, 3, 8),
},
}

for _, tc := range tests {
Expand Down
Loading