From 724fbe211f644ea0bb7cd2a7fd30e38a5eda4ea1 Mon Sep 17 00:00:00 2001 From: Chang Min Date: Wed, 4 Feb 2026 21:02:02 -0500 Subject: [PATCH 01/11] feat: ContentBlockParam fields added Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 187 ++++++++++++++++-- .../apischema/anthropic/anthropic_test.go | 168 ++++++++++++++++ 2 files changed, 343 insertions(+), 12 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index a73511d870..abd060f043 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -143,18 +143,102 @@ func (m *MessageContent) MarshalJSON() ([]byte, error) { type ( // ContentBlockParam represents an element of the array content in a message. - // https://docs.claude.com/en/api/messages#body-messages-content + // https://platform.claude.com/docs/en/api/messages#body-messages-content ContentBlockParam struct { - Text *TextBlockParam - // TODO add others when we need it for observability, etc. - } - + Text *TextBlockParam + Image *ImageBlockParam + Document *DocumentBlockParam + SearchResult *SearchResultBlockParam + Thinking *ThinkingBlockParam + RedactedThinking *RedactedThinkingBlockParam + ToolUse *ToolUseBlockParam + ToolResult *ToolResultBlockParam + ServerToolUse *ServerToolUseBlockParam + WebSearchToolResult *WebSearchToolResultBlockParam + } + + // TextBlockParam represents a text content block. TextBlockParam struct { Text string `json:"text"` Type string `json:"type"` // Always "text". CacheControl any `json:"cache_control,omitempty"` Citations []any `json:"citations,omitempty"` } + + // ImageBlockParam represents an image content block. + ImageBlockParam struct { + Type string `json:"type"` // Always "image". + Source any `json:"source"` + CacheControl any `json:"cache_control,omitempty"` + } + + // DocumentBlockParam represents a document content block. + DocumentBlockParam struct { + Type string `json:"type"` // Always "document". + Source any `json:"source"` + CacheControl any `json:"cache_control,omitempty"` + Citations any `json:"citations,omitempty"` + Context string `json:"context,omitempty"` + Title string `json:"title,omitempty"` + } + + // SearchResultBlockParam represents a search result content block. + SearchResultBlockParam struct { + Type string `json:"type"` // Always "search_result". + Content []TextBlockParam `json:"content"` + Source string `json:"source"` + Title string `json:"title"` + CacheControl any `json:"cache_control,omitempty"` + Citations any `json:"citations,omitempty"` + } + + // ThinkingBlockParam represents a thinking content block in a request. + ThinkingBlockParam struct { + Type string `json:"type"` // Always "thinking". + Thinking string `json:"thinking"` + Signature string `json:"signature"` + } + + // RedactedThinkingBlockParam represents a redacted thinking content block. + RedactedThinkingBlockParam struct { + Type string `json:"type"` // Always "redacted_thinking". + Data string `json:"data"` + } + + // ToolUseBlockParam represents a tool use content block in a request. + ToolUseBlockParam struct { + Type string `json:"type"` // Always "tool_use". + ID string `json:"id"` + Name string `json:"name"` + Input map[string]any `json:"input"` + CacheControl any `json:"cache_control,omitempty"` + } + + // ToolResultBlockParam represents a tool result content block. + ToolResultBlockParam struct { + Type string `json:"type"` // Always "tool_result". + ToolUseID string `json:"tool_use_id"` + Content any `json:"content,omitempty"` // string or array of content blocks. + IsError bool `json:"is_error,omitempty"` + CacheControl any `json:"cache_control,omitempty"` + } + + // ServerToolUseBlockParam represents a server tool use content block. + ServerToolUseBlockParam struct { + Type string `json:"type"` // Always "server_tool_use". + ID string `json:"id"` + Name string `json:"name"` + Input map[string]any `json:"input"` + CacheControl any `json:"cache_control,omitempty"` + } + + // WebSearchToolResultBlockParam represents a web search tool result content block. + WebSearchToolResultBlockParam struct { + Type string `json:"type"` // Always "web_search_tool_result". + ToolUseID string `json:"tool_use_id"` + Content any `json:"content"` + CacheControl any `json:"cache_control,omitempty"` + } ) func (m *ContentBlockParam) UnmarshalJSON(data []byte) error { @@ -164,24 +248,103 @@ func (m *ContentBlockParam) UnmarshalJSON(data []byte) error { } switch typ.String() { case "text": - var textBlock TextBlockParam - if err := json.Unmarshal(data, &textBlock); err != nil { + var block TextBlockParam + if err := json.Unmarshal(data, &block); err != nil { return fmt.Errorf("failed to unmarshal text block: %w", err) } - m.Text = &textBlock - return nil + m.Text = &block + case "image": + var block ImageBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal image block: %w", err) + } + m.Image = &block + case "document": + var block DocumentBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal document block: %w", err) + } + m.Document = &block + case "search_result": + var block SearchResultBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal search result block: %w", err) + } + m.SearchResult = &block + case "thinking": + var block ThinkingBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal thinking block: %w", err) + } + m.Thinking = &block + case "redacted_thinking": + var block RedactedThinkingBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal redacted thinking block: %w", err) + } + m.RedactedThinking = &block + case "tool_use": + var block ToolUseBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal tool use block: %w", err) + } + m.ToolUse = &block + case "tool_result": + var block ToolResultBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal tool result block: %w", err) + } + m.ToolResult = &block + case "server_tool_use": + var block ServerToolUseBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal server tool use block: %w", err) + } + m.ServerToolUse = &block + case "web_search_tool_result": + var block WebSearchToolResultBlockParam + if err := json.Unmarshal(data, &block); err != nil { + return fmt.Errorf("failed to unmarshal web search tool result block: %w", err) + } + m.WebSearchToolResult = &block default: - // TODO add others when we need it for observability, etc. - // Fow now, we ignore undefined types. + // Ignore unknown types for forward compatibility. return nil } + return nil } func (m *ContentBlockParam) MarshalJSON() ([]byte, error) { if m.Text != nil { return json.Marshal(m.Text) } - // TODO add others when we need it for observability, etc. + if m.Image != nil { + return json.Marshal(m.Image) + } + if m.Document != nil { + return json.Marshal(m.Document) + } + if m.SearchResult != nil { + return json.Marshal(m.SearchResult) + } + if m.Thinking != nil { + return json.Marshal(m.Thinking) + } + if m.RedactedThinking != nil { + return json.Marshal(m.RedactedThinking) + } + if m.ToolUse != nil { + return json.Marshal(m.ToolUse) + } + if m.ToolResult != nil { + return json.Marshal(m.ToolResult) + } + if m.ServerToolUse != nil { + return json.Marshal(m.ServerToolUse) + } + if m.WebSearchToolResult != nil { + return json.Marshal(m.WebSearchToolResult) + } return nil, fmt.Errorf("content block must have a defined type") } diff --git a/internal/apischema/anthropic/anthropic_test.go b/internal/apischema/anthropic/anthropic_test.go index 9c85fcbc4c..6c37810180 100644 --- a/internal/apischema/anthropic/anthropic_test.go +++ b/internal/apischema/anthropic/anthropic_test.go @@ -276,6 +276,91 @@ func TestContentBlockParam_UnmarshalJSON(t *testing.T) { jsonStr: `{"type": "text", "text": "Hello"}`, want: ContentBlockParam{Text: &TextBlockParam{Text: "Hello", Type: "text"}}, }, + { + name: "image block", + jsonStr: `{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "abc123"}}`, + want: ContentBlockParam{Image: &ImageBlockParam{ + Type: "image", + Source: map[string]any{"type": "base64", "media_type": "image/png", "data": "abc123"}, + }}, + }, + { + name: "document block", + jsonStr: `{"type": "document", "source": {"type": "text", "data": "hello", "media_type": "text/plain"}, "context": "some context", "title": "doc title"}`, + want: ContentBlockParam{Document: &DocumentBlockParam{ + Type: "document", + Source: map[string]any{"type": "text", "data": "hello", "media_type": "text/plain"}, + Context: "some context", + Title: "doc title", + }}, + }, + { + name: "search result block", + jsonStr: `{"type": "search_result", "source": "https://example.com", "title": "Example", "content": [{"type": "text", "text": "result text"}]}`, + want: ContentBlockParam{SearchResult: &SearchResultBlockParam{ + Type: "search_result", + Source: "https://example.com", + Title: "Example", + Content: []TextBlockParam{{Type: "text", Text: "result text"}}, + }}, + }, + { + name: "thinking block", + jsonStr: `{"type": "thinking", "thinking": "Let me think...", "signature": "sig123"}`, + want: ContentBlockParam{Thinking: &ThinkingBlockParam{ + Type: "thinking", + Thinking: "Let me think...", + Signature: "sig123", + }}, + }, + { + name: "redacted thinking block", + jsonStr: `{"type": "redacted_thinking", "data": "redacted_data_here"}`, + want: ContentBlockParam{RedactedThinking: &RedactedThinkingBlockParam{ + Type: "redacted_thinking", + Data: "redacted_data_here", + }}, + }, + { + name: "tool use block", + jsonStr: `{"type": "tool_use", "id": "tu_123", "name": "my_tool", "input": {"query": "test"}}`, + want: ContentBlockParam{ToolUse: &ToolUseBlockParam{ + Type: "tool_use", + ID: "tu_123", + Name: "my_tool", + Input: map[string]any{"query": "test"}, + }}, + }, + { + name: "tool result block", + jsonStr: `{"type": "tool_result", "tool_use_id": "tu_123", "content": "result text", "is_error": false}`, + want: ContentBlockParam{ToolResult: &ToolResultBlockParam{ + Type: "tool_result", + ToolUseID: "tu_123", + Content: "result text", + }}, + }, + { + name: "server tool use block", + jsonStr: `{"type": "server_tool_use", "id": "stu_123", "name": "web_search", "input": {"query": "test"}}`, + want: ContentBlockParam{ServerToolUse: &ServerToolUseBlockParam{ + Type: "server_tool_use", + ID: "stu_123", + Name: "web_search", + Input: map[string]any{"query": "test"}, + }}, + }, + { + name: "web search tool result block", + jsonStr: `{"type": "web_search_tool_result", "tool_use_id": "stu_123", "content": [{"type": "web_search_result", "title": "Result", "url": "https://example.com", "encrypted_content": "enc123"}]}`, + want: ContentBlockParam{WebSearchToolResult: &WebSearchToolResultBlockParam{ + Type: "web_search_tool_result", + ToolUseID: "stu_123", + Content: []any{ + map[string]any{"type": "web_search_result", "title": "Result", "url": "https://example.com", "encrypted_content": "enc123"}, + }, + }}, + }, { name: "missing type", jsonStr: `{"text": "Hello"}`, @@ -314,6 +399,89 @@ func TestContentBlockParam_MarshalJSON(t *testing.T) { cbp: ContentBlockParam{Text: &TextBlockParam{Text: "Hello", Type: "text"}}, want: `{"text":"Hello","type":"text"}`, }, + { + name: "image block", + cbp: ContentBlockParam{Image: &ImageBlockParam{ + Type: "image", + Source: map[string]any{"type": "base64", "media_type": "image/png", "data": "abc123"}, + }}, + want: `{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc123"}}`, + }, + { + name: "document block", + cbp: ContentBlockParam{Document: &DocumentBlockParam{ + Type: "document", + Source: map[string]any{"type": "text", "data": "hello", "media_type": "text/plain"}, + Context: "some context", + Title: "doc title", + }}, + want: `{"type":"document","source":{"type":"text","data":"hello","media_type":"text/plain"},"context":"some context","title":"doc title"}`, + }, + { + name: "search result block", + cbp: ContentBlockParam{SearchResult: &SearchResultBlockParam{ + Type: "search_result", + Source: "https://example.com", + Title: "Example", + Content: []TextBlockParam{{Type: "text", Text: "result text"}}, + }}, + want: `{"type":"search_result","content":[{"type":"text","text":"result text"}],"source":"https://example.com","title":"Example"}`, + }, + { + name: "thinking block", + cbp: ContentBlockParam{Thinking: &ThinkingBlockParam{ + Type: "thinking", + Thinking: "Let me think...", + Signature: "sig123", + }}, + want: `{"type":"thinking","thinking":"Let me think...","signature":"sig123"}`, + }, + { + name: "redacted thinking block", + cbp: ContentBlockParam{RedactedThinking: &RedactedThinkingBlockParam{ + Type: "redacted_thinking", + Data: "redacted_data_here", + }}, + want: `{"type":"redacted_thinking","data":"redacted_data_here"}`, + }, + { + name: "tool use block", + cbp: ContentBlockParam{ToolUse: &ToolUseBlockParam{ + Type: "tool_use", + ID: "tu_123", + Name: "my_tool", + Input: map[string]any{"query": "test"}, + }}, + want: `{"type":"tool_use","id":"tu_123","name":"my_tool","input":{"query":"test"}}`, + }, + { + name: "tool result block", + cbp: ContentBlockParam{ToolResult: &ToolResultBlockParam{ + Type: "tool_result", + ToolUseID: "tu_123", + Content: "result text", + }}, + want: `{"type":"tool_result","tool_use_id":"tu_123","content":"result text"}`, + }, + { + name: "server tool use block", + cbp: ContentBlockParam{ServerToolUse: &ServerToolUseBlockParam{ + Type: "server_tool_use", + ID: "stu_123", + Name: "web_search", + Input: map[string]any{"query": "test"}, + }}, + want: `{"type":"server_tool_use","id":"stu_123","name":"web_search","input":{"query":"test"}}`, + }, + { + name: "web search tool result block", + cbp: ContentBlockParam{WebSearchToolResult: &WebSearchToolResultBlockParam{ + Type: "web_search_tool_result", + ToolUseID: "stu_123", + Content: "some content", + }}, + want: `{"type":"web_search_tool_result","tool_use_id":"stu_123","content":"some content"}`, + }, { name: "empty block", cbp: ContentBlockParam{}, From 49cd009b9496213b587c71c2beb72124b7b37100 Mon Sep 17 00:00:00 2001 From: Chang Min Date: Thu, 5 Feb 2026 14:37:13 -0500 Subject: [PATCH 02/11] feat: Anthropic toolunion messages api Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 139 +++++++++++++++++- .../translator/anthropic_gcpanthropic_test.go | 12 +- 2 files changed, 137 insertions(+), 14 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index abd060f043..e61258f436 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -76,7 +76,7 @@ type MessagesRequest struct { // Tools is the list of tools available to the model. // https://docs.claude.com/en/api/messages#body-tools - Tools []Tool `json:"tools,omitempty"` + Tools []ToolUnion `json:"tools,omitempty"` // Stream indicates whether to stream the response. Stream bool `json:"stream,omitempty"` @@ -158,6 +158,7 @@ type ( } // TextBlockParam represents a text content block. + // https://platform.claude.com/docs/en/api/messages#text_block_param TextBlockParam struct { Text string `json:"text"` Type string `json:"type"` // Always "text". @@ -166,6 +167,7 @@ type ( } // ImageBlockParam represents an image content block. + // https://platform.claude.com/docs/en/api/messages#image_block_param ImageBlockParam struct { Type string `json:"type"` // Always "image". Source any `json:"source"` @@ -173,6 +175,7 @@ type ( } // DocumentBlockParam represents a document content block. + // https://platform.claude.com/docs/en/api/messages#document_block_param DocumentBlockParam struct { Type string `json:"type"` // Always "document". Source any `json:"source"` @@ -183,6 +186,7 @@ type ( } // SearchResultBlockParam represents a search result content block. + // https://platform.claude.com/docs/en/api/messages#search_result_block_param SearchResultBlockParam struct { Type string `json:"type"` // Always "search_result". Content []TextBlockParam `json:"content"` @@ -193,6 +197,7 @@ type ( } // ThinkingBlockParam represents a thinking content block in a request. + // https://platform.claude.com/docs/en/api/messages#thinking_block_param ThinkingBlockParam struct { Type string `json:"type"` // Always "thinking". Thinking string `json:"thinking"` @@ -200,12 +205,14 @@ type ( } // RedactedThinkingBlockParam represents a redacted thinking content block. + // https://platform.claude.com/docs/en/api/messages#redacted_thinking_block_param RedactedThinkingBlockParam struct { Type string `json:"type"` // Always "redacted_thinking". Data string `json:"data"` } // ToolUseBlockParam represents a tool use content block in a request. + // https://platform.claude.com/docs/en/api/messages#tool_use_block_param ToolUseBlockParam struct { Type string `json:"type"` // Always "tool_use". ID string `json:"id"` @@ -215,6 +222,7 @@ type ( } // ToolResultBlockParam represents a tool result content block. + // https://platform.claude.com/docs/en/api/messages#tool_result_block_param ToolResultBlockParam struct { Type string `json:"type"` // Always "tool_result". ToolUseID string `json:"tool_use_id"` @@ -224,6 +232,7 @@ type ( } // ServerToolUseBlockParam represents a server tool use content block. + // https://platform.claude.com/docs/en/api/messages#server_tool_use_block_param ServerToolUseBlockParam struct { Type string `json:"type"` // Always "server_tool_use". ID string `json:"id"` @@ -233,6 +242,7 @@ type ( } // WebSearchToolResultBlockParam represents a web search tool result content block. + // https://platform.claude.com/docs/en/api/messages#web_search_tool_result_block_param WebSearchToolResultBlockParam struct { Type string `json:"type"` // Always "web_search_tool_result". ToolUseID string `json:"tool_use_id"` @@ -366,24 +376,86 @@ const ( ) // Container represents a container identifier for reuse across requests. -// https://docs.claude.com/en/api/messages#body-container +// This became a beta status so it is not implemented for now. +// https://platform.claude.com/docs/en/api/beta/messages/create type Container any // TODO when we need it for observability, etc. type ( // ToolUnion represents a tool available to the model. // https://platform.claude.com/docs/en/api/messages#tool_union ToolUnion struct { - Tool *Tool - // TODO when we need it for observability, etc. + Tool *Tool + BashTool *BashTool + TextEditorTool20250124 *TextEditorTool20250124 + TextEditorTool20250429 *TextEditorTool20250429 + TextEditorTool20250728 *TextEditorTool20250728 + WebSearchTool *WebSearchTool } + + // Tool represents a custom tool definition. + // https://platform.claude.com/docs/en/api/messages#tool Tool struct { Type string `json:"type"` // Always "custom". Name string `json:"name"` InputSchema ToolInputSchema `json:"input_schema"` - CacheControl any `json:"cache_schema,omitempty"` + CacheControl any `json:"cache_control,omitempty"` Description string `json:"description,omitempty"` } + // BashTool represents the bash tool for computer use. + // https://platform.claude.com/docs/en/api/messages#tool_bash_20250124 + BashTool struct { + Type string `json:"type"` // Always "bash_20250124". + Name string `json:"name"` // Always "bash". + CacheControl any `json:"cache_control,omitempty"` + } + + // TextEditorTool20250124 represents the text editor tool (v1). + // https://platform.claude.com/docs/en/api/messages#tool_text_editor_20250124 + TextEditorTool20250124 struct { + Type string `json:"type"` // Always "text_editor_20250124". + Name string `json:"name"` // Always "str_replace_editor". + CacheControl any `json:"cache_control,omitempty"` + } + + // TextEditorTool20250429 represents the text editor tool (v2). + // https://platform.claude.com/docs/en/api/messages#tool_text_editor_20250429 + TextEditorTool20250429 struct { + Type string `json:"type"` // Always "text_editor_20250429". + Name string `json:"name"` // Always "str_replace_based_edit_tool". + CacheControl any `json:"cache_control,omitempty"` + } + + // TextEditorTool20250728 represents the text editor tool (v3). + // https://platform.claude.com/docs/en/api/messages#tool_text_editor_20250728 + TextEditorTool20250728 struct { + Type string `json:"type"` // Always "text_editor_20250728". + Name string `json:"name"` // Always "str_replace_based_edit_tool". + MaxCharacters *float64 `json:"max_characters,omitempty"` + CacheControl any `json:"cache_control,omitempty"` + } + + // WebSearchTool represents the web search tool. + // https://platform.claude.com/docs/en/api/messages#web_search_tool_20250305 + WebSearchTool struct { + Type string `json:"type"` // Always "web_search_20250305". + Name string `json:"name"` // Always "web_search". + AllowedDomains []string `json:"allowed_domains,omitempty"` + BlockedDomains []string `json:"blocked_domains,omitempty"` + MaxUses *float64 `json:"max_uses,omitempty"` + UserLocation *WebSearchLocation `json:"user_location,omitempty"` + CacheControl any `json:"cache_control,omitempty"` + } + + // WebSearchLocation represents the user location for the web search tool. + WebSearchLocation struct { + Type string `json:"type"` // Always "approximate". + City string `json:"city,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + Timezone string `json:"timezone,omitempty"` + } + ToolInputSchema struct { Type string `json:"type"` // Always "object". Properties map[string]any `json:"properties,omitempty"` @@ -403,12 +475,63 @@ func (t *ToolUnion) UnmarshalJSON(data []byte) error { return fmt.Errorf("failed to unmarshal tool: %w", err) } t.Tool = &tool - return nil + case "bash_20250124": + var tool BashTool + if err := json.Unmarshal(data, &tool); err != nil { + return fmt.Errorf("failed to unmarshal bash tool: %w", err) + } + t.BashTool = &tool + case "text_editor_20250124": + var tool TextEditorTool20250124 + if err := json.Unmarshal(data, &tool); err != nil { + return fmt.Errorf("failed to unmarshal text editor tool: %w", err) + } + t.TextEditorTool20250124 = &tool + case "text_editor_20250429": + var tool TextEditorTool20250429 + if err := json.Unmarshal(data, &tool); err != nil { + return fmt.Errorf("failed to unmarshal text editor tool: %w", err) + } + t.TextEditorTool20250429 = &tool + case "text_editor_20250728": + var tool TextEditorTool20250728 + if err := json.Unmarshal(data, &tool); err != nil { + return fmt.Errorf("failed to unmarshal text editor tool: %w", err) + } + t.TextEditorTool20250728 = &tool + case "web_search_20250305": + var tool WebSearchTool + if err := json.Unmarshal(data, &tool); err != nil { + return fmt.Errorf("failed to unmarshal web search tool: %w", err) + } + t.WebSearchTool = &tool default: - // TODO add others when we need it for observability, etc. - // Fow now, we ignore undefined types. + // Ignore unknown types for forward compatibility. return nil } + return nil +} + +func (t *ToolUnion) MarshalJSON() ([]byte, error) { + if t.Tool != nil { + return json.Marshal(t.Tool) + } + if t.BashTool != nil { + return json.Marshal(t.BashTool) + } + if t.TextEditorTool20250124 != nil { + return json.Marshal(t.TextEditorTool20250124) + } + if t.TextEditorTool20250429 != nil { + return json.Marshal(t.TextEditorTool20250429) + } + if t.TextEditorTool20250728 != nil { + return json.Marshal(t.TextEditorTool20250728) + } + if t.WebSearchTool != nil { + return json.Marshal(t.WebSearchTool) + } + return nil, fmt.Errorf("tool union must have a defined type") } // ToolChoice represents the tool choice for the model. diff --git a/internal/translator/anthropic_gcpanthropic_test.go b/internal/translator/anthropic_gcpanthropic_test.go index ad6a249af8..d6369768f1 100644 --- a/internal/translator/anthropic_gcpanthropic_test.go +++ b/internal/translator/anthropic_gcpanthropic_test.go @@ -122,8 +122,8 @@ func TestAnthropicToGCPAnthropicTranslator_ComprehensiveMarshalling(t *testing.T TopP: func() *float64 { v := 0.95; return &v }(), StopSequences: []string{"Human:", "Assistant:"}, System: &anthropic.SystemPrompt{Text: "You are a helpful weather assistant."}, - Tools: []anthropic.Tool{ - { + Tools: []anthropic.ToolUnion{ + {Tool: &anthropic.Tool{ Name: "get_weather", Description: "Get current weather information", InputSchema: anthropic.ToolInputSchema{ @@ -136,7 +136,7 @@ func TestAnthropicToGCPAnthropicTranslator_ComprehensiveMarshalling(t *testing.T }, Required: []string{"location"}, }, - }, + }}, }, ToolChoice: ptr.To(anthropic.ToolChoice(map[string]any{ "type": "auto", @@ -348,8 +348,8 @@ func TestAnthropicToGCPAnthropicTranslator_RequestBody_FieldPassthrough(t *testi StopSequences: []string{"Human:", "Assistant:"}, Stream: false, System: &anthropic.SystemPrompt{Text: "You are a helpful assistant"}, - Tools: []anthropic.Tool{ - { + Tools: []anthropic.ToolUnion{ + {Tool: &anthropic.Tool{ Name: "get_weather", Description: "Get weather info", InputSchema: anthropic.ToolInputSchema{ @@ -358,7 +358,7 @@ func TestAnthropicToGCPAnthropicTranslator_RequestBody_FieldPassthrough(t *testi "location": map[string]any{"type": "string"}, }, }, - }, + }}, }, ToolChoice: ptr.To(anthropic.ToolChoice(map[string]any{ "type": "auto", From aa625e565c6d25fd438511778138a46fe317930c Mon Sep 17 00:00:00 2001 From: Chang Min Date: Thu, 5 Feb 2026 14:46:46 -0500 Subject: [PATCH 03/11] feat: Anthropic tool choice messages api Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 100 +++++++++++++++++- .../translator/anthropic_gcpanthropic_test.go | 10 +- 2 files changed, 98 insertions(+), 12 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index e61258f436..b9f9035738 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -534,9 +534,97 @@ func (t *ToolUnion) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("tool union must have a defined type") } -// ToolChoice represents the tool choice for the model. -// https://docs.claude.com/en/api/messages#body-tool-choice -type ToolChoice any // TODO when we need it for observability, etc. +type ( + // ToolChoice represents the tool choice for the model. + // https://platform.claude.com/docs/en/api/messages#body-tool-choice + ToolChoice struct { + Auto *ToolChoiceAuto + Any *ToolChoiceAny + Tool *ToolChoiceTool + None *ToolChoiceNone + } + + // ToolChoiceAuto lets the model automatically decide whether to use tools. + // https://platform.claude.com/docs/en/api/messages#tool_choice_auto + ToolChoiceAuto struct { + Type string `json:"type"` // Always "auto". + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` + } + + // ToolChoiceAny forces the model to use any available tool. + // https://platform.claude.com/docs/en/api/messages#tool_choice_any + ToolChoiceAny struct { + Type string `json:"type"` // Always "any". + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` + } + + // ToolChoiceTool forces the model to use the specified tool. + // https://platform.claude.com/docs/en/api/messages#tool_choice_tool + ToolChoiceTool struct { + Type string `json:"type"` // Always "tool". + Name string `json:"name"` + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` + } + + // ToolChoiceNone prevents the model from using any tools. + // https://platform.claude.com/docs/en/api/messages#tool_choice_none + ToolChoiceNone struct { + Type string `json:"type"` // Always "none". + } +) + +func (tc *ToolChoice) UnmarshalJSON(data []byte) error { + typ := gjson.GetBytes(data, "type") + if !typ.Exists() { + return errors.New("missing type field in tool choice") + } + switch typ.String() { + case "auto": + var v ToolChoiceAuto + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal tool choice auto: %w", err) + } + tc.Auto = &v + case "any": + var v ToolChoiceAny + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal tool choice any: %w", err) + } + tc.Any = &v + case "tool": + var v ToolChoiceTool + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal tool choice tool: %w", err) + } + tc.Tool = &v + case "none": + var v ToolChoiceNone + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal tool choice none: %w", err) + } + tc.None = &v + default: + // Ignore unknown types for forward compatibility. + return nil + } + return nil +} + +func (tc *ToolChoice) MarshalJSON() ([]byte, error) { + if tc.Auto != nil { + return json.Marshal(tc.Auto) + } + if tc.Any != nil { + return json.Marshal(tc.Any) + } + if tc.Tool != nil { + return json.Marshal(tc.Tool) + } + if tc.None != nil { + return json.Marshal(tc.None) + } + return nil, fmt.Errorf("tool choice must have a defined type") +} // Thinking represents the configuration for the model's "thinking" behavior. // https://docs.claude.com/en/api/messages#body-thinking @@ -577,11 +665,13 @@ func (s *SystemPrompt) MarshalJSON() ([]byte, error) { } // MCPServer represents an MCP server. -// https://docs.claude.com/en/api/messages#body-mcp-servers +// This became a beta status so it is not implemented for now. +// https://platform.claude.com/docs/en/api/beta/messages/create type MCPServer any // TODO when we need it for observability, etc. // ContextManagement represents the context management configuration. -// https://docs.claude.com/en/api/messages#body-context-management +// This became a beta status so it is not implemented for now. +// https://platform.claude.com/docs/en/api/beta/messages/create type ContextManagement any // TODO when we need it for observability, etc. // MessagesResponse represents a response from the Anthropic Messages API. diff --git a/internal/translator/anthropic_gcpanthropic_test.go b/internal/translator/anthropic_gcpanthropic_test.go index d6369768f1..542e193bca 100644 --- a/internal/translator/anthropic_gcpanthropic_test.go +++ b/internal/translator/anthropic_gcpanthropic_test.go @@ -138,9 +138,7 @@ func TestAnthropicToGCPAnthropicTranslator_ComprehensiveMarshalling(t *testing.T }, }}, }, - ToolChoice: ptr.To(anthropic.ToolChoice(map[string]any{ - "type": "auto", - })), + ToolChoice: &anthropic.ToolChoice{Auto: &anthropic.ToolChoiceAuto{Type: "auto"}}, } raw, err := json.Marshal(originalReq) @@ -360,10 +358,8 @@ func TestAnthropicToGCPAnthropicTranslator_RequestBody_FieldPassthrough(t *testi }, }}, }, - ToolChoice: ptr.To(anthropic.ToolChoice(map[string]any{ - "type": "auto", - })), - Metadata: &anthropic.MessagesMetadata{UserID: ptr.To("test123")}, + ToolChoice: &anthropic.ToolChoice{Auto: &anthropic.ToolChoiceAuto{Type: "auto"}}, + Metadata: &anthropic.MessagesMetadata{UserID: ptr.To("test123")}, } raw, err := json.Marshal(parsedReq) From 018d468d2ca76a3b8a7aef0d14d6208e4ac42ad8 Mon Sep 17 00:00:00 2001 From: Chang Min Date: Thu, 5 Feb 2026 14:56:46 -0500 Subject: [PATCH 04/11] feat: Anthropic request message thinking config Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 76 ++++++++++++++++++++++- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index b9f9035738..55927f5cb9 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -626,9 +626,79 @@ func (tc *ToolChoice) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("tool choice must have a defined type") } -// Thinking represents the configuration for the model's "thinking" behavior. -// https://docs.claude.com/en/api/messages#body-thinking -type Thinking any // TODO when we need it for observability, etc. +type ( + // Thinking represents the configuration for the model's "thinking" behavior. + // This is not to be confused with the thinking block that is part of the response message's contentblock + // https://platform.claude.com/docs/en/api/messages#body-thinking + Thinking struct { + Enabled *ThinkingEnabled + Disabled *ThinkingDisabled + Adaptive *ThinkingAdaptive + } + + // ThinkingEnabled enables extended thinking with a token budget. + // https://platform.claude.com/docs/en/api/messages#thinking_config_enabled + ThinkingEnabled struct { + Type string `json:"type"` // Always "enabled". + BudgetTokens float64 `json:"budget_tokens"` // Must be >= 1024 and < max_tokens. + } + + // ThinkingDisabled disables extended thinking. + // https://platform.claude.com/docs/en/api/messages#thinking_config_disabled + ThinkingDisabled struct { + Type string `json:"type"` // Always "disabled". + } + + // ThinkingAdaptive lets the model decide whether to use extended thinking. + // https://platform.claude.com/docs/en/api/messages#thinking_config_adaptive + ThinkingAdaptive struct { + Type string `json:"type"` // Always "adaptive". + } +) + +func (t *Thinking) UnmarshalJSON(data []byte) error { + typ := gjson.GetBytes(data, "type") + if !typ.Exists() { + return errors.New("missing type field in thinking config") + } + switch typ.String() { + case "enabled": + var v ThinkingEnabled + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal thinking enabled: %w", err) + } + t.Enabled = &v + case "disabled": + var v ThinkingDisabled + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal thinking disabled: %w", err) + } + t.Disabled = &v + case "adaptive": + var v ThinkingAdaptive + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("failed to unmarshal thinking adaptive: %w", err) + } + t.Adaptive = &v + default: + // Ignore unknown types for forward compatibility. + return nil + } + return nil +} + +func (t *Thinking) MarshalJSON() ([]byte, error) { + if t.Enabled != nil { + return json.Marshal(t.Enabled) + } + if t.Disabled != nil { + return json.Marshal(t.Disabled) + } + if t.Adaptive != nil { + return json.Marshal(t.Adaptive) + } + return nil, fmt.Errorf("thinking config must have a defined type") +} // SystemPrompt represents a system prompt to guide the model's behavior. // https://docs.claude.com/en/api/messages#body-system From e4da90e885dda791c038ada604bdaa629cf03216 Mon Sep 17 00:00:00 2001 From: Chang Min Date: Thu, 5 Feb 2026 15:14:58 -0500 Subject: [PATCH 05/11] feat: anthropic response message contentblock api Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 225 ++++++++++++++-------- 1 file changed, 140 insertions(+), 85 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index 55927f5cb9..ec8d1cef35 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -258,65 +258,65 @@ func (m *ContentBlockParam) UnmarshalJSON(data []byte) error { } switch typ.String() { case "text": - var block TextBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal text block: %w", err) + var blockParam TextBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal text blockParam: %w", err) } - m.Text = &block + m.Text = &blockParam case "image": - var block ImageBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal image block: %w", err) + var blockParam ImageBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal image blockParam: %w", err) } - m.Image = &block + m.Image = &blockParam case "document": - var block DocumentBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal document block: %w", err) + var blockParam DocumentBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal document blockParam: %w", err) } - m.Document = &block + m.Document = &blockParam case "search_result": - var block SearchResultBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal search result block: %w", err) + var blockParam SearchResultBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal search result blockParam: %w", err) } - m.SearchResult = &block + m.SearchResult = &blockParam case "thinking": - var block ThinkingBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal thinking block: %w", err) + var blockParam ThinkingBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal thinking blockParam: %w", err) } - m.Thinking = &block + m.Thinking = &blockParam case "redacted_thinking": - var block RedactedThinkingBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal redacted thinking block: %w", err) + var blockParam RedactedThinkingBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal redacted thinking blockParam: %w", err) } - m.RedactedThinking = &block + m.RedactedThinking = &blockParam case "tool_use": - var block ToolUseBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal tool use block: %w", err) + var blockParam ToolUseBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal tool use blockParam: %w", err) } - m.ToolUse = &block + m.ToolUse = &blockParam case "tool_result": - var block ToolResultBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal tool result block: %w", err) + var blockParam ToolResultBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal tool result blockParam: %w", err) } - m.ToolResult = &block + m.ToolResult = &blockParam case "server_tool_use": - var block ServerToolUseBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal server tool use block: %w", err) + var blockParam ServerToolUseBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal server tool use blockParam: %w", err) } - m.ServerToolUse = &block + m.ServerToolUse = &blockParam case "web_search_tool_result": - var block WebSearchToolResultBlockParam - if err := json.Unmarshal(data, &block); err != nil { - return fmt.Errorf("failed to unmarshal web search tool result block: %w", err) + var blockParam WebSearchToolResultBlockParam + if err := json.Unmarshal(data, &blockParam); err != nil { + return fmt.Errorf("failed to unmarshal web search tool result blockParam: %w", err) } - m.WebSearchToolResult = &block + m.WebSearchToolResult = &blockParam default: // Ignore unknown types for forward compatibility. return nil @@ -355,7 +355,7 @@ func (m *ContentBlockParam) MarshalJSON() ([]byte, error) { if m.WebSearchToolResult != nil { return json.Marshal(m.WebSearchToolResult) } - return nil, fmt.Errorf("content block must have a defined type") + return nil, fmt.Errorf("content block param must have a defined type") } // MessagesMetadata represents the metadata for the Anthropic Messages API request. @@ -580,29 +580,29 @@ func (tc *ToolChoice) UnmarshalJSON(data []byte) error { } switch typ.String() { case "auto": - var v ToolChoiceAuto - if err := json.Unmarshal(data, &v); err != nil { + var toolChoice ToolChoiceAuto + if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice auto: %w", err) } - tc.Auto = &v + tc.Auto = &toolChoice case "any": - var v ToolChoiceAny - if err := json.Unmarshal(data, &v); err != nil { + var toolChoice ToolChoiceAny + if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice any: %w", err) } - tc.Any = &v + tc.Any = &toolChoice case "tool": - var v ToolChoiceTool - if err := json.Unmarshal(data, &v); err != nil { + var toolChoice ToolChoiceTool + if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice tool: %w", err) } - tc.Tool = &v + tc.Tool = &toolChoice case "none": - var v ToolChoiceNone - if err := json.Unmarshal(data, &v); err != nil { + var toolChoice ToolChoiceNone + if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice none: %w", err) } - tc.None = &v + tc.None = &toolChoice default: // Ignore unknown types for forward compatibility. return nil @@ -663,23 +663,23 @@ func (t *Thinking) UnmarshalJSON(data []byte) error { } switch typ.String() { case "enabled": - var v ThinkingEnabled - if err := json.Unmarshal(data, &v); err != nil { + var thinking ThinkingEnabled + if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking enabled: %w", err) } - t.Enabled = &v + t.Enabled = &thinking case "disabled": - var v ThinkingDisabled - if err := json.Unmarshal(data, &v); err != nil { + var thinking ThinkingDisabled + if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking disabled: %w", err) } - t.Disabled = &v + t.Disabled = &thinking case "adaptive": - var v ThinkingAdaptive - if err := json.Unmarshal(data, &v); err != nil { + var thinking ThinkingAdaptive + if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking adaptive: %w", err) } - t.Adaptive = &v + t.Adaptive = &thinking default: // Ignore unknown types for forward compatibility. return nil @@ -785,20 +785,26 @@ type ConstantMessagesResponseRoleAssistant string type ( // MessagesContentBlock represents a block of content in the Anthropic Messages API response. - // https://docs.claude.com/en/api/messages#response-content + // https://platform.claude.com/docs/en/api/messages#response-content MessagesContentBlock struct { - Text *TextBlock - Tool *ToolUseBlock - Thinking *ThinkingBlock - // TODO when we need it for observability, etc. + Text *TextBlock + Tool *ToolUseBlock + Thinking *ThinkingBlock + RedactedThinking *RedactedThinkingBlock + ServerToolUse *ServerToolUseBlock + WebSearchToolResult *WebSearchToolResultBlock } + // TextBlock represents a text content block in the response. + // https://platform.claude.com/docs/en/api/messages#text_block TextBlock struct { - Type string `json:"type"` // Always "text". - Text string `json:"text"` - // TODO: citation? + Type string `json:"type"` // Always "text". + Text string `json:"text"` + Citations []any `json:"citations,omitempty"` } + // ToolUseBlock represents a tool use content block in the response. + // https://platform.claude.com/docs/en/api/messages#tool_use_block ToolUseBlock struct { Type string `json:"type"` // Always "tool_use". ID string `json:"id"` @@ -806,11 +812,37 @@ type ( Input map[string]any `json:"input"` } + // ThinkingBlock represents a thinking content block in the response. + // https://platform.claude.com/docs/en/api/messages#thinking_block ThinkingBlock struct { Type string `json:"type"` // Always "thinking". Thinking string `json:"thinking"` Signature string `json:"signature,omitempty"` } + + // RedactedThinkingBlock represents a redacted thinking content block in the response. + // https://platform.claude.com/docs/en/api/messages#redacted_thinking_block + RedactedThinkingBlock struct { + Type string `json:"type"` // Always "redacted_thinking". + Data string `json:"data"` + } + + // ServerToolUseBlock represents a server tool use content block in the response. + // https://platform.claude.com/docs/en/api/messages#server_tool_use_block + ServerToolUseBlock struct { + Type string `json:"type"` // Always "server_tool_use". + ID string `json:"id"` + Name string `json:"name"` // e.g. "web_search". + Input map[string]any `json:"input"` + } + + // WebSearchToolResultBlock represents a web search tool result content block in the response. + // https://platform.claude.com/docs/en/api/messages#web_search_tool_result_block + WebSearchToolResultBlock struct { + Type string `json:"type"` // Always "web_search_tool_result". + ToolUseID string `json:"tool_use_id"` + Content any `json:"content"` // Array of WebSearchResult or a WebSearchToolResultError. + } ) func (m *MessagesContentBlock) UnmarshalJSON(data []byte) error { @@ -820,31 +852,46 @@ func (m *MessagesContentBlock) UnmarshalJSON(data []byte) error { } switch typ.String() { case "text": - var textBlock TextBlock - if err := json.Unmarshal(data, &textBlock); err != nil { + var contentBlock TextBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal text block: %w", err) } - m.Text = &textBlock - return nil + m.Text = &contentBlock case "tool_use": - var toolUseBlock ToolUseBlock - if err := json.Unmarshal(data, &toolUseBlock); err != nil { + var contentBlock ToolUseBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal tool use block: %w", err) } - m.Tool = &toolUseBlock - return nil + m.Tool = &contentBlock case "thinking": - var thinkingBlock ThinkingBlock - if err := json.Unmarshal(data, &thinkingBlock); err != nil { + var contentBlock ThinkingBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal thinking block: %w", err) } - m.Thinking = &thinkingBlock - return nil + m.Thinking = &contentBlock + case "redacted_thinking": + var contentBlock RedactedThinkingBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { + return fmt.Errorf("failed to unmarshal redacted thinking block: %w", err) + } + m.RedactedThinking = &contentBlock + case "server_tool_use": + var contentBlock ServerToolUseBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { + return fmt.Errorf("failed to unmarshal server tool use block: %w", err) + } + m.ServerToolUse = &contentBlock + case "web_search_tool_result": + var contentBlock WebSearchToolResultBlock + if err := json.Unmarshal(data, &contentBlock); err != nil { + return fmt.Errorf("failed to unmarshal web search tool result block: %w", err) + } + m.WebSearchToolResult = &contentBlock default: - // TODO add others when we need it for observability, etc. - // Fow now, we ignore undefined types. + // Ignore unknown types for forward compatibility. return nil } + return nil } func (m *MessagesContentBlock) MarshalJSON() ([]byte, error) { @@ -857,7 +904,15 @@ func (m *MessagesContentBlock) MarshalJSON() ([]byte, error) { if m.Thinking != nil { return json.Marshal(m.Thinking) } - // TODO add others when we need it for observability, etc. + if m.RedactedThinking != nil { + return json.Marshal(m.RedactedThinking) + } + if m.ServerToolUse != nil { + return json.Marshal(m.ServerToolUse) + } + if m.WebSearchToolResult != nil { + return json.Marshal(m.WebSearchToolResult) + } return nil, fmt.Errorf("content block must have a defined type") } From 2fc2414c2929f9802c8c1cff1a821fefd0bfecce Mon Sep 17 00:00:00 2001 From: Johnu George Date: Fri, 6 Feb 2026 02:17:18 +0530 Subject: [PATCH 06/11] fix: trigger rollout when MCPRoute exists but extproc lacks -mcpAddr (#1836) --- internal/controller/gateway.go | 25 +++++++++- internal/controller/gateway_test.go | 74 ++++++++++++++++++++++++----- 2 files changed, 85 insertions(+), 14 deletions(-) diff --git a/internal/controller/gateway.go b/internal/controller/gateway.go index 70489921c1..9e605cb178 100644 --- a/internal/controller/gateway.go +++ b/internal/controller/gateway.go @@ -142,7 +142,7 @@ func (c *GatewayController) Reconcile(ctx context.Context, req ctrl.Request) (ct // Finally, we need to annotate the pods of the gateway deployment with the new uuid to propagate the filter config Secret update faster. // If the pod doesn't have the extproc container, it will roll out the deployment altogether which eventually ends up // the mutation hook invoked. - if err := c.annotateGatewayPods(ctx, pods, deployments, daemonSets, uid, hasEffectiveRoutes); err != nil { + if err := c.annotateGatewayPods(ctx, pods, deployments, daemonSets, uid, hasEffectiveRoutes, len(mcpRoutes.Items) > 0); err != nil { c.logger.Error(err, "Failed to annotate gateway pods", "namespace", gw.Namespace, "name", gw.Name) return ctrl.Result{}, err } @@ -755,6 +755,7 @@ func (c *GatewayController) annotateGatewayPods(ctx context.Context, daemonSets []appsv1.DaemonSet, uuid string, hasEffectiveRoute bool, + needMCP bool, ) error { hasSideCar := false for i := range pods { @@ -766,12 +767,23 @@ func (c *GatewayController) annotateGatewayPods(ctx context.Context, // If there's an extproc sidecar container with the current target image, we don't need to roll out the deployment. if podSpec.InitContainers[i].Name == extProcContainerName && podSpec.InitContainers[i].Image == c.extProcImage { hasSideCar = true + hasMCPAddr := false for j := range podSpec.InitContainers[i].Args { // logLevel arg should be indexed 2 based on gateway_mutator.go, but we check all args to be safe. if j > 0 && podSpec.InitContainers[i].Args[j-1] == "-logLevel" && podSpec.InitContainers[i].Args[j] != c.extProcLogLevel { hasSideCar = false break } + // Check if the -mcpAddr argument is present + if j > 0 && podSpec.InitContainers[i].Args[j-1] == "-mcpAddr" { + hasMCPAddr = true + } + } + // If MCPRoutes exist but the sidecar doesn't have -mcpAddr, we need to roll out + if needMCP && !hasMCPAddr { + c.logger.Info("MCPRoutes exist but sidecar is missing -mcpAddr argument, triggering rollout", + "pod", pod.Name, "namespace", pod.Namespace) + hasSideCar = false } break } @@ -781,11 +793,22 @@ func (c *GatewayController) annotateGatewayPods(ctx context.Context, // If there's an extproc container with the current target image, we don't need to roll out the deployment. if podSpec.Containers[i].Name == extProcContainerName && podSpec.Containers[i].Image == c.extProcImage { hasSideCar = true + hasMCPAddr := false for j := range podSpec.Containers[i].Args { if j > 0 && podSpec.Containers[i].Args[j-1] == "-logLevel" && podSpec.Containers[i].Args[j] != c.extProcLogLevel { hasSideCar = false break } + // Check if the -mcpAddr argument is present + if j > 0 && podSpec.Containers[i].Args[j-1] == "-mcpAddr" { + hasMCPAddr = true + } + } + // If MCPRoutes exist but the sidecar doesn't have -mcpAddr, we need to roll out + if needMCP && !hasMCPAddr { + c.logger.Info("MCPRoutes exist but sidecar is missing -mcpAddr argument, triggering rollout", + "pod", pod.Name, "namespace", pod.Namespace) + hasSideCar = false } break } diff --git a/internal/controller/gateway_test.go b/internal/controller/gateway_test.go index 4fdc728bcf..bfba5c5fab 100644 --- a/internal/controller/gateway_test.go +++ b/internal/controller/gateway_test.go @@ -681,7 +681,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) hasEffectiveRoute := true - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, nil, "some-uuid", hasEffectiveRoute) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, nil, "some-uuid", hasEffectiveRoute, false) require.NoError(t, err) annotated, err := kube.CoreV1().Pods(egNamespace).Get(t.Context(), "pod1", metav1.GetOptions{}) @@ -701,7 +701,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { // Since it has already a sidecar container, passing the hasEffectiveRoute=false should result in adding an annotation to the deployment. hasEffectiveRoute = false - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "another-uuid", hasEffectiveRoute) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "another-uuid", hasEffectiveRoute, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -734,7 +734,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { // When there's no effective route, this should not add the annotation to the deployment. hasEffectiveRoute := false - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", hasEffectiveRoute) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", hasEffectiveRoute, false) require.NoError(t, err) deployment, err = kube.AppsV1().Deployments(egNamespace).Get(t.Context(), "deployment1", metav1.GetOptions{}) require.NoError(t, err) @@ -743,7 +743,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { // When there's an effective route, this should add the annotation to the deployment. hasEffectiveRoute = true - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", hasEffectiveRoute) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", hasEffectiveRoute, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -778,7 +778,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -792,7 +792,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { require.NoError(t, err) // Call annotateGatewayPods again but the deployment's pod template should not be updated again. - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true, false) require.NoError(t, err) deployment, err = kube.AppsV1().Deployments(egNamespace).Get(t.Context(), "deployment2", metav1.GetOptions{}) @@ -826,7 +826,7 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -840,13 +840,61 @@ func TestGatewayController_annotateGatewayPods(t *testing.T) { require.NoError(t, err) // Call annotateGatewayPods again but the deployment's pod template should not be updated again. - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true, false) require.NoError(t, err) deployment, err = kube.AppsV1().Deployments(egNamespace).Get(t.Context(), "deployment3", metav1.GetOptions{}) require.NoError(t, err) require.Equal(t, "some-uuid", deployment.Spec.Template.Annotations[aigatewayUUIDAnnotationKey]) }) + + t.Run("pod with extproc but missing mcpAddr", func(t *testing.T) { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod5", + Namespace: egNamespace, + Labels: labels, + }, + Spec: corev1.PodSpec{InitContainers: []corev1.Container{ + {Name: extProcContainerName, Image: v2Container, Args: []string{"-logLevel", logLevel, "-adminPort", "1064"}}, + }}, + } + _, err := kube.CoreV1().Pods(egNamespace).Create(t.Context(), pod, metav1.CreateOptions{}) + require.NoError(t, err) + + deployment, err := kube.AppsV1().Deployments(egNamespace).Create(t.Context(), &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "deployment4", + Namespace: egNamespace, + Labels: labels, + }, + Spec: appsv1.DeploymentSpec{Template: corev1.PodTemplateSpec{ObjectMeta: metav1.ObjectMeta{}}}, + }, metav1.CreateOptions{}) + require.NoError(t, err) + + // Call with needMCP=true - should trigger rollout due to missing -mcpAddr + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "some-uuid", true, true) + require.NoError(t, err) + + // Check the deployment's pod template has the annotation (rollout triggered). + deployment, err = kube.AppsV1().Deployments(egNamespace).Get(t.Context(), "deployment4", metav1.GetOptions{}) + require.NoError(t, err) + require.Equal(t, "some-uuid", deployment.Spec.Template.Annotations[aigatewayUUIDAnnotationKey]) + + // Simulate new pod created after rollout with -mcpAddr present + pod.Spec.InitContainers[0].Args = []string{"-logLevel", logLevel, "-mcpAddr", ":9856", "-adminPort", "1064"} + pod, err = kube.CoreV1().Pods(egNamespace).Update(t.Context(), pod, metav1.UpdateOptions{}) + require.NoError(t, err) + + // Call annotateGatewayPods again - should NOT trigger another rollout + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, []appsv1.Deployment{*deployment}, nil, "another-uuid", true, true) + require.NoError(t, err) + + // Deployment annotation should remain unchanged (no new rollout) + deployment, err = kube.AppsV1().Deployments(egNamespace).Get(t.Context(), "deployment4", metav1.GetOptions{}) + require.NoError(t, err) + require.Equal(t, "some-uuid", deployment.Spec.Template.Annotations[aigatewayUUIDAnnotationKey]) + }) } func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { @@ -887,7 +935,7 @@ func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -922,7 +970,7 @@ func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -936,7 +984,7 @@ func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { require.NoError(t, err) // Call annotateGatewayPods again, but the deployment's pod template should not be updated again. - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true, false) require.NoError(t, err) deployment, err = kube.AppsV1().DaemonSets(egNamespace).Get(t.Context(), "deployment2", metav1.GetOptions{}) @@ -970,7 +1018,7 @@ func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { }, metav1.CreateOptions{}) require.NoError(t, err) - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true, false) require.NoError(t, err) // Check the deployment's pod template has the annotation. @@ -984,7 +1032,7 @@ func TestGatewayController_annotateDaemonSetGatewayPods(t *testing.T) { require.NoError(t, err) // Call annotateGatewayPods again, but the deployment's pod template should not be updated again. - err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true) + err = c.annotateGatewayPods(t.Context(), []corev1.Pod{*pod}, nil, []appsv1.DaemonSet{*dss}, "some-uuid", true, false) require.NoError(t, err) deployment, err = kube.AppsV1().DaemonSets(egNamespace).Get(t.Context(), "deployment3", metav1.GetOptions{}) From dd0b016a6d7a1b001725fa4c3ff57f4aeed2b1ef Mon Sep 17 00:00:00 2001 From: Sivanantham <90966311+sivanantha321@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:33:46 +0530 Subject: [PATCH 07/11] docs: fix helm release name in tracing documentation (#1841) --- site/docs/capabilities/observability/tracing.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/site/docs/capabilities/observability/tracing.md b/site/docs/capabilities/observability/tracing.md index 263f6e8f29..5686291ffb 100644 --- a/site/docs/capabilities/observability/tracing.md +++ b/site/docs/capabilities/observability/tracing.md @@ -58,7 +58,7 @@ kubectl wait --timeout=5m -n envoy-ai-gateway-system \ Upgrade your AI Gateway installation with [OpenTelemetry configuration][otel-config]: -{`helm upgrade ai-eg oci://docker.io/envoyproxy/ai-gateway-helm \\ +{`helm upgrade aieg oci://docker.io/envoyproxy/ai-gateway-helm \\ --version v${vars.aigwVersion} \\ --namespace envoy-ai-gateway-system \\ --set "extProc.extraEnvVars[0].name=OTEL_EXPORTER_OTLP_ENDPOINT" \\ @@ -179,7 +179,7 @@ Here's an example of keeping the default session mapping for spans/logs while only adding a low-cardinality attribute to metrics: -{`helm upgrade ai-eg oci://docker.io/envoyproxy/ai-gateway-helm \\ +{`helm upgrade aieg oci://docker.io/envoyproxy/ai-gateway-helm \\ --version v${vars.aigwVersion} \\ --namespace envoy-ai-gateway-system \\ --reuse-values \\ @@ -196,7 +196,7 @@ helm uninstall phoenix -n envoy-ai-gateway-system # Disable tracing in AI Gateway -helm upgrade ai-eg oci://docker.io/envoyproxy/ai-gateway-helm \\ +helm upgrade aieg oci://docker.io/envoyproxy/ai-gateway-helm \\ --version v${vars.aigwVersion} \\ --namespace envoy-ai-gateway-system \\ --reuse-values \\ From b1599d2e64065b7e27545be185296268626049a3 Mon Sep 17 00:00:00 2001 From: Xiaolin Lin Date: Sat, 7 Feb 2026 11:48:25 -0500 Subject: [PATCH 08/11] feat: support claude-opus-4.6 new thinking mode adaptive (#1842) **Description** Claude Opus 4.6 has new value `adaptive` for thinking [1] type 1. https://platform.claude.com/docs/en/about-claude/models/whats-new-claude-4-6#adaptive-thinking-mode Signed-off-by: Xiaolin Lin --- internal/apischema/openai/openai.go | 14 +++ internal/apischema/openai/union_test.go | 109 ++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go index e2ede0993d..de1e21284f 100644 --- a/internal/apischema/openai/openai.go +++ b/internal/apischema/openai/openai.go @@ -869,6 +869,7 @@ type WebSearchLocation struct { type ThinkingUnion struct { OfEnabled *ThinkingEnabled `json:",omitzero,inline"` OfDisabled *ThinkingDisabled `json:",omitzero,inline"` + OfAdaptive *ThinkingAdaptive `json:",omitzero,inline"` } type ThinkingEnabled struct { @@ -887,6 +888,10 @@ type ThinkingDisabled struct { Type string `json:"type,"` } +type ThinkingAdaptive struct { + Type string `json:"type,"` +} + // MarshalJSON implements the json.Marshaler interface for ThinkingUnion. func (t *ThinkingUnion) MarshalJSON() ([]byte, error) { if t.OfEnabled != nil { @@ -895,6 +900,9 @@ func (t *ThinkingUnion) MarshalJSON() ([]byte, error) { if t.OfDisabled != nil { return json.Marshal(t.OfDisabled) } + if t.OfAdaptive != nil { + return json.Marshal(t.OfAdaptive) + } // If both are nil, return an empty object or an error, depending on your desired behavior. return []byte(`{}`), nil } @@ -923,6 +931,12 @@ func (t *ThinkingUnion) UnmarshalJSON(data []byte) error { return err } t.OfDisabled = &disabled + case "adaptive": + var adaptive ThinkingAdaptive + if err := json.Unmarshal(data, &adaptive); err != nil { + return err + } + t.OfAdaptive = &adaptive default: return fmt.Errorf("invalid thinking union type: %s", typeVal) } diff --git a/internal/apischema/openai/union_test.go b/internal/apischema/openai/union_test.go index 83e8283687..1b37277ff0 100644 --- a/internal/apischema/openai/union_test.go +++ b/internal/apischema/openai/union_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/json" ) // TestUnmarshalJSONNestedUnion tests the completion API prompt parsing. @@ -285,3 +287,110 @@ func TestUnmarshalJSONEmbeddingInput_Errors(t *testing.T) { }) } } + +func TestThinkingUnion_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + data string + expect ThinkingUnion + }{ + { + name: "enabled", + data: `{"type":"enabled","budget_tokens":1024}`, + expect: ThinkingUnion{ + OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024}, + }, + }, + { + name: "disabled", + data: `{"type":"disabled"}`, + expect: ThinkingUnion{ + OfDisabled: &ThinkingDisabled{Type: "disabled"}, + }, + }, + { + name: "adaptive", + data: `{"type":"adaptive"}`, + expect: ThinkingUnion{ + OfAdaptive: &ThinkingAdaptive{Type: "adaptive"}, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got ThinkingUnion + err := json.Unmarshal([]byte(tc.data), &got) + require.NoError(t, err) + require.Equal(t, tc.expect, got) + }) + } +} + +func TestThinkingUnion_UnmarshalJSON_Errors(t *testing.T) { + tests := []struct { + name string + data string + expectedErr string + }{ + { + name: "missing type field", + data: `{"budget_tokens":1024}`, + expectedErr: "thinking config does not have a type", + }, + { + name: "invalid type value", + data: `{"type":"unknown"}`, + expectedErr: "invalid thinking union type: unknown", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got ThinkingUnion + err := json.Unmarshal([]byte(tc.data), &got) + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedErr) + }) + } +} + +func TestThinkingUnion_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input ThinkingUnion + expect string + }{ + { + name: "enabled", + input: ThinkingUnion{ + OfEnabled: &ThinkingEnabled{Type: "enabled", BudgetTokens: 1024}, + }, + expect: `{"budget_tokens":1024,"type":"enabled"}`, + }, + { + name: "disabled", + input: ThinkingUnion{ + OfDisabled: &ThinkingDisabled{Type: "disabled"}, + }, + expect: `{"type":"disabled"}`, + }, + { + name: "adaptive", + input: ThinkingUnion{ + OfAdaptive: &ThinkingAdaptive{Type: "adaptive"}, + }, + expect: `{"type":"adaptive"}`, + }, + { + name: "all nil returns empty object", + input: ThinkingUnion{}, + expect: `{}`, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := json.Marshal(&tc.input) + require.NoError(t, err) + require.JSONEq(t, tc.expect, string(got)) + }) + } +} From 1cad24746765300e74b8c487430939014d721e5c Mon Sep 17 00:00:00 2001 From: hustxiayang Date: Sat, 7 Feb 2026 12:51:50 -0500 Subject: [PATCH 09/11] feat: add InvokeModel API support for claude models in aws bedrock (#1648) **Description** Add InvokeModel API support for claude models in aws bedrock. The motivation is to provide consistent services cross providers, https://github.com/envoyproxy/ai-gateway/issues/1644 for more details about the motivation. Other Changes: I put common codes related to anthropic into `anthropic_helper.go`, so that both aws and gcp can share these codes. --------- Signed-off-by: yxia216 Signed-off-by: Takeshi Yoneda Signed-off-by: Aaron Choo Signed-off-by: Dan Sun Co-authored-by: Takeshi Yoneda Co-authored-by: Aaron Choo Co-authored-by: Dan Sun --- api/v1alpha1/shared_types.go | 2 + .../basic/aws-bedrock-openai-anthropic.yaml | 109 ++ internal/endpointspec/endpointspec.go | 2 + internal/endpointspec/endpointspec_test.go | 1 + internal/filterapi/filterconfig.go | 5 +- internal/translator/anthropic_helper.go | 1133 ++++++++++++++++ internal/translator/anthropic_helper_test.go | 894 +++++++++++++ internal/translator/openai_awsanthropic.go | 261 ++++ .../translator/openai_awsanthropic_test.go | 812 ++++++++++++ internal/translator/openai_awsbedrock.go | 3 +- internal/translator/openai_gcpanthropic.go | 740 +---------- .../translator/openai_gcpanthropic_stream.go | 421 ------ .../openai_gcpanthropic_stream_test.go | 1031 --------------- .../translator/openai_gcpanthropic_test.go | 1170 +++++------------ site/docs/api/api.mdx | 2 +- 15 files changed, 3529 insertions(+), 3057 deletions(-) create mode 100644 examples/basic/aws-bedrock-openai-anthropic.yaml create mode 100644 internal/translator/anthropic_helper.go create mode 100644 internal/translator/anthropic_helper_test.go create mode 100644 internal/translator/openai_awsanthropic.go create mode 100644 internal/translator/openai_awsanthropic_test.go delete mode 100644 internal/translator/openai_gcpanthropic_stream.go delete mode 100644 internal/translator/openai_gcpanthropic_stream_test.go diff --git a/api/v1alpha1/shared_types.go b/api/v1alpha1/shared_types.go index 09fad0431a..036e8b7a47 100644 --- a/api/v1alpha1/shared_types.go +++ b/api/v1alpha1/shared_types.go @@ -80,6 +80,8 @@ const ( APISchemaAnthropic APISchema = "Anthropic" // APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock. // Uses the native Anthropic Messages API format for requests and responses. + // When used with /v1/chat/completions endpoint, translates OpenAI format to Anthropic. + // When used with /v1/messages endpoint, passes through native Anthropic format. // // https://aws.amazon.com/bedrock/anthropic/ // https://docs.claude.com/en/api/claude-on-amazon-bedrock diff --git a/examples/basic/aws-bedrock-openai-anthropic.yaml b/examples/basic/aws-bedrock-openai-anthropic.yaml new file mode 100644 index 0000000000..bf36817dcf --- /dev/null +++ b/examples/basic/aws-bedrock-openai-anthropic.yaml @@ -0,0 +1,109 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + +# This example demonstrates using the AWSAnthropic schema to access +# Claude models on AWS Bedrock via the InvokeModel API with OpenAI-compatible requests. +# +# The AWSAnthropic schema works with both input formats: +# - /v1/chat/completions: Translates OpenAI ChatCompletion requests to Anthropic Messages API format +# - /v1/messages: Passes through native Anthropic Messages API format +# +# Use cases: +# - When you want to use OpenAI SDK/format with Claude models on AWS Bedrock +# - When migrating from OpenAI to Claude on AWS without changing client code +# - When using tools that only support OpenAI format but need Claude on AWS + +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: AIGatewayRoute +metadata: + name: envoy-ai-gateway-aws-bedrock-claude-openai-format + namespace: default +spec: + parentRefs: + - name: envoy-ai-gateway-basic + kind: Gateway + group: gateway.networking.k8s.io + rules: + - matches: + - headers: + - type: Exact + name: x-ai-eg-model + value: anthropic.claude-3-5-sonnet-20241022-v2:0 + backendRefs: + - name: envoy-ai-gateway-aws-bedrock-claude-openai +--- +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: AIServiceBackend +metadata: + name: envoy-ai-gateway-aws-bedrock-claude-openai + namespace: default +spec: + # AWSAnthropic schema supports both OpenAI and Anthropic input formats. + # The endpoint path determines the translator used. + schema: + name: AWSAnthropic + # Optional: Specify Anthropic API version for Bedrock + # Default: bedrock-2023-05-31 + version: bedrock-2023-05-31 + backendRef: + name: envoy-ai-gateway-basic-aws + kind: Backend + group: gateway.envoyproxy.io +--- +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: BackendSecurityPolicy +metadata: + name: envoy-ai-gateway-aws-bedrock-credentials + namespace: default +spec: + targetRefs: + - group: aigateway.envoyproxy.io + kind: AIServiceBackend + name: envoy-ai-gateway-aws-bedrock-claude-openai + type: AWSCredentials + awsCredentials: + region: us-east-1 + credentialsFile: + secretRef: + name: envoy-ai-gateway-basic-aws-credentials +--- +apiVersion: gateway.envoyproxy.io/v1alpha1 +kind: Backend +metadata: + name: envoy-ai-gateway-basic-aws + namespace: default +spec: + endpoints: + - fqdn: + hostname: bedrock-runtime.us-east-1.amazonaws.com + port: 443 +--- +apiVersion: gateway.networking.k8s.io/v1alpha3 +kind: BackendTLSPolicy +metadata: + name: envoy-ai-gateway-basic-aws-tls + namespace: default +spec: + targetRefs: + - group: "gateway.envoyproxy.io" + kind: Backend + name: envoy-ai-gateway-basic-aws + validation: + wellKnownCACertificates: "System" + hostname: bedrock-runtime.us-east-1.amazonaws.com +--- +apiVersion: v1 +kind: Secret +metadata: + name: envoy-ai-gateway-basic-aws-credentials + namespace: default +type: Opaque +stringData: + # Replace this with your AWS credentials. + # You can also use AWS IAM roles for service accounts (IRSA) in EKS. + credentials: | + [default] + aws_access_key_id = AWS_ACCESS_KEY_ID + aws_secret_access_key = AWS_SECRET_ACCESS_KEY diff --git a/internal/endpointspec/endpointspec.go b/internal/endpointspec/endpointspec.go index a330f0873d..209ad2a8b7 100644 --- a/internal/endpointspec/endpointspec.go +++ b/internal/endpointspec/endpointspec.go @@ -129,6 +129,8 @@ func (ChatCompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISc return translator.NewChatCompletionOpenAIToOpenAITranslator(schema.OpenAIPrefix(), modelNameOverride), nil case filterapi.APISchemaAWSBedrock: return translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride), nil + case filterapi.APISchemaAWSAnthropic: + return translator.NewChatCompletionOpenAIToAWSAnthropicTranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaAzureOpenAI: return translator.NewChatCompletionOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaGCPVertexAI: diff --git a/internal/endpointspec/endpointspec_test.go b/internal/endpointspec/endpointspec_test.go index b182760445..371c5cc998 100644 --- a/internal/endpointspec/endpointspec_test.go +++ b/internal/endpointspec/endpointspec_test.go @@ -81,6 +81,7 @@ func TestChatCompletionsEndpointSpec_GetTranslator(t *testing.T) { supported := []filterapi.VersionedAPISchema{ {Name: filterapi.APISchemaOpenAI, Prefix: "v1"}, {Name: filterapi.APISchemaAWSBedrock}, + {Name: filterapi.APISchemaAWSAnthropic}, {Name: filterapi.APISchemaAzureOpenAI, Version: "2024-02-01"}, {Name: filterapi.APISchemaGCPVertexAI}, {Name: filterapi.APISchemaGCPAnthropic, Version: "2024-05-01"}, diff --git a/internal/filterapi/filterconfig.go b/internal/filterapi/filterconfig.go index 947cce5f4f..529bb1cf59 100644 --- a/internal/filterapi/filterconfig.go +++ b/internal/filterapi/filterconfig.go @@ -114,7 +114,7 @@ const ( APISchemaOpenAI APISchemaName = "OpenAI" // APISchemaCohere represents the Cohere API schema. APISchemaCohere APISchemaName = "Cohere" - // APISchemaAWSBedrock represents the AWS Bedrock API schema. + // APISchemaAWSBedrock represents the AWS Bedrock Converse API schema. APISchemaAWSBedrock APISchemaName = "AWSBedrock" // APISchemaAzureOpenAI represents the Azure OpenAI API schema. APISchemaAzureOpenAI APISchemaName = "AzureOpenAI" @@ -127,7 +127,8 @@ const ( // APISchemaAnthropic represents the standard Anthropic API schema. APISchemaAnthropic APISchemaName = "Anthropic" // APISchemaAWSAnthropic represents the AWS Bedrock Anthropic API schema. - // Used for Claude models hosted on AWS Bedrock using the native Anthropic Messages API. + // Used for Claude models hosted on AWS Bedrock. Supports both OpenAI and Anthropic input formats + // depending on the endpoint path, similar to APISchemaGCPAnthropic. APISchemaAWSAnthropic APISchemaName = "AWSAnthropic" ) diff --git a/internal/translator/anthropic_helper.go b/internal/translator/anthropic_helper.go new file mode 100644 index 0000000000..3798c88469 --- /dev/null +++ b/internal/translator/anthropic_helper.go @@ -0,0 +1,1133 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "cmp" + "encoding/base64" + "fmt" + "io" + "strings" + "time" + + "github.com/anthropics/anthropic-sdk-go" + anthropicParam "github.com/anthropics/anthropic-sdk-go/packages/param" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + openAIconstant "github.com/openai/openai-go/shared/constant" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/json" + "github.com/envoyproxy/ai-gateway/internal/metrics" + "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" +) + +const ( + anthropicVersionKey = "anthropic_version" + tempNotSupportedError = "temperature %.2f is not supported by Anthropic (must be between 0.0 and 1.0)" +) + +func anthropicToOpenAIFinishReason(stopReason anthropic.StopReason) (openai.ChatCompletionChoicesFinishReason, error) { + switch stopReason { + // The most common stop reason. Indicates Claude finished its response naturally. + // or Claude encountered one of your custom stop sequences. + // TODO: A better way to return pause_turn + // TODO: "pause_turn" Used with server tools like web search when Claude needs to pause a long-running operation. + case anthropic.StopReasonEndTurn, anthropic.StopReasonStopSequence, anthropic.StopReasonPauseTurn: + return openai.ChatCompletionChoicesFinishReasonStop, nil + case anthropic.StopReasonMaxTokens: // Claude stopped because it reached the max_tokens limit specified in your request. + // TODO: do we want to return an error? see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#handling-the-max-tokens-stop-reason + return openai.ChatCompletionChoicesFinishReasonLength, nil + case anthropic.StopReasonToolUse: + return openai.ChatCompletionChoicesFinishReasonToolCalls, nil + case anthropic.StopReasonRefusal: + return openai.ChatCompletionChoicesFinishReasonContentFilter, nil + default: + return "", fmt.Errorf("received invalid stop reason %v", stopReason) + } +} + +// validateTemperatureForAnthropic checks if the temperature is within Anthropic's supported range (0.0 to 1.0). +// Returns an error if the value is greater than 1.0. +func validateTemperatureForAnthropic(temp *float64) error { + if temp != nil && (*temp < 0.0 || *temp > 1.0) { + return fmt.Errorf("%w: "+tempNotSupportedError, internalapi.ErrInvalidRequestBody, *temp) + } + return nil +} + +// translateAnthropicToolChoice converts the OpenAI tool_choice parameter to the Anthropic format. +func translateAnthropicToolChoice(openAIToolChoice *openai.ChatCompletionToolChoiceUnion, disableParallelToolUse anthropicParam.Opt[bool]) (anthropic.ToolChoiceUnionParam, error) { + var toolChoice anthropic.ToolChoiceUnionParam + + if openAIToolChoice == nil { + return toolChoice, nil + } + + switch choice := openAIToolChoice.Value.(type) { + case string: + switch choice { + case string(openAIconstant.ValueOf[openAIconstant.Auto]()): + toolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}} + toolChoice.OfAuto.DisableParallelToolUse = disableParallelToolUse + case "required", "any": + toolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}} + toolChoice.OfAny.DisableParallelToolUse = disableParallelToolUse + case "none": + toolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}} + case string(openAIconstant.ValueOf[openAIconstant.Function]()): + // This is how anthropic forces tool use. + // TODO: should we check if strict true in openAI request, and if so, use this? + toolChoice = anthropic.ToolChoiceUnionParam{OfTool: &anthropic.ToolChoiceToolParam{Name: choice}} + toolChoice.OfTool.DisableParallelToolUse = disableParallelToolUse + default: + return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("unsupported tool_choice value: %s", choice) + } + case openai.ChatCompletionNamedToolChoice: + if choice.Type == openai.ToolTypeFunction && choice.Function.Name != "" { + toolChoice = anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{ + Type: constant.Tool("tool"), + Name: choice.Function.Name, + DisableParallelToolUse: disableParallelToolUse, + }, + } + } + default: + return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("unsupported tool_choice type: %T", openAIToolChoice) + } + return toolChoice, nil +} + +func isAnthropicSupportedImageMediaType(mediaType string) bool { + switch anthropic.Base64ImageSourceMediaType(mediaType) { + case anthropic.Base64ImageSourceMediaTypeImageJPEG, + anthropic.Base64ImageSourceMediaTypeImagePNG, + anthropic.Base64ImageSourceMediaTypeImageGIF, + anthropic.Base64ImageSourceMediaTypeImageWebP: + return true + default: + return false + } +} + +// translateOpenAItoAnthropicTools translates OpenAI tool and tool_choice parameters +// into the Anthropic format and returns translated tool & tool choice. +func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice *openai.ChatCompletionToolChoiceUnion, parallelToolCalls *bool) (tools []anthropic.ToolUnionParam, toolChoice anthropic.ToolChoiceUnionParam, err error) { + if len(openAITools) > 0 { + anthropicTools := make([]anthropic.ToolUnionParam, 0, len(openAITools)) + for _, openAITool := range openAITools { + if openAITool.Type != openai.ToolTypeFunction || openAITool.Function == nil { + // Anthropic only supports 'function' tools, so we skip others. + continue + } + toolParam := anthropic.ToolParam{ + Name: openAITool.Function.Name, + Description: anthropic.String(openAITool.Function.Description), + } + + if isCacheEnabled(openAITool.Function.AnthropicContentFields) { + toolParam.CacheControl = anthropic.NewCacheControlEphemeralParam() + } + + // The parameters for the function are expected to be a JSON Schema object. + // We can pass them through as-is. + if openAITool.Function.Parameters != nil { + paramsMap, ok := openAITool.Function.Parameters.(map[string]any) + if !ok { + err = fmt.Errorf("failed to cast tool parameters to map[string]interface{}") + return + } + + inputSchema := anthropic.ToolInputSchemaParam{} + + // Dereference json schema + // If the paramsMap contains $refs we need to dereference them + var dereferencedParamsMap any + if dereferencedParamsMap, err = jsonSchemaDereference(paramsMap); err != nil { + return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to dereference tool parameters: %w", err) + } + if paramsMap, ok = dereferencedParamsMap.(map[string]any); !ok { + return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to cast dereferenced tool parameters to map[string]interface{}") + } + + var typeVal string + if typeVal, ok = paramsMap["type"].(string); ok { + inputSchema.Type = constant.Object(typeVal) + } + + var propsVal map[string]any + if propsVal, ok = paramsMap["properties"].(map[string]any); ok { + inputSchema.Properties = propsVal + } + + var requiredVal []any + if requiredVal, ok = paramsMap["required"].([]any); ok { + requiredSlice := make([]string, len(requiredVal)) + for i, v := range requiredVal { + if s, ok := v.(string); ok { + requiredSlice[i] = s + } + } + inputSchema.Required = requiredSlice + } + + toolParam.InputSchema = inputSchema + } + + anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &toolParam}) + if len(anthropicTools) > 0 { + tools = anthropicTools + } + } + + // 2. Handle the tool_choice parameter. + // disable parallel tool use default value is false + // see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use + disableParallelToolUse := anthropic.Bool(false) + if parallelToolCalls != nil { + // OpenAI variable checks to allow parallel tool calls. + // Anthropic variable checks to disable, so need to use the inverse. + disableParallelToolUse = anthropic.Bool(!*parallelToolCalls) + } + + toolChoice, err = translateAnthropicToolChoice(openAIToolChoice, disableParallelToolUse) + if err != nil { + return + } + + } + return +} + +// convertImageContentToAnthropic translates an OpenAI image URL into the corresponding Anthropic content block. +// It handles data URIs for various image types and PDFs, as well as remote URLs. +func convertImageContentToAnthropic(imageURL string, fields *openai.AnthropicContentFields) (anthropic.ContentBlockParamUnion, error) { + var cacheControlParam anthropic.CacheControlEphemeralParam + if isCacheEnabled(fields) { + cacheControlParam = fields.CacheControl + } + + switch { + case strings.HasPrefix(imageURL, "data:"): + contentType, data, err := parseDataURI(imageURL) + if err != nil { + return anthropic.ContentBlockParamUnion{}, fmt.Errorf("failed to parse image URL: %w", err) + } + base64Data := base64.StdEncoding.EncodeToString(data) + if contentType == string(constant.ValueOf[constant.ApplicationPDF]()) { + pdfSource := anthropic.Base64PDFSourceParam{Data: base64Data} + docBlock := anthropic.NewDocumentBlock(pdfSource) + docBlock.OfDocument.CacheControl = cacheControlParam + return docBlock, nil + } + if isAnthropicSupportedImageMediaType(contentType) { + imgBlock := anthropic.NewImageBlockBase64(contentType, base64Data) + imgBlock.OfImage.CacheControl = cacheControlParam + return imgBlock, nil + } + return anthropic.ContentBlockParamUnion{}, fmt.Errorf("invalid media_type for image '%s'", contentType) + case strings.HasSuffix(strings.ToLower(imageURL), ".pdf"): + docBlock := anthropic.NewDocumentBlock(anthropic.URLPDFSourceParam{URL: imageURL}) + docBlock.OfDocument.CacheControl = cacheControlParam + return docBlock, nil + default: + imgBlock := anthropic.NewImageBlock(anthropic.URLImageSourceParam{URL: imageURL}) + imgBlock.OfImage.CacheControl = cacheControlParam + return imgBlock, nil + } +} + +func isCacheEnabled(fields *openai.AnthropicContentFields) bool { + return fields != nil && fields.CacheControl.Type == constant.ValueOf[constant.Ephemeral]() +} + +// convertContentPartsToAnthropic iterates over a slice of OpenAI content parts +// and converts each into an Anthropic content block. +func convertContentPartsToAnthropic(parts []openai.ChatCompletionContentPartUserUnionParam) ([]anthropic.ContentBlockParamUnion, error) { + resultContent := make([]anthropic.ContentBlockParamUnion, 0, len(parts)) + for _, contentPart := range parts { + switch { + case contentPart.OfText != nil: + textBlock := anthropic.NewTextBlock(contentPart.OfText.Text) + if isCacheEnabled(contentPart.OfText.AnthropicContentFields) { + textBlock.OfText.CacheControl = contentPart.OfText.CacheControl + } + resultContent = append(resultContent, textBlock) + + case contentPart.OfImageURL != nil: + block, err := convertImageContentToAnthropic(contentPart.OfImageURL.ImageURL.URL, contentPart.OfImageURL.AnthropicContentFields) + if err != nil { + return nil, err + } + resultContent = append(resultContent, block) + + case contentPart.OfInputAudio != nil: + return nil, fmt.Errorf("input audio content not supported yet") + case contentPart.OfFile != nil: + return nil, fmt.Errorf("file content not supported yet") + } + } + return resultContent, nil +} + +// Helper: Convert OpenAI message content to Anthropic content. +func openAIToAnthropicContent(content any) ([]anthropic.ContentBlockParamUnion, error) { + switch v := content.(type) { + case nil: + return nil, nil + case string: + if v == "" { + return nil, nil + } + return []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock(v), + }, nil + case []openai.ChatCompletionContentPartUserUnionParam: + return convertContentPartsToAnthropic(v) + case openai.ContentUnion: + switch val := v.Value.(type) { + case string: + if val == "" { + return nil, nil + } + return []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock(val), + }, nil + case []openai.ChatCompletionContentPartTextParam: + var contentBlocks []anthropic.ContentBlockParamUnion + for _, part := range val { + textBlock := anthropic.NewTextBlock(part.Text) + // In an array of text parts, each can have its own cache setting. + if isCacheEnabled(part.AnthropicContentFields) { + textBlock.OfText.CacheControl = part.CacheControl + } + contentBlocks = append(contentBlocks, textBlock) + } + return contentBlocks, nil + default: + return nil, fmt.Errorf("unsupported ContentUnion value type: %T", val) + } + } + return nil, fmt.Errorf("unsupported OpenAI content type: %T", content) +} + +// extractSystemPromptFromDeveloperMsg flattens content and checks for cache flags. +// It returns the combined string and a boolean indicating if any part was cacheable. +func extractSystemPromptFromDeveloperMsg(msg openai.ChatCompletionDeveloperMessageParam) (msgValue string, cacheParam *anthropic.CacheControlEphemeralParam) { + switch v := msg.Content.Value.(type) { + case nil: + return + case string: + msgValue = v + return + case []openai.ChatCompletionContentPartTextParam: + // Concatenate all text parts and check for caching. + var sb strings.Builder + for _, part := range v { + sb.WriteString(part.Text) + if isCacheEnabled(part.AnthropicContentFields) { + cacheParam = &part.CacheControl + } + } + msgValue = sb.String() + return + default: + return + } +} + +func anthropicRoleToOpenAIRole(role anthropic.MessageParamRole) (string, error) { + switch role { + case anthropic.MessageParamRoleAssistant: + return openai.ChatMessageRoleAssistant, nil + case anthropic.MessageParamRoleUser: + return openai.ChatMessageRoleUser, nil + default: + return "", fmt.Errorf("invalid anthropic role %v", role) + } +} + +// processAssistantContent processes a single assistant content block and adds it to the content blocks. +func processAssistantContent(contentBlocks []anthropic.ContentBlockParamUnion, content openai.ChatCompletionAssistantMessageParamContent) ([]anthropic.ContentBlockParamUnion, error) { + switch content.Type { + case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: + if content.Refusal != nil { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(*content.Refusal)) + } + case openai.ChatCompletionAssistantMessageParamContentTypeText: + if content.Text != nil { + textBlock := anthropic.NewTextBlock(*content.Text) + if isCacheEnabled(content.AnthropicContentFields) { + textBlock.OfText.CacheControl = content.CacheControl + } + contentBlocks = append(contentBlocks, textBlock) + } + case openai.ChatCompletionAssistantMessageParamContentTypeThinking: + // Thinking content requires both text and signature + if content.Text != nil && content.Signature != nil { + contentBlocks = append(contentBlocks, anthropic.NewThinkingBlock(*content.Signature, *content.Text)) + } + case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: + if content.RedactedContent != nil { + switch v := content.RedactedContent.Value.(type) { + case string: + contentBlocks = append(contentBlocks, anthropic.NewRedactedThinkingBlock(v)) + default: + return nil, fmt.Errorf("unsupported RedactedContent type: %T, expected string", v) + } + } + default: + return nil, fmt.Errorf("content type not supported: %v", content.Type) + } + return contentBlocks, nil +} + +// openAIMessageToAnthropicMessageRoleAssistant converts an OpenAI assistant message to Anthropic content blocks. +// The tool_use content is appended to the Anthropic message content list if tool_calls are present. +func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatCompletionAssistantMessageParam) (anthropicMsg anthropic.MessageParam, err error) { + contentBlocks := make([]anthropic.ContentBlockParamUnion, 0) + if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(v)) + } else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok { + contentBlocks, err = processAssistantContent(contentBlocks, content) + if err != nil { + return + } + } else if contents, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok { + for _, content := range contents { + contentBlocks, err = processAssistantContent(contentBlocks, content) + if err != nil { + return + } + } + } + + // Handle tool_calls (if any). + for i := range openAiMessage.ToolCalls { + toolCall := &openAiMessage.ToolCalls[i] + var input map[string]any + if err = json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + err = fmt.Errorf("failed to unmarshal tool call arguments: %w", err) + return + } + toolUse := anthropic.ToolUseBlockParam{ + ID: *toolCall.ID, + Type: "tool_use", + Name: toolCall.Function.Name, + Input: input, + } + + if isCacheEnabled(toolCall.AnthropicContentFields) { + toolUse.CacheControl = toolCall.CacheControl + } + + contentBlocks = append(contentBlocks, anthropic.ContentBlockParamUnion{OfToolUse: &toolUse}) + } + + return anthropic.MessageParam{ + Role: anthropic.MessageParamRoleAssistant, + Content: contentBlocks, + }, nil +} + +// openAIToAnthropicMessages converts OpenAI messages to Anthropic message params type, handling all roles and system/developer logic. +func openAIToAnthropicMessages(openAIMsgs []openai.ChatCompletionMessageParamUnion) (anthropicMessages []anthropic.MessageParam, systemBlocks []anthropic.TextBlockParam, err error) { + for i := 0; i < len(openAIMsgs); { + msg := &openAIMsgs[i] + switch { + case msg.OfSystem != nil: + devParam := systemMsgToDeveloperMsg(*msg.OfSystem) + systemText, cacheControl := extractSystemPromptFromDeveloperMsg(devParam) + systemBlock := anthropic.TextBlockParam{Text: systemText} + if cacheControl != nil { + systemBlock.CacheControl = *cacheControl + } + systemBlocks = append(systemBlocks, systemBlock) + i++ + case msg.OfDeveloper != nil: + systemText, cacheControl := extractSystemPromptFromDeveloperMsg(*msg.OfDeveloper) + systemBlock := anthropic.TextBlockParam{Text: systemText} + if cacheControl != nil { + systemBlock.CacheControl = *cacheControl + } + systemBlocks = append(systemBlocks, systemBlock) + i++ + case msg.OfUser != nil: + message := *msg.OfUser + var content []anthropic.ContentBlockParamUnion + content, err = openAIToAnthropicContent(message.Content.Value) + if err != nil { + return + } + anthropicMsg := anthropic.MessageParam{ + Role: anthropic.MessageParamRoleUser, + Content: content, + } + anthropicMessages = append(anthropicMessages, anthropicMsg) + i++ + case msg.OfAssistant != nil: + assistantMessage := msg.OfAssistant + var messages anthropic.MessageParam + messages, err = openAIMessageToAnthropicMessageRoleAssistant(assistantMessage) + if err != nil { + return + } + anthropicMessages = append(anthropicMessages, messages) + i++ + case msg.OfTool != nil: + // Aggregate all consecutive tool messages into a single user message + // to support parallel tool use. + var toolResultBlocks []anthropic.ContentBlockParamUnion + for i < len(openAIMsgs) && openAIMsgs[i].ExtractMessgaeRole() == openai.ChatMessageRoleTool { + currentMsg := &openAIMsgs[i] + toolMsg := currentMsg.OfTool + + var contentBlocks []anthropic.ContentBlockParamUnion + contentBlocks, err = openAIToAnthropicContent(toolMsg.Content) + if err != nil { + return + } + + var toolContent []anthropic.ToolResultBlockParamContentUnion + var cacheControl *anthropic.CacheControlEphemeralParam + + for _, c := range contentBlocks { + var trb anthropic.ToolResultBlockParamContentUnion + // Check if the translated part has caching enabled. + switch { + case c.OfText != nil: + trb.OfText = c.OfText + cacheControl = &c.OfText.CacheControl + case c.OfImage != nil: + trb.OfImage = c.OfImage + cacheControl = &c.OfImage.CacheControl + case c.OfDocument != nil: + trb.OfDocument = c.OfDocument + cacheControl = &c.OfDocument.CacheControl + } + toolContent = append(toolContent, trb) + } + + isError := false + if contentStr, ok := toolMsg.Content.Value.(string); ok { + var contentMap map[string]any + if json.Unmarshal([]byte(contentStr), &contentMap) == nil { + if _, ok = contentMap["error"]; ok { + isError = true + } + } + } + + toolResultBlock := anthropic.ToolResultBlockParam{ + ToolUseID: toolMsg.ToolCallID, + Type: "tool_result", + Content: toolContent, + IsError: anthropic.Bool(isError), + } + + if cacheControl != nil { + toolResultBlock.CacheControl = *cacheControl + } + + toolResultBlockUnion := anthropic.ContentBlockParamUnion{OfToolResult: &toolResultBlock} + toolResultBlocks = append(toolResultBlocks, toolResultBlockUnion) + i++ + } + // Append all aggregated tool results. + anthropicMsg := anthropic.MessageParam{ + Role: anthropic.MessageParamRoleUser, + Content: toolResultBlocks, + } + anthropicMessages = append(anthropicMessages, anthropicMsg) + default: + err = fmt.Errorf("unsupported OpenAI role type: %s", msg.ExtractMessgaeRole()) + return + } + } + return +} + +// NewThinkingConfigParamUnion converts a ThinkingUnion into a ThinkingConfigParamUnion. +func getThinkingConfigParamUnion(tu *openai.ThinkingUnion) *anthropic.ThinkingConfigParamUnion { + if tu == nil { + return nil + } + + result := &anthropic.ThinkingConfigParamUnion{} + + if tu.OfEnabled != nil { + result.OfEnabled = &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: tu.OfEnabled.BudgetTokens, + Type: constant.Enabled(tu.OfEnabled.Type), + } + } else if tu.OfDisabled != nil { + result.OfDisabled = &anthropic.ThinkingConfigDisabledParam{ + Type: constant.Disabled(tu.OfDisabled.Type), + } + } + + return result +} + +// buildAnthropicParams is a helper function that translates an OpenAI request +// into the parameter struct required by the Anthropic SDK. +func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anthropic.MessageNewParams, err error) { + // 1. Handle simple parameters and defaults. + maxTokens := cmp.Or(openAIReq.MaxCompletionTokens, openAIReq.MaxTokens) + if maxTokens == nil { + err = fmt.Errorf("%w: max_tokens or max_completion_tokens is required", internalapi.ErrInvalidRequestBody) + return + } + + // Translate openAI contents to anthropic params. + // 2. Translate messages and system prompts. + messages, systemBlocks, err := openAIToAnthropicMessages(openAIReq.Messages) + if err != nil { + return + } + + // 3. Translate tools and tool choice. + tools, toolChoice, err := translateOpenAItoAnthropicTools(openAIReq.Tools, openAIReq.ToolChoice, openAIReq.ParallelToolCalls) + if err != nil { + return + } + + // 4. Construct the final struct in one place. + params = &anthropic.MessageNewParams{ + Messages: messages, + MaxTokens: *maxTokens, + System: systemBlocks, + Tools: tools, + ToolChoice: toolChoice, + } + + if openAIReq.Temperature != nil { + if err = validateTemperatureForAnthropic(openAIReq.Temperature); err != nil { + return nil, err + } + params.Temperature = anthropic.Float(*openAIReq.Temperature) + } + if openAIReq.TopP != nil { + params.TopP = anthropic.Float(*openAIReq.TopP) + } + if openAIReq.Stop.OfString.Valid() { + params.StopSequences = []string{openAIReq.Stop.OfString.String()} + } else if openAIReq.Stop.OfStringArray != nil { + params.StopSequences = openAIReq.Stop.OfStringArray + } + + // 5. Handle Vendor specific fields. + // Since GCPAnthropic follows the Anthropic API, we also check for Anthropic vendor fields. + if openAIReq.Thinking != nil { + params.Thinking = *getThinkingConfigParamUnion(openAIReq.Thinking) + } + + return params, nil +} + +// anthropicToolUseToOpenAICalls converts Anthropic tool_use content blocks to OpenAI tool calls. +func anthropicToolUseToOpenAICalls(block *anthropic.ContentBlockUnion) ([]openai.ChatCompletionMessageToolCallParam, error) { + var toolCalls []openai.ChatCompletionMessageToolCallParam + if block.Type != string(constant.ValueOf[constant.ToolUse]()) { + return toolCalls, nil + } + argsBytes, err := json.Marshal(block.Input) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool_use input: %w", err) + } + toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: &block.ID, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: block.Name, + Arguments: string(argsBytes), + }, + }) + + return toolCalls, nil +} + +// following are streaming part + +var ( + sseEventPrefix = []byte("event: ") + emptyStrPtr = ptr.To("") +) + +// streamingToolCall holds the state for a single tool call that is being streamed. +type streamingToolCall struct { + id string + name string + inputJSON string +} + +// anthropicStreamParser manages the stateful translation of an Anthropic SSE stream +// to an OpenAI-compatible SSE stream. +type anthropicStreamParser struct { + buffer bytes.Buffer + activeMessageID string + activeToolCalls map[int64]*streamingToolCall + toolIndex int64 + tokenUsage metrics.TokenUsage + stopReason anthropic.StopReason + requestModel internalapi.RequestModel + sentFirstChunk bool + created openai.JSONUNIXTime +} + +// newAnthropicStreamParser creates a new parser for a streaming request. +func newAnthropicStreamParser(requestModel string) *anthropicStreamParser { + toolIdx := int64(-1) + return &anthropicStreamParser{ + requestModel: requestModel, + activeToolCalls: make(map[int64]*streamingToolCall), + toolIndex: toolIdx, + } +} + +func (p *anthropicStreamParser) writeChunk(eventBlock []byte, buf *[]byte) error { + chunk, err := p.parseAndHandleEvent(eventBlock) + if err != nil { + return err + } + if chunk != nil { + err := serializeOpenAIChatCompletionChunk(chunk, buf) + if err != nil { + return err + } + } + return nil +} + +// Process reads from the Anthropic SSE stream, translates events to OpenAI chunks, +// and returns the mutations for Envoy. +func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( + newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, +) { + newBody = make([]byte, 0) + _ = span // TODO: add support for streaming chunks in tracing. + responseModel = p.requestModel + if _, err = p.buffer.ReadFrom(body); err != nil { + err = fmt.Errorf("failed to read from stream body: %w", err) + return + } + + for { + eventBlock, remaining, found := bytes.Cut(p.buffer.Bytes(), []byte("\n\n")) + if !found { + break + } + + if err = p.writeChunk(eventBlock, &newBody); err != nil { + return + } + + p.buffer.Reset() + p.buffer.Write(remaining) + } + + if endOfStream && p.buffer.Len() > 0 { + finalEventBlock := p.buffer.Bytes() + p.buffer.Reset() + + if err = p.writeChunk(finalEventBlock, &newBody); err != nil { + return + } + } + + if endOfStream { + inputTokens, _ := p.tokenUsage.InputTokens() + outputTokens, _ := p.tokenUsage.OutputTokens() + p.tokenUsage.SetTotalTokens(inputTokens + outputTokens) + totalTokens, _ := p.tokenUsage.TotalTokens() + cachedTokens, _ := p.tokenUsage.CachedInputTokens() + cacheCreationTokens, _ := p.tokenUsage.CacheCreationInputTokens() + finalChunk := openai.ChatCompletionResponseChunk{ + ID: p.activeMessageID, + Created: p.created, + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{}, + Usage: &openai.Usage{ + PromptTokens: int(inputTokens), + CompletionTokens: int(outputTokens), + TotalTokens: int(totalTokens), + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: int(cachedTokens), + CacheCreationTokens: int(cacheCreationTokens), + }, + }, + Model: p.requestModel, + } + + // Add active tool calls to the final chunk. + var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall + for toolIndex, tool := range p.activeToolCalls { + toolCalls = append(toolCalls, openai.ChatCompletionChunkChoiceDeltaToolCall{ + ID: &tool.id, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: tool.name, + Arguments: tool.inputJSON, + }, + Index: toolIndex, + }) + } + + if len(toolCalls) > 0 { + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: toolCalls, + } + finalChunk.Choices = append(finalChunk.Choices, openai.ChatCompletionResponseChunkChoice{ + Delta: &delta, + }) + } + + if finalChunk.Usage.PromptTokens > 0 || finalChunk.Usage.CompletionTokens > 0 || len(finalChunk.Choices) > 0 { + err := serializeOpenAIChatCompletionChunk(&finalChunk, &newBody) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal final stream chunk: %w", err) + } + } + // Add the final [DONE] message to indicate the end of the stream. + newBody = append(newBody, sseDataPrefix...) + newBody = append(newBody, sseDoneMessage...) + newBody = append(newBody, '\n', '\n') + } + tokenUsage = p.tokenUsage + return +} + +func (p *anthropicStreamParser) parseAndHandleEvent(eventBlock []byte) (*openai.ChatCompletionResponseChunk, error) { + var eventType []byte + var eventData []byte + + lines := bytes.SplitSeq(eventBlock, []byte("\n")) + for line := range lines { + if after, ok := bytes.CutPrefix(line, sseEventPrefix); ok { + eventType = bytes.TrimSpace(after) + } else if after, ok := bytes.CutPrefix(line, sseDataPrefix); ok { + // This handles JSON data that might be split across multiple 'data:' lines + // by concatenating them (Anthropic's format). + data := bytes.TrimSpace(after) + eventData = append(eventData, data...) + } + } + + if len(eventType) > 0 && len(eventData) > 0 { + return p.handleAnthropicStreamEvent(eventType, eventData) + } + + return nil, nil +} + +func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, data []byte) (*openai.ChatCompletionResponseChunk, error) { + switch string(eventType) { + case string(constant.ValueOf[constant.MessageStart]()): + var event anthropic.MessageStartEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_start: %w", err) + } + p.activeMessageID = event.Message.ID + p.created = openai.JSONUNIXTime(time.Now()) + u := event.Message.Usage + usage := metrics.ExtractTokenUsageFromExplicitCaching( + u.InputTokens, + u.OutputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, + ) + // For message_start, we store the initial usage but don't add to the accumulated + // The message_delta event will contain the final totals + if input, ok := usage.InputTokens(); ok { + p.tokenUsage.SetInputTokens(input) + } + if cached, ok := usage.CachedInputTokens(); ok { + p.tokenUsage.SetCachedInputTokens(cached) + } + + // reset the toolIndex for each message + p.toolIndex = -1 + return nil, nil + + case string(constant.ValueOf[constant.ContentBlockStart]()): + var event anthropic.ContentBlockStartEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err) + } + if event.ContentBlock.Type == string(constant.ValueOf[constant.ToolUse]()) || event.ContentBlock.Type == string(constant.ValueOf[constant.ServerToolUse]()) { + p.toolIndex++ + var argsJSON string + // Check if the input field is provided directly in the start event. + if event.ContentBlock.Input != nil { + switch input := event.ContentBlock.Input.(type) { + case map[string]any: + // for case where "input":{}, skip adding it to arguments. + if len(input) > 0 { + argsBytes, err := json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool use input: %w", err) + } + argsJSON = string(argsBytes) + } + default: + // although golang sdk defines type of Input to be any, + // python sdk requires the type of Input to be Dict[str, object]: + // https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_use_block.py#L14. + return nil, fmt.Errorf("unexpected tool use input type: %T", input) + } + } + + // Store the complete input JSON in our state. + p.activeToolCalls[p.toolIndex] = &streamingToolCall{ + id: event.ContentBlock.ID, + name: event.ContentBlock.Name, + inputJSON: argsJSON, + } + + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ + { + Index: p.toolIndex, + ID: &event.ContentBlock.ID, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: event.ContentBlock.Name, + // Include the arguments if they are available. + Arguments: argsJSON, + }, + }, + }, + } + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + if event.ContentBlock.Type == string(constant.ValueOf[constant.Thinking]()) { + delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: emptyStrPtr} + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + + if event.ContentBlock.Type == string(constant.ValueOf[constant.RedactedThinking]()) { + // This is a latency-hiding event, ignore it. + return nil, nil + } + + return nil, nil + + case string(constant.ValueOf[constant.MessageDelta]()): + var event anthropic.MessageDeltaEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_delta: %w", err) + } + u := event.Usage + usage := metrics.ExtractTokenUsageFromExplicitCaching( + u.InputTokens, + u.OutputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, + ) + // For message_delta, accumulate the incremental output tokens + if output, ok := usage.OutputTokens(); ok { + p.tokenUsage.AddOutputTokens(output) + } + // Update input tokens to include any cache tokens from delta + if cached, ok := usage.CachedInputTokens(); ok { + p.tokenUsage.AddInputTokens(cached) + // Accumulate any additional cache tokens from delta + p.tokenUsage.AddCachedInputTokens(cached) + } + if event.Delta.StopReason != "" { + p.stopReason = event.Delta.StopReason + } + return nil, nil + + case string(constant.ValueOf[constant.ContentBlockDelta]()): + var event anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal content_block_delta: %w", err) + } + switch event.Delta.Type { + case string(constant.ValueOf[constant.TextDelta]()), string(constant.ValueOf[constant.ThinkingDelta]()): + // Treat thinking_delta just like a text_delta. + delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text} + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + case string(constant.ValueOf[constant.InputJSONDelta]()): + tool, ok := p.activeToolCalls[p.toolIndex] + if !ok { + return nil, fmt.Errorf("received input_json_delta for unknown tool at index %d", p.toolIndex) + } + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ + { + Index: p.toolIndex, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Arguments: event.Delta.PartialJSON, + }, + }, + }, + } + tool.inputJSON += event.Delta.PartialJSON + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + + case string(constant.ValueOf[constant.ContentBlockStop]()): + // This event is for state cleanup, no chunk is sent. + var event anthropic.ContentBlockStopEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal content_block_stop: %w", err) + } + delete(p.activeToolCalls, p.toolIndex) + return nil, nil + + case string(constant.ValueOf[constant.MessageStop]()): + var event anthropic.MessageStopEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_stop: %w", err) + } + + if p.stopReason == "" { + p.stopReason = anthropic.StopReasonEndTurn + } + + finishReason, err := anthropicToOpenAIFinishReason(p.stopReason) + if err != nil { + return nil, err + } + return p.constructOpenAIChatCompletionChunk(openai.ChatCompletionResponseChunkChoiceDelta{}, finishReason), nil + + case string(constant.ValueOf[constant.Error]()): + var errEvent anthropic.ErrorResponse + if err := json.Unmarshal(data, &errEvent); err != nil { + return nil, fmt.Errorf("unparsable error event: %s", string(data)) + } + return nil, fmt.Errorf("anthropic stream error: %s - %s", errEvent.Error.Type, errEvent.Error.Message) + + case "ping": + // Per documentation, ping events can be ignored. + return nil, nil + } + return nil, nil +} + +// constructOpenAIChatCompletionChunk builds the stream chunk. +func (p *anthropicStreamParser) constructOpenAIChatCompletionChunk(delta openai.ChatCompletionResponseChunkChoiceDelta, finishReason openai.ChatCompletionChoicesFinishReason) *openai.ChatCompletionResponseChunk { + // Add the 'assistant' role to the very first chunk of the response. + if !p.sentFirstChunk { + // Only add the role if the delta actually contains content or a tool call. + if delta.Content != nil || len(delta.ToolCalls) > 0 { + delta.Role = openai.ChatMessageRoleAssistant + p.sentFirstChunk = true + } + } + + return &openai.ChatCompletionResponseChunk{ + ID: p.activeMessageID, + Created: p.created, + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{ + { + Delta: &delta, + FinishReason: finishReason, + }, + }, + Model: p.requestModel, + } +} + +// messageToChatCompletion is to translate from anthropic API's response Message into OpenAI API's response ChatCompletion +func messageToChatCompletion(anthropicResp *anthropic.Message, responseModel internalapi.RequestModel) (openAIResp *openai.ChatCompletionResponse, tokenUsage metrics.TokenUsage, err error) { + openAIResp = &openai.ChatCompletionResponse{ + ID: anthropicResp.ID, + Model: responseModel, + Object: string(openAIconstant.ValueOf[openAIconstant.ChatCompletion]()), + Choices: make([]openai.ChatCompletionResponseChoice, 0), + Created: openai.JSONUNIXTime(time.Now()), + } + usage := anthropicResp.Usage + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( + usage.InputTokens, + usage.OutputTokens, + &usage.CacheReadInputTokens, + &usage.CacheCreationInputTokens, + ) + inputTokens, _ := tokenUsage.InputTokens() + outputTokens, _ := tokenUsage.OutputTokens() + totalTokens, _ := tokenUsage.TotalTokens() + cachedTokens, _ := tokenUsage.CachedInputTokens() + cacheCreationTokens, _ := tokenUsage.CacheCreationInputTokens() + openAIResp.Usage = openai.Usage{ + CompletionTokens: int(outputTokens), + PromptTokens: int(inputTokens), + TotalTokens: int(totalTokens), + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: int(cachedTokens), + CacheCreationTokens: int(cacheCreationTokens), + }, + } + + finishReason, err := anthropicToOpenAIFinishReason(anthropicResp.StopReason) + if err != nil { + return nil, metrics.TokenUsage{}, err + } + + role, err := anthropicRoleToOpenAIRole(anthropic.MessageParamRole(anthropicResp.Role)) + if err != nil { + return nil, metrics.TokenUsage{}, err + } + + choice := openai.ChatCompletionResponseChoice{ + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: role}, + FinishReason: finishReason, + } + + for i := range anthropicResp.Content { // NOTE: Content structure is massive, do not range over values. + output := &anthropicResp.Content[i] + switch output.Type { + case string(constant.ValueOf[constant.ToolUse]()): + if output.ID != "" { + toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) + if toolErr != nil { + return nil, metrics.TokenUsage{}, fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) + } + choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) + } + case string(constant.ValueOf[constant.Text]()): + if output.Text != "" { + if choice.Message.Content == nil { + choice.Message.Content = &output.Text + } + } + case string(constant.ValueOf[constant.Thinking]()): + if output.Thinking != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: output.Thinking, + Signature: output.Signature, + }, + }, + }, + } + } + case string(constant.ValueOf[constant.RedactedThinking]()): + if output.Data != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + RedactedContent: []byte(output.Data), + }, + }, + } + } + } + } + openAIResp.Choices = append(openAIResp.Choices, choice) + return openAIResp, tokenUsage, nil +} diff --git a/internal/translator/anthropic_helper_test.go b/internal/translator/anthropic_helper_test.go new file mode 100644 index 0000000000..f6d13ffad6 --- /dev/null +++ b/internal/translator/anthropic_helper_test.go @@ -0,0 +1,894 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "fmt" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// mockErrorReader is a helper for testing io.Reader failures. +type mockErrorReader struct{} + +func (r *mockErrorReader) Read(_ []byte) (n int, err error) { + return 0, fmt.Errorf("mock reader error") +} + +// New test function for helper coverage. +func TestHelperFunctions(t *testing.T) { + t.Run("anthropicToOpenAIFinishReason invalid reason", func(t *testing.T) { + _, err := anthropicToOpenAIFinishReason("unknown_reason") + require.Error(t, err) + require.Contains(t, err.Error(), "received invalid stop reason") + }) + + t.Run("anthropicRoleToOpenAIRole invalid role", func(t *testing.T) { + _, err := anthropicRoleToOpenAIRole("unknown_role") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid anthropic role") + }) +} + +func TestTranslateOpenAItoAnthropicTools(t *testing.T) { + anthropicTestTool := []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + } + openaiTestTool := []openai.Tool{ + {Type: "function", Function: &openai.FunctionDefinition{Name: "get_weather"}}, + } + tests := []struct { + name string + openAIReq *openai.ChatCompletionRequest + expectedTools []anthropic.ToolUnionParam + expectedToolChoice anthropic.ToolChoiceUnionParam + expectErr bool + }{ + { + name: "auto tool choice", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{ + DisableParallelToolUse: anthropic.Bool(false), + }, + }, + }, + { + name: "any tool choice", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "any"}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAny: &anthropic.ToolChoiceAnyParam{}, + }, + }, + { + name: "specific tool choice by name", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: openai.ChatCompletionNamedToolChoice{Type: "function", Function: openai.ChatCompletionNamedToolChoiceFunction{Name: "my_func"}}}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{Type: "tool", Name: "my_func"}, + }, + }, + { + name: "tool definition", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + }, + { + name: "tool_definition_with_required_field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather with a required location", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + "unit": map[string]any{"type": "string"}, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather with a required location"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + "unit": map[string]any{"type": "string"}, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool definition with no parameters", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_time", + Description: "Get the current time", + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_time", + Description: anthropic.String("Get the current time"), + }, + }, + }, + }, + { + name: "disable parallel tool calls", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + Tools: openaiTestTool, + ParallelToolCalls: ptr.To(false), + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{ + DisableParallelToolUse: anthropic.Bool(true), + }, + }, + }, + { + name: "explicitly enable parallel tool calls", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + ParallelToolCalls: ptr.To(true), + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, + }, + }, + { + name: "default disable parallel tool calls to false (nil)", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, + }, + }, + { + name: "none tool choice", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "none"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfNone: &anthropic.ToolChoiceNoneParam{}, + }, + }, + { + name: "function tool choice", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "function"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{Name: "function"}, + }, + }, + { + name: "invalid tool choice string", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "invalid_choice"}, + }, + expectErr: true, + }, + { + name: "skips function tool with nil function definition", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: nil, // This tool has the correct type but a nil definition and should be skipped. + }, + { + Type: "function", + Function: &openai.FunctionDefinition{Name: "get_weather"}, // This is a valid tool. + }, + }, + }, + // We expect only the valid function tool to be translated. + expectedTools: []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + }, + expectErr: false, + }, + { + name: "skips non-function tools", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "retrieval", + }, + { + Type: "function", + Function: &openai.FunctionDefinition{Name: "get_weather"}, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + }, + expectErr: false, + }, + { + name: "tool definition without type field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather without type", + Parameters: map[string]any{ + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather without type"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool definition without properties field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather without properties", + Parameters: map[string]any{ + "type": "object", + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather without properties"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "unsupported tool_choice type", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: 123}, // Use an integer to trigger the default case. + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.openAIReq.ToolChoice != nil { + require.NotNil(t, toolChoice) + require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) + if tt.expectedToolChoice.GetName() != nil { + require.Equal(t, *tt.expectedToolChoice.GetName(), *toolChoice.GetName()) + } + if tt.expectedToolChoice.OfTool != nil { + require.Equal(t, tt.expectedToolChoice.OfTool.Name, toolChoice.OfTool.Name) + } + if tt.expectedToolChoice.OfAuto != nil { + require.Equal(t, tt.expectedToolChoice.OfAuto.DisableParallelToolUse, toolChoice.OfAuto.DisableParallelToolUse) + } + } + if tt.openAIReq.Tools != nil { + require.NotNil(t, tools) + require.Len(t, tools, len(tt.expectedTools)) + require.Equal(t, tt.expectedTools[0].GetName(), tools[0].GetName()) + require.Equal(t, tt.expectedTools[0].GetType(), tools[0].GetType()) + require.Equal(t, tt.expectedTools[0].GetDescription(), tools[0].GetDescription()) + if tt.expectedTools[0].GetInputSchema().Properties != nil { + require.Equal(t, tt.expectedTools[0].GetInputSchema().Properties, tools[0].GetInputSchema().Properties) + } + } + } + }) + } +} + +// TestFinishReasonTranslation covers specific cases for the anthropicToOpenAIFinishReason function. +func TestFinishReasonTranslation(t *testing.T) { + tests := []struct { + name string + input anthropic.StopReason + expectedFinishReason openai.ChatCompletionChoicesFinishReason + expectErr bool + }{ + { + name: "max tokens stop reason", + input: anthropic.StopReasonMaxTokens, + expectedFinishReason: openai.ChatCompletionChoicesFinishReasonLength, + }, + { + name: "refusal stop reason", + input: anthropic.StopReasonRefusal, + expectedFinishReason: openai.ChatCompletionChoicesFinishReasonContentFilter, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reason, err := anthropicToOpenAIFinishReason(tt.input) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedFinishReason, reason) + } + }) + } +} + +// TestToolParameterDereferencing tests the JSON schema dereferencing functionality +// for tool parameters when translating from OpenAI to GCP Anthropic. +func TestToolParameterDereferencing(t *testing.T) { + tests := []struct { + name string + openAIReq *openai.ChatCompletionRequest + expectedTools []anthropic.ToolUnionParam + expectedToolChoice anthropic.ToolChoiceUnionParam + expectErr bool + expectedErrMsg string + }{ + { + name: "tool with complex nested $ref - successful dereferencing", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "complex_tool", + Description: "Tool with complex nested references", + Parameters: map[string]any{ + "type": "object", + "$defs": map[string]any{ + "BaseType": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + "required": []any{"id"}, + }, + }, + "NestedType": map[string]any{ + "allOf": []any{ + map[string]any{"$ref": "#/$defs/BaseType"}, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + }, + "properties": map[string]any{ + "nested": map[string]any{ + "$ref": "#/$defs/NestedType", + }, + }, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "complex_tool", + Description: anthropic.String("Tool with complex nested references"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "nested": map[string]any{ + "allOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + "required": []any{"id"}, + }, + }, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "tool with invalid $ref - dereferencing error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "invalid_ref_tool", + Description: "Tool with invalid reference", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "$ref": "#/$defs/NonExistent", + }, + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to dereference tool parameters", + }, + { + name: "tool with circular $ref - dereferencing error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "circular_ref_tool", + Description: "Tool with circular reference", + Parameters: map[string]any{ + "type": "object", + "$defs": map[string]any{ + "A": map[string]any{ + "type": "object", + "properties": map[string]any{ + "b": map[string]any{ + "$ref": "#/$defs/B", + }, + }, + }, + "B": map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{ + "$ref": "#/$defs/A", + }, + }, + }, + }, + "properties": map[string]any{ + "circular": map[string]any{ + "$ref": "#/$defs/A", + }, + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to dereference tool parameters", + }, + { + name: "tool without $ref - no dereferencing needed", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "simple_tool", + Description: "Simple tool without references", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + }, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "simple_tool", + Description: anthropic.String("Simple tool without references"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{ + "type": "string", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool parameter dereferencing returns non-map type - casting error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "problematic_tool", + Description: "Tool with parameters that can't be properly dereferenced to map", + // This creates a scenario where jsonSchemaDereference might return a non-map type + // though this is a contrived example since normally the function should return map[string]any + Parameters: map[string]any{ + "$ref": "#/$defs/StringType", // This would resolve to a string, not a map + "$defs": map[string]any{ + "StringType": "not-a-map", // This would cause the casting to fail + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to cast dereferenced tool parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) + + if tt.expectErr { + require.Error(t, err) + if tt.expectedErrMsg != "" { + require.Contains(t, err.Error(), tt.expectedErrMsg) + } + return + } + + require.NoError(t, err) + + if tt.openAIReq.Tools != nil { + require.NotNil(t, tools) + require.Len(t, tools, len(tt.expectedTools)) + + for i, expectedTool := range tt.expectedTools { + actualTool := tools[i] + require.Equal(t, expectedTool.GetName(), actualTool.GetName()) + require.Equal(t, expectedTool.GetType(), actualTool.GetType()) + require.Equal(t, expectedTool.GetDescription(), actualTool.GetDescription()) + + expectedSchema := expectedTool.GetInputSchema() + actualSchema := actualTool.GetInputSchema() + + require.Equal(t, expectedSchema.Type, actualSchema.Type) + require.Equal(t, expectedSchema.Required, actualSchema.Required) + + // For properties, we'll do a deep comparison to verify dereferencing worked + if expectedSchema.Properties != nil { + require.NotNil(t, actualSchema.Properties) + require.Equal(t, expectedSchema.Properties, actualSchema.Properties) + } + } + } + + if tt.openAIReq.ToolChoice != nil { + require.NotNil(t, toolChoice) + require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) + } + }) + } +} + +// TestContentTranslationCoverage adds specific coverage for the openAIToAnthropicContent helper. +func TestContentTranslationCoverage(t *testing.T) { + tests := []struct { + name string + inputContent any + expectedContent []anthropic.ContentBlockParamUnion + expectErr bool + }{ + { + name: "nil content", + inputContent: nil, + }, + { + name: "empty string content", + inputContent: "", + }, + { + name: "pdf data uri", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "data:application/pdf;base64,dGVzdA=="}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfDocument: &anthropic.DocumentBlockParam{ + Source: anthropic.DocumentBlockParamSourceUnion{ + OfBase64: &anthropic.Base64PDFSourceParam{ + Type: constant.ValueOf[constant.Base64](), + MediaType: constant.ValueOf[constant.ApplicationPDF](), + Data: "dGVzdA==", + }, + }, + }, + }, + }, + }, + { + name: "pdf url", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/doc.pdf"}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfDocument: &anthropic.DocumentBlockParam{ + Source: anthropic.DocumentBlockParamSourceUnion{ + OfURL: &anthropic.URLPDFSourceParam{ + Type: constant.ValueOf[constant.URL](), + URL: "https://example.com/doc.pdf", + }, + }, + }, + }, + }, + }, + { + name: "image url", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/image.png"}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfImage: &anthropic.ImageBlockParam{ + Source: anthropic.ImageBlockParamSourceUnion{ + OfURL: &anthropic.URLImageSourceParam{ + Type: constant.ValueOf[constant.URL](), + URL: "https://example.com/image.png", + }, + }, + }, + }, + }, + }, + { + name: "audio content error", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{{OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{}}}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + content, err := openAIToAnthropicContent(tt.inputContent) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. + require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") + + // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. + require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") + for i, expectedBlock := range tt.expectedContent { + actualBlock := content[i] + require.Equal(t, expectedBlock.GetType(), actualBlock.GetType(), "Content block types should match") + if expectedBlock.OfDocument != nil { + require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") + require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") + + if expectedBlock.OfDocument.Source.OfBase64 != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") + require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) + } + if expectedBlock.OfDocument.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") + require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) + } + } + if expectedBlock.OfImage != nil { + require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") + require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") + + if expectedBlock.OfImage.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") + require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) + } + } + } + + for i, expectedBlock := range tt.expectedContent { + actualBlock := content[i] + if expectedBlock.OfDocument != nil { + require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") + require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") + + if expectedBlock.OfDocument.Source.OfBase64 != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") + require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) + } + if expectedBlock.OfDocument.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") + require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) + } + } + if expectedBlock.OfImage != nil { + require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") + require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") + + if expectedBlock.OfImage.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") + require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) + } + } + } + }) + } +} + +// TestSystemPromptExtractionCoverage adds specific coverage for the extractSystemPromptFromDeveloperMsg helper. +func TestSystemPromptExtractionCoverage(t *testing.T) { + tests := []struct { + name string + inputMsg openai.ChatCompletionDeveloperMessageParam + expectedPrompt string + }{ + { + name: "developer message with content parts", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ + {Type: "text", Text: "part 1"}, + {Type: "text", Text: " part 2"}, + }}, + }, + expectedPrompt: "part 1 part 2", + }, + { + name: "developer message with nil content", + inputMsg: openai.ChatCompletionDeveloperMessageParam{Content: openai.ContentUnion{Value: nil}}, + expectedPrompt: "", + }, + { + name: "developer message with string content", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: "simple string"}, + }, + expectedPrompt: "simple string", + }, + { + name: "developer message with text parts array", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ + {Type: "text", Text: "text part"}, + }}, + }, + expectedPrompt: "text part", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt, _ := extractSystemPromptFromDeveloperMsg(tt.inputMsg) + require.Equal(t, tt.expectedPrompt, prompt) + }) + } +} diff --git a/internal/translator/openai_awsanthropic.go b/internal/translator/openai_awsanthropic.go new file mode 100644 index 0000000000..3da7e8ab69 --- /dev/null +++ b/internal/translator/openai_awsanthropic.go @@ -0,0 +1,261 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/url" + "strconv" + "strings" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/json" + "github.com/envoyproxy/ai-gateway/internal/metrics" + "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" +) + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +const BedrockDefaultVersion = "bedrock-2023-05-31" + +// NewChatCompletionOpenAIToAWSAnthropicTranslator implements [Factory] for OpenAI to AWS Anthropic translation. +// This translator converts OpenAI ChatCompletion API requests to AWS Anthropic API format. +func NewChatCompletionOpenAIToAWSAnthropicTranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride) OpenAIChatCompletionTranslator { + return &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + apiVersion: apiVersion, + modelNameOverride: modelNameOverride, + } +} + +// openAIToAWSAnthropicTranslatorV1ChatCompletion translates OpenAI Chat Completions API to AWS Anthropic Claude API. +// This uses the AWS Bedrock InvokeModel and InvokeModelWithResponseStream APIs: +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +type openAIToAWSAnthropicTranslatorV1ChatCompletion struct { + apiVersion string + modelNameOverride internalapi.ModelNameOverride + streamParser *anthropicStreamParser + requestModel internalapi.RequestModel + bufferedBody []byte +} + +// RequestBody implements [OpenAIChatCompletionTranslator.RequestBody] for AWS Anthropic. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + o.requestModel = openAIReq.Model + if o.modelNameOverride != "" { + o.requestModel = o.modelNameOverride + } + + // URL encode the model name for the path to handle special characters (e.g., ARNs) + encodedModelName := url.PathEscape(o.requestModel) + + // Set the path for AWS Bedrock InvokeModel API + // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_RequestSyntax + pathTemplate := "/model/%s/invoke" + if openAIReq.Stream { + pathTemplate = "/model/%s/invoke-with-response-stream" + o.streamParser = newAnthropicStreamParser(o.requestModel) + } + + params, err := buildAnthropicParams(openAIReq) + if err != nil { + return + } + + body, err := json.Marshal(params) + if err != nil { + return + } + + // b. Set the "anthropic_version" key in the JSON body + // Using same logic as anthropic go SDK: https://github.com/anthropics/anthropic-sdk-go/blob/e252e284244755b2b2f6eef292b09d6d1e6cd989/bedrock/bedrock.go#L167 + anthropicVersion := BedrockDefaultVersion + if o.apiVersion != "" { + anthropicVersion = o.apiVersion + } + body, err = sjson.SetBytes(body, anthropicVersionKey, anthropicVersion) + if err != nil { + return + } + newBody = body + + newHeaders = []internalapi.Header{ + {pathHeaderName, fmt.Sprintf(pathTemplate, encodedModelName)}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + return +} + +// ResponseError implements [OpenAIChatCompletionTranslator.ResponseError]. +// Translate AWS Bedrock exceptions to OpenAI error type. +// The error type is stored in the "x-amzn-errortype" HTTP header for AWS error responses. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + statusCode := respHeaders[statusHeaderName] + var openaiError openai.Error + if v, ok := respHeaders[contentTypeHeaderName]; ok && strings.Contains(v, jsonContentType) { + var bedrockError awsbedrock.BedrockException + if err = json.NewDecoder(body).Decode(&bedrockError); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal error body: %w", err) + } + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: respHeaders[awsErrorTypeHeaderName], + Message: bedrockError.Message, + Code: &statusCode, + }, + } + } else { + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read error body: %w", err) + } + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsBedrockBackendError, + Message: string(buf), + Code: &statusCode, + }, + } + } + newBody, err = json.Marshal(openaiError) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal error body: %w", err) + } + newHeaders = []internalapi.Header{ + {contentTypeHeaderName, jsonContentType}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + return +} + +// ResponseHeaders implements [OpenAIChatCompletionTranslator.ResponseHeaders]. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseHeaders(_ map[string]string) ( + newHeaders []internalapi.Header, err error, +) { + if o.streamParser != nil { + newHeaders = []internalapi.Header{{contentTypeHeaderName, eventStreamContentType}} + } + return +} + +// ResponseBody implements [OpenAIChatCompletionTranslator.ResponseBody] for AWS Anthropic. +// AWS Anthropic uses deterministic model mapping without virtualization, where the requested model +// is exactly what gets executed. The response does not contain a model field, so we return +// the request model that was originally sent. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[string]string, body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( + newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, +) { + // If a stream parser was initialized, this is a streaming request. + if o.streamParser != nil { + // AWS Bedrock wraps Anthropic events in EventStream binary format + // We need to decode EventStream and extract the SSE payload + buf, readErr := io.ReadAll(body) + if readErr != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to read stream body: %w", readErr) + } + + // Buffer the data for EventStream decoding + o.bufferedBody = append(o.bufferedBody, buf...) + + // Extract Anthropic SSE from AWS EventStream wrapper + // This decodes the base64-encoded events and formats them as SSE + anthropicSSE := o.extractAnthropicSSEFromEventStream() + + // Pass the extracted SSE to the Anthropic parser + return o.streamParser.Process(bytes.NewReader(anthropicSSE), endOfStream, span) + } + + var anthropicResp anthropic.Message + if err = json.NewDecoder(body).Decode(&anthropicResp); err != nil { + return nil, nil, tokenUsage, "", fmt.Errorf("failed to unmarshal body: %w", err) + } + + responseModel = o.requestModel + if anthropicResp.Model != "" { + responseModel = string(anthropicResp.Model) + } + + openAIResp, tokenUsage, err := messageToChatCompletion(&anthropicResp, responseModel) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", err + } + + newBody, err = json.Marshal(openAIResp) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal body: %w", err) + } + + if span != nil { + span.RecordResponse(openAIResp) + } + newHeaders = []internalapi.Header{{contentLengthHeaderName, strconv.Itoa(len(newBody))}} + return +} + +// extractAnthropicSSEFromEventStream decodes AWS EventStream binary format +// and extracts Anthropic events, converting them to SSE format. +// AWS Bedrock wraps each Anthropic event as base64-encoded JSON in EventStream messages. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) extractAnthropicSSEFromEventStream() []byte { + if len(o.bufferedBody) == 0 { + return nil + } + + r := bytes.NewReader(o.bufferedBody) + dec := eventstream.NewDecoder() + var result []byte + var lastRead int64 + + for { + msg, err := dec.Decode(r, nil) + if err != nil { + // End of stream or incomplete message - keep remaining data buffered + o.bufferedBody = o.bufferedBody[lastRead:] + return result + } + + // AWS Bedrock payload format: {"bytes":"base64-encoded-json","p":"..."} + var payload struct { + Bytes string `json:"bytes"` // base64-encoded Anthropic event JSON + } + if unMarshalErr := json.Unmarshal(msg.Payload, &payload); unMarshalErr != nil || payload.Bytes == "" { + lastRead = r.Size() - int64(r.Len()) + continue + } + + // Base64 decode to get the Anthropic event JSON + decodedBytes, err := base64.StdEncoding.DecodeString(payload.Bytes) + if err != nil { + lastRead = r.Size() - int64(r.Len()) + continue + } + + // Extract the event type from JSON + // Use gjson for robust extraction even from malformed JSON + eventType := gjson.GetBytes(decodedBytes, "type").String() + + // Convert to SSE format: "event: TYPE\ndata: JSON\n\n" + // Pass through even if malformed - streamParser will detect and report errors + sseEvent := fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(decodedBytes)) + result = append(result, []byte(sseEvent)...) + + lastRead = r.Size() - int64(r.Len()) + } +} diff --git a/internal/translator/openai_awsanthropic_test.go b/internal/translator/openai_awsanthropic_test.go new file mode 100644 index 0000000000..4c983bc03c --- /dev/null +++ b/internal/translator/openai_awsanthropic_test.go @@ -0,0 +1,812 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "encoding/base64" + stdjson "encoding/json" // nolint: depguard + "fmt" + "io" + "strconv" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/json" +) + +// wrapAnthropicSSEInEventStream wraps Anthropic SSE data in AWS EventStream format. +// AWS Bedrock base64-encodes each event's JSON data (which includes the type field) and wraps it in EventStream messages. +func wrapAnthropicSSEInEventStream(sseData string) ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := eventstream.NewEncoder() + + // Parse SSE format to extract individual events + // SSE format: "event: TYPE\ndata: JSON\n\n" + events := bytes.Split([]byte(sseData), []byte("\n\n")) + + for _, eventBlock := range events { + if len(bytes.TrimSpace(eventBlock)) == 0 { + continue + } + + // Extract both event type and data from the SSE event + lines := bytes.Split(eventBlock, []byte("\n")) + var eventType string + var jsonData []byte + for _, line := range lines { + if bytes.HasPrefix(line, []byte("event: ")) { + eventType = string(bytes.TrimPrefix(line, []byte("event: "))) + } else if bytes.HasPrefix(line, []byte("data: ")) { + jsonData = bytes.TrimPrefix(line, []byte("data: ")) + } + } + + if len(jsonData) == 0 { + continue + } + + // AWS Bedrock Anthropic format includes the type in the JSON data itself + // If the JSON doesn't already have a "type" field (like in malformed test cases), + // we need to add it to match real AWS Bedrock behavior + var finalJSON []byte + if eventType != "" && !bytes.Contains(jsonData, []byte(`"type"`)) { + // Prepend the type field to simulate real Anthropic event format + // For malformed JSON, this creates something like: {"type": "message_start", {invalid...} + // which is still malformed, but has the type field that can be extracted + finalJSON = []byte(fmt.Sprintf(`{"type": "%s", %s`, eventType, string(jsonData[1:]))) + if jsonData[0] != '{' { + // If it doesn't even start with {, just wrap it + finalJSON = []byte(fmt.Sprintf(`{"type": "%s", "data": %s}`, eventType, string(jsonData))) + } + } else { + finalJSON = jsonData + } + + // Base64 encode the JSON data (this is what AWS Bedrock does) + base64Data := base64.StdEncoding.EncodeToString(finalJSON) + + // Create a payload with the base64-encoded data in the "bytes" field + payload := struct { + Bytes string `json:"bytes"` + }{ + Bytes: base64Data, + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + // Encode as EventStream message + err = encoder.Encode(buf, eventstream.Message{ + Headers: eventstream.Headers{{Name: ":event-type", Value: eventstream.StringValue("chunk")}}, + Payload: payloadBytes, + }) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +// TestResponseModel_AWSAnthropic tests that AWS Anthropic (non-streaming) returns the request model +// AWS Anthropic uses deterministic model mapping without virtualization +func TestResponseModel_AWSAnthropic(t *testing.T) { + modelName := "anthropic.claude-sonnet-4-20250514-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-sonnet-4", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // AWS Anthropic response doesn't have model field, uses Anthropic format + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello!", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + _, _, tokenUsage, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization + inputTokens, ok := tokenUsage.InputTokens() + require.True(t, ok) + require.Equal(t, uint32(10), inputTokens) + outputTokens, ok := tokenUsage.OutputTokens() + require.True(t, ok) + require.Equal(t, uint32(5), outputTokens) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T) { + // Define a common input request to use for both standard and vertex tests. + openAIReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfSystem: &openai.ChatCompletionSystemMessageParam{Content: openai.ContentUnion{Value: "You are a helpful assistant."}, Role: openai.ChatMessageRoleSystem}, + }, + { + OfUser: &openai.ChatCompletionUserMessageParam{Content: openai.StringOrUserRoleContentUnion{Value: "Hello!"}, Role: openai.ChatMessageRoleUser}, + }, + }, + MaxTokens: ptr.To(int64(1024)), + Temperature: ptr.To(0.7), + } + + t.Run("AWS Bedrock InvokeModel Values Configured Correctly", func(t *testing.T) { + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + require.NotNil(t, body) + + // Check the path header. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke", openAIReq.Model) + require.Equal(t, expectedPath, pathHeader.Value()) + + // Check the body content. + require.NotNil(t, body) + // Model should NOT be present in the body for AWS Bedrock. + require.False(t, gjson.GetBytes(body, "model").Exists()) + // Anthropic version should be present for AWS Bedrock. + require.Equal(t, BedrockDefaultVersion, gjson.GetBytes(body, "anthropic_version").String()) + }) + + t.Run("Model Name Override", func(t *testing.T) { + overrideModelName := "anthropic.claude-3-haiku-20240307-v1:0" + // Instantiate the translator with the model name override. + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", overrideModelName) + + // Call RequestBody with the original request, which has a different model name. + hm, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses the override model name. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke", overrideModelName) + require.Equal(t, expectedPath, pathHeader.Value()) + }) + + t.Run("Model Name with ARN (URL encoding)", func(t *testing.T) { + arnModelName := "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-opus-20240229-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", arnModelName) + + hm, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses URL-encoded model name. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + // url.PathEscape encodes slashes but not colons (colons are valid in URL paths) + // So we expect slashes to be encoded as %2F + require.Contains(t, pathHeader.Value(), "arn:aws:bedrock") // Colons are not encoded + require.Contains(t, pathHeader.Value(), "%2Fanthropic") // Slashes are encoded + }) + + t.Run("Streaming Request Validation", func(t *testing.T) { + streamReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: ptr.To(int64(100)), + Stream: true, + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, err := translator.RequestBody(nil, streamReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses the invoke-with-response-stream endpoint. + pathHeader := hm + require.Equal(t, pathHeaderName, pathHeader[0].Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke-with-response-stream", streamReq.Model) + require.Equal(t, expectedPath, pathHeader[0].Value()) + + // AWS Bedrock uses the endpoint path to indicate streaming (invoke-with-response-stream) + // The Anthropic Messages API body format doesn't require a "stream" field + // Verify the body is valid JSON with expected Anthropic fields + require.True(t, gjson.GetBytes(body, "max_tokens").Exists()) + require.True(t, gjson.GetBytes(body, "anthropic_version").Exists()) + }) + + t.Run("API Version Override", func(t *testing.T) { + customAPIVersion := "bedrock-2024-01-01" + // Instantiate the translator with the custom API version. + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator(customAPIVersion, "") + + // Call RequestBody with a standard request. + _, body, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, body) + + // Check that the anthropic_version in the body uses the custom version. + require.Equal(t, customAPIVersion, gjson.GetBytes(body, "anthropic_version").String()) + }) + + t.Run("Invalid Temperature (above bound)", func(t *testing.T) { + invalidTempReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: ptr.To(int64(100)), + Temperature: ptr.To(2.5), + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, err := translator.RequestBody(nil, invalidTempReq, false) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf(tempNotSupportedError, *invalidTempReq.Temperature)) + }) + + t.Run("Missing MaxTokens Throws Error", func(t *testing.T) { + missingTokensReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: nil, + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, err := translator.RequestBody(nil, missingTokensReq, false) + require.ErrorContains(t, err, "max_tokens or max_completion_tokens is required") + }) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + t.Run("invalid json body", func(t *testing.T) { + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, _, _, err := translator.ResponseBody(map[string]string{statusHeaderName: "200"}, bytes.NewBufferString("invalid json"), true, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal body") + }) + + tests := []struct { + name string + inputResponse *anthropic.Message + respHeaders map[string]string + expectedOpenAIResponse openai.ChatCompletionResponse + }{ + { + name: "basic text response", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Hello there!"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(releaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 15, + CompletionTokens: 20, + TotalTokens: 35, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 5, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: "assistant", Content: ptr.To("Hello there!")}, + FinishReason: openai.ChatCompletionChoicesFinishReasonStop, + }, + }, + }, + }, + { + name: "response with tool use", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{ + {Type: "text", Text: "Ok, I will call the tool."}, + {Type: "tool_use", ID: "toolu_01", Name: "get_weather", Input: stdjson.RawMessage(`{"location": "Tokyo", "unit": "celsius"}`)}, + }, + StopReason: anthropic.StopReasonToolUse, + Usage: anthropic.Usage{InputTokens: 25, OutputTokens: 15, CacheReadInputTokens: 10}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(releaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 35, CompletionTokens: 15, TotalTokens: 50, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 10, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + FinishReason: openai.ChatCompletionChoicesFinishReasonToolCalls, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: string(anthropic.MessageParamRoleAssistant), + Content: ptr.To("Ok, I will call the tool."), + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: ptr.To("toolu_01"), + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_weather", + Arguments: `{"location": "Tokyo", "unit": "celsius"}`, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := json.Marshal(tt.inputResponse) + require.NoError(t, err, "Test setup failed: could not marshal input struct") + + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, usedToken, _, err := translator.ResponseBody(tt.respHeaders, bytes.NewBuffer(body), true, nil) + + require.NoError(t, err, "Translator returned an unexpected internal error") + require.NotNil(t, hm) + require.NotNil(t, body) + + newBody := body + require.NotNil(t, newBody) + require.Len(t, hm, 1) + require.Equal(t, contentLengthHeaderName, hm[0].Key()) + require.Equal(t, strconv.Itoa(len(newBody)), hm[0].Value()) + + var gotResp openai.ChatCompletionResponse + err = json.Unmarshal(newBody, &gotResp) + require.NoError(t, err) + + expectedTokenUsage := tokenUsageFrom( + int32(tt.expectedOpenAIResponse.Usage.PromptTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.PromptTokensDetails.CachedTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.PromptTokensDetails.CacheCreationTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.CompletionTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.TotalTokens), // nolint:gosec + ) + require.Equal(t, expectedTokenUsage, usedToken) + + if diff := cmp.Diff(tt.expectedOpenAIResponse, gotResp, cmpopts.IgnoreFields(openai.ChatCompletionResponse{}, "Created")); diff != "" { + t.Errorf("ResponseBody mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToAWSAnthropicTranslator_ResponseError(t *testing.T) { + tests := []struct { + name string + responseHeaders map[string]string + inputBody any + expectedOutput openai.Error + }{ + { + name: "non-json error response", + responseHeaders: map[string]string{ + statusHeaderName: "503", + contentTypeHeaderName: "text/plain; charset=utf-8", + }, + inputBody: "Service Unavailable", + expectedOutput: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsBedrockBackendError, + Code: ptr.To("503"), + Message: "Service Unavailable", + }, + }, + }, + { + name: "json error response", + responseHeaders: map[string]string{ + statusHeaderName: "400", + contentTypeHeaderName: "application/json", + awsErrorTypeHeaderName: "ValidationException", + }, + inputBody: &awsbedrock.BedrockException{ + Message: "messages: field is required", + }, + expectedOutput: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: "ValidationException", + Code: ptr.To("400"), + Message: "messages: field is required", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var reader io.Reader + if bodyStr, ok := tt.inputBody.(string); ok { + reader = bytes.NewBufferString(bodyStr) + } else { + bodyBytes, err := json.Marshal(tt.inputBody) + require.NoError(t, err) + reader = bytes.NewBuffer(bodyBytes) + } + + o := &openAIToAWSAnthropicTranslatorV1ChatCompletion{} + hm, body, err := o.ResponseError(tt.responseHeaders, reader) + + require.NoError(t, err) + require.NotNil(t, body) + require.NotNil(t, hm) + require.Len(t, hm, 2) + require.Equal(t, contentTypeHeaderName, hm[0].Key()) + require.Equal(t, jsonContentType, hm[0].Value()) //nolint:testifylint + require.Equal(t, contentLengthHeaderName, hm[1].Key()) + require.Equal(t, strconv.Itoa(len(body)), hm[1].Value()) + + var gotError openai.Error + err = json.Unmarshal(body, &gotError) + require.NoError(t, err) + + if diff := cmp.Diff(tt.expectedOutput, gotError); diff != "" { + t.Errorf("ResponseError() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestResponseModel_AWSAnthropicStreaming tests that AWS Anthropic streaming returns the request model +// AWS Anthropic uses deterministic model mapping without virtualization +func TestResponseModel_AWSAnthropicStreaming(t *testing.T) { + modelName := "anthropic.claude-sonnet-4-20250514-v1:0" + sseStream := `event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-sonnet-4@20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 10, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 5}} + +event: message_stop +data: {"type": "message_stop"} + +` + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: modelName, + MaxTokens: new(int64), + } + + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // Test streaming response - AWS Anthropic doesn't return model in response, uses request model + _, _, tokenUsage, responseModel, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization + inputTokens, ok := tokenUsage.InputTokens() + require.True(t, ok) + require.Equal(t, uint32(10), inputTokens) + outputTokens, ok := tokenUsage.OutputTokens() + require.True(t, ok) + require.Equal(t, uint32(5), outputTokens) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_ResponseBody_Streaming(t *testing.T) { + t.Run("handles simple text stream", func(t *testing.T) { + sseStream := ` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-opus-4-20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} + +event: message_stop +data: {"type": "message_stop"} + +` + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: "test-model", + MaxTokens: new(int64), + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + + bodyStr := string(bm) + require.Contains(t, bodyStr, `"content":"Hello"`) + require.Contains(t, bodyStr, `"finish_reason":"stop"`) + require.Contains(t, bodyStr, `"prompt_tokens":25`) + require.Contains(t, bodyStr, `"completion_tokens":15`) + require.Contains(t, bodyStr, string(sseDoneMessage)) + }) + + t.Run("handles tool use stream", func(t *testing.T) { + sseStream := `event: message_start +data: {"type":"message_start","message":{"id":"msg_014p7gG3wDgGV9EUtLvnow3U","type":"message","role":"assistant","model":"claude-opus-4-20250514","stop_sequence":null,"usage":{"input_tokens":472,"output_tokens":2},"content":[],"stop_reason":null}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Checking weather"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: content_block_start +data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"San Francisco, CA\", \"unit\": \"fahrenheit\"}"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":1} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":89}} + +event: message_stop +data: {"type":"message_stop"} +` + + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + bodyStr := string(bm) + + require.Contains(t, bodyStr, `"content":"Checking weather"`) + require.Contains(t, bodyStr, `"name":"get_weather"`) + require.Contains(t, bodyStr, `"finish_reason":"tool_calls"`) + require.Contains(t, bodyStr, string(sseDoneMessage)) + }) +} + +func TestAWSAnthropicStreamParser_ErrorHandling(t *testing.T) { + runStreamErrTest := func(t *testing.T, sseStream string, endOfStream bool) error { + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, _, _, _, err = translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), endOfStream, nil) + return err + } + + tests := []struct { + name string + sseStream string + endOfStream bool + expectedError string + }{ + { + name: "malformed message_start event", + sseStream: "event: message_start\ndata: {invalid\n\n", + expectedError: "unmarshal message_start", + }, + { + name: "malformed content_block_start event", + sseStream: "event: content_block_start\ndata: {invalid\n\n", + expectedError: "failed to unmarshal content_block_start", + }, + { + name: "malformed error event data", + sseStream: "event: error\ndata: {invalid\n\n", + expectedError: "unparsable error event", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := runStreamErrTest(t, tt.sseStream, tt.endOfStream) + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + }) + } + + t.Run("body read error", func(t *testing.T) { + parser := newAnthropicStreamParser("test-model") + _, _, _, _, err := parser.Process(&mockErrorReader{}, false, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read from stream body") + }) +} + +func TestOpenAIToAWSAnthropicTranslator_ResponseHeaders(t *testing.T) { + t.Run("non-streaming request", func(t *testing.T) { + translator := &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + streamParser: nil, // Not streaming + } + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Empty(t, headers) + }) + + t.Run("streaming request", func(t *testing.T) { + translator := &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + streamParser: newAnthropicStreamParser("test-model"), + } + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Len(t, headers, 1) + require.Equal(t, contentTypeHeaderName, headers[0].Key()) + require.Equal(t, eventStreamContentType, headers[0].Value()) + }) +} + +func TestOpenAIToAWSAnthropicTranslator_EdgeCases(t *testing.T) { + t.Run("response with model field from API", func(t *testing.T) { + // AWS Anthropic may return model field in response + modelName := "custom-override-model" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + req := &openai.ChatCompletionRequest{ + Model: "original-model", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + {OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Test"}, + Role: openai.ChatMessageRoleUser, + }}, + }, + } + _, _, err := translator.RequestBody(nil, req, false) + require.NoError(t, err) + + // Response with model field + anthropicResp := anthropic.Message{ + ID: "msg_123", + Model: "claude-3-opus-20240229", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Response"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 5, OutputTokens: 3}, + } + + body, err := json.Marshal(anthropicResp) + require.NoError(t, err) + + _, _, _, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + // Should use model from response when available + assert.Equal(t, string(anthropicResp.Model), responseModel) + }) + + t.Run("response without model field", func(t *testing.T) { + // AWS Anthropic typically doesn't return model field + modelName := "anthropic.claude-3-haiku-20240307-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + req := &openai.ChatCompletionRequest{ + Model: "original-model", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + {OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Test"}, + Role: openai.ChatMessageRoleUser, + }}, + }, + } + _, _, err := translator.RequestBody(nil, req, false) + require.NoError(t, err) + + // Response without model field (typical for AWS Bedrock) + anthropicResp := anthropic.Message{ + ID: "msg_123", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Response"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 5, OutputTokens: 3}, + } + + body, err := json.Marshal(anthropicResp) + require.NoError(t, err) + + _, _, _, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + // Should use request model when response doesn't have model field + assert.Equal(t, modelName, responseModel) + }) +} diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index 1bc9b5d03a..90975ea1de 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -376,7 +376,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes CachePoint: cachePointBlock, }) } - case string: return nil, fmt.Errorf("%w: redacted_content must be a binary/bytes value in bedrock", internalapi.ErrInvalidRequestBody) default: @@ -774,7 +773,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string } } - // AWS Bedrock does not support N(multiple choices) > 0, so there could be only one choice. + // AWS Bedrock Converse API does not support N(multiple choices) > 0, so there could be only one choice. choice := openai.ChatCompletionResponseChoice{ Index: (int64)(0), Message: openai.ChatCompletionResponseChoiceMessage{ diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 197dae26ef..6fcbf829ea 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -6,23 +6,16 @@ package translator import ( - "cmp" - "encoding/base64" "fmt" "io" "log/slog" "strconv" "strings" - "time" "github.com/anthropics/anthropic-sdk-go" - anthropicParam "github.com/anthropics/anthropic-sdk-go/packages/param" - "github.com/anthropics/anthropic-sdk-go/shared/constant" anthropicVertex "github.com/anthropics/anthropic-sdk-go/vertex" - openAIconstant "github.com/openai/openai-go/shared/constant" "github.com/tidwall/sjson" - "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/json" @@ -33,9 +26,7 @@ import ( // currently a requirement for GCP Vertex / Anthropic API https://docs.anthropic.com/en/api/claude-on-vertex-ai const ( - anthropicVersionKey = "anthropic_version" - gcpBackendError = "GCPBackendError" - tempNotSupportedError = "temperature %.2f is not supported by Anthropic (must be between 0.0 and 1.0)" + gcpBackendError = "GCPBackendError" ) // NewChatCompletionOpenAIToGCPAnthropicTranslator implements [Factory] for OpenAI to GCP Anthropic translation. @@ -61,652 +52,18 @@ type openAIToGCPAnthropicTranslatorV1ChatCompletion struct { logger *slog.Logger } -func anthropicToOpenAIFinishReason(stopReason anthropic.StopReason) (openai.ChatCompletionChoicesFinishReason, error) { - switch stopReason { - // The most common stop reason. Indicates Claude finished its response naturally. - // or Claude encountered one of your custom stop sequences. - // TODO: A better way to return pause_turn - // TODO: "pause_turn" Used with server tools like web search when Claude needs to pause a long-running operation. - case anthropic.StopReasonEndTurn, anthropic.StopReasonStopSequence, anthropic.StopReasonPauseTurn: - return openai.ChatCompletionChoicesFinishReasonStop, nil - case anthropic.StopReasonMaxTokens: // Claude stopped because it reached the max_tokens limit specified in your request. - // TODO: do we want to return an error? see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#handling-the-max-tokens-stop-reason - return openai.ChatCompletionChoicesFinishReasonLength, nil - case anthropic.StopReasonToolUse: - return openai.ChatCompletionChoicesFinishReasonToolCalls, nil - case anthropic.StopReasonRefusal: - return openai.ChatCompletionChoicesFinishReasonContentFilter, nil - default: - return "", fmt.Errorf("received invalid stop reason %v", stopReason) - } -} - -// validateTemperatureForAnthropic checks if the temperature is within Anthropic's supported range (0.0 to 1.0). -// Returns an error if the value is greater than 1.0. -func validateTemperatureForAnthropic(temp *float64) error { - if temp != nil && (*temp < 0.0 || *temp > 1.0) { - return fmt.Errorf("%w: temperature must be between 0.0 and 1.0", internalapi.ErrInvalidRequestBody) - } - return nil -} - -func isAnthropicSupportedImageMediaType(mediaType string) bool { - switch anthropic.Base64ImageSourceMediaType(mediaType) { - case anthropic.Base64ImageSourceMediaTypeImageJPEG, - anthropic.Base64ImageSourceMediaTypeImagePNG, - anthropic.Base64ImageSourceMediaTypeImageGIF, - anthropic.Base64ImageSourceMediaTypeImageWebP: - return true - default: - return false - } -} - -// translateAnthropicToolChoice converts the OpenAI tool_choice parameter to the Anthropic format. -func translateAnthropicToolChoice(openAIToolChoice *openai.ChatCompletionToolChoiceUnion, disableParallelToolUse anthropicParam.Opt[bool]) (anthropic.ToolChoiceUnionParam, error) { - var toolChoice anthropic.ToolChoiceUnionParam - - if openAIToolChoice == nil { - return toolChoice, nil - } - - switch choice := openAIToolChoice.Value.(type) { - case string: - switch choice { - case string(openAIconstant.ValueOf[openAIconstant.Auto]()): - toolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}} - toolChoice.OfAuto.DisableParallelToolUse = disableParallelToolUse - case "required", "any": - toolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}} - toolChoice.OfAny.DisableParallelToolUse = disableParallelToolUse - case "none": - toolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}} - case string(openAIconstant.ValueOf[openAIconstant.Function]()): - // This is how anthropic forces tool use. - // TODO: should we check if strict true in openAI request, and if so, use this? - toolChoice = anthropic.ToolChoiceUnionParam{OfTool: &anthropic.ToolChoiceToolParam{Name: choice}} - toolChoice.OfTool.DisableParallelToolUse = disableParallelToolUse - default: - return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: unsupported tool_choice value '%s'", internalapi.ErrInvalidRequestBody, choice) - } - case openai.ChatCompletionNamedToolChoice: - if choice.Type == openai.ToolTypeFunction && choice.Function.Name != "" { - toolChoice = anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{ - Type: constant.Tool("tool"), - Name: choice.Function.Name, - DisableParallelToolUse: disableParallelToolUse, - }, - } - } - default: - return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: tool_choice type not supported", internalapi.ErrInvalidRequestBody) - } - return toolChoice, nil -} - -// translateOpenAItoAnthropicTools translates OpenAI tool and tool_choice parameters -// into the Anthropic format and returns translated tool & tool choice. -func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice *openai.ChatCompletionToolChoiceUnion, parallelToolCalls *bool) (tools []anthropic.ToolUnionParam, toolChoice anthropic.ToolChoiceUnionParam, err error) { - if len(openAITools) > 0 { - anthropicTools := make([]anthropic.ToolUnionParam, 0, len(openAITools)) - for _, openAITool := range openAITools { - if openAITool.Type != openai.ToolTypeFunction || openAITool.Function == nil { - // Anthropic only supports 'function' tools, so we skip others. - continue - } - toolParam := anthropic.ToolParam{ - Name: openAITool.Function.Name, - Description: anthropic.String(openAITool.Function.Description), - } - - if isCacheEnabled(openAITool.Function.AnthropicContentFields) { - toolParam.CacheControl = anthropic.NewCacheControlEphemeralParam() - } - - // The parameters for the function are expected to be a JSON Schema object. - // We can pass them through as-is. - if openAITool.Function.Parameters != nil { - paramsMap, ok := openAITool.Function.Parameters.(map[string]any) - if !ok { - err = fmt.Errorf("%w: tool parameters must be a JSON object", internalapi.ErrInvalidRequestBody) - return - } - - inputSchema := anthropic.ToolInputSchemaParam{} - - // Dereference json schema - // If the paramsMap contains $refs we need to dereference them - var dereferencedParamsMap any - if dereferencedParamsMap, err = jsonSchemaDereference(paramsMap); err != nil { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("invalid JSON schema in tool parameters: %w", err) - } - if paramsMap, ok = dereferencedParamsMap.(map[string]any); !ok { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: tool parameters must be a JSON object", internalapi.ErrInvalidRequestBody) - } - - var typeVal string - if typeVal, ok = paramsMap["type"].(string); ok { - inputSchema.Type = constant.Object(typeVal) - } - - var propsVal map[string]any - if propsVal, ok = paramsMap["properties"].(map[string]any); ok { - inputSchema.Properties = propsVal - } - - var requiredVal []any - if requiredVal, ok = paramsMap["required"].([]any); ok { - requiredSlice := make([]string, len(requiredVal)) - for i, v := range requiredVal { - if s, ok := v.(string); ok { - requiredSlice[i] = s - } - } - inputSchema.Required = requiredSlice - } - - toolParam.InputSchema = inputSchema - } - - anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &toolParam}) - if len(anthropicTools) > 0 { - tools = anthropicTools - } - } - - // 2. Handle the tool_choice parameter. - // disable parallel tool use default value is false - // see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use - disableParallelToolUse := anthropic.Bool(false) - if parallelToolCalls != nil { - // OpenAI variable checks to allow parallel tool calls. - // Anthropic variable checks to disable, so need to use the inverse. - disableParallelToolUse = anthropic.Bool(!*parallelToolCalls) - } - - toolChoice, err = translateAnthropicToolChoice(openAIToolChoice, disableParallelToolUse) - if err != nil { - return - } - - } - return -} - -// convertImageContentToAnthropic translates an OpenAI image URL into the corresponding Anthropic content block. -// It handles data URIs for various image types and PDFs, as well as remote URLs. -func convertImageContentToAnthropic(imageURL string, fields *openai.AnthropicContentFields) (anthropic.ContentBlockParamUnion, error) { - var cacheControlParam anthropic.CacheControlEphemeralParam - if isCacheEnabled(fields) { - cacheControlParam = fields.CacheControl - } - - switch { - case strings.HasPrefix(imageURL, "data:"): - contentType, data, err := parseDataURI(imageURL) - if err != nil { - return anthropic.ContentBlockParamUnion{}, fmt.Errorf("%w: invalid image data URI", internalapi.ErrInvalidRequestBody) - } - base64Data := base64.StdEncoding.EncodeToString(data) - if contentType == string(constant.ValueOf[constant.ApplicationPDF]()) { - pdfSource := anthropic.Base64PDFSourceParam{Data: base64Data} - docBlock := anthropic.NewDocumentBlock(pdfSource) - docBlock.OfDocument.CacheControl = cacheControlParam - return docBlock, nil - } - if isAnthropicSupportedImageMediaType(contentType) { - imgBlock := anthropic.NewImageBlockBase64(contentType, base64Data) - imgBlock.OfImage.CacheControl = cacheControlParam - return imgBlock, nil - } - return anthropic.ContentBlockParamUnion{}, fmt.Errorf("%w: invalid media_type for image '%s'", internalapi.ErrInvalidRequestBody, contentType) - case strings.HasSuffix(strings.ToLower(imageURL), ".pdf"): - docBlock := anthropic.NewDocumentBlock(anthropic.URLPDFSourceParam{URL: imageURL}) - docBlock.OfDocument.CacheControl = cacheControlParam - return docBlock, nil - default: - imgBlock := anthropic.NewImageBlock(anthropic.URLImageSourceParam{URL: imageURL}) - imgBlock.OfImage.CacheControl = cacheControlParam - return imgBlock, nil - } -} - -func isCacheEnabled(fields *openai.AnthropicContentFields) bool { - return fields != nil && fields.CacheControl.Type == constant.ValueOf[constant.Ephemeral]() -} - -// convertContentPartsToAnthropic iterates over a slice of OpenAI content parts -// and converts each into an Anthropic content block. -func convertContentPartsToAnthropic(parts []openai.ChatCompletionContentPartUserUnionParam) ([]anthropic.ContentBlockParamUnion, error) { - resultContent := make([]anthropic.ContentBlockParamUnion, 0, len(parts)) - for _, contentPart := range parts { - switch { - case contentPart.OfText != nil: - textBlock := anthropic.NewTextBlock(contentPart.OfText.Text) - if isCacheEnabled(contentPart.OfText.AnthropicContentFields) { - textBlock.OfText.CacheControl = contentPart.OfText.CacheControl - } - resultContent = append(resultContent, textBlock) - - case contentPart.OfImageURL != nil: - block, err := convertImageContentToAnthropic(contentPart.OfImageURL.ImageURL.URL, contentPart.OfImageURL.AnthropicContentFields) - if err != nil { - return nil, err - } - resultContent = append(resultContent, block) - - case contentPart.OfInputAudio != nil: - return nil, fmt.Errorf("%w: input audio content not supported yet", internalapi.ErrInvalidRequestBody) - case contentPart.OfFile != nil: - return nil, fmt.Errorf("%w: file content not supported yet", internalapi.ErrInvalidRequestBody) - } - } - return resultContent, nil -} - -// Helper: Convert OpenAI message content to Anthropic content. -func openAIToAnthropicContent(content any) ([]anthropic.ContentBlockParamUnion, error) { - switch v := content.(type) { - case nil: - return nil, nil - case string: - if v == "" { - return nil, nil - } - return []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock(v), - }, nil - case []openai.ChatCompletionContentPartUserUnionParam: - return convertContentPartsToAnthropic(v) - case openai.ContentUnion: - switch val := v.Value.(type) { - case string: - if val == "" { - return nil, nil - } - return []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock(val), - }, nil - case []openai.ChatCompletionContentPartTextParam: - var contentBlocks []anthropic.ContentBlockParamUnion - for _, part := range val { - textBlock := anthropic.NewTextBlock(part.Text) - // In an array of text parts, each can have its own cache setting. - if isCacheEnabled(part.AnthropicContentFields) { - textBlock.OfText.CacheControl = part.CacheControl - } - contentBlocks = append(contentBlocks, textBlock) - } - return contentBlocks, nil - default: - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) - } - } - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) -} - -// extractSystemPromptFromDeveloperMsg flattens content and checks for cache flags. -// It returns the combined string and a boolean indicating if any part was cacheable. -func extractSystemPromptFromDeveloperMsg(msg openai.ChatCompletionDeveloperMessageParam) (msgValue string, cacheParam *anthropic.CacheControlEphemeralParam) { - switch v := msg.Content.Value.(type) { - case nil: - return - case string: - msgValue = v - return - case []openai.ChatCompletionContentPartTextParam: - // Concatenate all text parts and check for caching. - var sb strings.Builder - for _, part := range v { - sb.WriteString(part.Text) - if isCacheEnabled(part.AnthropicContentFields) { - cacheParam = &part.CacheControl - } - } - msgValue = sb.String() - return - default: - return - } -} - -func anthropicRoleToOpenAIRole(role anthropic.MessageParamRole) (string, error) { - switch role { - case anthropic.MessageParamRoleAssistant: - return openai.ChatMessageRoleAssistant, nil - case anthropic.MessageParamRoleUser: - return openai.ChatMessageRoleUser, nil - default: - return "", fmt.Errorf("invalid anthropic role %v", role) - } -} - -// processAssistantContent processes a single ChatCompletionAssistantMessageParamContent and returns the corresponding Anthropic content block. -func processAssistantContent(content openai.ChatCompletionAssistantMessageParamContent) (*anthropic.ContentBlockParamUnion, error) { - switch content.Type { - case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: - if content.Refusal != nil { - block := anthropic.NewTextBlock(*content.Refusal) - return &block, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeText: - if content.Text != nil { - textBlock := anthropic.NewTextBlock(*content.Text) - if isCacheEnabled(content.AnthropicContentFields) { - textBlock.OfText.CacheControl = content.CacheControl - } - return &textBlock, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeThinking: - // thinking can not be cached: https://platform.claude.com/docs/en/build-with-claude/prompt-caching - if content.Text != nil && content.Signature != nil { - thinkBlock := anthropic.NewThinkingBlock(*content.Signature, *content.Text) - return &thinkBlock, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: - if content.RedactedContent != nil { - switch v := content.RedactedContent.Value.(type) { - case string: - redactedThinkingBlock := anthropic.NewRedactedThinkingBlock(v) - return &redactedThinkingBlock, nil - case []byte: - return nil, fmt.Errorf("%w: redacted_content must be a string in GCP", internalapi.ErrInvalidRequestBody) - default: - return nil, fmt.Errorf("%w: redacted_content must be a string in GCP", internalapi.ErrInvalidRequestBody) - } - } - default: - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) - } - return nil, nil -} - -// openAIMessageToAnthropicMessageRoleAssistant converts an OpenAI assistant message to Anthropic content blocks. -// The tool_use content is appended to the Anthropic message content list if tool_calls are present. -func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatCompletionAssistantMessageParam) (anthropicMsg anthropic.MessageParam, err error) { - contentBlocks := make([]anthropic.ContentBlockParamUnion, 0) - if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 { - contentBlocks = append(contentBlocks, anthropic.NewTextBlock(v)) - } else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok { - // Handle single content object - var block *anthropic.ContentBlockParamUnion - block, err = processAssistantContent(content) - if err != nil { - return anthropicMsg, err - } else if block != nil { - contentBlocks = append(contentBlocks, *block) - } - } else if contents, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok { - // Handle array of content objects - for _, content := range contents { - var block *anthropic.ContentBlockParamUnion - block, err = processAssistantContent(content) - if err != nil { - return anthropicMsg, err - } else if block != nil { - contentBlocks = append(contentBlocks, *block) - } - } - } - - // Handle tool_calls (if any). - for i := range openAiMessage.ToolCalls { - toolCall := &openAiMessage.ToolCalls[i] - var input map[string]any - if err = json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { - err = fmt.Errorf("failed to unmarshal tool call arguments: %w", err) - return - } - toolUse := anthropic.ToolUseBlockParam{ - ID: *toolCall.ID, - Type: "tool_use", - Name: toolCall.Function.Name, - Input: input, - } - - if isCacheEnabled(toolCall.AnthropicContentFields) { - toolUse.CacheControl = toolCall.CacheControl - } - - contentBlocks = append(contentBlocks, anthropic.ContentBlockParamUnion{OfToolUse: &toolUse}) - } - - return anthropic.MessageParam{ - Role: anthropic.MessageParamRoleAssistant, - Content: contentBlocks, - }, nil -} - -// openAIToAnthropicMessages converts OpenAI messages to Anthropic message params type, handling all roles and system/developer logic. -func openAIToAnthropicMessages(openAIMsgs []openai.ChatCompletionMessageParamUnion) (anthropicMessages []anthropic.MessageParam, systemBlocks []anthropic.TextBlockParam, err error) { - for i := 0; i < len(openAIMsgs); { - msg := &openAIMsgs[i] - switch { - case msg.OfSystem != nil: - devParam := systemMsgToDeveloperMsg(*msg.OfSystem) - systemText, cacheControl := extractSystemPromptFromDeveloperMsg(devParam) - systemBlock := anthropic.TextBlockParam{Text: systemText} - if cacheControl != nil { - systemBlock.CacheControl = *cacheControl - } - systemBlocks = append(systemBlocks, systemBlock) - i++ - case msg.OfDeveloper != nil: - systemText, cacheControl := extractSystemPromptFromDeveloperMsg(*msg.OfDeveloper) - systemBlock := anthropic.TextBlockParam{Text: systemText} - if cacheControl != nil { - systemBlock.CacheControl = *cacheControl - } - systemBlocks = append(systemBlocks, systemBlock) - i++ - case msg.OfUser != nil: - message := *msg.OfUser - var content []anthropic.ContentBlockParamUnion - content, err = openAIToAnthropicContent(message.Content.Value) - if err != nil { - return - } - anthropicMsg := anthropic.MessageParam{ - Role: anthropic.MessageParamRoleUser, - Content: content, - } - anthropicMessages = append(anthropicMessages, anthropicMsg) - i++ - case msg.OfAssistant != nil: - assistantMessage := msg.OfAssistant - var messages anthropic.MessageParam - messages, err = openAIMessageToAnthropicMessageRoleAssistant(assistantMessage) - if err != nil { - return - } - anthropicMessages = append(anthropicMessages, messages) - i++ - case msg.OfTool != nil: - // Aggregate all consecutive tool messages into a single user message - // to support parallel tool use. - var toolResultBlocks []anthropic.ContentBlockParamUnion - for i < len(openAIMsgs) && openAIMsgs[i].ExtractMessgaeRole() == openai.ChatMessageRoleTool { - currentMsg := &openAIMsgs[i] - toolMsg := currentMsg.OfTool - - var contentBlocks []anthropic.ContentBlockParamUnion - contentBlocks, err = openAIToAnthropicContent(toolMsg.Content) - if err != nil { - return - } - - var toolContent []anthropic.ToolResultBlockParamContentUnion - var cacheControl *anthropic.CacheControlEphemeralParam - - for _, c := range contentBlocks { - var trb anthropic.ToolResultBlockParamContentUnion - // Check if the translated part has caching enabled. - switch { - case c.OfText != nil: - trb.OfText = c.OfText - cacheControl = &c.OfText.CacheControl - case c.OfImage != nil: - trb.OfImage = c.OfImage - cacheControl = &c.OfImage.CacheControl - case c.OfDocument != nil: - trb.OfDocument = c.OfDocument - cacheControl = &c.OfDocument.CacheControl - } - toolContent = append(toolContent, trb) - } - - isError := false - if contentStr, ok := toolMsg.Content.Value.(string); ok { - var contentMap map[string]any - if json.Unmarshal([]byte(contentStr), &contentMap) == nil { - if _, ok = contentMap["error"]; ok { - isError = true - } - } - } - - toolResultBlock := anthropic.ToolResultBlockParam{ - ToolUseID: toolMsg.ToolCallID, - Type: "tool_result", - Content: toolContent, - IsError: anthropic.Bool(isError), - } - - if cacheControl != nil { - toolResultBlock.CacheControl = *cacheControl - } - - toolResultBlockUnion := anthropic.ContentBlockParamUnion{OfToolResult: &toolResultBlock} - toolResultBlocks = append(toolResultBlocks, toolResultBlockUnion) - i++ - } - // Append all aggregated tool results. - anthropicMsg := anthropic.MessageParam{ - Role: anthropic.MessageParamRoleUser, - Content: toolResultBlocks, - } - anthropicMessages = append(anthropicMessages, anthropicMsg) - default: - err = fmt.Errorf("%w: unsupported role type: %s", internalapi.ErrInvalidRequestBody, msg.ExtractMessgaeRole()) - return - } - } - return -} - -// NewThinkingConfigParamUnion converts a ThinkingUnion into a ThinkingConfigParamUnion. -func getThinkingConfigParamUnion(tu *openai.ThinkingUnion) *anthropic.ThinkingConfigParamUnion { - if tu == nil { - return nil - } - - result := &anthropic.ThinkingConfigParamUnion{} - - if tu.OfEnabled != nil { - result.OfEnabled = &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: tu.OfEnabled.BudgetTokens, - Type: constant.Enabled(tu.OfEnabled.Type), - } - } else if tu.OfDisabled != nil { - result.OfDisabled = &anthropic.ThinkingConfigDisabledParam{ - Type: constant.Disabled(tu.OfDisabled.Type), - } - } - - return result -} - -// buildAnthropicParams is a helper function that translates an OpenAI request -// into the parameter struct required by the Anthropic SDK. -func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anthropic.MessageNewParams, err error) { - // 1. Handle simple parameters and defaults. - maxTokens := cmp.Or(openAIReq.MaxCompletionTokens, openAIReq.MaxTokens) - if maxTokens == nil { - err = fmt.Errorf("%w: max_tokens or max_completion_tokens is required", internalapi.ErrInvalidRequestBody) - return - } - - // Translate openAI contents to anthropic params. - // 2. Translate messages and system prompts. - messages, systemBlocks, err := openAIToAnthropicMessages(openAIReq.Messages) - if err != nil { - return - } - - // 3. Translate tools and tool choice. - tools, toolChoice, err := translateOpenAItoAnthropicTools(openAIReq.Tools, openAIReq.ToolChoice, openAIReq.ParallelToolCalls) - if err != nil { - return - } - - // 4. Construct the final struct in one place. - params = &anthropic.MessageNewParams{ - Messages: messages, - MaxTokens: *maxTokens, - System: systemBlocks, - Tools: tools, - ToolChoice: toolChoice, - } - - if openAIReq.Temperature != nil { - if err = validateTemperatureForAnthropic(openAIReq.Temperature); err != nil { - return nil, err - } - params.Temperature = anthropic.Float(*openAIReq.Temperature) - } - if openAIReq.TopP != nil { - params.TopP = anthropic.Float(*openAIReq.TopP) - } - if openAIReq.Stop.OfString.Valid() { - params.StopSequences = []string{openAIReq.Stop.OfString.String()} - } else if openAIReq.Stop.OfStringArray != nil { - params.StopSequences = openAIReq.Stop.OfStringArray - } - - // 5. Handle Vendor specific fields. - // Since GCPAnthropic follows the Anthropic API, we also check for Anthropic vendor fields. - if openAIReq.Thinking != nil { - params.Thinking = *getThinkingConfigParamUnion(openAIReq.Thinking) - } - - return params, nil -} - -// anthropicToolUseToOpenAICalls converts Anthropic tool_use content blocks to OpenAI tool calls. -func anthropicToolUseToOpenAICalls(block *anthropic.ContentBlockUnion) ([]openai.ChatCompletionMessageToolCallParam, error) { - var toolCalls []openai.ChatCompletionMessageToolCallParam - if block.Type != string(constant.ValueOf[constant.ToolUse]()) { - return toolCalls, nil - } - argsBytes, err := json.Marshal(block.Input) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool_use input: %w", err) - } - toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: &block.ID, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: block.Name, - Arguments: string(argsBytes), - }, - }) - - return toolCalls, nil -} - // RequestBody implements [OpenAIChatCompletionTranslator.RequestBody] for GCP. func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( newHeaders []internalapi.Header, newBody []byte, err error, ) { params, err := buildAnthropicParams(openAIReq) if err != nil { - return nil, nil, err + return } body, err := json.Marshal(params) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal params: %w", err) + return } o.requestModel = openAIReq.Model @@ -896,12 +253,9 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri responseModel = string(anthropicResp.Model) } - openAIResp := &openai.ChatCompletionResponse{ - ID: anthropicResp.ID, - Model: responseModel, - Object: string(openAIconstant.ValueOf[openAIconstant.ChatCompletion]()), - Choices: make([]openai.ChatCompletionResponseChoice, 0), - Created: openai.JSONUNIXTime(time.Now()), + openAIResp, tokenUsage, err := messageToChatCompletion(&anthropicResp, responseModel) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", err } // Redact and log response when enabled @@ -912,88 +266,6 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri } } - usage := anthropicResp.Usage - tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( - usage.InputTokens, - usage.OutputTokens, - &usage.CacheReadInputTokens, - &usage.CacheCreationInputTokens, - ) - inputTokens, _ := tokenUsage.InputTokens() - outputTokens, _ := tokenUsage.OutputTokens() - totalTokens, _ := tokenUsage.TotalTokens() - cachedTokens, _ := tokenUsage.CachedInputTokens() - cacheWriteTokens, _ := tokenUsage.CacheCreationInputTokens() - openAIResp.Usage = openai.Usage{ - CompletionTokens: int(outputTokens), - PromptTokens: int(inputTokens), - TotalTokens: int(totalTokens), - PromptTokensDetails: &openai.PromptTokensDetails{ - CachedTokens: int(cachedTokens), - CacheCreationTokens: int(cacheWriteTokens), - }, - } - - finishReason, err := anthropicToOpenAIFinishReason(anthropicResp.StopReason) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", err - } - - role, err := anthropicRoleToOpenAIRole(anthropic.MessageParamRole(anthropicResp.Role)) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", err - } - - choice := openai.ChatCompletionResponseChoice{ - Index: 0, - Message: openai.ChatCompletionResponseChoiceMessage{Role: role}, - FinishReason: finishReason, - } - - for i := range anthropicResp.Content { // NOTE: Content structure is massive, do not range over values. - output := &anthropicResp.Content[i] - switch output.Type { - case string(constant.ValueOf[constant.ToolUse]()): - if output.ID != "" { - toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) - if toolErr != nil { - return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) - } - choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) - } - case string(constant.ValueOf[constant.Text]()): - if output.Text != "" { - if choice.Message.Content == nil { - choice.Message.Content = &output.Text - } - } - case string(constant.ValueOf[constant.Thinking]()): - if output.Thinking != "" { - choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ - Value: &openai.ReasoningContent{ - ReasoningContent: &awsbedrock.ReasoningContentBlock{ - ReasoningText: &awsbedrock.ReasoningTextBlock{ - Text: output.Thinking, - Signature: output.Signature, - }, - }, - }, - } - } - case string(constant.ValueOf[constant.RedactedThinking]()): - if output.Data != "" { - choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ - Value: &openai.ReasoningContent{ - ReasoningContent: &awsbedrock.ReasoningContentBlock{ - RedactedContent: []byte(output.Data), - }, - }, - } - } - } - } - openAIResp.Choices = append(openAIResp.Choices, choice) - newBody, err = json.Marshal(openAIResp) if err != nil { return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal body: %w", err) diff --git a/internal/translator/openai_gcpanthropic_stream.go b/internal/translator/openai_gcpanthropic_stream.go deleted file mode 100644 index bee94b8af2..0000000000 --- a/internal/translator/openai_gcpanthropic_stream.go +++ /dev/null @@ -1,421 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package translator - -import ( - "bytes" - "fmt" - "io" - "time" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/shared/constant" - - "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/internalapi" - "github.com/envoyproxy/ai-gateway/internal/json" - "github.com/envoyproxy/ai-gateway/internal/metrics" - "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" -) - -var sseEventPrefix = []byte("event: ") - -// streamingToolCall holds the state for a single tool call that is being streamed. -type streamingToolCall struct { - id string - name string - inputJSON string -} - -// anthropicStreamParser manages the stateful translation of an Anthropic SSE stream -// to an OpenAI-compatible SSE stream. -type anthropicStreamParser struct { - buffer bytes.Buffer - activeMessageID string - activeToolCalls map[int64]*streamingToolCall - toolIndex int64 - tokenUsage metrics.TokenUsage - stopReason anthropic.StopReason - requestModel internalapi.RequestModel - sentFirstChunk bool - created openai.JSONUNIXTime -} - -// newAnthropicStreamParser creates a new parser for a streaming request. -func newAnthropicStreamParser(requestModel string) *anthropicStreamParser { - toolIdx := int64(-1) - return &anthropicStreamParser{ - requestModel: requestModel, - activeToolCalls: make(map[int64]*streamingToolCall), - toolIndex: toolIdx, - } -} - -func (p *anthropicStreamParser) writeChunk(eventBlock []byte, buf *[]byte) error { - chunk, err := p.parseAndHandleEvent(eventBlock) - if err != nil { - return err - } - if chunk != nil { - err := serializeOpenAIChatCompletionChunk(chunk, buf) - if err != nil { - return err - } - } - return nil -} - -// Process reads from the Anthropic SSE stream, translates events to OpenAI chunks, -// and returns the mutations for Envoy. -func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( - newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, -) { - newBody = make([]byte, 0) - _ = span // TODO: add support for streaming chunks in tracingapi. - responseModel = p.requestModel - if _, err = p.buffer.ReadFrom(body); err != nil { - err = fmt.Errorf("failed to read from stream body: %w", err) - return - } - - for { - eventBlock, remaining, found := bytes.Cut(p.buffer.Bytes(), []byte("\n\n")) - if !found { - break - } - - if err = p.writeChunk(eventBlock, &newBody); err != nil { - return - } - - p.buffer.Reset() - p.buffer.Write(remaining) - } - - if endOfStream && p.buffer.Len() > 0 { - finalEventBlock := p.buffer.Bytes() - p.buffer.Reset() - - if err = p.writeChunk(finalEventBlock, &newBody); err != nil { - return - } - } - - if endOfStream { - inputTokens, _ := p.tokenUsage.InputTokens() - outputTokens, _ := p.tokenUsage.OutputTokens() - p.tokenUsage.SetTotalTokens(inputTokens + outputTokens) - totalTokens, _ := p.tokenUsage.TotalTokens() - cachedTokens, _ := p.tokenUsage.CachedInputTokens() - cacheCreationTokens, _ := p.tokenUsage.CacheCreationInputTokens() - finalChunk := &openai.ChatCompletionResponseChunk{ - ID: p.activeMessageID, - Created: p.created, - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionResponseChunkChoice{}, - Usage: &openai.Usage{ - PromptTokens: int(inputTokens), - CompletionTokens: int(outputTokens), - TotalTokens: int(totalTokens), - PromptTokensDetails: &openai.PromptTokensDetails{ - CachedTokens: int(cachedTokens), - CacheCreationTokens: int(cacheCreationTokens), - }, - }, - Model: p.requestModel, - } - - // Add active tool calls to the final chunk. - var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall - for toolIndex, tool := range p.activeToolCalls { - toolCalls = append(toolCalls, openai.ChatCompletionChunkChoiceDeltaToolCall{ - ID: &tool.id, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: tool.name, - Arguments: tool.inputJSON, - }, - Index: toolIndex, - }) - } - - if len(toolCalls) > 0 { - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: toolCalls, - } - finalChunk.Choices = append(finalChunk.Choices, openai.ChatCompletionResponseChunkChoice{ - Delta: &delta, - }) - } - - if finalChunk.Usage.PromptTokens > 0 || finalChunk.Usage.CompletionTokens > 0 || len(finalChunk.Choices) > 0 { - err := serializeOpenAIChatCompletionChunk(finalChunk, &newBody) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal final stream chunk: %w", err) - } - } - // Add the final [DONE] message to indicate the end of the stream. - newBody = append(newBody, sseDataPrefix...) - newBody = append(newBody, sseDoneMessage...) - newBody = append(newBody, '\n', '\n') - } - tokenUsage = p.tokenUsage - return -} - -func (p *anthropicStreamParser) parseAndHandleEvent(eventBlock []byte) (*openai.ChatCompletionResponseChunk, error) { - var eventType []byte - var eventData []byte - - lines := bytes.SplitSeq(eventBlock, []byte("\n")) - for line := range lines { - if after, ok := bytes.CutPrefix(line, sseEventPrefix); ok { - eventType = bytes.TrimSpace(after) - } else if after, ok := bytes.CutPrefix(line, sseDataPrefix); ok { - // This handles JSON data that might be split across multiple 'data:' lines - // by concatenating them (Anthropic's format). - data := bytes.TrimSpace(after) - eventData = append(eventData, data...) - } - } - - if len(eventType) > 0 && len(eventData) > 0 { - return p.handleAnthropicStreamEvent(eventType, eventData) - } - - return nil, nil -} - -func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, data []byte) (*openai.ChatCompletionResponseChunk, error) { - switch string(eventType) { - case string(constant.ValueOf[constant.MessageStart]()): - var event anthropic.MessageStartEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_start: %w", err) - } - p.activeMessageID = event.Message.ID - p.created = openai.JSONUNIXTime(time.Now()) - u := event.Message.Usage - usage := metrics.ExtractTokenUsageFromExplicitCaching( - u.InputTokens, - u.OutputTokens, - &u.CacheReadInputTokens, - &u.CacheCreationInputTokens, - ) - // For message_start, we store the initial usage but don't add to the accumulated - // The message_delta event will contain the final totals - if input, ok := usage.InputTokens(); ok { - p.tokenUsage.SetInputTokens(input) - } - if cached, ok := usage.CachedInputTokens(); ok { - p.tokenUsage.SetCachedInputTokens(cached) - } - if cacheCreation, ok := usage.CacheCreationInputTokens(); ok { - p.tokenUsage.SetCacheCreationInputTokens(cacheCreation) - } - - // reset the toolIndex for each message - p.toolIndex = -1 - return nil, nil - - case string(constant.ValueOf[constant.ContentBlockStart]()): - var event anthropic.ContentBlockStartEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err) - } - if event.ContentBlock.Type == string(constant.ValueOf[constant.ToolUse]()) || event.ContentBlock.Type == string(constant.ValueOf[constant.ServerToolUse]()) { - p.toolIndex++ - var argsJSON string - // Check if the input field is provided directly in the start event. - if event.ContentBlock.Input != nil { - switch input := event.ContentBlock.Input.(type) { - case map[string]any: - // for case where "input":{}, skip adding it to arguments. - if len(input) > 0 { - argsBytes, err := json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool use input: %w", err) - } - argsJSON = string(argsBytes) - } - default: - // although golang sdk defines type of Input to be any, - // python sdk requires the type of Input to be Dict[str, object]: - // https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_use_block.py#L14. - return nil, fmt.Errorf("unexpected tool use input type: %T", input) - } - } - - // Store the complete input JSON in our state. - p.activeToolCalls[p.toolIndex] = &streamingToolCall{ - id: event.ContentBlock.ID, - name: event.ContentBlock.Name, - inputJSON: argsJSON, - } - - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ - { - Index: p.toolIndex, - ID: &event.ContentBlock.ID, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: event.ContentBlock.Name, - // Include the arguments if they are available. - Arguments: argsJSON, - }, - }, - }, - } - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - } - // do not need to return an empty str for thinking start block - return nil, nil - - case string(constant.ValueOf[constant.MessageDelta]()): - var event anthropic.MessageDeltaEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_delta: %w", err) - } - u := event.Usage - usage := metrics.ExtractTokenUsageFromExplicitCaching( - u.InputTokens, - u.OutputTokens, - &u.CacheReadInputTokens, - &u.CacheCreationInputTokens, - ) - // For message_delta, accumulate the incremental output tokens - if output, ok := usage.OutputTokens(); ok { - p.tokenUsage.AddOutputTokens(output) - } - // Update input tokens to include read cache tokens from delta - if cached, ok := usage.CachedInputTokens(); ok { - p.tokenUsage.AddInputTokens(cached) - // Accumulate any additional cache tokens from delta - p.tokenUsage.AddCachedInputTokens(cached) - } - // Update input tokens to include write cache tokens from delta - if cached, ok := usage.CacheCreationInputTokens(); ok { - p.tokenUsage.AddInputTokens(cached) - // Accumulate any additional cache tokens from delta - p.tokenUsage.AddCacheCreationInputTokens(cached) - } - if event.Delta.StopReason != "" { - p.stopReason = event.Delta.StopReason - } - return nil, nil - - case string(constant.ValueOf[constant.ContentBlockDelta]()): - var event anthropic.ContentBlockDeltaEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal content_block_delta: %w", err) - } - switch event.Delta.Type { - case string(constant.ValueOf[constant.TextDelta]()): - delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text} - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - - case string(constant.ValueOf[constant.ThinkingDelta]()): - // this should already include the case for redacted thinking: https://platform.claude.com/docs/en/build-with-claude/streaming#content-block-delta-types - - reasoningDelta := &openai.StreamReasoningContent{} - - // Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct. - if event.Delta.Thinking != "" { - reasoningDelta.Text = event.Delta.Thinking - } - if event.Delta.Signature != "" { - reasoningDelta.Signature = event.Delta.Signature - } - - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ReasoningContent: reasoningDelta, - } - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - - case string(constant.ValueOf[constant.InputJSONDelta]()): - tool, ok := p.activeToolCalls[p.toolIndex] - if !ok { - return nil, fmt.Errorf("received input_json_delta for unknown tool at index %d", p.toolIndex) - } - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ - { - Index: p.toolIndex, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Arguments: event.Delta.PartialJSON, - }, - }, - }, - } - tool.inputJSON += event.Delta.PartialJSON - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - } - // Do not process redacted thinking stream? Did not find the source - - case string(constant.ValueOf[constant.ContentBlockStop]()): - // This event is for state cleanup, no chunk is sent. - var event anthropic.ContentBlockStopEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal content_block_stop: %w", err) - } - delete(p.activeToolCalls, p.toolIndex) - return nil, nil - - case string(constant.ValueOf[constant.MessageStop]()): - var event anthropic.MessageStopEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_stop: %w", err) - } - - if p.stopReason == "" { - p.stopReason = anthropic.StopReasonEndTurn - } - - finishReason, err := anthropicToOpenAIFinishReason(p.stopReason) - if err != nil { - return nil, err - } - return p.constructOpenAIChatCompletionChunk(openai.ChatCompletionResponseChunkChoiceDelta{}, finishReason), nil - - case string(constant.ValueOf[constant.Error]()): - var errEvent anthropic.ErrorResponse - if err := json.Unmarshal(data, &errEvent); err != nil { - return nil, fmt.Errorf("unparsable error event: %s", string(data)) - } - return nil, fmt.Errorf("anthropic stream error: %s - %s", errEvent.Error.Type, errEvent.Error.Message) - - case "ping": - // Per documentation, ping events can be ignored. - return nil, nil - } - return nil, nil -} - -// constructOpenAIChatCompletionChunk builds the stream chunk. -func (p *anthropicStreamParser) constructOpenAIChatCompletionChunk(delta openai.ChatCompletionResponseChunkChoiceDelta, finishReason openai.ChatCompletionChoicesFinishReason) *openai.ChatCompletionResponseChunk { - // Add the 'assistant' role to the very first chunk of the response. - if !p.sentFirstChunk { - // Only add the role if the delta actually contains content or a tool call. - if delta.Content != nil || len(delta.ToolCalls) > 0 { - delta.Role = openai.ChatMessageRoleAssistant - p.sentFirstChunk = true - } - } - - return &openai.ChatCompletionResponseChunk{ - ID: p.activeMessageID, - Created: p.created, - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionResponseChunkChoice{ - { - Delta: &delta, - FinishReason: finishReason, - }, - }, - Model: p.requestModel, - } -} diff --git a/internal/translator/openai_gcpanthropic_stream_test.go b/internal/translator/openai_gcpanthropic_stream_test.go deleted file mode 100644 index c10eafde7b..0000000000 --- a/internal/translator/openai_gcpanthropic_stream_test.go +++ /dev/null @@ -1,1031 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package translator - -import ( - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/json" - "github.com/envoyproxy/ai-gateway/internal/metrics" -) - -// mockErrorReader is a helper for testing io.Reader failures. -type mockErrorReader struct{} - -func (r *mockErrorReader) Read(_ []byte) (n int, err error) { - return 0, fmt.Errorf("mock reader error") -} - -func TestAnthropicStreamParser_ErrorHandling(t *testing.T) { - runStreamErrTest := func(t *testing.T, sseStream string, endOfStream bool) error { - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, _, _, _, err = translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), endOfStream, nil) - return err - } - - tests := []struct { - name string - sseStream string - endOfStream bool - expectedError string - }{ - { - name: "malformed message_start event", - sseStream: "event: message_start\ndata: {invalid\n\n", - expectedError: "unmarshal message_start", - }, - { - name: "malformed content_block_start event", - sseStream: "event: content_block_start\ndata: {invalid\n\n", - expectedError: "failed to unmarshal content_block_start", - }, - { - name: "malformed content_block_delta event", - sseStream: "event: content_block_delta\ndata: {invalid\n\n", - expectedError: "unmarshal content_block_delta", - }, - { - name: "malformed content_block_stop event", - sseStream: "event: content_block_stop\ndata: {invalid\n\n", - expectedError: "unmarshal content_block_stop", - }, - { - name: "malformed error event data", - sseStream: "event: error\ndata: {invalid\n\n", - expectedError: "unparsable error event", - }, - { - name: "unknown stop reason", - endOfStream: true, - sseStream: `event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "some_future_reason"}, "usage": {"output_tokens": 0}} - -event: message_stop -data: {"type": "message_stop"} -`, - expectedError: "received invalid stop reason", - }, - { - name: "malformed_final_event_block", - sseStream: "event: message_stop\ndata: {invalid", // No trailing \n\n. - endOfStream: true, - expectedError: "unmarshal message_stop", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := runStreamErrTest(t, tt.sseStream, tt.endOfStream) - require.Error(t, err) - require.Contains(t, err.Error(), tt.expectedError) - }) - } - - t.Run("body read error", func(t *testing.T) { - parser := newAnthropicStreamParser("test-model") - _, _, _, _, err := parser.Process(&mockErrorReader{}, false, nil) - require.Error(t, err) - require.Contains(t, err.Error(), "failed to read from stream body") - }) -} - -// TestResponseModel_GCPAnthropicStreaming tests that GCP Anthropic streaming returns the request model -// GCP Anthropic uses deterministic model mapping without virtualization -func TestResponseModel_GCPAnthropicStreaming(t *testing.T) { - modelName := "claude-sonnet-4@20250514" - sseStream := `event: message_start -data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-sonnet-4@20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 10, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 5}} - -event: message_stop -data: {"type": "message_stop"} - -` - openAIReq := &openai.ChatCompletionRequest{ - Stream: true, - Model: modelName, // Use the actual model name from documentation - MaxTokens: new(int64), - } - - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - // Test streaming response - GCP Anthropic doesn't return model in response, uses request model - _, _, tokenUsage, responseModel, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization - inputTokens, ok := tokenUsage.InputTokens() - require.True(t, ok) - require.Equal(t, uint32(10), inputTokens) - outputTokens, ok := tokenUsage.OutputTokens() - require.True(t, ok) - require.Equal(t, uint32(5), outputTokens) -} - -func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_Streaming(t *testing.T) { - t.Run("handles simple text stream", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-opus-4-20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} - -event: message_stop -data: {"type": "message_stop"} - -` - openAIReq := &openai.ChatCompletionRequest{ - Stream: true, - Model: "test-model", - MaxTokens: new(int64), - } - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - - bodyStr := string(bm) - require.Contains(t, bodyStr, `"content":"Hello"`) - require.Contains(t, bodyStr, `"finish_reason":"stop"`) - require.Contains(t, bodyStr, `"prompt_tokens":25`) - require.Contains(t, bodyStr, `"completion_tokens":15`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles text and tool use stream", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_014p7gG3wDgGV9EUtLvnow3U","type":"message","role":"assistant","model":"claude-opus-4-20250514","stop_sequence":null,"usage":{"input_tokens":472,"output_tokens":2},"content":[],"stop_reason":null}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Okay"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" let"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"'s"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" check"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" weather"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" San"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Francisco"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" CA"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":":"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: content_block_start -data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" Francisc"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"o,"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" CA\""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":", "}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\"unit\": \"fah"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"renheit\"}"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":1} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":89}} - -event: message_stop -data: {"type":"message_stop"} -` - - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // Parse all streaming events to verify the event flow - var chunks []openai.ChatCompletionResponseChunk - var textChunks []string - var toolCallStarted bool - var hasRole bool - var toolCallCompleted bool - var finalFinishReason openai.ChatCompletionChoicesFinishReason - var finalUsageChunk *openai.ChatCompletionResponseChunk - var toolCallChunks []string // Track partial JSON chunks - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - - // Check if this is the final usage chunk - if strings.Contains(jsonBody, `"usage"`) { - finalUsageChunk = &chunk - } - - if len(chunk.Choices) > 0 { - choice := chunk.Choices[0] - // Check for role in first content chunk - if choice.Delta != nil && choice.Delta.Content != nil && *choice.Delta.Content != "" && !hasRole { - require.NotNil(t, choice.Delta.Role, "Role should be present on first content chunk") - require.Equal(t, openai.ChatMessageRoleAssistant, choice.Delta.Role) - hasRole = true - } - - // Collect text content - if choice.Delta != nil && choice.Delta.Content != nil { - textChunks = append(textChunks, *choice.Delta.Content) - } - - // Check tool calls - start and accumulate partial JSON - if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 { - toolCall := choice.Delta.ToolCalls[0] - - // Check tool call initiation - if toolCall.Function.Name == "get_weather" && !toolCallStarted { - require.Equal(t, "get_weather", toolCall.Function.Name) - require.NotNil(t, toolCall.ID) - require.Equal(t, "toolu_01T1x1fJ34qAmk2tNTrN7Up6", *toolCall.ID) - require.Equal(t, int64(0), toolCall.Index, "Tool call should be at index 1 (after text content at index 0)") - toolCallStarted = true - } - - // Accumulate partial JSON arguments - these should also be at index 1 - if toolCall.Function.Arguments != "" { - toolCallChunks = append(toolCallChunks, toolCall.Function.Arguments) - - // Verify the index remains consistent at 1 for all tool call chunks - require.Equal(t, int64(0), toolCall.Index, "Tool call argument chunks should be at index 1") - } - } - - // Track finish reason - if choice.FinishReason != "" { - finalFinishReason = choice.FinishReason - if finalFinishReason == "tool_calls" { - toolCallCompleted = true - } - } - } - } - - // Check the final usage chunk for accumulated tool call arguments - if finalUsageChunk != nil { - require.Equal(t, 472, finalUsageChunk.Usage.PromptTokens) - require.Equal(t, 89, finalUsageChunk.Usage.CompletionTokens) - } - - // Verify partial JSON accumulation in streaming chunks - if len(toolCallChunks) > 0 { - // Verify we got multiple partial JSON chunks during streaming - require.GreaterOrEqual(t, len(toolCallChunks), 2, "Should receive multiple partial JSON chunks for tool arguments") - - // Verify some expected partial content appears in the chunks - fullPartialJSON := strings.Join(toolCallChunks, "") - require.Contains(t, fullPartialJSON, `"location":`, "Partial JSON should contain location field") - require.Contains(t, fullPartialJSON, `"unit":`, "Partial JSON should contain unit field") - require.Contains(t, fullPartialJSON, "San Francisco", "Partial JSON should contain location value") - require.Contains(t, fullPartialJSON, "fahrenheit", "Partial JSON should contain unit value") - } - - // Verify streaming event assertions - require.GreaterOrEqual(t, len(chunks), 5, "Should have multiple streaming chunks") - require.True(t, hasRole, "Should have role in first content chunk") - require.True(t, toolCallStarted, "Tool call should have been initiated") - require.True(t, toolCallCompleted, "Tool call should have complete arguments in final chunk") - require.Equal(t, openai.ChatCompletionChoicesFinishReasonToolCalls, finalFinishReason, "Final finish reason should be tool_calls") - - // Verify text content was streamed correctly - fullText := strings.Join(textChunks, "") - require.Contains(t, fullText, "Okay, let's check the weather for San Francisco, CA:") - require.GreaterOrEqual(t, len(textChunks), 3, "Text should be streamed in multiple chunks") - - // Original aggregate response assertions - require.Contains(t, bodyStr, `"content":"Okay"`) - require.Contains(t, bodyStr, `"name":"get_weather"`) - require.Contains(t, bodyStr, "\"arguments\":\"{\\\"location\\\":") - require.NotContains(t, bodyStr, "\"arguments\":\"{}\"") - require.Contains(t, bodyStr, "renheit\\\"}\"") - require.Contains(t, bodyStr, `"finish_reason":"tool_calls"`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles streaming with web search tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_01G...","type":"message","role":"assistant","usage":{"input_tokens":2679,"output_tokens":3}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"I'll check"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the current weather in New York City for you"}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: content_block_start -data: {"type":"content_block_start","index":1,"content_block":{"type":"server_tool_use","id":"srvtoolu_014hJH82Qum7Td6UV8gDXThB","name":"web_search","input":{}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"query\":\"weather NYC today\"}"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":1} - -event: content_block_start -data: {"type":"content_block_start","index":2,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_014hJH82Qum7Td6UV8gDXThB","content":[{"type":"web_search_result","title":"Weather in New York City in May 2025 (New York)","url":"https://world-weather.info/forecast/usa/new_york/may-2025/","page_age":null}]}} - -event: content_block_stop -data: {"type":"content_block_stop","index":2} - -event: content_block_start -data: {"type":"content_block_start","index":3,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":3,"delta":{"type":"text_delta","text":"Here's the current weather information for New York"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":3,"delta":{"type":"text_delta","text":" City."}} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":510}} - -event: message_stop -data: {"type":"message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - require.Contains(t, bodyStr, `"content":"I'll check"`) - require.Contains(t, bodyStr, `"content":" the current weather in New York City for you"`) - require.Contains(t, bodyStr, `"name":"web_search"`) - require.Contains(t, bodyStr, "\"arguments\":\"{\\\"query\\\":\\\"weather NYC today\\\"}\"") - require.NotContains(t, bodyStr, "\"arguments\":\"{}\"") - require.Contains(t, bodyStr, `"content":"Here's the current weather information for New York"`) - require.Contains(t, bodyStr, `"finish_reason":"stop"`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles unterminated tool call at end of stream", func(t *testing.T) { - // This stream starts a tool call but ends without a content_block_stop or message_stop. - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_abc","name":"get_weather"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"SF\"}"}} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var finalToolCallChunk openai.ChatCompletionResponseChunk - - // Split the response into individual SSE messages and find the final data chunk. - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.HasPrefix(line, "data: [DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - // The final chunk with the accumulated tool call is the only one with a "usage" field. - if strings.Contains(jsonBody, `"usage"`) { - err := json.Unmarshal([]byte(jsonBody), &finalToolCallChunk) - require.NoError(t, err, "Failed to unmarshal final tool call chunk") - break - } - } - - require.NotEmpty(t, finalToolCallChunk.Choices, "Final chunk should have choices") - require.NotNil(t, finalToolCallChunk.Choices[0].Delta.ToolCalls, "Final chunk should have tool calls") - - finalToolCall := finalToolCallChunk.Choices[0].Delta.ToolCalls[0] - require.Equal(t, "tool_abc", *finalToolCall.ID) - require.Equal(t, "get_weather", finalToolCall.Function.Name) - require.JSONEq(t, `{"location": "SF"}`, finalToolCall.Function.Arguments) - }) - t.Run("handles thinking and tool use stream", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_123", "type": "message", "role": "assistant", "usage": {"input_tokens": 50, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking", "name": "web_searcher"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Searching for information..."}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: content_block_start -data: {"type": "content_block_start", "index": 1, "content_block": {"type": "tool_use", "id": "toolu_abc123", "name": "get_weather", "input": {"location": "San Francisco, CA"}}} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "tool_use"}, "usage": {"output_tokens": 35}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var contentDeltas []string - var reasoningTexts []string - var foundToolCallWithArgs bool - var finalFinishReason openai.ChatCompletionChoicesFinishReason - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil { - if choice.Delta.Content != nil { - contentDeltas = append(contentDeltas, *choice.Delta.Content) - } - if choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - } - if len(choice.Delta.ToolCalls) > 0 { - toolCall := choice.Delta.ToolCalls[0] - // Check if this is the tool chunk that contains the arguments. - if toolCall.Function.Arguments != "" { - expectedArgs := `{"location":"San Francisco, CA"}` - assert.JSONEq(t, expectedArgs, toolCall.Function.Arguments, "Tool call arguments do not match") - assert.Equal(t, "get_weather", toolCall.Function.Name) - assert.Equal(t, "toolu_abc123", *toolCall.ID) - foundToolCallWithArgs = true - } else { - // This should be the initial tool call chunk with empty arguments since input is provided upfront - assert.Equal(t, "get_weather", toolCall.Function.Name) - assert.Equal(t, "toolu_abc123", *toolCall.ID) - } - } - } - if choice.FinishReason != "" { - finalFinishReason = choice.FinishReason - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - - assert.Contains(t, fullReasoning, "Searching for information...") - require.True(t, foundToolCallWithArgs, "Did not find a tool call chunk with arguments to assert against") - assert.Equal(t, openai.ChatCompletionChoicesFinishReasonToolCalls, finalFinishReason, "Final finish reason should be 'tool_calls'") - }) - - t.Run("handles thinking delta stream with text only", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_thinking_1", "type": "message", "role": "assistant", "usage": {"input_tokens": 20, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Let me think about this problem step by step."}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " First, I need to understand the requirements."}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 15}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var reasoningTexts []string - var foundFinishReason bool - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil && choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - } - if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { - foundFinishReason = true - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - assert.Contains(t, fullReasoning, "Let me think about this problem step by step.") - assert.Contains(t, fullReasoning, " First, I need to understand the requirements.") - require.True(t, foundFinishReason, "Should find stop finish reason") - }) - - t.Run("handles thinking delta stream with text and signature", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_thinking_2", "type": "message", "role": "assistant", "usage": {"input_tokens": 25, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Processing request...", "signature": "sig_abc123"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " Analyzing data...", "signature": "sig_def456"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 20}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var reasoningTexts []string - var signatures []string - var foundFinishReason bool - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil && choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - if choice.Delta.ReasoningContent.Signature != "" { - signatures = append(signatures, choice.Delta.ReasoningContent.Signature) - } - } - if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { - foundFinishReason = true - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - assert.Contains(t, fullReasoning, "Processing request...") - assert.Contains(t, fullReasoning, " Analyzing data...") - - allSignatures := strings.Join(signatures, ",") - assert.Contains(t, allSignatures, "sig_abc123") - assert.Contains(t, allSignatures, "sig_def456") - - require.True(t, foundFinishReason, "Should find stop finish reason") - }) -} - -func TestAnthropicStreamParser_EventTypes(t *testing.T) { - runStreamTest := func(t *testing.T, sseStream string, endOfStream bool) ([]byte, metrics.TokenUsage, error) { - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, tokenUsage, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), endOfStream, nil) - return bm, tokenUsage, err - } - - t.Run("handles message_start event", func(t *testing.T) { - sseStream := `event: message_start -data: {"type": "message_start", "message": {"id": "msg_123", "usage": {"input_tokens": 15}}} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - assert.Empty(t, string(bm), "message_start should produce an empty chunk") - }) - - t.Run("handles content_block events for tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "tool_use", "id": "tool_abc", "name": "get_weather", "input":{}}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "input_json_delta", "partial_json": "{\"location\": \"SF\"}"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // 1. Split the stream into individual data chunks - // and remove the "data: " prefix. - var chunks []openai.ChatCompletionResponseChunk - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - } - - // 2. Inspect the Go structs directly. - require.Len(t, chunks, 2, "Expected two data chunks for this tool call stream") - - // Check the first chunk (the tool call initiation). - firstChunk := chunks[0] - require.NotNil(t, firstChunk.Choices[0].Delta.ToolCalls) - require.Equal(t, "tool_abc", *firstChunk.Choices[0].Delta.ToolCalls[0].ID) - require.Equal(t, "get_weather", firstChunk.Choices[0].Delta.ToolCalls[0].Function.Name) - // With empty input, arguments should be empty string, not "{}" - require.Empty(t, firstChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) - - // Check the second chunk (the arguments delta). - secondChunk := chunks[1] - require.NotNil(t, secondChunk.Choices[0].Delta.ToolCalls) - argumentsJSON := secondChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments - - // 3. Unmarshal the arguments string to verify its contents. - var args map[string]string - err = json.Unmarshal([]byte(argumentsJSON), &args) - require.NoError(t, err) - require.Equal(t, "SF", args["location"]) - }) - - t.Run("handles ping event", func(t *testing.T) { - sseStream := `event: ping -data: {"type": "ping"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "ping should produce an empty chunk") - }) - - t.Run("handles error event", func(t *testing.T) { - sseStream := `event: error -data: {"type": "error", "error": {"type": "overloaded_error", "message": "Overloaded"}} - -` - _, _, err := runStreamTest(t, sseStream, false) - require.Error(t, err) - require.Contains(t, err.Error(), "anthropic stream error: overloaded_error - Overloaded") - }) - - t.Run("gracefully handles unknown event types", func(t *testing.T) { - sseStream := `event: future_event_type -data: {"some_new_data": "value"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "unknown events should be ignored and produce an empty chunk") - }) - - t.Run("handles message_stop event", func(t *testing.T) { - sseStream := `event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "max_tokens"}, "usage": {"output_tokens": 1}} - -event: message_stop -data: {"type": "message_stop"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - require.Contains(t, string(bm), `"finish_reason":"length"`) - }) - - t.Run("handles chunked input_json_delta for tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "tool_use", "id": "tool_123", "name": "get_weather"}} - -event: content_block_delta -data: {"type": "content_block_delta","index": 0,"delta": {"type": "input_json_delta","partial_json": "{\"location\": \"San Fra"}} - -event: content_block_delta -data: {"type": "content_block_delta","index": 0,"delta": {"type": "input_json_delta","partial_json": "ncisco\"}"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // 1. Unmarshal all the chunks from the stream response. - var chunks []openai.ChatCompletionResponseChunk - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err := json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - } - - // 2. We expect 3 chunks: start, delta part 1, delta part 2. - require.Len(t, chunks, 3, "Expected three data chunks for this stream") - - // 3. Verify the contents of each relevant chunk. - - // Chunk 1: Tool call start. - chunk1ToolCalls := chunks[0].Choices[0].Delta.ToolCalls - require.NotNil(t, chunk1ToolCalls) - require.Equal(t, "get_weather", chunk1ToolCalls[0].Function.Name) - - // Chunk 2: First part of the arguments. - chunk2Args := chunks[1].Choices[0].Delta.ToolCalls[0].Function.Arguments - require.Equal(t, `{"location": "San Fra`, chunk2Args) //nolint:testifylint - - // Chunk 3: Second part of the arguments. - chunk3Args := chunks[2].Choices[0].Delta.ToolCalls[0].Function.Arguments - require.Equal(t, `ncisco"}`, chunk3Args) - }) - t.Run("sends role on first chunk", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} -` - // Set endOfStream to true to ensure all events in the buffer are processed. - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var contentChunk openai.ChatCompletionResponseChunk - foundChunk := false - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if after, ok := strings.CutPrefix(line, "data: "); ok { - jsonBody := after - // We only care about the chunk that has the text content. - if strings.Contains(jsonBody, `"content"`) { - err := json.Unmarshal([]byte(jsonBody), &contentChunk) - require.NoError(t, err, "Failed to unmarshal content chunk") - foundChunk = true - break - } - } - } - - require.True(t, foundChunk, "Did not find a data chunk with content in the output") - - require.NotNil(t, contentChunk.Choices[0].Delta.Role, "Role should be present on the first chunk") - require.Equal(t, openai.ChatMessageRoleAssistant, contentChunk.Choices[0].Delta.Role) - }) - - t.Run("accumulates output tokens", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":20}}} - -event: message_delta -data: {"type":"message_delta","delta":{},"usage":{"output_tokens":10}} - -event: message_delta -data: {"type":"message_delta","delta":{},"usage":{"output_tokens":5}} - -event: message_stop -data: {"type":"message_stop"} -` - // Run with endOfStream:true to get the final usage chunk. - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // The final usage chunk should sum the tokens from all message_delta events. - require.Contains(t, bodyStr, `"completion_tokens":15`) - require.Contains(t, bodyStr, `"prompt_tokens":20`) - require.Contains(t, bodyStr, `"total_tokens":35`) - }) - - t.Run("ignores SSE comments", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -: this is a comment and should be ignored - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} -` - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - require.Contains(t, bodyStr, `"content":"Hello"`) - require.NotContains(t, bodyStr, "this is a comment") - }) - t.Run("handles data-only event as a message event", func(t *testing.T) { - sseStream := `data: some text - -data: another message with two lines -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "data-only events should be treated as no-op 'message' events and produce an empty chunk") - }) -} diff --git a/internal/translator/openai_gcpanthropic_test.go b/internal/translator/openai_gcpanthropic_test.go index b631edfe23..9f968b2f25 100644 --- a/internal/translator/openai_gcpanthropic_test.go +++ b/internal/translator/openai_gcpanthropic_test.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "fmt" "io" + "log/slog" "strconv" "testing" "time" @@ -1212,872 +1213,6 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseError(t *testing } } -// New test function for helper coverage. -func TestHelperFunctions(t *testing.T) { - t.Run("anthropicToOpenAIFinishReason invalid reason", func(t *testing.T) { - _, err := anthropicToOpenAIFinishReason("unknown_reason") - require.Error(t, err) - require.Contains(t, err.Error(), "received invalid stop reason") - }) - - t.Run("anthropicRoleToOpenAIRole invalid role", func(t *testing.T) { - _, err := anthropicRoleToOpenAIRole("unknown_role") - require.Error(t, err) - require.Contains(t, err.Error(), "invalid anthropic role") - }) -} - -func TestTranslateOpenAItoAnthropicTools(t *testing.T) { - anthropicTestTool := []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - } - openaiTestTool := []openai.Tool{ - {Type: "function", Function: &openai.FunctionDefinition{Name: "get_weather"}}, - } - tests := []struct { - name string - openAIReq *openai.ChatCompletionRequest - expectedTools []anthropic.ToolUnionParam - expectedToolChoice anthropic.ToolChoiceUnionParam - expectErr bool - }{ - { - name: "auto tool choice", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - DisableParallelToolUse: anthropic.Bool(false), - }, - }, - }, - { - name: "any tool choice", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "any"}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{}, - }, - }, - { - name: "specific tool choice by name", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: openai.ChatCompletionNamedToolChoice{Type: "function", Function: openai.ChatCompletionNamedToolChoiceFunction{Name: "my_func"}}}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{Type: "tool", Name: "my_func"}, - }, - }, - { - name: "tool definition", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - }, - { - name: "tool_definition_with_required_field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather with a required location", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - "unit": map[string]any{"type": "string"}, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather with a required location"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - "unit": map[string]any{"type": "string"}, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool definition with no parameters", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_time", - Description: "Get the current time", - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_time", - Description: anthropic.String("Get the current time"), - }, - }, - }, - }, - { - name: "disable parallel tool calls", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - Tools: openaiTestTool, - ParallelToolCalls: ptr.To(false), - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - DisableParallelToolUse: anthropic.Bool(true), - }, - }, - }, - { - name: "explicitly enable parallel tool calls", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - ParallelToolCalls: ptr.To(true), - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, - }, - }, - { - name: "default disable parallel tool calls to false (nil)", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, - }, - }, - { - name: "none tool choice", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "none"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfNone: &anthropic.ToolChoiceNoneParam{}, - }, - }, - { - name: "function tool choice", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "function"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{Name: "function"}, - }, - }, - { - name: "invalid tool choice string", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "invalid_choice"}, - }, - expectErr: true, - }, - { - name: "skips function tool with nil function definition", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: nil, // This tool has the correct type but a nil definition and should be skipped. - }, - { - Type: "function", - Function: &openai.FunctionDefinition{Name: "get_weather"}, // This is a valid tool. - }, - }, - }, - // We expect only the valid function tool to be translated. - expectedTools: []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - }, - expectErr: false, - }, - { - name: "skips non-function tools", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "retrieval", - }, - { - Type: "function", - Function: &openai.FunctionDefinition{Name: "get_weather"}, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - }, - expectErr: false, - }, - { - name: "tool definition without type field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather without type", - Parameters: map[string]any{ - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather without type"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool definition without properties field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather without properties", - Parameters: map[string]any{ - "type": "object", - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather without properties"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "unsupported tool_choice type", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: 123}, // Use an integer to trigger the default case. - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - if tt.openAIReq.ToolChoice != nil { - require.NotNil(t, toolChoice) - require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) - if tt.expectedToolChoice.GetName() != nil { - require.Equal(t, *tt.expectedToolChoice.GetName(), *toolChoice.GetName()) - } - if tt.expectedToolChoice.OfTool != nil { - require.Equal(t, tt.expectedToolChoice.OfTool.Name, toolChoice.OfTool.Name) - } - if tt.expectedToolChoice.OfAuto != nil { - require.Equal(t, tt.expectedToolChoice.OfAuto.DisableParallelToolUse, toolChoice.OfAuto.DisableParallelToolUse) - } - } - if tt.openAIReq.Tools != nil { - require.NotNil(t, tools) - require.Len(t, tools, len(tt.expectedTools)) - require.Equal(t, tt.expectedTools[0].GetName(), tools[0].GetName()) - require.Equal(t, tt.expectedTools[0].GetType(), tools[0].GetType()) - require.Equal(t, tt.expectedTools[0].GetDescription(), tools[0].GetDescription()) - if tt.expectedTools[0].GetInputSchema().Properties != nil { - require.Equal(t, tt.expectedTools[0].GetInputSchema().Properties, tools[0].GetInputSchema().Properties) - } - } - } - }) - } -} - -// TestFinishReasonTranslation covers specific cases for the anthropicToOpenAIFinishReason function. -func TestFinishReasonTranslation(t *testing.T) { - tests := []struct { - name string - input anthropic.StopReason - expectedFinishReason openai.ChatCompletionChoicesFinishReason - expectErr bool - }{ - { - name: "max tokens stop reason", - input: anthropic.StopReasonMaxTokens, - expectedFinishReason: openai.ChatCompletionChoicesFinishReasonLength, - }, - { - name: "refusal stop reason", - input: anthropic.StopReasonRefusal, - expectedFinishReason: openai.ChatCompletionChoicesFinishReasonContentFilter, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reason, err := anthropicToOpenAIFinishReason(tt.input) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectedFinishReason, reason) - } - }) - } -} - -// TestToolParameterDereferencing tests the JSON schema dereferencing functionality -// for tool parameters when translating from OpenAI to GCP Anthropic. -func TestToolParameterDereferencing(t *testing.T) { - tests := []struct { - name string - openAIReq *openai.ChatCompletionRequest - expectedTools []anthropic.ToolUnionParam - expectedToolChoice anthropic.ToolChoiceUnionParam - expectErr bool - expectUserFacingErr bool - }{ - { - name: "tool with complex nested $ref - successful dereferencing", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "complex_tool", - Description: "Tool with complex nested references", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "BaseType": map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - "NestedType": map[string]any{ - "allOf": []any{ - map[string]any{"$ref": "#/$defs/BaseType"}, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - "properties": map[string]any{ - "nested": map[string]any{ - "$ref": "#/$defs/NestedType", - }, - }, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "complex_tool", - Description: anthropic.String("Tool with complex nested references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "nested": map[string]any{ - "allOf": []any{ - map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - name: "tool with invalid $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "invalid_ref_tool", - Description: "Tool with invalid reference", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "$ref": "#/$defs/NonExistent", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - { - name: "tool with circular $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "circular_ref_tool", - Description: "Tool with circular reference", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "A": map[string]any{ - "type": "object", - "properties": map[string]any{ - "b": map[string]any{ - "$ref": "#/$defs/B", - }, - }, - }, - "B": map[string]any{ - "type": "object", - "properties": map[string]any{ - "a": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - "properties": map[string]any{ - "circular": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - { - name: "tool without $ref - no dereferencing needed", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "simple_tool", - Description: "Simple tool without references", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "simple_tool", - Description: anthropic.String("Simple tool without references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool parameter dereferencing returns non-map type - casting error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "problematic_tool", - Description: "Tool with parameters that can't be properly dereferenced to map", - // This creates a scenario where jsonSchemaDereference might return a non-map type - // though this is a contrived example since normally the function should return map[string]any - Parameters: map[string]any{ - "$ref": "#/$defs/StringType", // This would resolve to a string, not a map - "$defs": map[string]any{ - "StringType": "not-a-map", // This would cause the casting to fail - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) - - if tt.expectErr { - require.Error(t, err) - if tt.expectUserFacingErr { - require.ErrorIs(t, err, internalapi.ErrInvalidRequestBody) - } - return - } - - require.NoError(t, err) - - if tt.openAIReq.Tools != nil { - require.NotNil(t, tools) - require.Len(t, tools, len(tt.expectedTools)) - - for i, expectedTool := range tt.expectedTools { - actualTool := tools[i] - require.Equal(t, expectedTool.GetName(), actualTool.GetName()) - require.Equal(t, expectedTool.GetType(), actualTool.GetType()) - require.Equal(t, expectedTool.GetDescription(), actualTool.GetDescription()) - - expectedSchema := expectedTool.GetInputSchema() - actualSchema := actualTool.GetInputSchema() - - require.Equal(t, expectedSchema.Type, actualSchema.Type) - require.Equal(t, expectedSchema.Required, actualSchema.Required) - - // For properties, we'll do a deep comparison to verify dereferencing worked - if expectedSchema.Properties != nil { - require.NotNil(t, actualSchema.Properties) - require.Equal(t, expectedSchema.Properties, actualSchema.Properties) - } - } - } - - if tt.openAIReq.ToolChoice != nil { - require.NotNil(t, toolChoice) - require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) - } - }) - } -} - -// TestContentTranslationCoverage adds specific coverage for the openAIToAnthropicContent helper. -func TestContentTranslationCoverage(t *testing.T) { - tests := []struct { - name string - inputContent any - expectedContent []anthropic.ContentBlockParamUnion - expectErr bool - }{ - { - name: "nil content", - inputContent: nil, - }, - { - name: "empty string content", - inputContent: "", - }, - { - name: "pdf data uri", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "data:application/pdf;base64,dGVzdA=="}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfDocument: &anthropic.DocumentBlockParam{ - Source: anthropic.DocumentBlockParamSourceUnion{ - OfBase64: &anthropic.Base64PDFSourceParam{ - Type: constant.ValueOf[constant.Base64](), - MediaType: constant.ValueOf[constant.ApplicationPDF](), - Data: "dGVzdA==", - }, - }, - }, - }, - }, - }, - { - name: "pdf url", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/doc.pdf"}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfDocument: &anthropic.DocumentBlockParam{ - Source: anthropic.DocumentBlockParamSourceUnion{ - OfURL: &anthropic.URLPDFSourceParam{ - Type: constant.ValueOf[constant.URL](), - URL: "https://example.com/doc.pdf", - }, - }, - }, - }, - }, - }, - { - name: "image url", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/image.png"}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfImage: &anthropic.ImageBlockParam{ - Source: anthropic.ImageBlockParamSourceUnion{ - OfURL: &anthropic.URLImageSourceParam{ - Type: constant.ValueOf[constant.URL](), - URL: "https://example.com/image.png", - }, - }, - }, - }, - }, - }, - { - name: "audio content error", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{{OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{}}}, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - content, err := openAIToAnthropicContent(tt.inputContent) - if tt.expectErr { - require.Error(t, err) - return - } - require.NoError(t, err) - - // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. - require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") - - // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. - require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") - for i, expectedBlock := range tt.expectedContent { - actualBlock := content[i] - require.Equal(t, expectedBlock.GetType(), actualBlock.GetType(), "Content block types should match") - if expectedBlock.OfDocument != nil { - require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") - require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") - - if expectedBlock.OfDocument.Source.OfBase64 != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") - require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) - } - if expectedBlock.OfDocument.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") - require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) - } - } - if expectedBlock.OfImage != nil { - require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") - require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") - - if expectedBlock.OfImage.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") - require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) - } - } - } - - for i, expectedBlock := range tt.expectedContent { - actualBlock := content[i] - if expectedBlock.OfDocument != nil { - require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") - require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") - - if expectedBlock.OfDocument.Source.OfBase64 != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") - require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) - } - if expectedBlock.OfDocument.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") - require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) - } - } - if expectedBlock.OfImage != nil { - require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") - require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") - - if expectedBlock.OfImage.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") - require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) - } - } - } - }) - } -} - -// TestSystemPromptExtractionCoverage adds specific coverage for the extractSystemPromptFromDeveloperMsg helper. -func TestSystemPromptExtractionCoverage(t *testing.T) { - tests := []struct { - name string - inputMsg openai.ChatCompletionDeveloperMessageParam - expectedPrompt string - }{ - { - name: "developer message with content parts", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ - {Type: "text", Text: "part 1"}, - {Type: "text", Text: " part 2"}, - }}, - }, - expectedPrompt: "part 1 part 2", - }, - { - name: "developer message with nil content", - inputMsg: openai.ChatCompletionDeveloperMessageParam{Content: openai.ContentUnion{Value: nil}}, - expectedPrompt: "", - }, - { - name: "developer message with string content", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: "simple string"}, - }, - expectedPrompt: "simple string", - }, - { - name: "developer message with text parts array", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ - {Type: "text", Text: "text part"}, - }}, - }, - expectedPrompt: "text part", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prompt, _ := extractSystemPromptFromDeveloperMsg(tt.inputMsg) - require.Equal(t, tt.expectedPrompt, prompt) - }) - } -} - func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_Cache(t *testing.T) { t.Run("full request with mixed caching", func(t *testing.T) { openAIReq := &openai.ChatCompletionRequest{ @@ -2605,3 +1740,306 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_Cache(t *testing.T) { require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), result.Get("messages.0.content.2.cache_control.type").String(), "tool 3 (with cache) should be cached") }) } + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_SetRedactionConfig(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + translator.SetRedactionConfig(true, true, logger) + + require.True(t, translator.debugLogEnabled) + require.True(t, translator.enableRedaction) + require.NotNil(t, translator.logger) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RedactBody(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + t.Run("nil response returns nil", func(t *testing.T) { + result := translator.RedactBody(nil) + require.Nil(t, result) + }) + + t.Run("redacts message content", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: "assistant", Content: ptr.To("sensitive content")}, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Equal(t, "test-id", result.ID) + require.Len(t, result.Choices, 1) + // Content should be redacted (not the original value) + require.NotNil(t, result.Choices[0].Message.Content) + require.NotEqual(t, "sensitive content", *result.Choices[0].Message.Content) + }) + + t.Run("redacts tool calls", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: ptr.To("tool-1"), + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_secret", + Arguments: `{"password": "secret123"}`, + }, + }, + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.Len(t, result.Choices[0].Message.ToolCalls, 1) + // Tool call name and arguments should be redacted + require.NotEqual(t, "get_secret", result.Choices[0].Message.ToolCalls[0].Function.Name) + require.NotEqual(t, `{"password": "secret123"}`, result.Choices[0].Message.ToolCalls[0].Function.Arguments) + }) + + t.Run("redacts audio data", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + Audio: &openai.ChatCompletionResponseChoiceMessageAudio{ + Data: "base64-audio-data", + Transcript: "sensitive transcript", + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.NotNil(t, result.Choices[0].Message.Audio) + // Audio data and transcript should be redacted + require.NotEqual(t, "base64-audio-data", result.Choices[0].Message.Audio.Data) + require.NotEqual(t, "sensitive transcript", result.Choices[0].Message.Audio.Transcript) + }) + + t.Run("redacts reasoning content", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ReasoningContent: &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: "sensitive reasoning", + Signature: "sig123", + }, + }, + }, + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.NotNil(t, result.Choices[0].Message.ReasoningContent) + }) + + t.Run("empty choices returns empty choices", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{}, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Empty(t, result.Choices) + }) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + t.Run("returns event-stream content type for streaming", func(t *testing.T) { + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: "test-model", + MaxTokens: ptr.To(int64(100)), + } + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize the stream parser by calling RequestBody with streaming request + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // Now ResponseHeaders should return the streaming content type + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Len(t, headers, 1) + require.Equal(t, contentTypeHeaderName, headers[0].Key()) + require.Equal(t, eventStreamContentType, headers[0].Value()) + }) + + t.Run("returns no headers for non-streaming", func(t *testing.T) { + openAIReq := &openai.ChatCompletionRequest{ + Stream: false, + Model: "test-model", + MaxTokens: ptr.To(int64(100)), + } + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize without streaming + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // ResponseHeaders should return nil for non-streaming + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Nil(t, headers) + }) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_WithDebugLogging(t *testing.T) { + // Create a buffer to capture log output + var logBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + translator.SetRedactionConfig(true, true, logger) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-3", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // Create a response + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello! How can I help you?", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + _, _, _, _, err = translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + + // Verify that debug logging occurred + logOutput := logBuf.String() + require.Contains(t, logOutput, "response body processing") +} + +// mockSpan implements tracingapi.ChatCompletionSpan for testing +type mockSpan struct { + recordedResponse *openai.ChatCompletionResponse +} + +func (m *mockSpan) RecordResponseChunk(_ *openai.ChatCompletionResponseChunk) {} +func (m *mockSpan) RecordResponse(resp *openai.ChatCompletionResponse) { + m.recordedResponse = resp +} +func (m *mockSpan) EndSpanOnError(_ int, _ []byte) {} +func (m *mockSpan) EndSpan() {} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_WithSpanRecording(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-3", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // Create a response + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello!", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + // Create a mock span + span := &mockSpan{} + + _, _, _, _, err = translator.ResponseBody(nil, bytes.NewReader(body), true, span) + require.NoError(t, err) + + // Verify the span recorded the response + require.NotNil(t, span.recordedResponse) + require.Equal(t, "msg_01XYZ", span.recordedResponse.ID) + require.Len(t, span.recordedResponse.Choices, 1) + require.Equal(t, "Hello!", *span.recordedResponse.Choices[0].Message.Content) +} diff --git a/site/docs/api/api.mdx b/site/docs/api/api.mdx index 47ceb67c58..240eaa93b5 100644 --- a/site/docs/api/api.mdx +++ b/site/docs/api/api.mdx @@ -982,7 +982,7 @@ APISchema defines the API schema. name="AWSAnthropic" type="enum" required="false" - description="APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock.
Uses the native Anthropic Messages API format for requests and responses.
https://aws.amazon.com/bedrock/anthropic/
https://docs.claude.com/en/api/claude-on-amazon-bedrock
" + description="APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock.
Uses the native Anthropic Messages API format for requests and responses.
When used with /v1/chat/completions endpoint, translates OpenAI format to Anthropic.
When used with /v1/messages endpoint, passes through native Anthropic format.
https://aws.amazon.com/bedrock/anthropic/
https://docs.claude.com/en/api/claude-on-amazon-bedrock
" /> #### AWSCredentialsFile From d0347c8c6196494f498b24e599ae01548bc67f62 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Sun, 8 Feb 2026 09:36:56 -0800 Subject: [PATCH 10/11] mcp: query param APIKey (#1800) **Description** This adds support for API Key upstream auth as a HTTP query parameter for MCP servers. There are several public MCP server that employ this pattern, thus this unlocks several patterns. To adds support for that, this remove the use of CredentialsInjector for inserting API Keys as it doesn't support the query parameter as a location. **Related Issues/PRs (if applicable)** Closes #1731 --------- Signed-off-by: Takeshi Yoneda Co-authored-by: Ignasi Barrera --- api/v1alpha1/mcp_route.go | 15 ++ api/v1alpha1/zz_generated.deepcopy.go | 5 + internal/controller/gateway.go | 1 - internal/controller/mcp_route.go | 154 ++++++++---------- internal/controller/mcp_route_test.go | 113 +++++++------ internal/filterapi/mcpconfig.go | 3 - internal/mcpproxy/config_test.go | 27 ++- internal/mcpproxy/handlers.go | 4 - internal/mcpproxy/handlers_test.go | 6 +- internal/mcpproxy/mcpproxy.go | 2 +- internal/mcpproxy/mcpproxy_test.go | 15 +- internal/mcpproxy/session.go | 14 +- internal/mcpproxy/session_test.go | 2 +- .../aigateway.envoyproxy.io_mcproutes.yaml | 13 ++ site/docs/api/api.mdx | 8 +- tests/crdcel/main_test.go | 4 + ...backend_api_key_both_header_and_query.yaml | 25 +++ tests/data-plane-mcp/env.go | 10 +- tests/data-plane-mcp/envoy.yaml | 5 + tests/data-plane-mcp/publicmcp_test.go | 5 +- tests/e2e/mcp_route_test.go | 2 +- tests/e2e/testdata/mcp_route.yaml | 44 ++++- tests/internal/testmcp/server.go | 19 ++- 23 files changed, 301 insertions(+), 195 deletions(-) create mode 100644 tests/crdcel/testdata/mcpgatewayroutes/backend_api_key_both_header_and_query.yaml diff --git a/api/v1alpha1/mcp_route.go b/api/v1alpha1/mcp_route.go index e8416d98e4..1d5d9a097c 100644 --- a/api/v1alpha1/mcp_route.go +++ b/api/v1alpha1/mcp_route.go @@ -152,8 +152,10 @@ type MCPBackendSecurityPolicy struct { } // MCPBackendAPIKey defines the configuration for the API Key Authentication to a backend. +// When both `header` and `queryParam` are unspecified, the API key will be injected into the "Authorization" header by default. // // +kubebuilder:validation:XValidation:rule="(has(self.secretRef) && !has(self.inline)) || (!has(self.secretRef) && has(self.inline))", message="exactly one of secretRef or inline must be set" +// +kubebuilder:validation:XValidation:rule="!(has(self.header) && has(self.queryParam))", message="only one of header or queryParam can be set" type MCPBackendAPIKey struct { // secretRef is the Kubernetes secret which contains the API keys. // The key of the secret should be "apiKey". @@ -170,10 +172,23 @@ type MCPBackendAPIKey struct { // When the header is "Authorization", the injected header value will be // prefixed with "Bearer ". // + // Either one of Header or QueryParam can be specified to inject the API key. + // // +kubebuilder:validation:Optional // +kubebuilder:validation:MinLength=1 // +optional Header *string `json:"header,omitempty"` + + // QueryParam is the HTTP query parameter to inject the API key into. + // For example, if QueryParam is set to "api_key", and the API key is "mysecretkey", the request URL will be modified to include + // "?api_key=mysecretkey". + // + // Either one of Header or QueryParam can be specified to inject the API key. + // + // +kubebuilder:validation:Optional + // +kubebuilder:validation:MinLength=1 + // +optional + QueryParam *string `json:"queryParam,omitempty"` } // MCPRouteSecurityPolicy defines the security policy for a MCPRoute. diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index feffa039be..388014dfad 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1137,6 +1137,11 @@ func (in *MCPBackendAPIKey) DeepCopyInto(out *MCPBackendAPIKey) { *out = new(string) **out = **in } + if in.QueryParam != nil { + in, out := &in.QueryParam, &out.QueryParam + *out = new(string) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MCPBackendAPIKey. diff --git a/internal/controller/gateway.go b/internal/controller/gateway.go index 9e605cb178..8228a75b11 100644 --- a/internal/controller/gateway.go +++ b/internal/controller/gateway.go @@ -495,7 +495,6 @@ func mcpConfig(mcpRoutes []aigv1a1.MCPRoute) (_ *filterapi.MCPConfig, hasEffecti mcpBackend := filterapi.MCPBackend{ // MCPRoute doesn't support cross-namespace backend reference so just use the name. Name: filterapi.MCPBackendName(b.Name), - Path: ptr.Deref(b.Path, "/mcp"), } if b.ToolSelector != nil { mcpBackend.ToolSelector = &filterapi.MCPToolSelector{ diff --git a/internal/controller/mcp_route.go b/internal/controller/mcp_route.go index 6ca5b1b170..45a93abbd6 100644 --- a/internal/controller/mcp_route.go +++ b/internal/controller/mcp_route.go @@ -6,14 +6,12 @@ package controller import ( - "cmp" "context" "fmt" "strings" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" - corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -499,26 +497,76 @@ func mcpBackendRefFilterName(mcpRoute *aigv1a1.MCPRoute, backendName gwapiv1.Obj // which is set by the MCP proxy based on its routing logic. // This route rule will eventually be moved to the backend listener in the extension server. func (c *MCPRouteController) mcpBackendRefToHTTPRouteRule(ctx context.Context, mcpRoute *aigv1a1.MCPRoute, ref *aigv1a1.MCPRouteBackendRef) (gwapiv1.HTTPRouteRule, error) { - var apiKey *aigv1a1.MCPBackendAPIKey - if ref.SecurityPolicy != nil && ref.SecurityPolicy.APIKey != nil { - apiKey = ref.SecurityPolicy.APIKey - } - // Ensure the HTTPRouteFilter for this backend with its optional security configuration. - filterName := mcpBackendRefFilterName(mcpRoute, ref.Name) - err := c.ensureMCPBackendRefHTTPFilter(ctx, filterName, apiKey, mcpRoute) + egFilterName := mcpBackendRefFilterName(mcpRoute, ref.Name) + err := c.ensureMCPBackendRefHTTPFilter(ctx, egFilterName, mcpRoute) if err != nil { return gwapiv1.HTTPRouteRule{}, fmt.Errorf("failed to ensure MCP backend API key HTTP filter: %w", err) } + filters := []gwapiv1.HTTPRouteFilter{ + { + Type: gwapiv1.HTTPRouteFilterExtensionRef, + ExtensionRef: &gwapiv1.LocalObjectReference{ + Group: "gateway.envoyproxy.io", + Kind: "HTTPRouteFilter", + Name: gwapiv1.ObjectName(egFilterName), + }, + }, + } + + fullPathPtr := ptr.Deref(ref.Path, defaultMCPPath) - filters := []gwapiv1.HTTPRouteFilter{{ - Type: gwapiv1.HTTPRouteFilterExtensionRef, - ExtensionRef: &gwapiv1.LocalObjectReference{ - Group: "gateway.envoyproxy.io", - Kind: "HTTPRouteFilter", - Name: gwapiv1.ObjectName(filterName), + // Add credential injection if apiKey is specified. + if ref.SecurityPolicy != nil && ref.SecurityPolicy.APIKey != nil { + apiKey := ref.SecurityPolicy.APIKey + + apiKeyLiteral, err := c.readAPIKey(ctx, mcpRoute.Namespace, apiKey) + if err != nil { + return gwapiv1.HTTPRouteRule{}, fmt.Errorf("failed to read API key for backend %s: %w", ref.Name, err) + } + switch { + case apiKey.QueryParam != nil: + fullPathPtr = fmt.Sprintf("%s?%s=%s", fullPathPtr, *apiKey.QueryParam, apiKeyLiteral) + case apiKey.Header != nil: + header := *apiKey.Header + if header == "Authorization" { + apiKeyLiteral = "Bearer " + apiKeyLiteral + } + filters = append(filters, + gwapiv1.HTTPRouteFilter{ + Type: gwapiv1.HTTPRouteFilterRequestHeaderModifier, + RequestHeaderModifier: &gwapiv1.HTTPHeaderFilter{ + Set: []gwapiv1.HTTPHeader{ + {Name: gwapiv1.HTTPHeaderName(header), Value: apiKeyLiteral}, + }, + }, + }, + ) + default: + filters = append(filters, + gwapiv1.HTTPRouteFilter{ + Type: gwapiv1.HTTPRouteFilterRequestHeaderModifier, + RequestHeaderModifier: &gwapiv1.HTTPHeaderFilter{ + Set: []gwapiv1.HTTPHeader{ + {Name: "Authorization", Value: "Bearer " + apiKeyLiteral}, + }, + }, + }, + ) + } + } + + filters = append(filters, + gwapiv1.HTTPRouteFilter{ + Type: gwapiv1.HTTPRouteFilterURLRewrite, + URLRewrite: &gwapiv1.HTTPURLRewriteFilter{ + Path: &gwapiv1.HTTPPathModifier{ + Type: gwapiv1.FullPathHTTPPathModifier, + ReplaceFullPath: ptr.To(fullPathPtr), + }, + }, }, - }} + ) return gwapiv1.HTTPRouteRule{ Matches: []gwapiv1.HTTPRouteMatch{ { @@ -555,7 +603,7 @@ func mcpRouteHeaderValue(mcpRoute *aigv1a1.MCPRoute) string { } // ensureMCPBackendRefHTTPFilter ensures that an HTTPRouteFilter exists for the given backend reference in the MCPRoute. -func (c *MCPRouteController) ensureMCPBackendRefHTTPFilter(ctx context.Context, filterName string, apiKey *aigv1a1.MCPBackendAPIKey, mcpRoute *aigv1a1.MCPRoute) error { +func (c *MCPRouteController) ensureMCPBackendRefHTTPFilter(ctx context.Context, filterName string, mcpRoute *aigv1a1.MCPRoute) error { // Rewrite the hostname to the backend service name. // This allows Envoy to route to public MCP services with SNI matching the service name. // This could be a standalone filter and moved to the main mcp gateway route logic. @@ -575,25 +623,6 @@ func (c *MCPRouteController) ensureMCPBackendRefHTTPFilter(ctx context.Context, if err := ctrlutil.SetControllerReference(mcpRoute, filter, c.client.Scheme()); err != nil { return fmt.Errorf("failed to set controller reference for HTTPRouteFilter: %w", err) } - - // add credential injection if apiKey is specified. - if apiKey != nil { - secretName := fmt.Sprintf("%s-credential", filterName) - if secretErr := c.ensureCredentialSecret(ctx, mcpRoute.Namespace, secretName, apiKey, mcpRoute); secretErr != nil { - return fmt.Errorf("failed to ensure credential secret: %w", secretErr) - } - header := cmp.Or(ptr.Deref(apiKey.Header, ""), "Authorization") - filter.Spec.CredentialInjection = &egv1a1.HTTPCredentialInjectionFilter{ - Header: ptr.To(header), - Overwrite: ptr.To(true), - Credential: egv1a1.InjectedCredential{ - ValueRef: gwapiv1.SecretObjectReference{ - Name: gwapiv1.ObjectName(secretName), - }, - }, - } - } - // Create or Update the HTTPRouteFilter. var existingFilter egv1a1.HTTPRouteFilter err := c.client.Get(ctx, client.ObjectKey{Name: filterName, Namespace: mcpRoute.Namespace}, &existingFilter) @@ -617,64 +646,19 @@ func (c *MCPRouteController) ensureMCPBackendRefHTTPFilter(ctx context.Context, return nil } -func (c *MCPRouteController) ensureCredentialSecret(ctx context.Context, namespace, secretName string, apiKey *aigv1a1.MCPBackendAPIKey, mcpRoute *aigv1a1.MCPRoute) error { - var credentialValue string +func (c *MCPRouteController) readAPIKey(ctx context.Context, namespace string, apiKey *aigv1a1.MCPBackendAPIKey) (string, error) { key := ptr.Deref(apiKey.Inline, "") if key == "" { secretRef := apiKey.SecretRef secret, err := c.kube.CoreV1().Secrets(namespace).Get(ctx, string(secretRef.Name), metav1.GetOptions{}) if err != nil { - return fmt.Errorf("failed to get secret for API key: %w", err) + return "", fmt.Errorf("failed to get secret for API key: %w", err) } if k, ok := secret.Data["apiKey"]; ok { key = string(k) } else if key, ok = secret.StringData["apiKey"]; !ok { - return fmt.Errorf("secret %s/%s does not contain 'apiKey' key", namespace, secretRef.Name) + return "", fmt.Errorf("secret %s/%s does not contain 'apiKey' key", namespace, secretRef.Name) } } - - // Only prepend the "Bearer " prefix if the header is not set or is set to "Authorization". - header := cmp.Or(ptr.Deref(apiKey.Header, ""), "Authorization") - if header == "Authorization" { - credentialValue = fmt.Sprintf("Bearer %s", key) - } else { - credentialValue = key - } - - existingSecret, secretErr := c.kube.CoreV1().Secrets(namespace).Get(ctx, secretName, metav1.GetOptions{}) - if secretErr != nil && !apierrors.IsNotFound(secretErr) { - return fmt.Errorf("failed to get credential secret: %w", secretErr) - } - - secretData := map[string][]byte{ - egv1a1.InjectedCredentialKey: []byte(credentialValue), - } - - if apierrors.IsNotFound(secretErr) { - secret := &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: secretName, - Namespace: namespace, - }, - Data: secretData, - } - - if mcpRoute != nil { - if err := ctrlutil.SetControllerReference(mcpRoute, secret, c.client.Scheme()); err != nil { - return fmt.Errorf("failed to set controller reference for credential secret: %w", err) - } - } - - c.logger.Info("Creating credential secret", "namespace", namespace, "name", secretName) - if _, err := c.kube.CoreV1().Secrets(namespace).Create(ctx, secret, metav1.CreateOptions{}); err != nil { - return fmt.Errorf("failed to create credential secret: %w", err) - } - } else if existingSecret.Data == nil || string(existingSecret.Data[egv1a1.InjectedCredentialKey]) != credentialValue { - existingSecret.Data = secretData - c.logger.Info("Updating credential secret", "namespace", namespace, "name", secretName) - if _, err := c.kube.CoreV1().Secrets(namespace).Update(ctx, existingSecret, metav1.UpdateOptions{}); err != nil { - return fmt.Errorf("failed to update credential secret: %w", secretErr) - } - } - return nil + return key, nil } diff --git a/internal/controller/mcp_route_test.go b/internal/controller/mcp_route_test.go index b98de8eabd..1e9bfc64d3 100644 --- a/internal/controller/mcp_route_test.go +++ b/internal/controller/mcp_route_test.go @@ -268,13 +268,36 @@ func TestMCPRouteController_mcpRuleWithAPIKeyBackendSecurity(t *testing.T) { ctrlr := NewMCPRouteController(c, kubeClient, logr.Discard(), eventCh.Ch) tests := []struct { - name string - header *string - wantHeader string - wantCredential []byte + name string + key *aigv1a1.MCPBackendAPIKey + expRequestHeader *internalapi.Header + refPath *string + expPath string }{ - {"default header", nil, "Authorization", []byte("Bearer secretvalue")}, - {"custom header", ptr.To("X-Api-Key"), "X-Api-Key", []byte("secretvalue")}, + { + name: "inline API key default header", + key: &aigv1a1.MCPBackendAPIKey{Inline: ptr.To("inline-key")}, + expRequestHeader: &internalapi.Header{"Authorization", "Bearer inline-key"}, + expPath: "/mcp", + }, + { + name: "inline API key custom header", + key: &aigv1a1.MCPBackendAPIKey{Inline: ptr.To("inline-key"), Header: ptr.To("X-API-KEY")}, + expRequestHeader: &internalapi.Header{"X-API-KEY", "inline-key"}, + expPath: "/mcp", + }, + { + name: "secret ref API key default header", + key: &aigv1a1.MCPBackendAPIKey{SecretRef: &gwapiv1.SecretObjectReference{Name: "some-secret"}}, + expRequestHeader: &internalapi.Header{"Authorization", "Bearer secretvalue"}, + refPath: ptr.To("/some/path"), + expPath: "/some/path", + }, + { + name: "query param API key", + key: &aigv1a1.MCPBackendAPIKey{Inline: ptr.To("inline-key"), QueryParam: ptr.To("api_key")}, + expPath: "/mcp?api_key=inline-key", + }, } for _, tt := range tests { @@ -286,14 +309,8 @@ func TestMCPRouteController_mcpRuleWithAPIKeyBackendSecurity(t *testing.T) { Name: "svc-a", Namespace: ptr.To(gwapiv1.Namespace("default")), }, - SecurityPolicy: &aigv1a1.MCPBackendSecurityPolicy{ - APIKey: &aigv1a1.MCPBackendAPIKey{ - Header: tt.header, - SecretRef: &gwapiv1.SecretObjectReference{ - Name: "some-secret", - }, - }, - }, + SecurityPolicy: &aigv1a1.MCPBackendSecurityPolicy{APIKey: tt.key}, + Path: tt.refPath, }, ) require.NoError(t, err) @@ -305,28 +322,44 @@ func TestMCPRouteController_mcpRuleWithAPIKeyBackendSecurity(t *testing.T) { require.Equal(t, "svc-a", headers[0].Value) require.Equal(t, internalapi.MCPRouteHeader, string(headers[1].Name)) require.Contains(t, headers[1].Value, "route-a") - require.Len(t, httpRule.Filters, 1) - require.Equal(t, gwapiv1.HTTPRouteFilterExtensionRef, httpRule.Filters[0].Type) - require.NotNil(t, httpRule.Filters[0].ExtensionRef) - require.Equal(t, gwapiv1.Group("gateway.envoyproxy.io"), httpRule.Filters[0].ExtensionRef.Group) - require.Equal(t, gwapiv1.Kind("HTTPRouteFilter"), httpRule.Filters[0].ExtensionRef.Kind) - require.Contains(t, string(httpRule.Filters[0].ExtensionRef.Name), internalapi.MCPPerBackendHTTPRouteFilterPrefix) + // The first filter is the EG extension ref filter for URL host rewrite. + egFilter := httpRule.Filters[0] + require.Equal(t, gwapiv1.HTTPRouteFilterExtensionRef, egFilter.Type) + require.NotNil(t, egFilter.ExtensionRef) + require.Equal(t, gwapiv1.Group("gateway.envoyproxy.io"), egFilter.ExtensionRef.Group) + require.Equal(t, gwapiv1.Kind("HTTPRouteFilter"), egFilter.ExtensionRef.Kind) + require.Contains(t, string(egFilter.ExtensionRef.Name), internalapi.MCPPerBackendHTTPRouteFilterPrefix) var httpFilter egv1a1.HTTPRouteFilter - err = c.Get(t.Context(), types.NamespacedName{Namespace: "default", Name: string(httpRule.Filters[0].ExtensionRef.Name)}, &httpFilter) - require.NoError(t, err) - require.NotNil(t, httpFilter.Spec.CredentialInjection) - require.Equal(t, tt.wantHeader, ptr.Deref(httpFilter.Spec.CredentialInjection.Header, "")) - require.Equal(t, httpFilter.Name+"-credential", string(httpFilter.Spec.CredentialInjection.Credential.ValueRef.Name)) - - secret, err := kubeClient.CoreV1().Secrets("default").Get(t.Context(), - string(httpFilter.Spec.CredentialInjection.Credential.ValueRef.Name), metav1.GetOptions{}) + err = c.Get(t.Context(), types.NamespacedName{Namespace: "default", Name: string(egFilter.ExtensionRef.Name)}, &httpFilter) require.NoError(t, err) - require.Equal(t, tt.wantCredential, secret.Data[egv1a1.InjectedCredentialKey]) - require.NotNil(t, httpFilter.Spec.URLRewrite) require.NotNil(t, httpFilter.Spec.URLRewrite.Hostname) require.Equal(t, egv1a1.BackendHTTPHostnameModifier, httpFilter.Spec.URLRewrite.Hostname.Type) + + if tt.expRequestHeader != nil { + // The second filter is the request header modifier for API key injection. + reqHeaderFilter := httpRule.Filters[1] + require.Equal(t, gwapiv1.HTTPRouteFilterRequestHeaderModifier, reqHeaderFilter.Type) + require.NotNil(t, reqHeaderFilter.RequestHeaderModifier) + found := false + for _, set := range reqHeaderFilter.RequestHeaderModifier.Set { + if set.Name == gwapiv1.HTTPHeaderName(tt.expRequestHeader.Key()) && + set.Value == tt.expRequestHeader.Value() { + found = true + break + } + } + require.Truef(t, found, "Expected request header modifier not found in %v", reqHeaderFilter.RequestHeaderModifier.Set) + } + + // Verify the last filter is the path rewrite filter. + pathRewriteFilter := httpRule.Filters[len(httpRule.Filters)-1] + require.Equal(t, gwapiv1.HTTPRouteFilterURLRewrite, pathRewriteFilter.Type) + require.NotNil(t, pathRewriteFilter.URLRewrite) + require.NotNil(t, pathRewriteFilter.URLRewrite.Path) + require.Equal(t, gwapiv1.FullPathHTTPPathModifier, pathRewriteFilter.URLRewrite.Path.Type) + require.Equal(t, tt.expPath, *pathRewriteFilter.URLRewrite.Path.ReplaceFullPath) }) } } @@ -348,31 +381,13 @@ func TestMCPRouteController_ensureMCPBackendRefHTTPFilter(t *testing.T) { require.NoError(t, err) filterName := mcpBackendRefFilterName(mcpRoute, "some-name") - err = ctrlr.ensureMCPBackendRefHTTPFilter(t.Context(), filterName, &aigv1a1.MCPBackendAPIKey{ - SecretRef: &gwapiv1.SecretObjectReference{ - Name: "test-secret", - }, - }, mcpRoute) + err = ctrlr.ensureMCPBackendRefHTTPFilter(t.Context(), filterName, mcpRoute) require.NoError(t, err) // Verify HTTPRouteFilter was created. var httpFilter egv1a1.HTTPRouteFilter err = c.Get(t.Context(), types.NamespacedName{Namespace: "default", Name: filterName}, &httpFilter) require.NoError(t, err) - - // Verify filter has credential injection configured. - require.NotNil(t, httpFilter.Spec.CredentialInjection) - require.Equal(t, "Authorization", ptr.Deref(httpFilter.Spec.CredentialInjection.Header, "")) - require.Equal(t, filterName+"-credential", string(httpFilter.Spec.CredentialInjection.Credential.ValueRef.Name)) - - // Update the route without API key and ensure the filter is deleted. - err = ctrlr.ensureMCPBackendRefHTTPFilter(t.Context(), filterName, nil, mcpRoute) - require.NoError(t, err) - - // Check that the HTTPRouteFilter doesn't have CredentialInjection anymore. - err = c.Get(t.Context(), types.NamespacedName{Namespace: "default", Name: filterName}, &httpFilter) - require.NoError(t, err) - require.Nil(t, httpFilter.Spec.CredentialInjection) } func TestMCPRouteController_syncGateways_NamespaceCrossReference(t *testing.T) { diff --git a/internal/filterapi/mcpconfig.go b/internal/filterapi/mcpconfig.go index d70a231860..5e119b81a1 100644 --- a/internal/filterapi/mcpconfig.go +++ b/internal/filterapi/mcpconfig.go @@ -38,9 +38,6 @@ type MCPBackend struct { // This name is set in [internalapi.MCPBackendHeader] header to route the request to the specific backend. Name MCPBackendName `json:"name"` - // Path is the HTTP endpoint path of the backend MCP server. - Path string `json:"path"` - // ToolSelector filters the tools exposed by this backend. If not set, all tools are exposed. ToolSelector *MCPToolSelector `json:"toolSelector,omitempty"` } diff --git a/internal/mcpproxy/config_test.go b/internal/mcpproxy/config_test.go index 7fd7576db1..e0ad4d4a0a 100644 --- a/internal/mcpproxy/config_test.go +++ b/internal/mcpproxy/config_test.go @@ -77,9 +77,9 @@ func TestLoadConfig_BasicConfiguration(t *testing.T) { { Name: "route1", Backends: []filterapi.MCPBackend{ - {Name: "backend1", Path: "/mcp1"}, + {Name: "backend1"}, { - Name: "backend2", Path: "/mcp2", + Name: "backend2", ToolSelector: &filterapi.MCPToolSelector{ Include: []string{"tool1", "tool2"}, IncludeRegex: []string{"^test.*"}, @@ -90,8 +90,8 @@ func TestLoadConfig_BasicConfiguration(t *testing.T) { { Name: "route2", Backends: []filterapi.MCPBackend{ - {Name: "backend3", Path: "/mcp3"}, - {Name: "backend4", Path: "/mcp4"}, + {Name: "backend3"}, + {Name: "backend4"}, }, }, }, @@ -130,7 +130,7 @@ func TestLoadConfig_ToolsChangedNotification(t *testing.T) { routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ "route1": { backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend1": {Name: "backend1", Path: "/mcp1"}, + "backend1": {Name: "backend1"}, }, toolSelectors: map[filterapi.MCPBackendName]*toolSelector{}, }, @@ -147,8 +147,8 @@ func TestLoadConfig_ToolsChangedNotification(t *testing.T) { { Name: "route1", Backends: []filterapi.MCPBackend{ - {Name: "backend1", Path: "/mcp1"}, - {Name: "backend2", Path: "/mcp2"}, // Added backend + {Name: "backend1"}, + {Name: "backend2"}, // Added backend }, }, }, @@ -178,7 +178,7 @@ func TestLoadConfig_NoToolsChangedNotification(t *testing.T) { routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ "route1": { backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend1": {Name: "backend1", Path: "/mcp1"}, + "backend1": {Name: "backend1"}, }, toolSelectors: map[filterapi.MCPBackendName]*toolSelector{}, }, @@ -195,7 +195,7 @@ func TestLoadConfig_NoToolsChangedNotification(t *testing.T) { { Name: "route1", Backends: []filterapi.MCPBackend{ - {Name: "backend1", Path: "/mcp1"}, // Same backend + {Name: "backend1"}, // Same backend }, }, }, @@ -229,7 +229,6 @@ func TestLoadConfig_InvalidRegex(t *testing.T) { Backends: []filterapi.MCPBackend{ { Name: "backend1", - Path: "/mcp1", ToolSelector: &filterapi.MCPToolSelector{ IncludeRegex: []string{"[invalid"}, // Invalid regex }, @@ -256,7 +255,7 @@ func TestLoadConfig_ToolSelectorChange(t *testing.T) { routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ "route1": { backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend1": {Name: "backend1", Path: "/mcp1"}, + "backend1": {Name: "backend1"}, }, toolSelectors: map[filterapi.MCPBackendName]*toolSelector{ "backend1": { @@ -279,7 +278,6 @@ func TestLoadConfig_ToolSelectorChange(t *testing.T) { Backends: []filterapi.MCPBackend{ { Name: "backend1", - Path: "/mcp1", ToolSelector: &filterapi.MCPToolSelector{ Include: []string{"tool1", "tool2"}, // Different tools }, @@ -315,11 +313,11 @@ func TestLoadConfig_ToolOrderDoesNotMatter(t *testing.T) { // Initialize proxy with initial configuration directly proxy := &ProxyConfig{ mcpProxyConfig: &mcpProxyConfig{ - backendListenerAddr: "http://localhost:8080", + backendListenerAddr: "http://localhost:8080/", routes: map[filterapi.MCPRouteName]*mcpProxyConfigRoute{ "route1": { backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend1": {Name: "backend1", Path: "/mcp1"}, + "backend1": {Name: "backend1"}, }, toolSelectors: map[filterapi.MCPBackendName]*toolSelector{ "backend1": { @@ -351,7 +349,6 @@ func TestLoadConfig_ToolOrderDoesNotMatter(t *testing.T) { Backends: []filterapi.MCPBackend{ { Name: "backend1", - Path: "/mcp1", ToolSelector: &filterapi.MCPToolSelector{ Include: []string{"tool-c", "tool-a", "tool-b"}, // Different order IncludeRegex: []string{"^exact$", ".*suffix$", "^prefix.*"}, // Different order diff --git a/internal/mcpproxy/handlers.go b/internal/mcpproxy/handlers.go index 5a953651bd..af43188667 100644 --- a/internal/mcpproxy/handlers.go +++ b/internal/mcpproxy/handlers.go @@ -1021,10 +1021,6 @@ func (m *mcpRequestContext) recordResponse(ctx context.Context, rawMsg jsonrpc.M } } -func (m *mcpRequestContext) mcpEndpointForBackend(backend filterapi.MCPBackend) string { - return m.backendListenerAddr + backend.Path -} - func (m *mcpRequestContext) handleResourceReadRequest(ctx context.Context, s *session, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.ReadResourceParams) error { backendName, resourceName, err := upstreamResourceURI(p.URI) if err != nil { diff --git a/internal/mcpproxy/handlers_test.go b/internal/mcpproxy/handlers_test.go index 356dab28af..ef9710e6c7 100644 --- a/internal/mcpproxy/handlers_test.go +++ b/internal/mcpproxy/handlers_test.go @@ -59,13 +59,13 @@ func newTestMCPProxyWithTracer(t tracingapi.MCPTracer) *mcpRequestContext { "backend1": {include: map[string]struct{}{"test-tool": {}}}, }, backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend1": {Name: "backend1", Path: "/mcp"}, - "backend2": {Name: "backend2", Path: "/"}, + "backend1": {Name: "backend1"}, + "backend2": {Name: "backend2"}, }, }, "test-route-another": { backends: map[filterapi.MCPBackendName]filterapi.MCPBackend{ - "backend3": {Name: "backend3", Path: "/mcp"}, + "backend3": {Name: "backend3"}, }, }, }, diff --git a/internal/mcpproxy/mcpproxy.go b/internal/mcpproxy/mcpproxy.go index 9fcf0bf8ae..68cb4c8a95 100644 --- a/internal/mcpproxy/mcpproxy.go +++ b/internal/mcpproxy/mcpproxy.go @@ -357,7 +357,7 @@ func (m *mcpRequestContext) invokeJSONRPCRequest(ctx context.Context, routeName if err != nil { return nil, fmt.Errorf("failed to encode MCP message: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.mcpEndpointForBackend(backend), bytes.NewReader(encoded)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.backendListenerAddr, bytes.NewReader(encoded)) if err != nil { return nil, fmt.Errorf("failed to create MCP notifications/initialized request: %w", err) } diff --git a/internal/mcpproxy/mcpproxy_test.go b/internal/mcpproxy/mcpproxy_test.go index 7d268715b0..8b3dd5b5e4 100644 --- a/internal/mcpproxy/mcpproxy_test.go +++ b/internal/mcpproxy/mcpproxy_test.go @@ -354,7 +354,7 @@ func TestInitializeSession_Success(t *testing.T) { proxy := newTestMCPProxy() proxy.backendListenerAddr = backendServer.URL - res, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend", Path: "/a/b/c"}, &mcp.InitializeParams{}) + res, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend"}, &mcp.InitializeParams{}) require.NoError(t, err) require.Equal(t, gatewayToMCPServerSessionID("test-session-123"), res.sessionID) @@ -372,7 +372,7 @@ func TestInitializeSession_InitializeFailure(t *testing.T) { proxy := newTestMCPProxy() proxy.backendListenerAddr = backendServer.URL - sessionID, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend", Path: "/a/b/c"}, &mcp.InitializeParams{}) + sessionID, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend"}, &mcp.InitializeParams{}) require.Error(t, err) require.Empty(t, sessionID) @@ -400,7 +400,7 @@ func TestInitializeSession_NotificationsInitializedFailure(t *testing.T) { proxy := newTestMCPProxy() proxy.backendListenerAddr = backendServer.URL - sessionID, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend", Path: "/aaaaaaaaaaaaaa"}, &mcp.InitializeParams{}) + sessionID, err := proxy.initializeSession(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend"}, &mcp.InitializeParams{}) require.Error(t, err) require.Empty(t, sessionID) @@ -409,7 +409,7 @@ func TestInitializeSession_NotificationsInitializedFailure(t *testing.T) { func TestInvokeJSONRPCRequest_Success(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/aaaaaaaaaaaaaa", r.URL.Path) + require.Equal(t, "/", r.URL.Path) require.Equal(t, "test-backend", r.Header.Get("x-ai-eg-mcp-backend")) require.Equal(t, "test-session", r.Header.Get(sessionIDHeader)) require.Equal(t, "application/json", r.Header.Get("Content-Type")) @@ -420,7 +420,7 @@ func TestInvokeJSONRPCRequest_Success(t *testing.T) { m := newTestMCPProxy() m.backendListenerAddr = backendServer.URL - resp, err := m.invokeJSONRPCRequest(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend", Path: "/aaaaaaaaaaaaaa"}, &compositeSessionEntry{ + resp, err := m.invokeJSONRPCRequest(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend"}, &compositeSessionEntry{ sessionID: "test-session", }, &jsonrpc.Request{}) @@ -432,8 +432,7 @@ func TestInvokeJSONRPCRequest_Success(t *testing.T) { func TestInvokeJSONRPCRequest_NoSessionID(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check the path equals /mcp. - require.Equal(t, "/mcp", r.URL.Path) + require.Equal(t, "/", r.URL.Path) require.Equal(t, "test-backend", r.Header.Get("x-ai-eg-mcp-backend")) require.Empty(t, r.Header.Get(sessionIDHeader)) w.WriteHeader(http.StatusOK) @@ -443,7 +442,7 @@ func TestInvokeJSONRPCRequest_NoSessionID(t *testing.T) { m := newTestMCPProxy() m.backendListenerAddr = backendServer.URL - resp, err := m.invokeJSONRPCRequest(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend", Path: "/mcp"}, &compositeSessionEntry{ + resp, err := m.invokeJSONRPCRequest(t.Context(), "route1", filterapi.MCPBackend{Name: "test-backend"}, &compositeSessionEntry{ sessionID: "", }, &jsonrpc.Request{}) diff --git a/internal/mcpproxy/session.go b/internal/mcpproxy/session.go index 35d30d5bb4..65eca30736 100644 --- a/internal/mcpproxy/session.go +++ b/internal/mcpproxy/session.go @@ -55,17 +55,7 @@ func (s *session) Close() error { // Stateless backend, nothing to do. continue } - // Make DELETE request to the MCP server to close the session. - backend, err := s.reqCtx.getBackendForRoute(s.route, backendName) - if err != nil { - s.reqCtx.l.Error("failed to get backend for route", - slog.String("backend", backendName), - slog.String("session_id", string(sessionID)), - slog.String("error", err.Error()), - ) - continue - } - req, err := http.NewRequest(http.MethodDelete, s.reqCtx.mcpEndpointForBackend(backend), nil) + req, err := http.NewRequest(http.MethodDelete, s.reqCtx.backendListenerAddr, nil) if err != nil { s.reqCtx.l.Error("failed to create DELETE request to MCP server to close session", slog.String("backend", backendName), @@ -334,7 +324,7 @@ func (s *session) sendRequestPerBackend(ctx context.Context, eventChan chan<- *s body = bytes.NewReader(encodedReq) } - req, err := http.NewRequestWithContext(ctx, httpMethod, s.reqCtx.mcpEndpointForBackend(backend), body) + req, err := http.NewRequestWithContext(ctx, httpMethod, s.reqCtx.backendListenerAddr, body) if err != nil { return fmt.Errorf("failed to create GET request: %w", err) } diff --git a/internal/mcpproxy/session_test.go b/internal/mcpproxy/session_test.go index 7e4799a1cd..7bca61757f 100644 --- a/internal/mcpproxy/session_test.go +++ b/internal/mcpproxy/session_test.go @@ -131,7 +131,7 @@ func TestSendRequestPerBackend_SetsOriginalPathHeaders(t *testing.T) { ch := make(chan *sseEvent, 1) ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() - err := s.sendRequestPerBackend(ctx, ch, "test-route", filterapi.MCPBackend{Name: "backend1", Path: "/mcp"}, &compositeSessionEntry{ + err := s.sendRequestPerBackend(ctx, ch, "test-route", filterapi.MCPBackend{Name: "backend1"}, &compositeSessionEntry{ sessionID: "sess1", }, http.MethodGet, nil) require.NoError(t, err) diff --git a/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_mcproutes.yaml b/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_mcproutes.yaml index 06caed8df9..417901544c 100644 --- a/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_mcproutes.yaml +++ b/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_mcproutes.yaml @@ -147,12 +147,23 @@ spec: defaults to "Authorization". When the header is "Authorization", the injected header value will be prefixed with "Bearer ". + + Either one of Header or QueryParam can be specified to inject the API key. minLength: 1 type: string inline: description: Inline contains the API key as an inline string. type: string + queryParam: + description: |- + QueryParam is the HTTP query parameter to inject the API key into. + For example, if QueryParam is set to "api_key", and the API key is "mysecretkey", the request URL will be modified to include + "?api_key=mysecretkey". + + Either one of Header or QueryParam can be specified to inject the API key. + minLength: 1 + type: string secretRef: description: |- secretRef is the Kubernetes secret which contains the API keys. @@ -202,6 +213,8 @@ spec: - message: exactly one of secretRef or inline must be set rule: (has(self.secretRef) && !has(self.inline)) || (!has(self.secretRef) && has(self.inline)) + - message: only one of header or queryParam can be set + rule: '!(has(self.header) && has(self.queryParam))' type: object toolSelector: description: |- diff --git a/site/docs/api/api.mdx b/site/docs/api/api.mdx index 240eaa93b5..c75c840bc4 100644 --- a/site/docs/api/api.mdx +++ b/site/docs/api/api.mdx @@ -1858,6 +1858,7 @@ MCPAuthorizationTarget defines the target of an authorization rule. - [MCPBackendSecurityPolicy](#mcpbackendsecuritypolicy) MCPBackendAPIKey defines the configuration for the API Key Authentication to a backend. +When both `header` and `queryParam` are unspecified, the API key will be injected into the "Authorization" header by default. ##### Fields @@ -1877,7 +1878,12 @@ MCPBackendAPIKey defines the configuration for the API Key Authentication to a b name="header" type="string" required="false" - description="Header is the HTTP header to inject the API key into. If not specified,
defaults to `Authorization`.
When the header is `Authorization`, the injected header value will be
prefixed with `Bearer `." + description="Header is the HTTP header to inject the API key into. If not specified,
defaults to `Authorization`.
When the header is `Authorization`, the injected header value will be
prefixed with `Bearer `.
Either one of Header or QueryParam can be specified to inject the API key." +/> diff --git a/tests/crdcel/main_test.go b/tests/crdcel/main_test.go index 1bf9b64afd..3b14bfd448 100644 --- a/tests/crdcel/main_test.go +++ b/tests/crdcel/main_test.go @@ -244,6 +244,10 @@ func TestMCPRoutes(t *testing.T) { name: "backend_api_key_missing.yaml", expErr: "spec.backendRefs[0].securityPolicy.apiKey: Invalid value: \"object\": exactly one of secretRef or inline must be set", }, + { + name: "backend_api_key_both_header_and_query.yaml", + expErr: "only one of header or queryParam can be set", + }, { name: "jwks_missing.yaml", expErr: "spec.securityPolicy.oauth.jwks: Invalid value: \"object\": either remoteJWKS or localJWKS must be specified.", diff --git a/tests/crdcel/testdata/mcpgatewayroutes/backend_api_key_both_header_and_query.yaml b/tests/crdcel/testdata/mcpgatewayroutes/backend_api_key_both_header_and_query.yaml new file mode 100644 index 0000000000..49a1d43b87 --- /dev/null +++ b/tests/crdcel/testdata/mcpgatewayroutes/backend_api_key_both_header_and_query.yaml @@ -0,0 +1,25 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + +# This should fail validation: API key has both header and query param specified. +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: MCPRoute +metadata: + name: backend-apikey-both-header-and-query + namespace: default +spec: + parentRefs: + - name: some-gateway + kind: Gateway + group: gateway.networking.k8s.io + backendRefs: + - name: mcp-service + kind: Service + port: 80 + securityPolicy: + apiKey: + inline: "my-api-key" + header: "X-API-KEY" + queryParam: "api_key" diff --git a/tests/data-plane-mcp/env.go b/tests/data-plane-mcp/env.go index 04e95bea60..908ad2ea68 100644 --- a/tests/data-plane-mcp/env.go +++ b/tests/data-plane-mcp/env.go @@ -88,25 +88,25 @@ func requireNewMCPEnv(t *testing.T, forceJSONResponse bool, writeTimeout time.Du { Name: "test-route", Backends: []filterapi.MCPBackend{ - {Name: "dumb-mcp-backend", Path: "/mcp"}, - {Name: "default-mcp-backend", Path: "/mcp"}, + {Name: "dumb-mcp-backend"}, + {Name: "default-mcp-backend"}, }, }, { Name: "yet-another-route", Backends: []filterapi.MCPBackend{ { - Name: "default-mcp-backend", Path: "/mcp", + Name: "default-mcp-backend", // This shouldn't affect any other routes. ToolSelector: &filterapi.MCPToolSelector{Include: []string{"non-existent"}}, }, - {Name: "dumb-mcp-backend", Path: "/mcp"}, + {Name: "dumb-mcp-backend"}, }, }, { Name: "awesome-route", Backends: []filterapi.MCPBackend{ - {Name: "dumb-mcp-backend", Path: "/mcp"}, + {Name: "dumb-mcp-backend"}, }, }, }, diff --git a/tests/data-plane-mcp/envoy.yaml b/tests/data-plane-mcp/envoy.yaml index be305d900b..9497699783 100644 --- a/tests/data-plane-mcp/envoy.yaml +++ b/tests/data-plane-mcp/envoy.yaml @@ -128,6 +128,7 @@ static_resources: cluster: dumb-mcp-backend idleTimeout: 3600s timeout: 120s + pathRewrite: /mcp - match: prefix: "/" headers: @@ -138,6 +139,7 @@ static_resources: cluster: default-mcp-backend idleTimeout: 3600s timeout: 120s + pathRewrite: /mcp - match: prefix: "/" headers: @@ -148,6 +150,7 @@ static_resources: autoHostRewrite: true cluster: context7 idleTimeout: 3600s + pathRewrite: /mcp timeout: 120s - match: prefix: "/" @@ -159,6 +162,7 @@ static_resources: autoHostRewrite: true cluster: github idleTimeout: 3600s + pathRewrite: /mcp/readonly timeout: 120s request_headers_to_add: - header: @@ -174,6 +178,7 @@ static_resources: autoHostRewrite: true cluster: kiwi idleTimeout: 3600s + pathRewrite: / timeout: 120s http_filters: - name: envoy.filters.http.header_to_metadata diff --git a/tests/data-plane-mcp/publicmcp_test.go b/tests/data-plane-mcp/publicmcp_test.go index 8ea00379db..f57a6bb769 100644 --- a/tests/data-plane-mcp/publicmcp_test.go +++ b/tests/data-plane-mcp/publicmcp_test.go @@ -29,8 +29,8 @@ func TestPublicMCPServers(t *testing.T) { { Name: "test-route", Backends: []filterapi.MCPBackend{ - {Name: "context7", Path: "/mcp"}, - {Name: "kiwi", Path: "/"}, + {Name: "context7"}, + {Name: "kiwi"}, }, }, }, @@ -42,7 +42,6 @@ func TestPublicMCPServers(t *testing.T) { mcpConfig.Routes[0].Backends = append(mcpConfig.Routes[0].Backends, filterapi.MCPBackend{ Name: "github", - Path: "/mcp/readonly", ToolSelector: &filterapi.MCPToolSelector{ IncludeRegex: []string{".*pull_requests?.*", ".*issues?.*"}, }, diff --git a/tests/e2e/mcp_route_test.go b/tests/e2e/mcp_route_test.go index adee300236..fac428f369 100644 --- a/tests/e2e/mcp_route_test.go +++ b/tests/e2e/mcp_route_test.go @@ -52,7 +52,7 @@ func TestMCP(t *testing.T) { }) t.Run("tenant route with another path suffix", func(t *testing.T) { testMCPRouteTools(t.Context(), t, client, fwd.Address(), "/mcp/another", []string{ - "mcp-backend__sum", + "mcp-backend-query-api-key__sum", }, nil, false, true) }) t.Run("tenant route with different path", func(t *testing.T) { diff --git a/tests/e2e/testdata/mcp_route.yaml b/tests/e2e/testdata/mcp_route.yaml index 6d18804377..1a04c25243 100644 --- a/tests/e2e/testdata/mcp_route.yaml +++ b/tests/e2e/testdata/mcp_route.yaml @@ -86,6 +86,47 @@ spec: port: 1063 targetPort: 1063 type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: mcp-backend-query-api-key + namespace: default +spec: + replicas: 1 + selector: + matchLabels: + app: mcp-backend-query-api-key + template: + metadata: + labels: + app: mcp-backend-query-api-key + spec: + containers: + - name: mcp-backend + image: docker.io/envoyproxy/ai-gateway-testmcpserver:latest + imagePullPolicy: IfNotPresent + ports: + - containerPort: 1063 + env: + - name: TEST_API_KEY + value: "test-api-key" + - name: TEST_API_KEY_QUERY_PARAM + value: "api_key" +--- +apiVersion: v1 +kind: Service +metadata: + name: mcp-backend-query-api-key + namespace: default +spec: + selector: + app: mcp-backend-query-api-key + ports: + - protocol: TCP + port: 1063 + targetPort: 1063 + type: ClusterIP --- apiVersion: v1 @@ -186,11 +227,12 @@ spec: group: gateway.networking.k8s.io namespace: default backendRefs: - - name: mcp-backend + - name: mcp-backend-query-api-key port: 1063 securityPolicy: apiKey: inline: "test-api-key" + queryParam: "api_key" toolSelector: include: - sum diff --git a/tests/internal/testmcp/server.go b/tests/internal/testmcp/server.go index c767639861..6a683937c7 100644 --- a/tests/internal/testmcp/server.go +++ b/tests/internal/testmcp/server.go @@ -83,7 +83,11 @@ func NewServer(opts *Options) (*http.Server, *mcp.Server) { }, ) - if apiKey := os.Getenv("TEST_API_KEY"); apiKey != "" { + // Setup API key auth when environment variable TEST_API_KEY is set. + apiKey := os.Getenv("TEST_API_KEY") + apiKeyQueryParam := os.Getenv("TEST_API_KEY_QUERY_PARAM") + // Query param auth takes precedence over header. + if apiKey != "" && apiKeyQueryParam == "" { header := strings.ToLower(cmp.Or(os.Getenv("TEST_API_KEY_HEADER"), "Authorization")) expectedValue := apiKey if header == "authorization" { @@ -119,7 +123,18 @@ func NewServer(opts *Options) (*http.Server, *mcp.Server) { notificationsCounts := newToolNotificationCounts(handlerCounts) mcp.AddTool(s, notificationsCounts.Tool, notificationsCounts.Handler) - handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + // Check for API key in query param if configured. + if apiKey != "" && apiKeyQueryParam != "" { + log.Printf("checking for API key in query param %q\n", apiKeyQueryParam) + queryParam := r.URL.Query().Get(apiKeyQueryParam) + if queryParam != apiKey { + // Returning nil will cause 400 response in the current implementation of NewStreamableHTTPHandler. + log.Printf("invalid API key in query param %q: %q\n", apiKeyQueryParam, queryParam) + return nil + } + log.Printf("valid API key in query param %q\n", apiKeyQueryParam) + } return s }, &mcp.StreamableHTTPOptions{JSONResponse: opts.ForceJSONResponse}) From a66942490421e7d9eb3fd31154070b892f31ce1f Mon Sep 17 00:00:00 2001 From: Chang Min Date: Mon, 9 Feb 2026 09:59:37 -0500 Subject: [PATCH 11/11] fix: use constant strings for anthropic api types Signed-off-by: Chang Min --- internal/apischema/anthropic/anthropic.go | 97 ++++++++++++++++------- 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/internal/apischema/anthropic/anthropic.go b/internal/apischema/anthropic/anthropic.go index ec8d1cef35..0889cd1a07 100644 --- a/internal/apischema/anthropic/anthropic.go +++ b/internal/apischema/anthropic/anthropic.go @@ -251,67 +251,81 @@ type ( } ) +// Content block type constants used by ContentBlockParam and MessagesContentBlock. +const ( + contentBlockTypeText = "text" + contentBlockTypeImage = "image" + contentBlockTypeDocument = "document" + contentBlockTypeSearchResult = "search_result" + contentBlockTypeThinking = "thinking" + contentBlockTypeRedactedThinking = "redacted_thinking" + contentBlockTypeToolUse = "tool_use" + contentBlockTypeToolResult = "tool_result" + contentBlockTypeServerToolUse = "server_tool_use" + contentBlockTypeWebSearchToolResult = "web_search_tool_result" +) + func (m *ContentBlockParam) UnmarshalJSON(data []byte) error { typ := gjson.GetBytes(data, "type") if !typ.Exists() { return errors.New("missing type field in message content block") } switch typ.String() { - case "text": + case contentBlockTypeText: var blockParam TextBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal text blockParam: %w", err) } m.Text = &blockParam - case "image": + case contentBlockTypeImage: var blockParam ImageBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal image blockParam: %w", err) } m.Image = &blockParam - case "document": + case contentBlockTypeDocument: var blockParam DocumentBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal document blockParam: %w", err) } m.Document = &blockParam - case "search_result": + case contentBlockTypeSearchResult: var blockParam SearchResultBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal search result blockParam: %w", err) } m.SearchResult = &blockParam - case "thinking": + case contentBlockTypeThinking: var blockParam ThinkingBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal thinking blockParam: %w", err) } m.Thinking = &blockParam - case "redacted_thinking": + case contentBlockTypeRedactedThinking: var blockParam RedactedThinkingBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal redacted thinking blockParam: %w", err) } m.RedactedThinking = &blockParam - case "tool_use": + case contentBlockTypeToolUse: var blockParam ToolUseBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal tool use blockParam: %w", err) } m.ToolUse = &blockParam - case "tool_result": + case contentBlockTypeToolResult: var blockParam ToolResultBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal tool result blockParam: %w", err) } m.ToolResult = &blockParam - case "server_tool_use": + case contentBlockTypeServerToolUse: var blockParam ServerToolUseBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal server tool use blockParam: %w", err) } m.ServerToolUse = &blockParam - case "web_search_tool_result": + case contentBlockTypeWebSearchToolResult: var blockParam WebSearchToolResultBlockParam if err := json.Unmarshal(data, &blockParam); err != nil { return fmt.Errorf("failed to unmarshal web search tool result blockParam: %w", err) @@ -463,43 +477,53 @@ type ( } ) +// Tool type constants used by ToolUnion. +const ( + toolTypeCustom = "custom" + toolTypeBash20250124 = "bash_20250124" + toolTypeTextEditor20250124 = "text_editor_20250124" + toolTypeTextEditor20250429 = "text_editor_20250429" + toolTypeTextEditor20250728 = "text_editor_20250728" + toolTypeWebSearch20250305 = "web_search_20250305" +) + func (t *ToolUnion) UnmarshalJSON(data []byte) error { typ := gjson.GetBytes(data, "type") if !typ.Exists() { return errors.New("missing type field in tool") } switch typ.String() { - case "custom": + case toolTypeCustom: var tool Tool if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal tool: %w", err) } t.Tool = &tool - case "bash_20250124": + case toolTypeBash20250124: var tool BashTool if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal bash tool: %w", err) } t.BashTool = &tool - case "text_editor_20250124": + case toolTypeTextEditor20250124: var tool TextEditorTool20250124 if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal text editor tool: %w", err) } t.TextEditorTool20250124 = &tool - case "text_editor_20250429": + case toolTypeTextEditor20250429: var tool TextEditorTool20250429 if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal text editor tool: %w", err) } t.TextEditorTool20250429 = &tool - case "text_editor_20250728": + case toolTypeTextEditor20250728: var tool TextEditorTool20250728 if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal text editor tool: %w", err) } t.TextEditorTool20250728 = &tool - case "web_search_20250305": + case toolTypeWebSearch20250305: var tool WebSearchTool if err := json.Unmarshal(data, &tool); err != nil { return fmt.Errorf("failed to unmarshal web search tool: %w", err) @@ -573,31 +597,39 @@ type ( } ) +// Tool choice type constants used by ToolChoice. +const ( + toolChoiceTypeAuto = "auto" + toolChoiceTypeAny = "any" + toolChoiceTypeTool = "tool" + toolChoiceTypeNone = "none" +) + func (tc *ToolChoice) UnmarshalJSON(data []byte) error { typ := gjson.GetBytes(data, "type") if !typ.Exists() { return errors.New("missing type field in tool choice") } switch typ.String() { - case "auto": + case toolChoiceTypeAuto: var toolChoice ToolChoiceAuto if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice auto: %w", err) } tc.Auto = &toolChoice - case "any": + case toolChoiceTypeAny: var toolChoice ToolChoiceAny if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice any: %w", err) } tc.Any = &toolChoice - case "tool": + case toolChoiceTypeTool: var toolChoice ToolChoiceTool if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice tool: %w", err) } tc.Tool = &toolChoice - case "none": + case toolChoiceTypeNone: var toolChoice ToolChoiceNone if err := json.Unmarshal(data, &toolChoice); err != nil { return fmt.Errorf("failed to unmarshal tool choice none: %w", err) @@ -656,25 +688,32 @@ type ( } ) +// Thinking config type constants used by Thinking. +const ( + thinkingConfigTypeEnabled = "enabled" + thinkingConfigTypeDisabled = "disabled" + thinkingConfigTypeAdaptive = "adaptive" +) + func (t *Thinking) UnmarshalJSON(data []byte) error { typ := gjson.GetBytes(data, "type") if !typ.Exists() { return errors.New("missing type field in thinking config") } switch typ.String() { - case "enabled": + case thinkingConfigTypeEnabled: var thinking ThinkingEnabled if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking enabled: %w", err) } t.Enabled = &thinking - case "disabled": + case thinkingConfigTypeDisabled: var thinking ThinkingDisabled if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking disabled: %w", err) } t.Disabled = &thinking - case "adaptive": + case thinkingConfigTypeAdaptive: var thinking ThinkingAdaptive if err := json.Unmarshal(data, &thinking); err != nil { return fmt.Errorf("failed to unmarshal thinking adaptive: %w", err) @@ -851,37 +890,37 @@ func (m *MessagesContentBlock) UnmarshalJSON(data []byte) error { return errors.New("missing type field in message content block") } switch typ.String() { - case "text": + case contentBlockTypeText: var contentBlock TextBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal text block: %w", err) } m.Text = &contentBlock - case "tool_use": + case contentBlockTypeToolUse: var contentBlock ToolUseBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal tool use block: %w", err) } m.Tool = &contentBlock - case "thinking": + case contentBlockTypeThinking: var contentBlock ThinkingBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal thinking block: %w", err) } m.Thinking = &contentBlock - case "redacted_thinking": + case contentBlockTypeRedactedThinking: var contentBlock RedactedThinkingBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal redacted thinking block: %w", err) } m.RedactedThinking = &contentBlock - case "server_tool_use": + case contentBlockTypeServerToolUse: var contentBlock ServerToolUseBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal server tool use block: %w", err) } m.ServerToolUse = &contentBlock - case "web_search_tool_result": + case contentBlockTypeWebSearchToolResult: var contentBlock WebSearchToolResultBlock if err := json.Unmarshal(data, &contentBlock); err != nil { return fmt.Errorf("failed to unmarshal web search tool result block: %w", err)