diff --git a/api/v1alpha1/shared_types.go b/api/v1alpha1/shared_types.go index 09fad0431a..036e8b7a47 100644 --- a/api/v1alpha1/shared_types.go +++ b/api/v1alpha1/shared_types.go @@ -80,6 +80,8 @@ const ( APISchemaAnthropic APISchema = "Anthropic" // APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock. // Uses the native Anthropic Messages API format for requests and responses. + // When used with /v1/chat/completions endpoint, translates OpenAI format to Anthropic. + // When used with /v1/messages endpoint, passes through native Anthropic format. // // https://aws.amazon.com/bedrock/anthropic/ // https://docs.claude.com/en/api/claude-on-amazon-bedrock diff --git a/examples/basic/aws-bedrock-openai-anthropic.yaml b/examples/basic/aws-bedrock-openai-anthropic.yaml new file mode 100644 index 0000000000..bf36817dcf --- /dev/null +++ b/examples/basic/aws-bedrock-openai-anthropic.yaml @@ -0,0 +1,109 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + +# This example demonstrates using the AWSAnthropic schema to access +# Claude models on AWS Bedrock via the InvokeModel API with OpenAI-compatible requests. +# +# The AWSAnthropic schema works with both input formats: +# - /v1/chat/completions: Translates OpenAI ChatCompletion requests to Anthropic Messages API format +# - /v1/messages: Passes through native Anthropic Messages API format +# +# Use cases: +# - When you want to use OpenAI SDK/format with Claude models on AWS Bedrock +# - When migrating from OpenAI to Claude on AWS without changing client code +# - When using tools that only support OpenAI format but need Claude on AWS + +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: AIGatewayRoute +metadata: + name: envoy-ai-gateway-aws-bedrock-claude-openai-format + namespace: default +spec: + parentRefs: + - name: envoy-ai-gateway-basic + kind: Gateway + group: gateway.networking.k8s.io + rules: + - matches: + - headers: + - type: Exact + name: x-ai-eg-model + value: anthropic.claude-3-5-sonnet-20241022-v2:0 + backendRefs: + - name: envoy-ai-gateway-aws-bedrock-claude-openai +--- +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: AIServiceBackend +metadata: + name: envoy-ai-gateway-aws-bedrock-claude-openai + namespace: default +spec: + # AWSAnthropic schema supports both OpenAI and Anthropic input formats. + # The endpoint path determines the translator used. + schema: + name: AWSAnthropic + # Optional: Specify Anthropic API version for Bedrock + # Default: bedrock-2023-05-31 + version: bedrock-2023-05-31 + backendRef: + name: envoy-ai-gateway-basic-aws + kind: Backend + group: gateway.envoyproxy.io +--- +apiVersion: aigateway.envoyproxy.io/v1alpha1 +kind: BackendSecurityPolicy +metadata: + name: envoy-ai-gateway-aws-bedrock-credentials + namespace: default +spec: + targetRefs: + - group: aigateway.envoyproxy.io + kind: AIServiceBackend + name: envoy-ai-gateway-aws-bedrock-claude-openai + type: AWSCredentials + awsCredentials: + region: us-east-1 + credentialsFile: + secretRef: + name: envoy-ai-gateway-basic-aws-credentials +--- +apiVersion: gateway.envoyproxy.io/v1alpha1 +kind: Backend +metadata: + name: envoy-ai-gateway-basic-aws + namespace: default +spec: + endpoints: + - fqdn: + hostname: bedrock-runtime.us-east-1.amazonaws.com + port: 443 +--- +apiVersion: gateway.networking.k8s.io/v1alpha3 +kind: BackendTLSPolicy +metadata: + name: envoy-ai-gateway-basic-aws-tls + namespace: default +spec: + targetRefs: + - group: "gateway.envoyproxy.io" + kind: Backend + name: envoy-ai-gateway-basic-aws + validation: + wellKnownCACertificates: "System" + hostname: bedrock-runtime.us-east-1.amazonaws.com +--- +apiVersion: v1 +kind: Secret +metadata: + name: envoy-ai-gateway-basic-aws-credentials + namespace: default +type: Opaque +stringData: + # Replace this with your AWS credentials. + # You can also use AWS IAM roles for service accounts (IRSA) in EKS. + credentials: | + [default] + aws_access_key_id = AWS_ACCESS_KEY_ID + aws_secret_access_key = AWS_SECRET_ACCESS_KEY diff --git a/internal/endpointspec/endpointspec.go b/internal/endpointspec/endpointspec.go index a330f0873d..209ad2a8b7 100644 --- a/internal/endpointspec/endpointspec.go +++ b/internal/endpointspec/endpointspec.go @@ -129,6 +129,8 @@ func (ChatCompletionsEndpointSpec) GetTranslator(schema filterapi.VersionedAPISc return translator.NewChatCompletionOpenAIToOpenAITranslator(schema.OpenAIPrefix(), modelNameOverride), nil case filterapi.APISchemaAWSBedrock: return translator.NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride), nil + case filterapi.APISchemaAWSAnthropic: + return translator.NewChatCompletionOpenAIToAWSAnthropicTranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaAzureOpenAI: return translator.NewChatCompletionOpenAIToAzureOpenAITranslator(schema.Version, modelNameOverride), nil case filterapi.APISchemaGCPVertexAI: diff --git a/internal/endpointspec/endpointspec_test.go b/internal/endpointspec/endpointspec_test.go index b182760445..371c5cc998 100644 --- a/internal/endpointspec/endpointspec_test.go +++ b/internal/endpointspec/endpointspec_test.go @@ -81,6 +81,7 @@ func TestChatCompletionsEndpointSpec_GetTranslator(t *testing.T) { supported := []filterapi.VersionedAPISchema{ {Name: filterapi.APISchemaOpenAI, Prefix: "v1"}, {Name: filterapi.APISchemaAWSBedrock}, + {Name: filterapi.APISchemaAWSAnthropic}, {Name: filterapi.APISchemaAzureOpenAI, Version: "2024-02-01"}, {Name: filterapi.APISchemaGCPVertexAI}, {Name: filterapi.APISchemaGCPAnthropic, Version: "2024-05-01"}, diff --git a/internal/filterapi/filterconfig.go b/internal/filterapi/filterconfig.go index 947cce5f4f..529bb1cf59 100644 --- a/internal/filterapi/filterconfig.go +++ b/internal/filterapi/filterconfig.go @@ -114,7 +114,7 @@ const ( APISchemaOpenAI APISchemaName = "OpenAI" // APISchemaCohere represents the Cohere API schema. APISchemaCohere APISchemaName = "Cohere" - // APISchemaAWSBedrock represents the AWS Bedrock API schema. + // APISchemaAWSBedrock represents the AWS Bedrock Converse API schema. APISchemaAWSBedrock APISchemaName = "AWSBedrock" // APISchemaAzureOpenAI represents the Azure OpenAI API schema. APISchemaAzureOpenAI APISchemaName = "AzureOpenAI" @@ -127,7 +127,8 @@ const ( // APISchemaAnthropic represents the standard Anthropic API schema. APISchemaAnthropic APISchemaName = "Anthropic" // APISchemaAWSAnthropic represents the AWS Bedrock Anthropic API schema. - // Used for Claude models hosted on AWS Bedrock using the native Anthropic Messages API. + // Used for Claude models hosted on AWS Bedrock. Supports both OpenAI and Anthropic input formats + // depending on the endpoint path, similar to APISchemaGCPAnthropic. APISchemaAWSAnthropic APISchemaName = "AWSAnthropic" ) diff --git a/internal/translator/anthropic_helper.go b/internal/translator/anthropic_helper.go new file mode 100644 index 0000000000..3798c88469 --- /dev/null +++ b/internal/translator/anthropic_helper.go @@ -0,0 +1,1133 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "cmp" + "encoding/base64" + "fmt" + "io" + "strings" + "time" + + "github.com/anthropics/anthropic-sdk-go" + anthropicParam "github.com/anthropics/anthropic-sdk-go/packages/param" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + openAIconstant "github.com/openai/openai-go/shared/constant" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/json" + "github.com/envoyproxy/ai-gateway/internal/metrics" + "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" +) + +const ( + anthropicVersionKey = "anthropic_version" + tempNotSupportedError = "temperature %.2f is not supported by Anthropic (must be between 0.0 and 1.0)" +) + +func anthropicToOpenAIFinishReason(stopReason anthropic.StopReason) (openai.ChatCompletionChoicesFinishReason, error) { + switch stopReason { + // The most common stop reason. Indicates Claude finished its response naturally. + // or Claude encountered one of your custom stop sequences. + // TODO: A better way to return pause_turn + // TODO: "pause_turn" Used with server tools like web search when Claude needs to pause a long-running operation. + case anthropic.StopReasonEndTurn, anthropic.StopReasonStopSequence, anthropic.StopReasonPauseTurn: + return openai.ChatCompletionChoicesFinishReasonStop, nil + case anthropic.StopReasonMaxTokens: // Claude stopped because it reached the max_tokens limit specified in your request. + // TODO: do we want to return an error? see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#handling-the-max-tokens-stop-reason + return openai.ChatCompletionChoicesFinishReasonLength, nil + case anthropic.StopReasonToolUse: + return openai.ChatCompletionChoicesFinishReasonToolCalls, nil + case anthropic.StopReasonRefusal: + return openai.ChatCompletionChoicesFinishReasonContentFilter, nil + default: + return "", fmt.Errorf("received invalid stop reason %v", stopReason) + } +} + +// validateTemperatureForAnthropic checks if the temperature is within Anthropic's supported range (0.0 to 1.0). +// Returns an error if the value is greater than 1.0. +func validateTemperatureForAnthropic(temp *float64) error { + if temp != nil && (*temp < 0.0 || *temp > 1.0) { + return fmt.Errorf("%w: "+tempNotSupportedError, internalapi.ErrInvalidRequestBody, *temp) + } + return nil +} + +// translateAnthropicToolChoice converts the OpenAI tool_choice parameter to the Anthropic format. +func translateAnthropicToolChoice(openAIToolChoice *openai.ChatCompletionToolChoiceUnion, disableParallelToolUse anthropicParam.Opt[bool]) (anthropic.ToolChoiceUnionParam, error) { + var toolChoice anthropic.ToolChoiceUnionParam + + if openAIToolChoice == nil { + return toolChoice, nil + } + + switch choice := openAIToolChoice.Value.(type) { + case string: + switch choice { + case string(openAIconstant.ValueOf[openAIconstant.Auto]()): + toolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}} + toolChoice.OfAuto.DisableParallelToolUse = disableParallelToolUse + case "required", "any": + toolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}} + toolChoice.OfAny.DisableParallelToolUse = disableParallelToolUse + case "none": + toolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}} + case string(openAIconstant.ValueOf[openAIconstant.Function]()): + // This is how anthropic forces tool use. + // TODO: should we check if strict true in openAI request, and if so, use this? + toolChoice = anthropic.ToolChoiceUnionParam{OfTool: &anthropic.ToolChoiceToolParam{Name: choice}} + toolChoice.OfTool.DisableParallelToolUse = disableParallelToolUse + default: + return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("unsupported tool_choice value: %s", choice) + } + case openai.ChatCompletionNamedToolChoice: + if choice.Type == openai.ToolTypeFunction && choice.Function.Name != "" { + toolChoice = anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{ + Type: constant.Tool("tool"), + Name: choice.Function.Name, + DisableParallelToolUse: disableParallelToolUse, + }, + } + } + default: + return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("unsupported tool_choice type: %T", openAIToolChoice) + } + return toolChoice, nil +} + +func isAnthropicSupportedImageMediaType(mediaType string) bool { + switch anthropic.Base64ImageSourceMediaType(mediaType) { + case anthropic.Base64ImageSourceMediaTypeImageJPEG, + anthropic.Base64ImageSourceMediaTypeImagePNG, + anthropic.Base64ImageSourceMediaTypeImageGIF, + anthropic.Base64ImageSourceMediaTypeImageWebP: + return true + default: + return false + } +} + +// translateOpenAItoAnthropicTools translates OpenAI tool and tool_choice parameters +// into the Anthropic format and returns translated tool & tool choice. +func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice *openai.ChatCompletionToolChoiceUnion, parallelToolCalls *bool) (tools []anthropic.ToolUnionParam, toolChoice anthropic.ToolChoiceUnionParam, err error) { + if len(openAITools) > 0 { + anthropicTools := make([]anthropic.ToolUnionParam, 0, len(openAITools)) + for _, openAITool := range openAITools { + if openAITool.Type != openai.ToolTypeFunction || openAITool.Function == nil { + // Anthropic only supports 'function' tools, so we skip others. + continue + } + toolParam := anthropic.ToolParam{ + Name: openAITool.Function.Name, + Description: anthropic.String(openAITool.Function.Description), + } + + if isCacheEnabled(openAITool.Function.AnthropicContentFields) { + toolParam.CacheControl = anthropic.NewCacheControlEphemeralParam() + } + + // The parameters for the function are expected to be a JSON Schema object. + // We can pass them through as-is. + if openAITool.Function.Parameters != nil { + paramsMap, ok := openAITool.Function.Parameters.(map[string]any) + if !ok { + err = fmt.Errorf("failed to cast tool parameters to map[string]interface{}") + return + } + + inputSchema := anthropic.ToolInputSchemaParam{} + + // Dereference json schema + // If the paramsMap contains $refs we need to dereference them + var dereferencedParamsMap any + if dereferencedParamsMap, err = jsonSchemaDereference(paramsMap); err != nil { + return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to dereference tool parameters: %w", err) + } + if paramsMap, ok = dereferencedParamsMap.(map[string]any); !ok { + return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("failed to cast dereferenced tool parameters to map[string]interface{}") + } + + var typeVal string + if typeVal, ok = paramsMap["type"].(string); ok { + inputSchema.Type = constant.Object(typeVal) + } + + var propsVal map[string]any + if propsVal, ok = paramsMap["properties"].(map[string]any); ok { + inputSchema.Properties = propsVal + } + + var requiredVal []any + if requiredVal, ok = paramsMap["required"].([]any); ok { + requiredSlice := make([]string, len(requiredVal)) + for i, v := range requiredVal { + if s, ok := v.(string); ok { + requiredSlice[i] = s + } + } + inputSchema.Required = requiredSlice + } + + toolParam.InputSchema = inputSchema + } + + anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &toolParam}) + if len(anthropicTools) > 0 { + tools = anthropicTools + } + } + + // 2. Handle the tool_choice parameter. + // disable parallel tool use default value is false + // see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use + disableParallelToolUse := anthropic.Bool(false) + if parallelToolCalls != nil { + // OpenAI variable checks to allow parallel tool calls. + // Anthropic variable checks to disable, so need to use the inverse. + disableParallelToolUse = anthropic.Bool(!*parallelToolCalls) + } + + toolChoice, err = translateAnthropicToolChoice(openAIToolChoice, disableParallelToolUse) + if err != nil { + return + } + + } + return +} + +// convertImageContentToAnthropic translates an OpenAI image URL into the corresponding Anthropic content block. +// It handles data URIs for various image types and PDFs, as well as remote URLs. +func convertImageContentToAnthropic(imageURL string, fields *openai.AnthropicContentFields) (anthropic.ContentBlockParamUnion, error) { + var cacheControlParam anthropic.CacheControlEphemeralParam + if isCacheEnabled(fields) { + cacheControlParam = fields.CacheControl + } + + switch { + case strings.HasPrefix(imageURL, "data:"): + contentType, data, err := parseDataURI(imageURL) + if err != nil { + return anthropic.ContentBlockParamUnion{}, fmt.Errorf("failed to parse image URL: %w", err) + } + base64Data := base64.StdEncoding.EncodeToString(data) + if contentType == string(constant.ValueOf[constant.ApplicationPDF]()) { + pdfSource := anthropic.Base64PDFSourceParam{Data: base64Data} + docBlock := anthropic.NewDocumentBlock(pdfSource) + docBlock.OfDocument.CacheControl = cacheControlParam + return docBlock, nil + } + if isAnthropicSupportedImageMediaType(contentType) { + imgBlock := anthropic.NewImageBlockBase64(contentType, base64Data) + imgBlock.OfImage.CacheControl = cacheControlParam + return imgBlock, nil + } + return anthropic.ContentBlockParamUnion{}, fmt.Errorf("invalid media_type for image '%s'", contentType) + case strings.HasSuffix(strings.ToLower(imageURL), ".pdf"): + docBlock := anthropic.NewDocumentBlock(anthropic.URLPDFSourceParam{URL: imageURL}) + docBlock.OfDocument.CacheControl = cacheControlParam + return docBlock, nil + default: + imgBlock := anthropic.NewImageBlock(anthropic.URLImageSourceParam{URL: imageURL}) + imgBlock.OfImage.CacheControl = cacheControlParam + return imgBlock, nil + } +} + +func isCacheEnabled(fields *openai.AnthropicContentFields) bool { + return fields != nil && fields.CacheControl.Type == constant.ValueOf[constant.Ephemeral]() +} + +// convertContentPartsToAnthropic iterates over a slice of OpenAI content parts +// and converts each into an Anthropic content block. +func convertContentPartsToAnthropic(parts []openai.ChatCompletionContentPartUserUnionParam) ([]anthropic.ContentBlockParamUnion, error) { + resultContent := make([]anthropic.ContentBlockParamUnion, 0, len(parts)) + for _, contentPart := range parts { + switch { + case contentPart.OfText != nil: + textBlock := anthropic.NewTextBlock(contentPart.OfText.Text) + if isCacheEnabled(contentPart.OfText.AnthropicContentFields) { + textBlock.OfText.CacheControl = contentPart.OfText.CacheControl + } + resultContent = append(resultContent, textBlock) + + case contentPart.OfImageURL != nil: + block, err := convertImageContentToAnthropic(contentPart.OfImageURL.ImageURL.URL, contentPart.OfImageURL.AnthropicContentFields) + if err != nil { + return nil, err + } + resultContent = append(resultContent, block) + + case contentPart.OfInputAudio != nil: + return nil, fmt.Errorf("input audio content not supported yet") + case contentPart.OfFile != nil: + return nil, fmt.Errorf("file content not supported yet") + } + } + return resultContent, nil +} + +// Helper: Convert OpenAI message content to Anthropic content. +func openAIToAnthropicContent(content any) ([]anthropic.ContentBlockParamUnion, error) { + switch v := content.(type) { + case nil: + return nil, nil + case string: + if v == "" { + return nil, nil + } + return []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock(v), + }, nil + case []openai.ChatCompletionContentPartUserUnionParam: + return convertContentPartsToAnthropic(v) + case openai.ContentUnion: + switch val := v.Value.(type) { + case string: + if val == "" { + return nil, nil + } + return []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock(val), + }, nil + case []openai.ChatCompletionContentPartTextParam: + var contentBlocks []anthropic.ContentBlockParamUnion + for _, part := range val { + textBlock := anthropic.NewTextBlock(part.Text) + // In an array of text parts, each can have its own cache setting. + if isCacheEnabled(part.AnthropicContentFields) { + textBlock.OfText.CacheControl = part.CacheControl + } + contentBlocks = append(contentBlocks, textBlock) + } + return contentBlocks, nil + default: + return nil, fmt.Errorf("unsupported ContentUnion value type: %T", val) + } + } + return nil, fmt.Errorf("unsupported OpenAI content type: %T", content) +} + +// extractSystemPromptFromDeveloperMsg flattens content and checks for cache flags. +// It returns the combined string and a boolean indicating if any part was cacheable. +func extractSystemPromptFromDeveloperMsg(msg openai.ChatCompletionDeveloperMessageParam) (msgValue string, cacheParam *anthropic.CacheControlEphemeralParam) { + switch v := msg.Content.Value.(type) { + case nil: + return + case string: + msgValue = v + return + case []openai.ChatCompletionContentPartTextParam: + // Concatenate all text parts and check for caching. + var sb strings.Builder + for _, part := range v { + sb.WriteString(part.Text) + if isCacheEnabled(part.AnthropicContentFields) { + cacheParam = &part.CacheControl + } + } + msgValue = sb.String() + return + default: + return + } +} + +func anthropicRoleToOpenAIRole(role anthropic.MessageParamRole) (string, error) { + switch role { + case anthropic.MessageParamRoleAssistant: + return openai.ChatMessageRoleAssistant, nil + case anthropic.MessageParamRoleUser: + return openai.ChatMessageRoleUser, nil + default: + return "", fmt.Errorf("invalid anthropic role %v", role) + } +} + +// processAssistantContent processes a single assistant content block and adds it to the content blocks. +func processAssistantContent(contentBlocks []anthropic.ContentBlockParamUnion, content openai.ChatCompletionAssistantMessageParamContent) ([]anthropic.ContentBlockParamUnion, error) { + switch content.Type { + case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: + if content.Refusal != nil { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(*content.Refusal)) + } + case openai.ChatCompletionAssistantMessageParamContentTypeText: + if content.Text != nil { + textBlock := anthropic.NewTextBlock(*content.Text) + if isCacheEnabled(content.AnthropicContentFields) { + textBlock.OfText.CacheControl = content.CacheControl + } + contentBlocks = append(contentBlocks, textBlock) + } + case openai.ChatCompletionAssistantMessageParamContentTypeThinking: + // Thinking content requires both text and signature + if content.Text != nil && content.Signature != nil { + contentBlocks = append(contentBlocks, anthropic.NewThinkingBlock(*content.Signature, *content.Text)) + } + case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: + if content.RedactedContent != nil { + switch v := content.RedactedContent.Value.(type) { + case string: + contentBlocks = append(contentBlocks, anthropic.NewRedactedThinkingBlock(v)) + default: + return nil, fmt.Errorf("unsupported RedactedContent type: %T, expected string", v) + } + } + default: + return nil, fmt.Errorf("content type not supported: %v", content.Type) + } + return contentBlocks, nil +} + +// openAIMessageToAnthropicMessageRoleAssistant converts an OpenAI assistant message to Anthropic content blocks. +// The tool_use content is appended to the Anthropic message content list if tool_calls are present. +func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatCompletionAssistantMessageParam) (anthropicMsg anthropic.MessageParam, err error) { + contentBlocks := make([]anthropic.ContentBlockParamUnion, 0) + if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(v)) + } else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok { + contentBlocks, err = processAssistantContent(contentBlocks, content) + if err != nil { + return + } + } else if contents, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok { + for _, content := range contents { + contentBlocks, err = processAssistantContent(contentBlocks, content) + if err != nil { + return + } + } + } + + // Handle tool_calls (if any). + for i := range openAiMessage.ToolCalls { + toolCall := &openAiMessage.ToolCalls[i] + var input map[string]any + if err = json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + err = fmt.Errorf("failed to unmarshal tool call arguments: %w", err) + return + } + toolUse := anthropic.ToolUseBlockParam{ + ID: *toolCall.ID, + Type: "tool_use", + Name: toolCall.Function.Name, + Input: input, + } + + if isCacheEnabled(toolCall.AnthropicContentFields) { + toolUse.CacheControl = toolCall.CacheControl + } + + contentBlocks = append(contentBlocks, anthropic.ContentBlockParamUnion{OfToolUse: &toolUse}) + } + + return anthropic.MessageParam{ + Role: anthropic.MessageParamRoleAssistant, + Content: contentBlocks, + }, nil +} + +// openAIToAnthropicMessages converts OpenAI messages to Anthropic message params type, handling all roles and system/developer logic. +func openAIToAnthropicMessages(openAIMsgs []openai.ChatCompletionMessageParamUnion) (anthropicMessages []anthropic.MessageParam, systemBlocks []anthropic.TextBlockParam, err error) { + for i := 0; i < len(openAIMsgs); { + msg := &openAIMsgs[i] + switch { + case msg.OfSystem != nil: + devParam := systemMsgToDeveloperMsg(*msg.OfSystem) + systemText, cacheControl := extractSystemPromptFromDeveloperMsg(devParam) + systemBlock := anthropic.TextBlockParam{Text: systemText} + if cacheControl != nil { + systemBlock.CacheControl = *cacheControl + } + systemBlocks = append(systemBlocks, systemBlock) + i++ + case msg.OfDeveloper != nil: + systemText, cacheControl := extractSystemPromptFromDeveloperMsg(*msg.OfDeveloper) + systemBlock := anthropic.TextBlockParam{Text: systemText} + if cacheControl != nil { + systemBlock.CacheControl = *cacheControl + } + systemBlocks = append(systemBlocks, systemBlock) + i++ + case msg.OfUser != nil: + message := *msg.OfUser + var content []anthropic.ContentBlockParamUnion + content, err = openAIToAnthropicContent(message.Content.Value) + if err != nil { + return + } + anthropicMsg := anthropic.MessageParam{ + Role: anthropic.MessageParamRoleUser, + Content: content, + } + anthropicMessages = append(anthropicMessages, anthropicMsg) + i++ + case msg.OfAssistant != nil: + assistantMessage := msg.OfAssistant + var messages anthropic.MessageParam + messages, err = openAIMessageToAnthropicMessageRoleAssistant(assistantMessage) + if err != nil { + return + } + anthropicMessages = append(anthropicMessages, messages) + i++ + case msg.OfTool != nil: + // Aggregate all consecutive tool messages into a single user message + // to support parallel tool use. + var toolResultBlocks []anthropic.ContentBlockParamUnion + for i < len(openAIMsgs) && openAIMsgs[i].ExtractMessgaeRole() == openai.ChatMessageRoleTool { + currentMsg := &openAIMsgs[i] + toolMsg := currentMsg.OfTool + + var contentBlocks []anthropic.ContentBlockParamUnion + contentBlocks, err = openAIToAnthropicContent(toolMsg.Content) + if err != nil { + return + } + + var toolContent []anthropic.ToolResultBlockParamContentUnion + var cacheControl *anthropic.CacheControlEphemeralParam + + for _, c := range contentBlocks { + var trb anthropic.ToolResultBlockParamContentUnion + // Check if the translated part has caching enabled. + switch { + case c.OfText != nil: + trb.OfText = c.OfText + cacheControl = &c.OfText.CacheControl + case c.OfImage != nil: + trb.OfImage = c.OfImage + cacheControl = &c.OfImage.CacheControl + case c.OfDocument != nil: + trb.OfDocument = c.OfDocument + cacheControl = &c.OfDocument.CacheControl + } + toolContent = append(toolContent, trb) + } + + isError := false + if contentStr, ok := toolMsg.Content.Value.(string); ok { + var contentMap map[string]any + if json.Unmarshal([]byte(contentStr), &contentMap) == nil { + if _, ok = contentMap["error"]; ok { + isError = true + } + } + } + + toolResultBlock := anthropic.ToolResultBlockParam{ + ToolUseID: toolMsg.ToolCallID, + Type: "tool_result", + Content: toolContent, + IsError: anthropic.Bool(isError), + } + + if cacheControl != nil { + toolResultBlock.CacheControl = *cacheControl + } + + toolResultBlockUnion := anthropic.ContentBlockParamUnion{OfToolResult: &toolResultBlock} + toolResultBlocks = append(toolResultBlocks, toolResultBlockUnion) + i++ + } + // Append all aggregated tool results. + anthropicMsg := anthropic.MessageParam{ + Role: anthropic.MessageParamRoleUser, + Content: toolResultBlocks, + } + anthropicMessages = append(anthropicMessages, anthropicMsg) + default: + err = fmt.Errorf("unsupported OpenAI role type: %s", msg.ExtractMessgaeRole()) + return + } + } + return +} + +// NewThinkingConfigParamUnion converts a ThinkingUnion into a ThinkingConfigParamUnion. +func getThinkingConfigParamUnion(tu *openai.ThinkingUnion) *anthropic.ThinkingConfigParamUnion { + if tu == nil { + return nil + } + + result := &anthropic.ThinkingConfigParamUnion{} + + if tu.OfEnabled != nil { + result.OfEnabled = &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: tu.OfEnabled.BudgetTokens, + Type: constant.Enabled(tu.OfEnabled.Type), + } + } else if tu.OfDisabled != nil { + result.OfDisabled = &anthropic.ThinkingConfigDisabledParam{ + Type: constant.Disabled(tu.OfDisabled.Type), + } + } + + return result +} + +// buildAnthropicParams is a helper function that translates an OpenAI request +// into the parameter struct required by the Anthropic SDK. +func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anthropic.MessageNewParams, err error) { + // 1. Handle simple parameters and defaults. + maxTokens := cmp.Or(openAIReq.MaxCompletionTokens, openAIReq.MaxTokens) + if maxTokens == nil { + err = fmt.Errorf("%w: max_tokens or max_completion_tokens is required", internalapi.ErrInvalidRequestBody) + return + } + + // Translate openAI contents to anthropic params. + // 2. Translate messages and system prompts. + messages, systemBlocks, err := openAIToAnthropicMessages(openAIReq.Messages) + if err != nil { + return + } + + // 3. Translate tools and tool choice. + tools, toolChoice, err := translateOpenAItoAnthropicTools(openAIReq.Tools, openAIReq.ToolChoice, openAIReq.ParallelToolCalls) + if err != nil { + return + } + + // 4. Construct the final struct in one place. + params = &anthropic.MessageNewParams{ + Messages: messages, + MaxTokens: *maxTokens, + System: systemBlocks, + Tools: tools, + ToolChoice: toolChoice, + } + + if openAIReq.Temperature != nil { + if err = validateTemperatureForAnthropic(openAIReq.Temperature); err != nil { + return nil, err + } + params.Temperature = anthropic.Float(*openAIReq.Temperature) + } + if openAIReq.TopP != nil { + params.TopP = anthropic.Float(*openAIReq.TopP) + } + if openAIReq.Stop.OfString.Valid() { + params.StopSequences = []string{openAIReq.Stop.OfString.String()} + } else if openAIReq.Stop.OfStringArray != nil { + params.StopSequences = openAIReq.Stop.OfStringArray + } + + // 5. Handle Vendor specific fields. + // Since GCPAnthropic follows the Anthropic API, we also check for Anthropic vendor fields. + if openAIReq.Thinking != nil { + params.Thinking = *getThinkingConfigParamUnion(openAIReq.Thinking) + } + + return params, nil +} + +// anthropicToolUseToOpenAICalls converts Anthropic tool_use content blocks to OpenAI tool calls. +func anthropicToolUseToOpenAICalls(block *anthropic.ContentBlockUnion) ([]openai.ChatCompletionMessageToolCallParam, error) { + var toolCalls []openai.ChatCompletionMessageToolCallParam + if block.Type != string(constant.ValueOf[constant.ToolUse]()) { + return toolCalls, nil + } + argsBytes, err := json.Marshal(block.Input) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool_use input: %w", err) + } + toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallParam{ + ID: &block.ID, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: block.Name, + Arguments: string(argsBytes), + }, + }) + + return toolCalls, nil +} + +// following are streaming part + +var ( + sseEventPrefix = []byte("event: ") + emptyStrPtr = ptr.To("") +) + +// streamingToolCall holds the state for a single tool call that is being streamed. +type streamingToolCall struct { + id string + name string + inputJSON string +} + +// anthropicStreamParser manages the stateful translation of an Anthropic SSE stream +// to an OpenAI-compatible SSE stream. +type anthropicStreamParser struct { + buffer bytes.Buffer + activeMessageID string + activeToolCalls map[int64]*streamingToolCall + toolIndex int64 + tokenUsage metrics.TokenUsage + stopReason anthropic.StopReason + requestModel internalapi.RequestModel + sentFirstChunk bool + created openai.JSONUNIXTime +} + +// newAnthropicStreamParser creates a new parser for a streaming request. +func newAnthropicStreamParser(requestModel string) *anthropicStreamParser { + toolIdx := int64(-1) + return &anthropicStreamParser{ + requestModel: requestModel, + activeToolCalls: make(map[int64]*streamingToolCall), + toolIndex: toolIdx, + } +} + +func (p *anthropicStreamParser) writeChunk(eventBlock []byte, buf *[]byte) error { + chunk, err := p.parseAndHandleEvent(eventBlock) + if err != nil { + return err + } + if chunk != nil { + err := serializeOpenAIChatCompletionChunk(chunk, buf) + if err != nil { + return err + } + } + return nil +} + +// Process reads from the Anthropic SSE stream, translates events to OpenAI chunks, +// and returns the mutations for Envoy. +func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( + newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, +) { + newBody = make([]byte, 0) + _ = span // TODO: add support for streaming chunks in tracing. + responseModel = p.requestModel + if _, err = p.buffer.ReadFrom(body); err != nil { + err = fmt.Errorf("failed to read from stream body: %w", err) + return + } + + for { + eventBlock, remaining, found := bytes.Cut(p.buffer.Bytes(), []byte("\n\n")) + if !found { + break + } + + if err = p.writeChunk(eventBlock, &newBody); err != nil { + return + } + + p.buffer.Reset() + p.buffer.Write(remaining) + } + + if endOfStream && p.buffer.Len() > 0 { + finalEventBlock := p.buffer.Bytes() + p.buffer.Reset() + + if err = p.writeChunk(finalEventBlock, &newBody); err != nil { + return + } + } + + if endOfStream { + inputTokens, _ := p.tokenUsage.InputTokens() + outputTokens, _ := p.tokenUsage.OutputTokens() + p.tokenUsage.SetTotalTokens(inputTokens + outputTokens) + totalTokens, _ := p.tokenUsage.TotalTokens() + cachedTokens, _ := p.tokenUsage.CachedInputTokens() + cacheCreationTokens, _ := p.tokenUsage.CacheCreationInputTokens() + finalChunk := openai.ChatCompletionResponseChunk{ + ID: p.activeMessageID, + Created: p.created, + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{}, + Usage: &openai.Usage{ + PromptTokens: int(inputTokens), + CompletionTokens: int(outputTokens), + TotalTokens: int(totalTokens), + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: int(cachedTokens), + CacheCreationTokens: int(cacheCreationTokens), + }, + }, + Model: p.requestModel, + } + + // Add active tool calls to the final chunk. + var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall + for toolIndex, tool := range p.activeToolCalls { + toolCalls = append(toolCalls, openai.ChatCompletionChunkChoiceDeltaToolCall{ + ID: &tool.id, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: tool.name, + Arguments: tool.inputJSON, + }, + Index: toolIndex, + }) + } + + if len(toolCalls) > 0 { + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: toolCalls, + } + finalChunk.Choices = append(finalChunk.Choices, openai.ChatCompletionResponseChunkChoice{ + Delta: &delta, + }) + } + + if finalChunk.Usage.PromptTokens > 0 || finalChunk.Usage.CompletionTokens > 0 || len(finalChunk.Choices) > 0 { + err := serializeOpenAIChatCompletionChunk(&finalChunk, &newBody) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal final stream chunk: %w", err) + } + } + // Add the final [DONE] message to indicate the end of the stream. + newBody = append(newBody, sseDataPrefix...) + newBody = append(newBody, sseDoneMessage...) + newBody = append(newBody, '\n', '\n') + } + tokenUsage = p.tokenUsage + return +} + +func (p *anthropicStreamParser) parseAndHandleEvent(eventBlock []byte) (*openai.ChatCompletionResponseChunk, error) { + var eventType []byte + var eventData []byte + + lines := bytes.SplitSeq(eventBlock, []byte("\n")) + for line := range lines { + if after, ok := bytes.CutPrefix(line, sseEventPrefix); ok { + eventType = bytes.TrimSpace(after) + } else if after, ok := bytes.CutPrefix(line, sseDataPrefix); ok { + // This handles JSON data that might be split across multiple 'data:' lines + // by concatenating them (Anthropic's format). + data := bytes.TrimSpace(after) + eventData = append(eventData, data...) + } + } + + if len(eventType) > 0 && len(eventData) > 0 { + return p.handleAnthropicStreamEvent(eventType, eventData) + } + + return nil, nil +} + +func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, data []byte) (*openai.ChatCompletionResponseChunk, error) { + switch string(eventType) { + case string(constant.ValueOf[constant.MessageStart]()): + var event anthropic.MessageStartEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_start: %w", err) + } + p.activeMessageID = event.Message.ID + p.created = openai.JSONUNIXTime(time.Now()) + u := event.Message.Usage + usage := metrics.ExtractTokenUsageFromExplicitCaching( + u.InputTokens, + u.OutputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, + ) + // For message_start, we store the initial usage but don't add to the accumulated + // The message_delta event will contain the final totals + if input, ok := usage.InputTokens(); ok { + p.tokenUsage.SetInputTokens(input) + } + if cached, ok := usage.CachedInputTokens(); ok { + p.tokenUsage.SetCachedInputTokens(cached) + } + + // reset the toolIndex for each message + p.toolIndex = -1 + return nil, nil + + case string(constant.ValueOf[constant.ContentBlockStart]()): + var event anthropic.ContentBlockStartEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err) + } + if event.ContentBlock.Type == string(constant.ValueOf[constant.ToolUse]()) || event.ContentBlock.Type == string(constant.ValueOf[constant.ServerToolUse]()) { + p.toolIndex++ + var argsJSON string + // Check if the input field is provided directly in the start event. + if event.ContentBlock.Input != nil { + switch input := event.ContentBlock.Input.(type) { + case map[string]any: + // for case where "input":{}, skip adding it to arguments. + if len(input) > 0 { + argsBytes, err := json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool use input: %w", err) + } + argsJSON = string(argsBytes) + } + default: + // although golang sdk defines type of Input to be any, + // python sdk requires the type of Input to be Dict[str, object]: + // https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_use_block.py#L14. + return nil, fmt.Errorf("unexpected tool use input type: %T", input) + } + } + + // Store the complete input JSON in our state. + p.activeToolCalls[p.toolIndex] = &streamingToolCall{ + id: event.ContentBlock.ID, + name: event.ContentBlock.Name, + inputJSON: argsJSON, + } + + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ + { + Index: p.toolIndex, + ID: &event.ContentBlock.ID, + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: event.ContentBlock.Name, + // Include the arguments if they are available. + Arguments: argsJSON, + }, + }, + }, + } + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + if event.ContentBlock.Type == string(constant.ValueOf[constant.Thinking]()) { + delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: emptyStrPtr} + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + + if event.ContentBlock.Type == string(constant.ValueOf[constant.RedactedThinking]()) { + // This is a latency-hiding event, ignore it. + return nil, nil + } + + return nil, nil + + case string(constant.ValueOf[constant.MessageDelta]()): + var event anthropic.MessageDeltaEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_delta: %w", err) + } + u := event.Usage + usage := metrics.ExtractTokenUsageFromExplicitCaching( + u.InputTokens, + u.OutputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, + ) + // For message_delta, accumulate the incremental output tokens + if output, ok := usage.OutputTokens(); ok { + p.tokenUsage.AddOutputTokens(output) + } + // Update input tokens to include any cache tokens from delta + if cached, ok := usage.CachedInputTokens(); ok { + p.tokenUsage.AddInputTokens(cached) + // Accumulate any additional cache tokens from delta + p.tokenUsage.AddCachedInputTokens(cached) + } + if event.Delta.StopReason != "" { + p.stopReason = event.Delta.StopReason + } + return nil, nil + + case string(constant.ValueOf[constant.ContentBlockDelta]()): + var event anthropic.ContentBlockDeltaEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal content_block_delta: %w", err) + } + switch event.Delta.Type { + case string(constant.ValueOf[constant.TextDelta]()), string(constant.ValueOf[constant.ThinkingDelta]()): + // Treat thinking_delta just like a text_delta. + delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text} + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + case string(constant.ValueOf[constant.InputJSONDelta]()): + tool, ok := p.activeToolCalls[p.toolIndex] + if !ok { + return nil, fmt.Errorf("received input_json_delta for unknown tool at index %d", p.toolIndex) + } + delta := openai.ChatCompletionResponseChunkChoiceDelta{ + ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ + { + Index: p.toolIndex, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Arguments: event.Delta.PartialJSON, + }, + }, + }, + } + tool.inputJSON += event.Delta.PartialJSON + return p.constructOpenAIChatCompletionChunk(delta, ""), nil + } + + case string(constant.ValueOf[constant.ContentBlockStop]()): + // This event is for state cleanup, no chunk is sent. + var event anthropic.ContentBlockStopEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal content_block_stop: %w", err) + } + delete(p.activeToolCalls, p.toolIndex) + return nil, nil + + case string(constant.ValueOf[constant.MessageStop]()): + var event anthropic.MessageStopEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, fmt.Errorf("unmarshal message_stop: %w", err) + } + + if p.stopReason == "" { + p.stopReason = anthropic.StopReasonEndTurn + } + + finishReason, err := anthropicToOpenAIFinishReason(p.stopReason) + if err != nil { + return nil, err + } + return p.constructOpenAIChatCompletionChunk(openai.ChatCompletionResponseChunkChoiceDelta{}, finishReason), nil + + case string(constant.ValueOf[constant.Error]()): + var errEvent anthropic.ErrorResponse + if err := json.Unmarshal(data, &errEvent); err != nil { + return nil, fmt.Errorf("unparsable error event: %s", string(data)) + } + return nil, fmt.Errorf("anthropic stream error: %s - %s", errEvent.Error.Type, errEvent.Error.Message) + + case "ping": + // Per documentation, ping events can be ignored. + return nil, nil + } + return nil, nil +} + +// constructOpenAIChatCompletionChunk builds the stream chunk. +func (p *anthropicStreamParser) constructOpenAIChatCompletionChunk(delta openai.ChatCompletionResponseChunkChoiceDelta, finishReason openai.ChatCompletionChoicesFinishReason) *openai.ChatCompletionResponseChunk { + // Add the 'assistant' role to the very first chunk of the response. + if !p.sentFirstChunk { + // Only add the role if the delta actually contains content or a tool call. + if delta.Content != nil || len(delta.ToolCalls) > 0 { + delta.Role = openai.ChatMessageRoleAssistant + p.sentFirstChunk = true + } + } + + return &openai.ChatCompletionResponseChunk{ + ID: p.activeMessageID, + Created: p.created, + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{ + { + Delta: &delta, + FinishReason: finishReason, + }, + }, + Model: p.requestModel, + } +} + +// messageToChatCompletion is to translate from anthropic API's response Message into OpenAI API's response ChatCompletion +func messageToChatCompletion(anthropicResp *anthropic.Message, responseModel internalapi.RequestModel) (openAIResp *openai.ChatCompletionResponse, tokenUsage metrics.TokenUsage, err error) { + openAIResp = &openai.ChatCompletionResponse{ + ID: anthropicResp.ID, + Model: responseModel, + Object: string(openAIconstant.ValueOf[openAIconstant.ChatCompletion]()), + Choices: make([]openai.ChatCompletionResponseChoice, 0), + Created: openai.JSONUNIXTime(time.Now()), + } + usage := anthropicResp.Usage + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( + usage.InputTokens, + usage.OutputTokens, + &usage.CacheReadInputTokens, + &usage.CacheCreationInputTokens, + ) + inputTokens, _ := tokenUsage.InputTokens() + outputTokens, _ := tokenUsage.OutputTokens() + totalTokens, _ := tokenUsage.TotalTokens() + cachedTokens, _ := tokenUsage.CachedInputTokens() + cacheCreationTokens, _ := tokenUsage.CacheCreationInputTokens() + openAIResp.Usage = openai.Usage{ + CompletionTokens: int(outputTokens), + PromptTokens: int(inputTokens), + TotalTokens: int(totalTokens), + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: int(cachedTokens), + CacheCreationTokens: int(cacheCreationTokens), + }, + } + + finishReason, err := anthropicToOpenAIFinishReason(anthropicResp.StopReason) + if err != nil { + return nil, metrics.TokenUsage{}, err + } + + role, err := anthropicRoleToOpenAIRole(anthropic.MessageParamRole(anthropicResp.Role)) + if err != nil { + return nil, metrics.TokenUsage{}, err + } + + choice := openai.ChatCompletionResponseChoice{ + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: role}, + FinishReason: finishReason, + } + + for i := range anthropicResp.Content { // NOTE: Content structure is massive, do not range over values. + output := &anthropicResp.Content[i] + switch output.Type { + case string(constant.ValueOf[constant.ToolUse]()): + if output.ID != "" { + toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) + if toolErr != nil { + return nil, metrics.TokenUsage{}, fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) + } + choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) + } + case string(constant.ValueOf[constant.Text]()): + if output.Text != "" { + if choice.Message.Content == nil { + choice.Message.Content = &output.Text + } + } + case string(constant.ValueOf[constant.Thinking]()): + if output.Thinking != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: output.Thinking, + Signature: output.Signature, + }, + }, + }, + } + } + case string(constant.ValueOf[constant.RedactedThinking]()): + if output.Data != "" { + choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + RedactedContent: []byte(output.Data), + }, + }, + } + } + } + } + openAIResp.Choices = append(openAIResp.Choices, choice) + return openAIResp, tokenUsage, nil +} diff --git a/internal/translator/anthropic_helper_test.go b/internal/translator/anthropic_helper_test.go new file mode 100644 index 0000000000..f6d13ffad6 --- /dev/null +++ b/internal/translator/anthropic_helper_test.go @@ -0,0 +1,894 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "fmt" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// mockErrorReader is a helper for testing io.Reader failures. +type mockErrorReader struct{} + +func (r *mockErrorReader) Read(_ []byte) (n int, err error) { + return 0, fmt.Errorf("mock reader error") +} + +// New test function for helper coverage. +func TestHelperFunctions(t *testing.T) { + t.Run("anthropicToOpenAIFinishReason invalid reason", func(t *testing.T) { + _, err := anthropicToOpenAIFinishReason("unknown_reason") + require.Error(t, err) + require.Contains(t, err.Error(), "received invalid stop reason") + }) + + t.Run("anthropicRoleToOpenAIRole invalid role", func(t *testing.T) { + _, err := anthropicRoleToOpenAIRole("unknown_role") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid anthropic role") + }) +} + +func TestTranslateOpenAItoAnthropicTools(t *testing.T) { + anthropicTestTool := []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + } + openaiTestTool := []openai.Tool{ + {Type: "function", Function: &openai.FunctionDefinition{Name: "get_weather"}}, + } + tests := []struct { + name string + openAIReq *openai.ChatCompletionRequest + expectedTools []anthropic.ToolUnionParam + expectedToolChoice anthropic.ToolChoiceUnionParam + expectErr bool + }{ + { + name: "auto tool choice", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{ + DisableParallelToolUse: anthropic.Bool(false), + }, + }, + }, + { + name: "any tool choice", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "any"}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAny: &anthropic.ToolChoiceAnyParam{}, + }, + }, + { + name: "specific tool choice by name", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: openai.ChatCompletionNamedToolChoice{Type: "function", Function: openai.ChatCompletionNamedToolChoiceFunction{Name: "my_func"}}}, + Tools: openaiTestTool, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{Type: "tool", Name: "my_func"}, + }, + }, + { + name: "tool definition", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + }, + { + name: "tool_definition_with_required_field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather with a required location", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + "unit": map[string]any{"type": "string"}, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather with a required location"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + "unit": map[string]any{"type": "string"}, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool definition with no parameters", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_time", + Description: "Get the current time", + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_time", + Description: anthropic.String("Get the current time"), + }, + }, + }, + }, + { + name: "disable parallel tool calls", + openAIReq: &openai.ChatCompletionRequest{ + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + Tools: openaiTestTool, + ParallelToolCalls: ptr.To(false), + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{ + DisableParallelToolUse: anthropic.Bool(true), + }, + }, + }, + { + name: "explicitly enable parallel tool calls", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + ParallelToolCalls: ptr.To(true), + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, + }, + }, + { + name: "default disable parallel tool calls to false (nil)", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, + }, + }, + { + name: "none tool choice", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "none"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfNone: &anthropic.ToolChoiceNoneParam{}, + }, + }, + { + name: "function tool choice", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "function"}, + }, + expectedTools: anthropicTestTool, + expectedToolChoice: anthropic.ToolChoiceUnionParam{ + OfTool: &anthropic.ToolChoiceToolParam{Name: "function"}, + }, + }, + { + name: "invalid tool choice string", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "invalid_choice"}, + }, + expectErr: true, + }, + { + name: "skips function tool with nil function definition", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: nil, // This tool has the correct type but a nil definition and should be skipped. + }, + { + Type: "function", + Function: &openai.FunctionDefinition{Name: "get_weather"}, // This is a valid tool. + }, + }, + }, + // We expect only the valid function tool to be translated. + expectedTools: []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + }, + expectErr: false, + }, + { + name: "skips non-function tools", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "retrieval", + }, + { + Type: "function", + Function: &openai.FunctionDefinition{Name: "get_weather"}, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, + }, + expectErr: false, + }, + { + name: "tool definition without type field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather without type", + Parameters: map[string]any{ + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather without type"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "", + Properties: map[string]any{ + "location": map[string]any{"type": "string"}, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool definition without properties field", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_weather", + Description: "Get the weather without properties", + Parameters: map[string]any{ + "type": "object", + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "get_weather", + Description: anthropic.String("Get the weather without properties"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "unsupported tool_choice type", + openAIReq: &openai.ChatCompletionRequest{ + Tools: openaiTestTool, + ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: 123}, // Use an integer to trigger the default case. + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.openAIReq.ToolChoice != nil { + require.NotNil(t, toolChoice) + require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) + if tt.expectedToolChoice.GetName() != nil { + require.Equal(t, *tt.expectedToolChoice.GetName(), *toolChoice.GetName()) + } + if tt.expectedToolChoice.OfTool != nil { + require.Equal(t, tt.expectedToolChoice.OfTool.Name, toolChoice.OfTool.Name) + } + if tt.expectedToolChoice.OfAuto != nil { + require.Equal(t, tt.expectedToolChoice.OfAuto.DisableParallelToolUse, toolChoice.OfAuto.DisableParallelToolUse) + } + } + if tt.openAIReq.Tools != nil { + require.NotNil(t, tools) + require.Len(t, tools, len(tt.expectedTools)) + require.Equal(t, tt.expectedTools[0].GetName(), tools[0].GetName()) + require.Equal(t, tt.expectedTools[0].GetType(), tools[0].GetType()) + require.Equal(t, tt.expectedTools[0].GetDescription(), tools[0].GetDescription()) + if tt.expectedTools[0].GetInputSchema().Properties != nil { + require.Equal(t, tt.expectedTools[0].GetInputSchema().Properties, tools[0].GetInputSchema().Properties) + } + } + } + }) + } +} + +// TestFinishReasonTranslation covers specific cases for the anthropicToOpenAIFinishReason function. +func TestFinishReasonTranslation(t *testing.T) { + tests := []struct { + name string + input anthropic.StopReason + expectedFinishReason openai.ChatCompletionChoicesFinishReason + expectErr bool + }{ + { + name: "max tokens stop reason", + input: anthropic.StopReasonMaxTokens, + expectedFinishReason: openai.ChatCompletionChoicesFinishReasonLength, + }, + { + name: "refusal stop reason", + input: anthropic.StopReasonRefusal, + expectedFinishReason: openai.ChatCompletionChoicesFinishReasonContentFilter, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reason, err := anthropicToOpenAIFinishReason(tt.input) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedFinishReason, reason) + } + }) + } +} + +// TestToolParameterDereferencing tests the JSON schema dereferencing functionality +// for tool parameters when translating from OpenAI to GCP Anthropic. +func TestToolParameterDereferencing(t *testing.T) { + tests := []struct { + name string + openAIReq *openai.ChatCompletionRequest + expectedTools []anthropic.ToolUnionParam + expectedToolChoice anthropic.ToolChoiceUnionParam + expectErr bool + expectedErrMsg string + }{ + { + name: "tool with complex nested $ref - successful dereferencing", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "complex_tool", + Description: "Tool with complex nested references", + Parameters: map[string]any{ + "type": "object", + "$defs": map[string]any{ + "BaseType": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + "required": []any{"id"}, + }, + }, + "NestedType": map[string]any{ + "allOf": []any{ + map[string]any{"$ref": "#/$defs/BaseType"}, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + }, + "properties": map[string]any{ + "nested": map[string]any{ + "$ref": "#/$defs/NestedType", + }, + }, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "complex_tool", + Description: anthropic.String("Tool with complex nested references"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "nested": map[string]any{ + "allOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + "required": []any{"id"}, + }, + }, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "tool with invalid $ref - dereferencing error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "invalid_ref_tool", + Description: "Tool with invalid reference", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "$ref": "#/$defs/NonExistent", + }, + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to dereference tool parameters", + }, + { + name: "tool with circular $ref - dereferencing error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "circular_ref_tool", + Description: "Tool with circular reference", + Parameters: map[string]any{ + "type": "object", + "$defs": map[string]any{ + "A": map[string]any{ + "type": "object", + "properties": map[string]any{ + "b": map[string]any{ + "$ref": "#/$defs/B", + }, + }, + }, + "B": map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{ + "$ref": "#/$defs/A", + }, + }, + }, + }, + "properties": map[string]any{ + "circular": map[string]any{ + "$ref": "#/$defs/A", + }, + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to dereference tool parameters", + }, + { + name: "tool without $ref - no dereferencing needed", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "simple_tool", + Description: "Simple tool without references", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + }, + }, + "required": []any{"location"}, + }, + }, + }, + }, + }, + expectedTools: []anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "simple_tool", + Description: anthropic.String("Simple tool without references"), + InputSchema: anthropic.ToolInputSchemaParam{ + Type: "object", + Properties: map[string]any{ + "location": map[string]any{ + "type": "string", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + }, + { + name: "tool parameter dereferencing returns non-map type - casting error", + openAIReq: &openai.ChatCompletionRequest{ + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "problematic_tool", + Description: "Tool with parameters that can't be properly dereferenced to map", + // This creates a scenario where jsonSchemaDereference might return a non-map type + // though this is a contrived example since normally the function should return map[string]any + Parameters: map[string]any{ + "$ref": "#/$defs/StringType", // This would resolve to a string, not a map + "$defs": map[string]any{ + "StringType": "not-a-map", // This would cause the casting to fail + }, + }, + }, + }, + }, + }, + expectErr: true, + expectedErrMsg: "failed to cast dereferenced tool parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) + + if tt.expectErr { + require.Error(t, err) + if tt.expectedErrMsg != "" { + require.Contains(t, err.Error(), tt.expectedErrMsg) + } + return + } + + require.NoError(t, err) + + if tt.openAIReq.Tools != nil { + require.NotNil(t, tools) + require.Len(t, tools, len(tt.expectedTools)) + + for i, expectedTool := range tt.expectedTools { + actualTool := tools[i] + require.Equal(t, expectedTool.GetName(), actualTool.GetName()) + require.Equal(t, expectedTool.GetType(), actualTool.GetType()) + require.Equal(t, expectedTool.GetDescription(), actualTool.GetDescription()) + + expectedSchema := expectedTool.GetInputSchema() + actualSchema := actualTool.GetInputSchema() + + require.Equal(t, expectedSchema.Type, actualSchema.Type) + require.Equal(t, expectedSchema.Required, actualSchema.Required) + + // For properties, we'll do a deep comparison to verify dereferencing worked + if expectedSchema.Properties != nil { + require.NotNil(t, actualSchema.Properties) + require.Equal(t, expectedSchema.Properties, actualSchema.Properties) + } + } + } + + if tt.openAIReq.ToolChoice != nil { + require.NotNil(t, toolChoice) + require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) + } + }) + } +} + +// TestContentTranslationCoverage adds specific coverage for the openAIToAnthropicContent helper. +func TestContentTranslationCoverage(t *testing.T) { + tests := []struct { + name string + inputContent any + expectedContent []anthropic.ContentBlockParamUnion + expectErr bool + }{ + { + name: "nil content", + inputContent: nil, + }, + { + name: "empty string content", + inputContent: "", + }, + { + name: "pdf data uri", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "data:application/pdf;base64,dGVzdA=="}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfDocument: &anthropic.DocumentBlockParam{ + Source: anthropic.DocumentBlockParamSourceUnion{ + OfBase64: &anthropic.Base64PDFSourceParam{ + Type: constant.ValueOf[constant.Base64](), + MediaType: constant.ValueOf[constant.ApplicationPDF](), + Data: "dGVzdA==", + }, + }, + }, + }, + }, + }, + { + name: "pdf url", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/doc.pdf"}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfDocument: &anthropic.DocumentBlockParam{ + Source: anthropic.DocumentBlockParamSourceUnion{ + OfURL: &anthropic.URLPDFSourceParam{ + Type: constant.ValueOf[constant.URL](), + URL: "https://example.com/doc.pdf", + }, + }, + }, + }, + }, + }, + { + name: "image url", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{ + {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/image.png"}}}, + }, + expectedContent: []anthropic.ContentBlockParamUnion{ + { + OfImage: &anthropic.ImageBlockParam{ + Source: anthropic.ImageBlockParamSourceUnion{ + OfURL: &anthropic.URLImageSourceParam{ + Type: constant.ValueOf[constant.URL](), + URL: "https://example.com/image.png", + }, + }, + }, + }, + }, + }, + { + name: "audio content error", + inputContent: []openai.ChatCompletionContentPartUserUnionParam{{OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{}}}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + content, err := openAIToAnthropicContent(tt.inputContent) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. + require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") + + // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. + require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") + for i, expectedBlock := range tt.expectedContent { + actualBlock := content[i] + require.Equal(t, expectedBlock.GetType(), actualBlock.GetType(), "Content block types should match") + if expectedBlock.OfDocument != nil { + require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") + require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") + + if expectedBlock.OfDocument.Source.OfBase64 != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") + require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) + } + if expectedBlock.OfDocument.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") + require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) + } + } + if expectedBlock.OfImage != nil { + require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") + require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") + + if expectedBlock.OfImage.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") + require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) + } + } + } + + for i, expectedBlock := range tt.expectedContent { + actualBlock := content[i] + if expectedBlock.OfDocument != nil { + require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") + require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") + + if expectedBlock.OfDocument.Source.OfBase64 != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") + require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) + } + if expectedBlock.OfDocument.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") + require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) + } + } + if expectedBlock.OfImage != nil { + require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") + require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") + + if expectedBlock.OfImage.Source.OfURL != nil { + require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") + require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) + } + } + } + }) + } +} + +// TestSystemPromptExtractionCoverage adds specific coverage for the extractSystemPromptFromDeveloperMsg helper. +func TestSystemPromptExtractionCoverage(t *testing.T) { + tests := []struct { + name string + inputMsg openai.ChatCompletionDeveloperMessageParam + expectedPrompt string + }{ + { + name: "developer message with content parts", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ + {Type: "text", Text: "part 1"}, + {Type: "text", Text: " part 2"}, + }}, + }, + expectedPrompt: "part 1 part 2", + }, + { + name: "developer message with nil content", + inputMsg: openai.ChatCompletionDeveloperMessageParam{Content: openai.ContentUnion{Value: nil}}, + expectedPrompt: "", + }, + { + name: "developer message with string content", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: "simple string"}, + }, + expectedPrompt: "simple string", + }, + { + name: "developer message with text parts array", + inputMsg: openai.ChatCompletionDeveloperMessageParam{ + Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ + {Type: "text", Text: "text part"}, + }}, + }, + expectedPrompt: "text part", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt, _ := extractSystemPromptFromDeveloperMsg(tt.inputMsg) + require.Equal(t, tt.expectedPrompt, prompt) + }) + } +} diff --git a/internal/translator/openai_awsanthropic.go b/internal/translator/openai_awsanthropic.go new file mode 100644 index 0000000000..3da7e8ab69 --- /dev/null +++ b/internal/translator/openai_awsanthropic.go @@ -0,0 +1,261 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/url" + "strconv" + "strings" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/json" + "github.com/envoyproxy/ai-gateway/internal/metrics" + "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" +) + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +const BedrockDefaultVersion = "bedrock-2023-05-31" + +// NewChatCompletionOpenAIToAWSAnthropicTranslator implements [Factory] for OpenAI to AWS Anthropic translation. +// This translator converts OpenAI ChatCompletion API requests to AWS Anthropic API format. +func NewChatCompletionOpenAIToAWSAnthropicTranslator(apiVersion string, modelNameOverride internalapi.ModelNameOverride) OpenAIChatCompletionTranslator { + return &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + apiVersion: apiVersion, + modelNameOverride: modelNameOverride, + } +} + +// openAIToAWSAnthropicTranslatorV1ChatCompletion translates OpenAI Chat Completions API to AWS Anthropic Claude API. +// This uses the AWS Bedrock InvokeModel and InvokeModelWithResponseStream APIs: +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +type openAIToAWSAnthropicTranslatorV1ChatCompletion struct { + apiVersion string + modelNameOverride internalapi.ModelNameOverride + streamParser *anthropicStreamParser + requestModel internalapi.RequestModel + bufferedBody []byte +} + +// RequestBody implements [OpenAIChatCompletionTranslator.RequestBody] for AWS Anthropic. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + o.requestModel = openAIReq.Model + if o.modelNameOverride != "" { + o.requestModel = o.modelNameOverride + } + + // URL encode the model name for the path to handle special characters (e.g., ARNs) + encodedModelName := url.PathEscape(o.requestModel) + + // Set the path for AWS Bedrock InvokeModel API + // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_RequestSyntax + pathTemplate := "/model/%s/invoke" + if openAIReq.Stream { + pathTemplate = "/model/%s/invoke-with-response-stream" + o.streamParser = newAnthropicStreamParser(o.requestModel) + } + + params, err := buildAnthropicParams(openAIReq) + if err != nil { + return + } + + body, err := json.Marshal(params) + if err != nil { + return + } + + // b. Set the "anthropic_version" key in the JSON body + // Using same logic as anthropic go SDK: https://github.com/anthropics/anthropic-sdk-go/blob/e252e284244755b2b2f6eef292b09d6d1e6cd989/bedrock/bedrock.go#L167 + anthropicVersion := BedrockDefaultVersion + if o.apiVersion != "" { + anthropicVersion = o.apiVersion + } + body, err = sjson.SetBytes(body, anthropicVersionKey, anthropicVersion) + if err != nil { + return + } + newBody = body + + newHeaders = []internalapi.Header{ + {pathHeaderName, fmt.Sprintf(pathTemplate, encodedModelName)}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + return +} + +// ResponseError implements [OpenAIChatCompletionTranslator.ResponseError]. +// Translate AWS Bedrock exceptions to OpenAI error type. +// The error type is stored in the "x-amzn-errortype" HTTP header for AWS error responses. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body io.Reader) ( + newHeaders []internalapi.Header, newBody []byte, err error, +) { + statusCode := respHeaders[statusHeaderName] + var openaiError openai.Error + if v, ok := respHeaders[contentTypeHeaderName]; ok && strings.Contains(v, jsonContentType) { + var bedrockError awsbedrock.BedrockException + if err = json.NewDecoder(body).Decode(&bedrockError); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal error body: %w", err) + } + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: respHeaders[awsErrorTypeHeaderName], + Message: bedrockError.Message, + Code: &statusCode, + }, + } + } else { + var buf []byte + buf, err = io.ReadAll(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read error body: %w", err) + } + openaiError = openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsBedrockBackendError, + Message: string(buf), + Code: &statusCode, + }, + } + } + newBody, err = json.Marshal(openaiError) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal error body: %w", err) + } + newHeaders = []internalapi.Header{ + {contentTypeHeaderName, jsonContentType}, + {contentLengthHeaderName, strconv.Itoa(len(newBody))}, + } + return +} + +// ResponseHeaders implements [OpenAIChatCompletionTranslator.ResponseHeaders]. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseHeaders(_ map[string]string) ( + newHeaders []internalapi.Header, err error, +) { + if o.streamParser != nil { + newHeaders = []internalapi.Header{{contentTypeHeaderName, eventStreamContentType}} + } + return +} + +// ResponseBody implements [OpenAIChatCompletionTranslator.ResponseBody] for AWS Anthropic. +// AWS Anthropic uses deterministic model mapping without virtualization, where the requested model +// is exactly what gets executed. The response does not contain a model field, so we return +// the request model that was originally sent. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[string]string, body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( + newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, +) { + // If a stream parser was initialized, this is a streaming request. + if o.streamParser != nil { + // AWS Bedrock wraps Anthropic events in EventStream binary format + // We need to decode EventStream and extract the SSE payload + buf, readErr := io.ReadAll(body) + if readErr != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to read stream body: %w", readErr) + } + + // Buffer the data for EventStream decoding + o.bufferedBody = append(o.bufferedBody, buf...) + + // Extract Anthropic SSE from AWS EventStream wrapper + // This decodes the base64-encoded events and formats them as SSE + anthropicSSE := o.extractAnthropicSSEFromEventStream() + + // Pass the extracted SSE to the Anthropic parser + return o.streamParser.Process(bytes.NewReader(anthropicSSE), endOfStream, span) + } + + var anthropicResp anthropic.Message + if err = json.NewDecoder(body).Decode(&anthropicResp); err != nil { + return nil, nil, tokenUsage, "", fmt.Errorf("failed to unmarshal body: %w", err) + } + + responseModel = o.requestModel + if anthropicResp.Model != "" { + responseModel = string(anthropicResp.Model) + } + + openAIResp, tokenUsage, err := messageToChatCompletion(&anthropicResp, responseModel) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", err + } + + newBody, err = json.Marshal(openAIResp) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal body: %w", err) + } + + if span != nil { + span.RecordResponse(openAIResp) + } + newHeaders = []internalapi.Header{{contentLengthHeaderName, strconv.Itoa(len(newBody))}} + return +} + +// extractAnthropicSSEFromEventStream decodes AWS EventStream binary format +// and extracts Anthropic events, converting them to SSE format. +// AWS Bedrock wraps each Anthropic event as base64-encoded JSON in EventStream messages. +func (o *openAIToAWSAnthropicTranslatorV1ChatCompletion) extractAnthropicSSEFromEventStream() []byte { + if len(o.bufferedBody) == 0 { + return nil + } + + r := bytes.NewReader(o.bufferedBody) + dec := eventstream.NewDecoder() + var result []byte + var lastRead int64 + + for { + msg, err := dec.Decode(r, nil) + if err != nil { + // End of stream or incomplete message - keep remaining data buffered + o.bufferedBody = o.bufferedBody[lastRead:] + return result + } + + // AWS Bedrock payload format: {"bytes":"base64-encoded-json","p":"..."} + var payload struct { + Bytes string `json:"bytes"` // base64-encoded Anthropic event JSON + } + if unMarshalErr := json.Unmarshal(msg.Payload, &payload); unMarshalErr != nil || payload.Bytes == "" { + lastRead = r.Size() - int64(r.Len()) + continue + } + + // Base64 decode to get the Anthropic event JSON + decodedBytes, err := base64.StdEncoding.DecodeString(payload.Bytes) + if err != nil { + lastRead = r.Size() - int64(r.Len()) + continue + } + + // Extract the event type from JSON + // Use gjson for robust extraction even from malformed JSON + eventType := gjson.GetBytes(decodedBytes, "type").String() + + // Convert to SSE format: "event: TYPE\ndata: JSON\n\n" + // Pass through even if malformed - streamParser will detect and report errors + sseEvent := fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(decodedBytes)) + result = append(result, []byte(sseEvent)...) + + lastRead = r.Size() - int64(r.Len()) + } +} diff --git a/internal/translator/openai_awsanthropic_test.go b/internal/translator/openai_awsanthropic_test.go new file mode 100644 index 0000000000..4c983bc03c --- /dev/null +++ b/internal/translator/openai_awsanthropic_test.go @@ -0,0 +1,812 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "encoding/base64" + stdjson "encoding/json" // nolint: depguard + "fmt" + "io" + "strconv" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "k8s.io/utils/ptr" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/json" +) + +// wrapAnthropicSSEInEventStream wraps Anthropic SSE data in AWS EventStream format. +// AWS Bedrock base64-encodes each event's JSON data (which includes the type field) and wraps it in EventStream messages. +func wrapAnthropicSSEInEventStream(sseData string) ([]byte, error) { + buf := bytes.NewBuffer(nil) + encoder := eventstream.NewEncoder() + + // Parse SSE format to extract individual events + // SSE format: "event: TYPE\ndata: JSON\n\n" + events := bytes.Split([]byte(sseData), []byte("\n\n")) + + for _, eventBlock := range events { + if len(bytes.TrimSpace(eventBlock)) == 0 { + continue + } + + // Extract both event type and data from the SSE event + lines := bytes.Split(eventBlock, []byte("\n")) + var eventType string + var jsonData []byte + for _, line := range lines { + if bytes.HasPrefix(line, []byte("event: ")) { + eventType = string(bytes.TrimPrefix(line, []byte("event: "))) + } else if bytes.HasPrefix(line, []byte("data: ")) { + jsonData = bytes.TrimPrefix(line, []byte("data: ")) + } + } + + if len(jsonData) == 0 { + continue + } + + // AWS Bedrock Anthropic format includes the type in the JSON data itself + // If the JSON doesn't already have a "type" field (like in malformed test cases), + // we need to add it to match real AWS Bedrock behavior + var finalJSON []byte + if eventType != "" && !bytes.Contains(jsonData, []byte(`"type"`)) { + // Prepend the type field to simulate real Anthropic event format + // For malformed JSON, this creates something like: {"type": "message_start", {invalid...} + // which is still malformed, but has the type field that can be extracted + finalJSON = []byte(fmt.Sprintf(`{"type": "%s", %s`, eventType, string(jsonData[1:]))) + if jsonData[0] != '{' { + // If it doesn't even start with {, just wrap it + finalJSON = []byte(fmt.Sprintf(`{"type": "%s", "data": %s}`, eventType, string(jsonData))) + } + } else { + finalJSON = jsonData + } + + // Base64 encode the JSON data (this is what AWS Bedrock does) + base64Data := base64.StdEncoding.EncodeToString(finalJSON) + + // Create a payload with the base64-encoded data in the "bytes" field + payload := struct { + Bytes string `json:"bytes"` + }{ + Bytes: base64Data, + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + // Encode as EventStream message + err = encoder.Encode(buf, eventstream.Message{ + Headers: eventstream.Headers{{Name: ":event-type", Value: eventstream.StringValue("chunk")}}, + Payload: payloadBytes, + }) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +// TestResponseModel_AWSAnthropic tests that AWS Anthropic (non-streaming) returns the request model +// AWS Anthropic uses deterministic model mapping without virtualization +func TestResponseModel_AWSAnthropic(t *testing.T) { + modelName := "anthropic.claude-sonnet-4-20250514-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-sonnet-4", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // AWS Anthropic response doesn't have model field, uses Anthropic format + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello!", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + _, _, tokenUsage, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization + inputTokens, ok := tokenUsage.InputTokens() + require.True(t, ok) + require.Equal(t, uint32(10), inputTokens) + outputTokens, ok := tokenUsage.OutputTokens() + require.True(t, ok) + require.Equal(t, uint32(5), outputTokens) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T) { + // Define a common input request to use for both standard and vertex tests. + openAIReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfSystem: &openai.ChatCompletionSystemMessageParam{Content: openai.ContentUnion{Value: "You are a helpful assistant."}, Role: openai.ChatMessageRoleSystem}, + }, + { + OfUser: &openai.ChatCompletionUserMessageParam{Content: openai.StringOrUserRoleContentUnion{Value: "Hello!"}, Role: openai.ChatMessageRoleUser}, + }, + }, + MaxTokens: ptr.To(int64(1024)), + Temperature: ptr.To(0.7), + } + + t.Run("AWS Bedrock InvokeModel Values Configured Correctly", func(t *testing.T) { + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + require.NotNil(t, body) + + // Check the path header. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke", openAIReq.Model) + require.Equal(t, expectedPath, pathHeader.Value()) + + // Check the body content. + require.NotNil(t, body) + // Model should NOT be present in the body for AWS Bedrock. + require.False(t, gjson.GetBytes(body, "model").Exists()) + // Anthropic version should be present for AWS Bedrock. + require.Equal(t, BedrockDefaultVersion, gjson.GetBytes(body, "anthropic_version").String()) + }) + + t.Run("Model Name Override", func(t *testing.T) { + overrideModelName := "anthropic.claude-3-haiku-20240307-v1:0" + // Instantiate the translator with the model name override. + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", overrideModelName) + + // Call RequestBody with the original request, which has a different model name. + hm, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses the override model name. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke", overrideModelName) + require.Equal(t, expectedPath, pathHeader.Value()) + }) + + t.Run("Model Name with ARN (URL encoding)", func(t *testing.T) { + arnModelName := "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-opus-20240229-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", arnModelName) + + hm, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses URL-encoded model name. + pathHeader := hm[0] + require.Equal(t, pathHeaderName, pathHeader.Key()) + // url.PathEscape encodes slashes but not colons (colons are valid in URL paths) + // So we expect slashes to be encoded as %2F + require.Contains(t, pathHeader.Value(), "arn:aws:bedrock") // Colons are not encoded + require.Contains(t, pathHeader.Value(), "%2Fanthropic") // Slashes are encoded + }) + + t.Run("Streaming Request Validation", func(t *testing.T) { + streamReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: ptr.To(int64(100)), + Stream: true, + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, err := translator.RequestBody(nil, streamReq, false) + require.NoError(t, err) + require.NotNil(t, hm) + + // Check that the :path header uses the invoke-with-response-stream endpoint. + pathHeader := hm + require.Equal(t, pathHeaderName, pathHeader[0].Key()) + expectedPath := fmt.Sprintf("/model/%s/invoke-with-response-stream", streamReq.Model) + require.Equal(t, expectedPath, pathHeader[0].Value()) + + // AWS Bedrock uses the endpoint path to indicate streaming (invoke-with-response-stream) + // The Anthropic Messages API body format doesn't require a "stream" field + // Verify the body is valid JSON with expected Anthropic fields + require.True(t, gjson.GetBytes(body, "max_tokens").Exists()) + require.True(t, gjson.GetBytes(body, "anthropic_version").Exists()) + }) + + t.Run("API Version Override", func(t *testing.T) { + customAPIVersion := "bedrock-2024-01-01" + // Instantiate the translator with the custom API version. + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator(customAPIVersion, "") + + // Call RequestBody with a standard request. + _, body, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + require.NotNil(t, body) + + // Check that the anthropic_version in the body uses the custom version. + require.Equal(t, customAPIVersion, gjson.GetBytes(body, "anthropic_version").String()) + }) + + t.Run("Invalid Temperature (above bound)", func(t *testing.T) { + invalidTempReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: ptr.To(int64(100)), + Temperature: ptr.To(2.5), + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, err := translator.RequestBody(nil, invalidTempReq, false) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf(tempNotSupportedError, *invalidTempReq.Temperature)) + }) + + t.Run("Missing MaxTokens Throws Error", func(t *testing.T) { + missingTokensReq := &openai.ChatCompletionRequest{ + Model: "anthropic.claude-3-opus-20240229-v1:0", + Messages: []openai.ChatCompletionMessageParamUnion{}, + MaxTokens: nil, + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, err := translator.RequestBody(nil, missingTokensReq, false) + require.ErrorContains(t, err, "max_tokens or max_completion_tokens is required") + }) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + t.Run("invalid json body", func(t *testing.T) { + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + _, _, _, _, err := translator.ResponseBody(map[string]string{statusHeaderName: "200"}, bytes.NewBufferString("invalid json"), true, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal body") + }) + + tests := []struct { + name string + inputResponse *anthropic.Message + respHeaders map[string]string + expectedOpenAIResponse openai.ChatCompletionResponse + }{ + { + name: "basic text response", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Hello there!"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(releaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 15, + CompletionTokens: 20, + TotalTokens: 35, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 5, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: "assistant", Content: ptr.To("Hello there!")}, + FinishReason: openai.ChatCompletionChoicesFinishReasonStop, + }, + }, + }, + }, + { + name: "response with tool use", + inputResponse: &anthropic.Message{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{ + {Type: "text", Text: "Ok, I will call the tool."}, + {Type: "tool_use", ID: "toolu_01", Name: "get_weather", Input: stdjson.RawMessage(`{"location": "Tokyo", "unit": "celsius"}`)}, + }, + StopReason: anthropic.StopReasonToolUse, + Usage: anthropic.Usage{InputTokens: 25, OutputTokens: 15, CacheReadInputTokens: 10}, + }, + respHeaders: map[string]string{statusHeaderName: "200"}, + expectedOpenAIResponse: openai.ChatCompletionResponse{ + ID: "msg_01XYZ123", + Model: "claude-3-5-sonnet-20241022", + Created: openai.JSONUNIXTime(time.Unix(releaseDateUnix, 0)), + Object: "chat.completion", + Usage: openai.Usage{ + PromptTokens: 35, CompletionTokens: 15, TotalTokens: 50, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 10, + }, + }, + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + FinishReason: openai.ChatCompletionChoicesFinishReasonToolCalls, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: string(anthropic.MessageParamRoleAssistant), + Content: ptr.To("Ok, I will call the tool."), + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: ptr.To("toolu_01"), + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_weather", + Arguments: `{"location": "Tokyo", "unit": "celsius"}`, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := json.Marshal(tt.inputResponse) + require.NoError(t, err, "Test setup failed: could not marshal input struct") + + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "") + hm, body, usedToken, _, err := translator.ResponseBody(tt.respHeaders, bytes.NewBuffer(body), true, nil) + + require.NoError(t, err, "Translator returned an unexpected internal error") + require.NotNil(t, hm) + require.NotNil(t, body) + + newBody := body + require.NotNil(t, newBody) + require.Len(t, hm, 1) + require.Equal(t, contentLengthHeaderName, hm[0].Key()) + require.Equal(t, strconv.Itoa(len(newBody)), hm[0].Value()) + + var gotResp openai.ChatCompletionResponse + err = json.Unmarshal(newBody, &gotResp) + require.NoError(t, err) + + expectedTokenUsage := tokenUsageFrom( + int32(tt.expectedOpenAIResponse.Usage.PromptTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.PromptTokensDetails.CachedTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.PromptTokensDetails.CacheCreationTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.CompletionTokens), // nolint:gosec + int32(tt.expectedOpenAIResponse.Usage.TotalTokens), // nolint:gosec + ) + require.Equal(t, expectedTokenUsage, usedToken) + + if diff := cmp.Diff(tt.expectedOpenAIResponse, gotResp, cmpopts.IgnoreFields(openai.ChatCompletionResponse{}, "Created")); diff != "" { + t.Errorf("ResponseBody mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToAWSAnthropicTranslator_ResponseError(t *testing.T) { + tests := []struct { + name string + responseHeaders map[string]string + inputBody any + expectedOutput openai.Error + }{ + { + name: "non-json error response", + responseHeaders: map[string]string{ + statusHeaderName: "503", + contentTypeHeaderName: "text/plain; charset=utf-8", + }, + inputBody: "Service Unavailable", + expectedOutput: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: awsBedrockBackendError, + Code: ptr.To("503"), + Message: "Service Unavailable", + }, + }, + }, + { + name: "json error response", + responseHeaders: map[string]string{ + statusHeaderName: "400", + contentTypeHeaderName: "application/json", + awsErrorTypeHeaderName: "ValidationException", + }, + inputBody: &awsbedrock.BedrockException{ + Message: "messages: field is required", + }, + expectedOutput: openai.Error{ + Type: "error", + Error: openai.ErrorType{ + Type: "ValidationException", + Code: ptr.To("400"), + Message: "messages: field is required", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var reader io.Reader + if bodyStr, ok := tt.inputBody.(string); ok { + reader = bytes.NewBufferString(bodyStr) + } else { + bodyBytes, err := json.Marshal(tt.inputBody) + require.NoError(t, err) + reader = bytes.NewBuffer(bodyBytes) + } + + o := &openAIToAWSAnthropicTranslatorV1ChatCompletion{} + hm, body, err := o.ResponseError(tt.responseHeaders, reader) + + require.NoError(t, err) + require.NotNil(t, body) + require.NotNil(t, hm) + require.Len(t, hm, 2) + require.Equal(t, contentTypeHeaderName, hm[0].Key()) + require.Equal(t, jsonContentType, hm[0].Value()) //nolint:testifylint + require.Equal(t, contentLengthHeaderName, hm[1].Key()) + require.Equal(t, strconv.Itoa(len(body)), hm[1].Value()) + + var gotError openai.Error + err = json.Unmarshal(body, &gotError) + require.NoError(t, err) + + if diff := cmp.Diff(tt.expectedOutput, gotError); diff != "" { + t.Errorf("ResponseError() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// TestResponseModel_AWSAnthropicStreaming tests that AWS Anthropic streaming returns the request model +// AWS Anthropic uses deterministic model mapping without virtualization +func TestResponseModel_AWSAnthropicStreaming(t *testing.T) { + modelName := "anthropic.claude-sonnet-4-20250514-v1:0" + sseStream := `event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-sonnet-4@20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 10, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 5}} + +event: message_stop +data: {"type": "message_stop"} + +` + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: modelName, + MaxTokens: new(int64), + } + + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // Test streaming response - AWS Anthropic doesn't return model in response, uses request model + _, _, tokenUsage, responseModel, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization + inputTokens, ok := tokenUsage.InputTokens() + require.True(t, ok) + require.Equal(t, uint32(10), inputTokens) + outputTokens, ok := tokenUsage.OutputTokens() + require.True(t, ok) + require.Equal(t, uint32(5), outputTokens) +} + +func TestOpenAIToAWSAnthropicTranslatorV1ChatCompletion_ResponseBody_Streaming(t *testing.T) { + t.Run("handles simple text stream", func(t *testing.T) { + sseStream := ` +event: message_start +data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-opus-4-20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + +event: content_block_start +data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + +event: content_block_delta +data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} + +event: content_block_stop +data: {"type": "content_block_stop", "index": 0} + +event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} + +event: message_stop +data: {"type": "message_stop"} + +` + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: "test-model", + MaxTokens: new(int64), + } + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + + bodyStr := string(bm) + require.Contains(t, bodyStr, `"content":"Hello"`) + require.Contains(t, bodyStr, `"finish_reason":"stop"`) + require.Contains(t, bodyStr, `"prompt_tokens":25`) + require.Contains(t, bodyStr, `"completion_tokens":15`) + require.Contains(t, bodyStr, string(sseDoneMessage)) + }) + + t.Run("handles tool use stream", func(t *testing.T) { + sseStream := `event: message_start +data: {"type":"message_start","message":{"id":"msg_014p7gG3wDgGV9EUtLvnow3U","type":"message","role":"assistant","model":"claude-opus-4-20250514","stop_sequence":null,"usage":{"input_tokens":472,"output_tokens":2},"content":[],"stop_reason":null}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Checking weather"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: content_block_start +data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"San Francisco, CA\", \"unit\": \"fahrenheit\"}"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":1} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":89}} + +event: message_stop +data: {"type":"message_stop"} +` + + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, bm, _, _, err := translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), true, nil) + require.NoError(t, err) + require.NotNil(t, bm) + bodyStr := string(bm) + + require.Contains(t, bodyStr, `"content":"Checking weather"`) + require.Contains(t, bodyStr, `"name":"get_weather"`) + require.Contains(t, bodyStr, `"finish_reason":"tool_calls"`) + require.Contains(t, bodyStr, string(sseDoneMessage)) + }) +} + +func TestAWSAnthropicStreamParser_ErrorHandling(t *testing.T) { + runStreamErrTest := func(t *testing.T, sseStream string, endOfStream bool) error { + // Wrap SSE data in AWS EventStream format + eventStreamData, err := wrapAnthropicSSEInEventStream(sseStream) + require.NoError(t, err) + + openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", "").(*openAIToAWSAnthropicTranslatorV1ChatCompletion) + _, _, err = translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + _, _, _, _, err = translator.ResponseBody(map[string]string{}, bytes.NewReader(eventStreamData), endOfStream, nil) + return err + } + + tests := []struct { + name string + sseStream string + endOfStream bool + expectedError string + }{ + { + name: "malformed message_start event", + sseStream: "event: message_start\ndata: {invalid\n\n", + expectedError: "unmarshal message_start", + }, + { + name: "malformed content_block_start event", + sseStream: "event: content_block_start\ndata: {invalid\n\n", + expectedError: "failed to unmarshal content_block_start", + }, + { + name: "malformed error event data", + sseStream: "event: error\ndata: {invalid\n\n", + expectedError: "unparsable error event", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := runStreamErrTest(t, tt.sseStream, tt.endOfStream) + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + }) + } + + t.Run("body read error", func(t *testing.T) { + parser := newAnthropicStreamParser("test-model") + _, _, _, _, err := parser.Process(&mockErrorReader{}, false, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read from stream body") + }) +} + +func TestOpenAIToAWSAnthropicTranslator_ResponseHeaders(t *testing.T) { + t.Run("non-streaming request", func(t *testing.T) { + translator := &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + streamParser: nil, // Not streaming + } + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Empty(t, headers) + }) + + t.Run("streaming request", func(t *testing.T) { + translator := &openAIToAWSAnthropicTranslatorV1ChatCompletion{ + streamParser: newAnthropicStreamParser("test-model"), + } + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Len(t, headers, 1) + require.Equal(t, contentTypeHeaderName, headers[0].Key()) + require.Equal(t, eventStreamContentType, headers[0].Value()) + }) +} + +func TestOpenAIToAWSAnthropicTranslator_EdgeCases(t *testing.T) { + t.Run("response with model field from API", func(t *testing.T) { + // AWS Anthropic may return model field in response + modelName := "custom-override-model" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + req := &openai.ChatCompletionRequest{ + Model: "original-model", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + {OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Test"}, + Role: openai.ChatMessageRoleUser, + }}, + }, + } + _, _, err := translator.RequestBody(nil, req, false) + require.NoError(t, err) + + // Response with model field + anthropicResp := anthropic.Message{ + ID: "msg_123", + Model: "claude-3-opus-20240229", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Response"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 5, OutputTokens: 3}, + } + + body, err := json.Marshal(anthropicResp) + require.NoError(t, err) + + _, _, _, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + // Should use model from response when available + assert.Equal(t, string(anthropicResp.Model), responseModel) + }) + + t.Run("response without model field", func(t *testing.T) { + // AWS Anthropic typically doesn't return model field + modelName := "anthropic.claude-3-haiku-20240307-v1:0" + translator := NewChatCompletionOpenAIToAWSAnthropicTranslator("", modelName) + + req := &openai.ChatCompletionRequest{ + Model: "original-model", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + {OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Test"}, + Role: openai.ChatMessageRoleUser, + }}, + }, + } + _, _, err := translator.RequestBody(nil, req, false) + require.NoError(t, err) + + // Response without model field (typical for AWS Bedrock) + anthropicResp := anthropic.Message{ + ID: "msg_123", + Role: constant.Assistant(anthropic.MessageParamRoleAssistant), + Content: []anthropic.ContentBlockUnion{{Type: "text", Text: "Response"}}, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{InputTokens: 5, OutputTokens: 3}, + } + + body, err := json.Marshal(anthropicResp) + require.NoError(t, err) + + _, _, _, responseModel, err := translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + // Should use request model when response doesn't have model field + assert.Equal(t, modelName, responseModel) + }) +} diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index 1bc9b5d03a..90975ea1de 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -376,7 +376,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes CachePoint: cachePointBlock, }) } - case string: return nil, fmt.Errorf("%w: redacted_content must be a binary/bytes value in bedrock", internalapi.ErrInvalidRequestBody) default: @@ -774,7 +773,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string } } - // AWS Bedrock does not support N(multiple choices) > 0, so there could be only one choice. + // AWS Bedrock Converse API does not support N(multiple choices) > 0, so there could be only one choice. choice := openai.ChatCompletionResponseChoice{ Index: (int64)(0), Message: openai.ChatCompletionResponseChoiceMessage{ diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 197dae26ef..6fcbf829ea 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -6,23 +6,16 @@ package translator import ( - "cmp" - "encoding/base64" "fmt" "io" "log/slog" "strconv" "strings" - "time" "github.com/anthropics/anthropic-sdk-go" - anthropicParam "github.com/anthropics/anthropic-sdk-go/packages/param" - "github.com/anthropics/anthropic-sdk-go/shared/constant" anthropicVertex "github.com/anthropics/anthropic-sdk-go/vertex" - openAIconstant "github.com/openai/openai-go/shared/constant" "github.com/tidwall/sjson" - "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/json" @@ -33,9 +26,7 @@ import ( // currently a requirement for GCP Vertex / Anthropic API https://docs.anthropic.com/en/api/claude-on-vertex-ai const ( - anthropicVersionKey = "anthropic_version" - gcpBackendError = "GCPBackendError" - tempNotSupportedError = "temperature %.2f is not supported by Anthropic (must be between 0.0 and 1.0)" + gcpBackendError = "GCPBackendError" ) // NewChatCompletionOpenAIToGCPAnthropicTranslator implements [Factory] for OpenAI to GCP Anthropic translation. @@ -61,652 +52,18 @@ type openAIToGCPAnthropicTranslatorV1ChatCompletion struct { logger *slog.Logger } -func anthropicToOpenAIFinishReason(stopReason anthropic.StopReason) (openai.ChatCompletionChoicesFinishReason, error) { - switch stopReason { - // The most common stop reason. Indicates Claude finished its response naturally. - // or Claude encountered one of your custom stop sequences. - // TODO: A better way to return pause_turn - // TODO: "pause_turn" Used with server tools like web search when Claude needs to pause a long-running operation. - case anthropic.StopReasonEndTurn, anthropic.StopReasonStopSequence, anthropic.StopReasonPauseTurn: - return openai.ChatCompletionChoicesFinishReasonStop, nil - case anthropic.StopReasonMaxTokens: // Claude stopped because it reached the max_tokens limit specified in your request. - // TODO: do we want to return an error? see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#handling-the-max-tokens-stop-reason - return openai.ChatCompletionChoicesFinishReasonLength, nil - case anthropic.StopReasonToolUse: - return openai.ChatCompletionChoicesFinishReasonToolCalls, nil - case anthropic.StopReasonRefusal: - return openai.ChatCompletionChoicesFinishReasonContentFilter, nil - default: - return "", fmt.Errorf("received invalid stop reason %v", stopReason) - } -} - -// validateTemperatureForAnthropic checks if the temperature is within Anthropic's supported range (0.0 to 1.0). -// Returns an error if the value is greater than 1.0. -func validateTemperatureForAnthropic(temp *float64) error { - if temp != nil && (*temp < 0.0 || *temp > 1.0) { - return fmt.Errorf("%w: temperature must be between 0.0 and 1.0", internalapi.ErrInvalidRequestBody) - } - return nil -} - -func isAnthropicSupportedImageMediaType(mediaType string) bool { - switch anthropic.Base64ImageSourceMediaType(mediaType) { - case anthropic.Base64ImageSourceMediaTypeImageJPEG, - anthropic.Base64ImageSourceMediaTypeImagePNG, - anthropic.Base64ImageSourceMediaTypeImageGIF, - anthropic.Base64ImageSourceMediaTypeImageWebP: - return true - default: - return false - } -} - -// translateAnthropicToolChoice converts the OpenAI tool_choice parameter to the Anthropic format. -func translateAnthropicToolChoice(openAIToolChoice *openai.ChatCompletionToolChoiceUnion, disableParallelToolUse anthropicParam.Opt[bool]) (anthropic.ToolChoiceUnionParam, error) { - var toolChoice anthropic.ToolChoiceUnionParam - - if openAIToolChoice == nil { - return toolChoice, nil - } - - switch choice := openAIToolChoice.Value.(type) { - case string: - switch choice { - case string(openAIconstant.ValueOf[openAIconstant.Auto]()): - toolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}} - toolChoice.OfAuto.DisableParallelToolUse = disableParallelToolUse - case "required", "any": - toolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}} - toolChoice.OfAny.DisableParallelToolUse = disableParallelToolUse - case "none": - toolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}} - case string(openAIconstant.ValueOf[openAIconstant.Function]()): - // This is how anthropic forces tool use. - // TODO: should we check if strict true in openAI request, and if so, use this? - toolChoice = anthropic.ToolChoiceUnionParam{OfTool: &anthropic.ToolChoiceToolParam{Name: choice}} - toolChoice.OfTool.DisableParallelToolUse = disableParallelToolUse - default: - return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: unsupported tool_choice value '%s'", internalapi.ErrInvalidRequestBody, choice) - } - case openai.ChatCompletionNamedToolChoice: - if choice.Type == openai.ToolTypeFunction && choice.Function.Name != "" { - toolChoice = anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{ - Type: constant.Tool("tool"), - Name: choice.Function.Name, - DisableParallelToolUse: disableParallelToolUse, - }, - } - } - default: - return anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: tool_choice type not supported", internalapi.ErrInvalidRequestBody) - } - return toolChoice, nil -} - -// translateOpenAItoAnthropicTools translates OpenAI tool and tool_choice parameters -// into the Anthropic format and returns translated tool & tool choice. -func translateOpenAItoAnthropicTools(openAITools []openai.Tool, openAIToolChoice *openai.ChatCompletionToolChoiceUnion, parallelToolCalls *bool) (tools []anthropic.ToolUnionParam, toolChoice anthropic.ToolChoiceUnionParam, err error) { - if len(openAITools) > 0 { - anthropicTools := make([]anthropic.ToolUnionParam, 0, len(openAITools)) - for _, openAITool := range openAITools { - if openAITool.Type != openai.ToolTypeFunction || openAITool.Function == nil { - // Anthropic only supports 'function' tools, so we skip others. - continue - } - toolParam := anthropic.ToolParam{ - Name: openAITool.Function.Name, - Description: anthropic.String(openAITool.Function.Description), - } - - if isCacheEnabled(openAITool.Function.AnthropicContentFields) { - toolParam.CacheControl = anthropic.NewCacheControlEphemeralParam() - } - - // The parameters for the function are expected to be a JSON Schema object. - // We can pass them through as-is. - if openAITool.Function.Parameters != nil { - paramsMap, ok := openAITool.Function.Parameters.(map[string]any) - if !ok { - err = fmt.Errorf("%w: tool parameters must be a JSON object", internalapi.ErrInvalidRequestBody) - return - } - - inputSchema := anthropic.ToolInputSchemaParam{} - - // Dereference json schema - // If the paramsMap contains $refs we need to dereference them - var dereferencedParamsMap any - if dereferencedParamsMap, err = jsonSchemaDereference(paramsMap); err != nil { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("invalid JSON schema in tool parameters: %w", err) - } - if paramsMap, ok = dereferencedParamsMap.(map[string]any); !ok { - return nil, anthropic.ToolChoiceUnionParam{}, fmt.Errorf("%w: tool parameters must be a JSON object", internalapi.ErrInvalidRequestBody) - } - - var typeVal string - if typeVal, ok = paramsMap["type"].(string); ok { - inputSchema.Type = constant.Object(typeVal) - } - - var propsVal map[string]any - if propsVal, ok = paramsMap["properties"].(map[string]any); ok { - inputSchema.Properties = propsVal - } - - var requiredVal []any - if requiredVal, ok = paramsMap["required"].([]any); ok { - requiredSlice := make([]string, len(requiredVal)) - for i, v := range requiredVal { - if s, ok := v.(string); ok { - requiredSlice[i] = s - } - } - inputSchema.Required = requiredSlice - } - - toolParam.InputSchema = inputSchema - } - - anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &toolParam}) - if len(anthropicTools) > 0 { - tools = anthropicTools - } - } - - // 2. Handle the tool_choice parameter. - // disable parallel tool use default value is false - // see: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use - disableParallelToolUse := anthropic.Bool(false) - if parallelToolCalls != nil { - // OpenAI variable checks to allow parallel tool calls. - // Anthropic variable checks to disable, so need to use the inverse. - disableParallelToolUse = anthropic.Bool(!*parallelToolCalls) - } - - toolChoice, err = translateAnthropicToolChoice(openAIToolChoice, disableParallelToolUse) - if err != nil { - return - } - - } - return -} - -// convertImageContentToAnthropic translates an OpenAI image URL into the corresponding Anthropic content block. -// It handles data URIs for various image types and PDFs, as well as remote URLs. -func convertImageContentToAnthropic(imageURL string, fields *openai.AnthropicContentFields) (anthropic.ContentBlockParamUnion, error) { - var cacheControlParam anthropic.CacheControlEphemeralParam - if isCacheEnabled(fields) { - cacheControlParam = fields.CacheControl - } - - switch { - case strings.HasPrefix(imageURL, "data:"): - contentType, data, err := parseDataURI(imageURL) - if err != nil { - return anthropic.ContentBlockParamUnion{}, fmt.Errorf("%w: invalid image data URI", internalapi.ErrInvalidRequestBody) - } - base64Data := base64.StdEncoding.EncodeToString(data) - if contentType == string(constant.ValueOf[constant.ApplicationPDF]()) { - pdfSource := anthropic.Base64PDFSourceParam{Data: base64Data} - docBlock := anthropic.NewDocumentBlock(pdfSource) - docBlock.OfDocument.CacheControl = cacheControlParam - return docBlock, nil - } - if isAnthropicSupportedImageMediaType(contentType) { - imgBlock := anthropic.NewImageBlockBase64(contentType, base64Data) - imgBlock.OfImage.CacheControl = cacheControlParam - return imgBlock, nil - } - return anthropic.ContentBlockParamUnion{}, fmt.Errorf("%w: invalid media_type for image '%s'", internalapi.ErrInvalidRequestBody, contentType) - case strings.HasSuffix(strings.ToLower(imageURL), ".pdf"): - docBlock := anthropic.NewDocumentBlock(anthropic.URLPDFSourceParam{URL: imageURL}) - docBlock.OfDocument.CacheControl = cacheControlParam - return docBlock, nil - default: - imgBlock := anthropic.NewImageBlock(anthropic.URLImageSourceParam{URL: imageURL}) - imgBlock.OfImage.CacheControl = cacheControlParam - return imgBlock, nil - } -} - -func isCacheEnabled(fields *openai.AnthropicContentFields) bool { - return fields != nil && fields.CacheControl.Type == constant.ValueOf[constant.Ephemeral]() -} - -// convertContentPartsToAnthropic iterates over a slice of OpenAI content parts -// and converts each into an Anthropic content block. -func convertContentPartsToAnthropic(parts []openai.ChatCompletionContentPartUserUnionParam) ([]anthropic.ContentBlockParamUnion, error) { - resultContent := make([]anthropic.ContentBlockParamUnion, 0, len(parts)) - for _, contentPart := range parts { - switch { - case contentPart.OfText != nil: - textBlock := anthropic.NewTextBlock(contentPart.OfText.Text) - if isCacheEnabled(contentPart.OfText.AnthropicContentFields) { - textBlock.OfText.CacheControl = contentPart.OfText.CacheControl - } - resultContent = append(resultContent, textBlock) - - case contentPart.OfImageURL != nil: - block, err := convertImageContentToAnthropic(contentPart.OfImageURL.ImageURL.URL, contentPart.OfImageURL.AnthropicContentFields) - if err != nil { - return nil, err - } - resultContent = append(resultContent, block) - - case contentPart.OfInputAudio != nil: - return nil, fmt.Errorf("%w: input audio content not supported yet", internalapi.ErrInvalidRequestBody) - case contentPart.OfFile != nil: - return nil, fmt.Errorf("%w: file content not supported yet", internalapi.ErrInvalidRequestBody) - } - } - return resultContent, nil -} - -// Helper: Convert OpenAI message content to Anthropic content. -func openAIToAnthropicContent(content any) ([]anthropic.ContentBlockParamUnion, error) { - switch v := content.(type) { - case nil: - return nil, nil - case string: - if v == "" { - return nil, nil - } - return []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock(v), - }, nil - case []openai.ChatCompletionContentPartUserUnionParam: - return convertContentPartsToAnthropic(v) - case openai.ContentUnion: - switch val := v.Value.(type) { - case string: - if val == "" { - return nil, nil - } - return []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock(val), - }, nil - case []openai.ChatCompletionContentPartTextParam: - var contentBlocks []anthropic.ContentBlockParamUnion - for _, part := range val { - textBlock := anthropic.NewTextBlock(part.Text) - // In an array of text parts, each can have its own cache setting. - if isCacheEnabled(part.AnthropicContentFields) { - textBlock.OfText.CacheControl = part.CacheControl - } - contentBlocks = append(contentBlocks, textBlock) - } - return contentBlocks, nil - default: - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) - } - } - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) -} - -// extractSystemPromptFromDeveloperMsg flattens content and checks for cache flags. -// It returns the combined string and a boolean indicating if any part was cacheable. -func extractSystemPromptFromDeveloperMsg(msg openai.ChatCompletionDeveloperMessageParam) (msgValue string, cacheParam *anthropic.CacheControlEphemeralParam) { - switch v := msg.Content.Value.(type) { - case nil: - return - case string: - msgValue = v - return - case []openai.ChatCompletionContentPartTextParam: - // Concatenate all text parts and check for caching. - var sb strings.Builder - for _, part := range v { - sb.WriteString(part.Text) - if isCacheEnabled(part.AnthropicContentFields) { - cacheParam = &part.CacheControl - } - } - msgValue = sb.String() - return - default: - return - } -} - -func anthropicRoleToOpenAIRole(role anthropic.MessageParamRole) (string, error) { - switch role { - case anthropic.MessageParamRoleAssistant: - return openai.ChatMessageRoleAssistant, nil - case anthropic.MessageParamRoleUser: - return openai.ChatMessageRoleUser, nil - default: - return "", fmt.Errorf("invalid anthropic role %v", role) - } -} - -// processAssistantContent processes a single ChatCompletionAssistantMessageParamContent and returns the corresponding Anthropic content block. -func processAssistantContent(content openai.ChatCompletionAssistantMessageParamContent) (*anthropic.ContentBlockParamUnion, error) { - switch content.Type { - case openai.ChatCompletionAssistantMessageParamContentTypeRefusal: - if content.Refusal != nil { - block := anthropic.NewTextBlock(*content.Refusal) - return &block, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeText: - if content.Text != nil { - textBlock := anthropic.NewTextBlock(*content.Text) - if isCacheEnabled(content.AnthropicContentFields) { - textBlock.OfText.CacheControl = content.CacheControl - } - return &textBlock, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeThinking: - // thinking can not be cached: https://platform.claude.com/docs/en/build-with-claude/prompt-caching - if content.Text != nil && content.Signature != nil { - thinkBlock := anthropic.NewThinkingBlock(*content.Signature, *content.Text) - return &thinkBlock, nil - } - case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking: - if content.RedactedContent != nil { - switch v := content.RedactedContent.Value.(type) { - case string: - redactedThinkingBlock := anthropic.NewRedactedThinkingBlock(v) - return &redactedThinkingBlock, nil - case []byte: - return nil, fmt.Errorf("%w: redacted_content must be a string in GCP", internalapi.ErrInvalidRequestBody) - default: - return nil, fmt.Errorf("%w: redacted_content must be a string in GCP", internalapi.ErrInvalidRequestBody) - } - } - default: - return nil, fmt.Errorf("%w: message 'content' must be a string or an array", internalapi.ErrInvalidRequestBody) - } - return nil, nil -} - -// openAIMessageToAnthropicMessageRoleAssistant converts an OpenAI assistant message to Anthropic content blocks. -// The tool_use content is appended to the Anthropic message content list if tool_calls are present. -func openAIMessageToAnthropicMessageRoleAssistant(openAiMessage *openai.ChatCompletionAssistantMessageParam) (anthropicMsg anthropic.MessageParam, err error) { - contentBlocks := make([]anthropic.ContentBlockParamUnion, 0) - if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 { - contentBlocks = append(contentBlocks, anthropic.NewTextBlock(v)) - } else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok { - // Handle single content object - var block *anthropic.ContentBlockParamUnion - block, err = processAssistantContent(content) - if err != nil { - return anthropicMsg, err - } else if block != nil { - contentBlocks = append(contentBlocks, *block) - } - } else if contents, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok { - // Handle array of content objects - for _, content := range contents { - var block *anthropic.ContentBlockParamUnion - block, err = processAssistantContent(content) - if err != nil { - return anthropicMsg, err - } else if block != nil { - contentBlocks = append(contentBlocks, *block) - } - } - } - - // Handle tool_calls (if any). - for i := range openAiMessage.ToolCalls { - toolCall := &openAiMessage.ToolCalls[i] - var input map[string]any - if err = json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { - err = fmt.Errorf("failed to unmarshal tool call arguments: %w", err) - return - } - toolUse := anthropic.ToolUseBlockParam{ - ID: *toolCall.ID, - Type: "tool_use", - Name: toolCall.Function.Name, - Input: input, - } - - if isCacheEnabled(toolCall.AnthropicContentFields) { - toolUse.CacheControl = toolCall.CacheControl - } - - contentBlocks = append(contentBlocks, anthropic.ContentBlockParamUnion{OfToolUse: &toolUse}) - } - - return anthropic.MessageParam{ - Role: anthropic.MessageParamRoleAssistant, - Content: contentBlocks, - }, nil -} - -// openAIToAnthropicMessages converts OpenAI messages to Anthropic message params type, handling all roles and system/developer logic. -func openAIToAnthropicMessages(openAIMsgs []openai.ChatCompletionMessageParamUnion) (anthropicMessages []anthropic.MessageParam, systemBlocks []anthropic.TextBlockParam, err error) { - for i := 0; i < len(openAIMsgs); { - msg := &openAIMsgs[i] - switch { - case msg.OfSystem != nil: - devParam := systemMsgToDeveloperMsg(*msg.OfSystem) - systemText, cacheControl := extractSystemPromptFromDeveloperMsg(devParam) - systemBlock := anthropic.TextBlockParam{Text: systemText} - if cacheControl != nil { - systemBlock.CacheControl = *cacheControl - } - systemBlocks = append(systemBlocks, systemBlock) - i++ - case msg.OfDeveloper != nil: - systemText, cacheControl := extractSystemPromptFromDeveloperMsg(*msg.OfDeveloper) - systemBlock := anthropic.TextBlockParam{Text: systemText} - if cacheControl != nil { - systemBlock.CacheControl = *cacheControl - } - systemBlocks = append(systemBlocks, systemBlock) - i++ - case msg.OfUser != nil: - message := *msg.OfUser - var content []anthropic.ContentBlockParamUnion - content, err = openAIToAnthropicContent(message.Content.Value) - if err != nil { - return - } - anthropicMsg := anthropic.MessageParam{ - Role: anthropic.MessageParamRoleUser, - Content: content, - } - anthropicMessages = append(anthropicMessages, anthropicMsg) - i++ - case msg.OfAssistant != nil: - assistantMessage := msg.OfAssistant - var messages anthropic.MessageParam - messages, err = openAIMessageToAnthropicMessageRoleAssistant(assistantMessage) - if err != nil { - return - } - anthropicMessages = append(anthropicMessages, messages) - i++ - case msg.OfTool != nil: - // Aggregate all consecutive tool messages into a single user message - // to support parallel tool use. - var toolResultBlocks []anthropic.ContentBlockParamUnion - for i < len(openAIMsgs) && openAIMsgs[i].ExtractMessgaeRole() == openai.ChatMessageRoleTool { - currentMsg := &openAIMsgs[i] - toolMsg := currentMsg.OfTool - - var contentBlocks []anthropic.ContentBlockParamUnion - contentBlocks, err = openAIToAnthropicContent(toolMsg.Content) - if err != nil { - return - } - - var toolContent []anthropic.ToolResultBlockParamContentUnion - var cacheControl *anthropic.CacheControlEphemeralParam - - for _, c := range contentBlocks { - var trb anthropic.ToolResultBlockParamContentUnion - // Check if the translated part has caching enabled. - switch { - case c.OfText != nil: - trb.OfText = c.OfText - cacheControl = &c.OfText.CacheControl - case c.OfImage != nil: - trb.OfImage = c.OfImage - cacheControl = &c.OfImage.CacheControl - case c.OfDocument != nil: - trb.OfDocument = c.OfDocument - cacheControl = &c.OfDocument.CacheControl - } - toolContent = append(toolContent, trb) - } - - isError := false - if contentStr, ok := toolMsg.Content.Value.(string); ok { - var contentMap map[string]any - if json.Unmarshal([]byte(contentStr), &contentMap) == nil { - if _, ok = contentMap["error"]; ok { - isError = true - } - } - } - - toolResultBlock := anthropic.ToolResultBlockParam{ - ToolUseID: toolMsg.ToolCallID, - Type: "tool_result", - Content: toolContent, - IsError: anthropic.Bool(isError), - } - - if cacheControl != nil { - toolResultBlock.CacheControl = *cacheControl - } - - toolResultBlockUnion := anthropic.ContentBlockParamUnion{OfToolResult: &toolResultBlock} - toolResultBlocks = append(toolResultBlocks, toolResultBlockUnion) - i++ - } - // Append all aggregated tool results. - anthropicMsg := anthropic.MessageParam{ - Role: anthropic.MessageParamRoleUser, - Content: toolResultBlocks, - } - anthropicMessages = append(anthropicMessages, anthropicMsg) - default: - err = fmt.Errorf("%w: unsupported role type: %s", internalapi.ErrInvalidRequestBody, msg.ExtractMessgaeRole()) - return - } - } - return -} - -// NewThinkingConfigParamUnion converts a ThinkingUnion into a ThinkingConfigParamUnion. -func getThinkingConfigParamUnion(tu *openai.ThinkingUnion) *anthropic.ThinkingConfigParamUnion { - if tu == nil { - return nil - } - - result := &anthropic.ThinkingConfigParamUnion{} - - if tu.OfEnabled != nil { - result.OfEnabled = &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: tu.OfEnabled.BudgetTokens, - Type: constant.Enabled(tu.OfEnabled.Type), - } - } else if tu.OfDisabled != nil { - result.OfDisabled = &anthropic.ThinkingConfigDisabledParam{ - Type: constant.Disabled(tu.OfDisabled.Type), - } - } - - return result -} - -// buildAnthropicParams is a helper function that translates an OpenAI request -// into the parameter struct required by the Anthropic SDK. -func buildAnthropicParams(openAIReq *openai.ChatCompletionRequest) (params *anthropic.MessageNewParams, err error) { - // 1. Handle simple parameters and defaults. - maxTokens := cmp.Or(openAIReq.MaxCompletionTokens, openAIReq.MaxTokens) - if maxTokens == nil { - err = fmt.Errorf("%w: max_tokens or max_completion_tokens is required", internalapi.ErrInvalidRequestBody) - return - } - - // Translate openAI contents to anthropic params. - // 2. Translate messages and system prompts. - messages, systemBlocks, err := openAIToAnthropicMessages(openAIReq.Messages) - if err != nil { - return - } - - // 3. Translate tools and tool choice. - tools, toolChoice, err := translateOpenAItoAnthropicTools(openAIReq.Tools, openAIReq.ToolChoice, openAIReq.ParallelToolCalls) - if err != nil { - return - } - - // 4. Construct the final struct in one place. - params = &anthropic.MessageNewParams{ - Messages: messages, - MaxTokens: *maxTokens, - System: systemBlocks, - Tools: tools, - ToolChoice: toolChoice, - } - - if openAIReq.Temperature != nil { - if err = validateTemperatureForAnthropic(openAIReq.Temperature); err != nil { - return nil, err - } - params.Temperature = anthropic.Float(*openAIReq.Temperature) - } - if openAIReq.TopP != nil { - params.TopP = anthropic.Float(*openAIReq.TopP) - } - if openAIReq.Stop.OfString.Valid() { - params.StopSequences = []string{openAIReq.Stop.OfString.String()} - } else if openAIReq.Stop.OfStringArray != nil { - params.StopSequences = openAIReq.Stop.OfStringArray - } - - // 5. Handle Vendor specific fields. - // Since GCPAnthropic follows the Anthropic API, we also check for Anthropic vendor fields. - if openAIReq.Thinking != nil { - params.Thinking = *getThinkingConfigParamUnion(openAIReq.Thinking) - } - - return params, nil -} - -// anthropicToolUseToOpenAICalls converts Anthropic tool_use content blocks to OpenAI tool calls. -func anthropicToolUseToOpenAICalls(block *anthropic.ContentBlockUnion) ([]openai.ChatCompletionMessageToolCallParam, error) { - var toolCalls []openai.ChatCompletionMessageToolCallParam - if block.Type != string(constant.ValueOf[constant.ToolUse]()) { - return toolCalls, nil - } - argsBytes, err := json.Marshal(block.Input) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool_use input: %w", err) - } - toolCalls = append(toolCalls, openai.ChatCompletionMessageToolCallParam{ - ID: &block.ID, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: block.Name, - Arguments: string(argsBytes), - }, - }) - - return toolCalls, nil -} - // RequestBody implements [OpenAIChatCompletionTranslator.RequestBody] for GCP. func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) ( newHeaders []internalapi.Header, newBody []byte, err error, ) { params, err := buildAnthropicParams(openAIReq) if err != nil { - return nil, nil, err + return } body, err := json.Marshal(params) if err != nil { - return nil, nil, fmt.Errorf("failed to marshal params: %w", err) + return } o.requestModel = openAIReq.Model @@ -896,12 +253,9 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri responseModel = string(anthropicResp.Model) } - openAIResp := &openai.ChatCompletionResponse{ - ID: anthropicResp.ID, - Model: responseModel, - Object: string(openAIconstant.ValueOf[openAIconstant.ChatCompletion]()), - Choices: make([]openai.ChatCompletionResponseChoice, 0), - Created: openai.JSONUNIXTime(time.Now()), + openAIResp, tokenUsage, err := messageToChatCompletion(&anthropicResp, responseModel) + if err != nil { + return nil, nil, metrics.TokenUsage{}, "", err } // Redact and log response when enabled @@ -912,88 +266,6 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri } } - usage := anthropicResp.Usage - tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( - usage.InputTokens, - usage.OutputTokens, - &usage.CacheReadInputTokens, - &usage.CacheCreationInputTokens, - ) - inputTokens, _ := tokenUsage.InputTokens() - outputTokens, _ := tokenUsage.OutputTokens() - totalTokens, _ := tokenUsage.TotalTokens() - cachedTokens, _ := tokenUsage.CachedInputTokens() - cacheWriteTokens, _ := tokenUsage.CacheCreationInputTokens() - openAIResp.Usage = openai.Usage{ - CompletionTokens: int(outputTokens), - PromptTokens: int(inputTokens), - TotalTokens: int(totalTokens), - PromptTokensDetails: &openai.PromptTokensDetails{ - CachedTokens: int(cachedTokens), - CacheCreationTokens: int(cacheWriteTokens), - }, - } - - finishReason, err := anthropicToOpenAIFinishReason(anthropicResp.StopReason) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", err - } - - role, err := anthropicRoleToOpenAIRole(anthropic.MessageParamRole(anthropicResp.Role)) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", err - } - - choice := openai.ChatCompletionResponseChoice{ - Index: 0, - Message: openai.ChatCompletionResponseChoiceMessage{Role: role}, - FinishReason: finishReason, - } - - for i := range anthropicResp.Content { // NOTE: Content structure is massive, do not range over values. - output := &anthropicResp.Content[i] - switch output.Type { - case string(constant.ValueOf[constant.ToolUse]()): - if output.ID != "" { - toolCalls, toolErr := anthropicToolUseToOpenAICalls(output) - if toolErr != nil { - return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to convert anthropic tool use to openai tool call: %w", toolErr) - } - choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCalls...) - } - case string(constant.ValueOf[constant.Text]()): - if output.Text != "" { - if choice.Message.Content == nil { - choice.Message.Content = &output.Text - } - } - case string(constant.ValueOf[constant.Thinking]()): - if output.Thinking != "" { - choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ - Value: &openai.ReasoningContent{ - ReasoningContent: &awsbedrock.ReasoningContentBlock{ - ReasoningText: &awsbedrock.ReasoningTextBlock{ - Text: output.Thinking, - Signature: output.Signature, - }, - }, - }, - } - } - case string(constant.ValueOf[constant.RedactedThinking]()): - if output.Data != "" { - choice.Message.ReasoningContent = &openai.ReasoningContentUnion{ - Value: &openai.ReasoningContent{ - ReasoningContent: &awsbedrock.ReasoningContentBlock{ - RedactedContent: []byte(output.Data), - }, - }, - } - } - } - } - openAIResp.Choices = append(openAIResp.Choices, choice) - newBody, err = json.Marshal(openAIResp) if err != nil { return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal body: %w", err) diff --git a/internal/translator/openai_gcpanthropic_stream.go b/internal/translator/openai_gcpanthropic_stream.go deleted file mode 100644 index bee94b8af2..0000000000 --- a/internal/translator/openai_gcpanthropic_stream.go +++ /dev/null @@ -1,421 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package translator - -import ( - "bytes" - "fmt" - "io" - "time" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/shared/constant" - - "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/internalapi" - "github.com/envoyproxy/ai-gateway/internal/json" - "github.com/envoyproxy/ai-gateway/internal/metrics" - "github.com/envoyproxy/ai-gateway/internal/tracing/tracingapi" -) - -var sseEventPrefix = []byte("event: ") - -// streamingToolCall holds the state for a single tool call that is being streamed. -type streamingToolCall struct { - id string - name string - inputJSON string -} - -// anthropicStreamParser manages the stateful translation of an Anthropic SSE stream -// to an OpenAI-compatible SSE stream. -type anthropicStreamParser struct { - buffer bytes.Buffer - activeMessageID string - activeToolCalls map[int64]*streamingToolCall - toolIndex int64 - tokenUsage metrics.TokenUsage - stopReason anthropic.StopReason - requestModel internalapi.RequestModel - sentFirstChunk bool - created openai.JSONUNIXTime -} - -// newAnthropicStreamParser creates a new parser for a streaming request. -func newAnthropicStreamParser(requestModel string) *anthropicStreamParser { - toolIdx := int64(-1) - return &anthropicStreamParser{ - requestModel: requestModel, - activeToolCalls: make(map[int64]*streamingToolCall), - toolIndex: toolIdx, - } -} - -func (p *anthropicStreamParser) writeChunk(eventBlock []byte, buf *[]byte) error { - chunk, err := p.parseAndHandleEvent(eventBlock) - if err != nil { - return err - } - if chunk != nil { - err := serializeOpenAIChatCompletionChunk(chunk, buf) - if err != nil { - return err - } - } - return nil -} - -// Process reads from the Anthropic SSE stream, translates events to OpenAI chunks, -// and returns the mutations for Envoy. -func (p *anthropicStreamParser) Process(body io.Reader, endOfStream bool, span tracingapi.ChatCompletionSpan) ( - newHeaders []internalapi.Header, newBody []byte, tokenUsage metrics.TokenUsage, responseModel string, err error, -) { - newBody = make([]byte, 0) - _ = span // TODO: add support for streaming chunks in tracingapi. - responseModel = p.requestModel - if _, err = p.buffer.ReadFrom(body); err != nil { - err = fmt.Errorf("failed to read from stream body: %w", err) - return - } - - for { - eventBlock, remaining, found := bytes.Cut(p.buffer.Bytes(), []byte("\n\n")) - if !found { - break - } - - if err = p.writeChunk(eventBlock, &newBody); err != nil { - return - } - - p.buffer.Reset() - p.buffer.Write(remaining) - } - - if endOfStream && p.buffer.Len() > 0 { - finalEventBlock := p.buffer.Bytes() - p.buffer.Reset() - - if err = p.writeChunk(finalEventBlock, &newBody); err != nil { - return - } - } - - if endOfStream { - inputTokens, _ := p.tokenUsage.InputTokens() - outputTokens, _ := p.tokenUsage.OutputTokens() - p.tokenUsage.SetTotalTokens(inputTokens + outputTokens) - totalTokens, _ := p.tokenUsage.TotalTokens() - cachedTokens, _ := p.tokenUsage.CachedInputTokens() - cacheCreationTokens, _ := p.tokenUsage.CacheCreationInputTokens() - finalChunk := &openai.ChatCompletionResponseChunk{ - ID: p.activeMessageID, - Created: p.created, - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionResponseChunkChoice{}, - Usage: &openai.Usage{ - PromptTokens: int(inputTokens), - CompletionTokens: int(outputTokens), - TotalTokens: int(totalTokens), - PromptTokensDetails: &openai.PromptTokensDetails{ - CachedTokens: int(cachedTokens), - CacheCreationTokens: int(cacheCreationTokens), - }, - }, - Model: p.requestModel, - } - - // Add active tool calls to the final chunk. - var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall - for toolIndex, tool := range p.activeToolCalls { - toolCalls = append(toolCalls, openai.ChatCompletionChunkChoiceDeltaToolCall{ - ID: &tool.id, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: tool.name, - Arguments: tool.inputJSON, - }, - Index: toolIndex, - }) - } - - if len(toolCalls) > 0 { - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: toolCalls, - } - finalChunk.Choices = append(finalChunk.Choices, openai.ChatCompletionResponseChunkChoice{ - Delta: &delta, - }) - } - - if finalChunk.Usage.PromptTokens > 0 || finalChunk.Usage.CompletionTokens > 0 || len(finalChunk.Choices) > 0 { - err := serializeOpenAIChatCompletionChunk(finalChunk, &newBody) - if err != nil { - return nil, nil, metrics.TokenUsage{}, "", fmt.Errorf("failed to marshal final stream chunk: %w", err) - } - } - // Add the final [DONE] message to indicate the end of the stream. - newBody = append(newBody, sseDataPrefix...) - newBody = append(newBody, sseDoneMessage...) - newBody = append(newBody, '\n', '\n') - } - tokenUsage = p.tokenUsage - return -} - -func (p *anthropicStreamParser) parseAndHandleEvent(eventBlock []byte) (*openai.ChatCompletionResponseChunk, error) { - var eventType []byte - var eventData []byte - - lines := bytes.SplitSeq(eventBlock, []byte("\n")) - for line := range lines { - if after, ok := bytes.CutPrefix(line, sseEventPrefix); ok { - eventType = bytes.TrimSpace(after) - } else if after, ok := bytes.CutPrefix(line, sseDataPrefix); ok { - // This handles JSON data that might be split across multiple 'data:' lines - // by concatenating them (Anthropic's format). - data := bytes.TrimSpace(after) - eventData = append(eventData, data...) - } - } - - if len(eventType) > 0 && len(eventData) > 0 { - return p.handleAnthropicStreamEvent(eventType, eventData) - } - - return nil, nil -} - -func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, data []byte) (*openai.ChatCompletionResponseChunk, error) { - switch string(eventType) { - case string(constant.ValueOf[constant.MessageStart]()): - var event anthropic.MessageStartEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_start: %w", err) - } - p.activeMessageID = event.Message.ID - p.created = openai.JSONUNIXTime(time.Now()) - u := event.Message.Usage - usage := metrics.ExtractTokenUsageFromExplicitCaching( - u.InputTokens, - u.OutputTokens, - &u.CacheReadInputTokens, - &u.CacheCreationInputTokens, - ) - // For message_start, we store the initial usage but don't add to the accumulated - // The message_delta event will contain the final totals - if input, ok := usage.InputTokens(); ok { - p.tokenUsage.SetInputTokens(input) - } - if cached, ok := usage.CachedInputTokens(); ok { - p.tokenUsage.SetCachedInputTokens(cached) - } - if cacheCreation, ok := usage.CacheCreationInputTokens(); ok { - p.tokenUsage.SetCacheCreationInputTokens(cacheCreation) - } - - // reset the toolIndex for each message - p.toolIndex = -1 - return nil, nil - - case string(constant.ValueOf[constant.ContentBlockStart]()): - var event anthropic.ContentBlockStartEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err) - } - if event.ContentBlock.Type == string(constant.ValueOf[constant.ToolUse]()) || event.ContentBlock.Type == string(constant.ValueOf[constant.ServerToolUse]()) { - p.toolIndex++ - var argsJSON string - // Check if the input field is provided directly in the start event. - if event.ContentBlock.Input != nil { - switch input := event.ContentBlock.Input.(type) { - case map[string]any: - // for case where "input":{}, skip adding it to arguments. - if len(input) > 0 { - argsBytes, err := json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool use input: %w", err) - } - argsJSON = string(argsBytes) - } - default: - // although golang sdk defines type of Input to be any, - // python sdk requires the type of Input to be Dict[str, object]: - // https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_use_block.py#L14. - return nil, fmt.Errorf("unexpected tool use input type: %T", input) - } - } - - // Store the complete input JSON in our state. - p.activeToolCalls[p.toolIndex] = &streamingToolCall{ - id: event.ContentBlock.ID, - name: event.ContentBlock.Name, - inputJSON: argsJSON, - } - - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ - { - Index: p.toolIndex, - ID: &event.ContentBlock.ID, - Type: openai.ChatCompletionMessageToolCallTypeFunction, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Name: event.ContentBlock.Name, - // Include the arguments if they are available. - Arguments: argsJSON, - }, - }, - }, - } - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - } - // do not need to return an empty str for thinking start block - return nil, nil - - case string(constant.ValueOf[constant.MessageDelta]()): - var event anthropic.MessageDeltaEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_delta: %w", err) - } - u := event.Usage - usage := metrics.ExtractTokenUsageFromExplicitCaching( - u.InputTokens, - u.OutputTokens, - &u.CacheReadInputTokens, - &u.CacheCreationInputTokens, - ) - // For message_delta, accumulate the incremental output tokens - if output, ok := usage.OutputTokens(); ok { - p.tokenUsage.AddOutputTokens(output) - } - // Update input tokens to include read cache tokens from delta - if cached, ok := usage.CachedInputTokens(); ok { - p.tokenUsage.AddInputTokens(cached) - // Accumulate any additional cache tokens from delta - p.tokenUsage.AddCachedInputTokens(cached) - } - // Update input tokens to include write cache tokens from delta - if cached, ok := usage.CacheCreationInputTokens(); ok { - p.tokenUsage.AddInputTokens(cached) - // Accumulate any additional cache tokens from delta - p.tokenUsage.AddCacheCreationInputTokens(cached) - } - if event.Delta.StopReason != "" { - p.stopReason = event.Delta.StopReason - } - return nil, nil - - case string(constant.ValueOf[constant.ContentBlockDelta]()): - var event anthropic.ContentBlockDeltaEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal content_block_delta: %w", err) - } - switch event.Delta.Type { - case string(constant.ValueOf[constant.TextDelta]()): - delta := openai.ChatCompletionResponseChunkChoiceDelta{Content: &event.Delta.Text} - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - - case string(constant.ValueOf[constant.ThinkingDelta]()): - // this should already include the case for redacted thinking: https://platform.claude.com/docs/en/build-with-claude/streaming#content-block-delta-types - - reasoningDelta := &openai.StreamReasoningContent{} - - // Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct. - if event.Delta.Thinking != "" { - reasoningDelta.Text = event.Delta.Thinking - } - if event.Delta.Signature != "" { - reasoningDelta.Signature = event.Delta.Signature - } - - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ReasoningContent: reasoningDelta, - } - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - - case string(constant.ValueOf[constant.InputJSONDelta]()): - tool, ok := p.activeToolCalls[p.toolIndex] - if !ok { - return nil, fmt.Errorf("received input_json_delta for unknown tool at index %d", p.toolIndex) - } - delta := openai.ChatCompletionResponseChunkChoiceDelta{ - ToolCalls: []openai.ChatCompletionChunkChoiceDeltaToolCall{ - { - Index: p.toolIndex, - Function: openai.ChatCompletionMessageToolCallFunctionParam{ - Arguments: event.Delta.PartialJSON, - }, - }, - }, - } - tool.inputJSON += event.Delta.PartialJSON - return p.constructOpenAIChatCompletionChunk(delta, ""), nil - } - // Do not process redacted thinking stream? Did not find the source - - case string(constant.ValueOf[constant.ContentBlockStop]()): - // This event is for state cleanup, no chunk is sent. - var event anthropic.ContentBlockStopEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal content_block_stop: %w", err) - } - delete(p.activeToolCalls, p.toolIndex) - return nil, nil - - case string(constant.ValueOf[constant.MessageStop]()): - var event anthropic.MessageStopEvent - if err := json.Unmarshal(data, &event); err != nil { - return nil, fmt.Errorf("unmarshal message_stop: %w", err) - } - - if p.stopReason == "" { - p.stopReason = anthropic.StopReasonEndTurn - } - - finishReason, err := anthropicToOpenAIFinishReason(p.stopReason) - if err != nil { - return nil, err - } - return p.constructOpenAIChatCompletionChunk(openai.ChatCompletionResponseChunkChoiceDelta{}, finishReason), nil - - case string(constant.ValueOf[constant.Error]()): - var errEvent anthropic.ErrorResponse - if err := json.Unmarshal(data, &errEvent); err != nil { - return nil, fmt.Errorf("unparsable error event: %s", string(data)) - } - return nil, fmt.Errorf("anthropic stream error: %s - %s", errEvent.Error.Type, errEvent.Error.Message) - - case "ping": - // Per documentation, ping events can be ignored. - return nil, nil - } - return nil, nil -} - -// constructOpenAIChatCompletionChunk builds the stream chunk. -func (p *anthropicStreamParser) constructOpenAIChatCompletionChunk(delta openai.ChatCompletionResponseChunkChoiceDelta, finishReason openai.ChatCompletionChoicesFinishReason) *openai.ChatCompletionResponseChunk { - // Add the 'assistant' role to the very first chunk of the response. - if !p.sentFirstChunk { - // Only add the role if the delta actually contains content or a tool call. - if delta.Content != nil || len(delta.ToolCalls) > 0 { - delta.Role = openai.ChatMessageRoleAssistant - p.sentFirstChunk = true - } - } - - return &openai.ChatCompletionResponseChunk{ - ID: p.activeMessageID, - Created: p.created, - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionResponseChunkChoice{ - { - Delta: &delta, - FinishReason: finishReason, - }, - }, - Model: p.requestModel, - } -} diff --git a/internal/translator/openai_gcpanthropic_stream_test.go b/internal/translator/openai_gcpanthropic_stream_test.go deleted file mode 100644 index c10eafde7b..0000000000 --- a/internal/translator/openai_gcpanthropic_stream_test.go +++ /dev/null @@ -1,1031 +0,0 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - -package translator - -import ( - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/json" - "github.com/envoyproxy/ai-gateway/internal/metrics" -) - -// mockErrorReader is a helper for testing io.Reader failures. -type mockErrorReader struct{} - -func (r *mockErrorReader) Read(_ []byte) (n int, err error) { - return 0, fmt.Errorf("mock reader error") -} - -func TestAnthropicStreamParser_ErrorHandling(t *testing.T) { - runStreamErrTest := func(t *testing.T, sseStream string, endOfStream bool) error { - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, _, _, _, err = translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), endOfStream, nil) - return err - } - - tests := []struct { - name string - sseStream string - endOfStream bool - expectedError string - }{ - { - name: "malformed message_start event", - sseStream: "event: message_start\ndata: {invalid\n\n", - expectedError: "unmarshal message_start", - }, - { - name: "malformed content_block_start event", - sseStream: "event: content_block_start\ndata: {invalid\n\n", - expectedError: "failed to unmarshal content_block_start", - }, - { - name: "malformed content_block_delta event", - sseStream: "event: content_block_delta\ndata: {invalid\n\n", - expectedError: "unmarshal content_block_delta", - }, - { - name: "malformed content_block_stop event", - sseStream: "event: content_block_stop\ndata: {invalid\n\n", - expectedError: "unmarshal content_block_stop", - }, - { - name: "malformed error event data", - sseStream: "event: error\ndata: {invalid\n\n", - expectedError: "unparsable error event", - }, - { - name: "unknown stop reason", - endOfStream: true, - sseStream: `event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "some_future_reason"}, "usage": {"output_tokens": 0}} - -event: message_stop -data: {"type": "message_stop"} -`, - expectedError: "received invalid stop reason", - }, - { - name: "malformed_final_event_block", - sseStream: "event: message_stop\ndata: {invalid", // No trailing \n\n. - endOfStream: true, - expectedError: "unmarshal message_stop", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := runStreamErrTest(t, tt.sseStream, tt.endOfStream) - require.Error(t, err) - require.Contains(t, err.Error(), tt.expectedError) - }) - } - - t.Run("body read error", func(t *testing.T) { - parser := newAnthropicStreamParser("test-model") - _, _, _, _, err := parser.Process(&mockErrorReader{}, false, nil) - require.Error(t, err) - require.Contains(t, err.Error(), "failed to read from stream body") - }) -} - -// TestResponseModel_GCPAnthropicStreaming tests that GCP Anthropic streaming returns the request model -// GCP Anthropic uses deterministic model mapping without virtualization -func TestResponseModel_GCPAnthropicStreaming(t *testing.T) { - modelName := "claude-sonnet-4@20250514" - sseStream := `event: message_start -data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-sonnet-4@20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 10, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 5}} - -event: message_stop -data: {"type": "message_stop"} - -` - openAIReq := &openai.ChatCompletionRequest{ - Stream: true, - Model: modelName, // Use the actual model name from documentation - MaxTokens: new(int64), - } - - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - // Test streaming response - GCP Anthropic doesn't return model in response, uses request model - _, _, tokenUsage, responseModel, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.Equal(t, modelName, responseModel) // Returns the request model since no virtualization - inputTokens, ok := tokenUsage.InputTokens() - require.True(t, ok) - require.Equal(t, uint32(10), inputTokens) - outputTokens, ok := tokenUsage.OutputTokens() - require.True(t, ok) - require.Equal(t, uint32(5), outputTokens) -} - -func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_Streaming(t *testing.T) { - t.Run("handles simple text stream", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-opus-4-20250514", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 15}} - -event: message_stop -data: {"type": "message_stop"} - -` - openAIReq := &openai.ChatCompletionRequest{ - Stream: true, - Model: "test-model", - MaxTokens: new(int64), - } - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - - bodyStr := string(bm) - require.Contains(t, bodyStr, `"content":"Hello"`) - require.Contains(t, bodyStr, `"finish_reason":"stop"`) - require.Contains(t, bodyStr, `"prompt_tokens":25`) - require.Contains(t, bodyStr, `"completion_tokens":15`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles text and tool use stream", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_014p7gG3wDgGV9EUtLvnow3U","type":"message","role":"assistant","model":"claude-opus-4-20250514","stop_sequence":null,"usage":{"input_tokens":472,"output_tokens":2},"content":[],"stop_reason":null}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Okay"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" let"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"'s"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" check"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" weather"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" San"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Francisco"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":","}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" CA"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":":"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: content_block_start -data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" Francisc"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"o,"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" CA\""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":", "}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\"unit\": \"fah"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"renheit\"}"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":1} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":89}} - -event: message_stop -data: {"type":"message_stop"} -` - - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // Parse all streaming events to verify the event flow - var chunks []openai.ChatCompletionResponseChunk - var textChunks []string - var toolCallStarted bool - var hasRole bool - var toolCallCompleted bool - var finalFinishReason openai.ChatCompletionChoicesFinishReason - var finalUsageChunk *openai.ChatCompletionResponseChunk - var toolCallChunks []string // Track partial JSON chunks - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - - // Check if this is the final usage chunk - if strings.Contains(jsonBody, `"usage"`) { - finalUsageChunk = &chunk - } - - if len(chunk.Choices) > 0 { - choice := chunk.Choices[0] - // Check for role in first content chunk - if choice.Delta != nil && choice.Delta.Content != nil && *choice.Delta.Content != "" && !hasRole { - require.NotNil(t, choice.Delta.Role, "Role should be present on first content chunk") - require.Equal(t, openai.ChatMessageRoleAssistant, choice.Delta.Role) - hasRole = true - } - - // Collect text content - if choice.Delta != nil && choice.Delta.Content != nil { - textChunks = append(textChunks, *choice.Delta.Content) - } - - // Check tool calls - start and accumulate partial JSON - if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 { - toolCall := choice.Delta.ToolCalls[0] - - // Check tool call initiation - if toolCall.Function.Name == "get_weather" && !toolCallStarted { - require.Equal(t, "get_weather", toolCall.Function.Name) - require.NotNil(t, toolCall.ID) - require.Equal(t, "toolu_01T1x1fJ34qAmk2tNTrN7Up6", *toolCall.ID) - require.Equal(t, int64(0), toolCall.Index, "Tool call should be at index 1 (after text content at index 0)") - toolCallStarted = true - } - - // Accumulate partial JSON arguments - these should also be at index 1 - if toolCall.Function.Arguments != "" { - toolCallChunks = append(toolCallChunks, toolCall.Function.Arguments) - - // Verify the index remains consistent at 1 for all tool call chunks - require.Equal(t, int64(0), toolCall.Index, "Tool call argument chunks should be at index 1") - } - } - - // Track finish reason - if choice.FinishReason != "" { - finalFinishReason = choice.FinishReason - if finalFinishReason == "tool_calls" { - toolCallCompleted = true - } - } - } - } - - // Check the final usage chunk for accumulated tool call arguments - if finalUsageChunk != nil { - require.Equal(t, 472, finalUsageChunk.Usage.PromptTokens) - require.Equal(t, 89, finalUsageChunk.Usage.CompletionTokens) - } - - // Verify partial JSON accumulation in streaming chunks - if len(toolCallChunks) > 0 { - // Verify we got multiple partial JSON chunks during streaming - require.GreaterOrEqual(t, len(toolCallChunks), 2, "Should receive multiple partial JSON chunks for tool arguments") - - // Verify some expected partial content appears in the chunks - fullPartialJSON := strings.Join(toolCallChunks, "") - require.Contains(t, fullPartialJSON, `"location":`, "Partial JSON should contain location field") - require.Contains(t, fullPartialJSON, `"unit":`, "Partial JSON should contain unit field") - require.Contains(t, fullPartialJSON, "San Francisco", "Partial JSON should contain location value") - require.Contains(t, fullPartialJSON, "fahrenheit", "Partial JSON should contain unit value") - } - - // Verify streaming event assertions - require.GreaterOrEqual(t, len(chunks), 5, "Should have multiple streaming chunks") - require.True(t, hasRole, "Should have role in first content chunk") - require.True(t, toolCallStarted, "Tool call should have been initiated") - require.True(t, toolCallCompleted, "Tool call should have complete arguments in final chunk") - require.Equal(t, openai.ChatCompletionChoicesFinishReasonToolCalls, finalFinishReason, "Final finish reason should be tool_calls") - - // Verify text content was streamed correctly - fullText := strings.Join(textChunks, "") - require.Contains(t, fullText, "Okay, let's check the weather for San Francisco, CA:") - require.GreaterOrEqual(t, len(textChunks), 3, "Text should be streamed in multiple chunks") - - // Original aggregate response assertions - require.Contains(t, bodyStr, `"content":"Okay"`) - require.Contains(t, bodyStr, `"name":"get_weather"`) - require.Contains(t, bodyStr, "\"arguments\":\"{\\\"location\\\":") - require.NotContains(t, bodyStr, "\"arguments\":\"{}\"") - require.Contains(t, bodyStr, "renheit\\\"}\"") - require.Contains(t, bodyStr, `"finish_reason":"tool_calls"`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles streaming with web search tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_01G...","type":"message","role":"assistant","usage":{"input_tokens":2679,"output_tokens":3}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"I'll check"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the current weather in New York City for you"}} - -event: ping -data: {"type": "ping"} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: content_block_start -data: {"type":"content_block_start","index":1,"content_block":{"type":"server_tool_use","id":"srvtoolu_014hJH82Qum7Td6UV8gDXThB","name":"web_search","input":{}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"query\":\"weather NYC today\"}"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":1} - -event: content_block_start -data: {"type":"content_block_start","index":2,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_014hJH82Qum7Td6UV8gDXThB","content":[{"type":"web_search_result","title":"Weather in New York City in May 2025 (New York)","url":"https://world-weather.info/forecast/usa/new_york/may-2025/","page_age":null}]}} - -event: content_block_stop -data: {"type":"content_block_stop","index":2} - -event: content_block_start -data: {"type":"content_block_start","index":3,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":3,"delta":{"type":"text_delta","text":"Here's the current weather information for New York"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":3,"delta":{"type":"text_delta","text":" City."}} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":510}} - -event: message_stop -data: {"type":"message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - require.Contains(t, bodyStr, `"content":"I'll check"`) - require.Contains(t, bodyStr, `"content":" the current weather in New York City for you"`) - require.Contains(t, bodyStr, `"name":"web_search"`) - require.Contains(t, bodyStr, "\"arguments\":\"{\\\"query\\\":\\\"weather NYC today\\\"}\"") - require.NotContains(t, bodyStr, "\"arguments\":\"{}\"") - require.Contains(t, bodyStr, `"content":"Here's the current weather information for New York"`) - require.Contains(t, bodyStr, `"finish_reason":"stop"`) - require.Contains(t, bodyStr, string(sseDoneMessage)) - }) - - t.Run("handles unterminated tool call at end of stream", func(t *testing.T) { - // This stream starts a tool call but ends without a content_block_stop or message_stop. - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_abc","name":"get_weather"}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"location\": \"SF\"}"}} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var finalToolCallChunk openai.ChatCompletionResponseChunk - - // Split the response into individual SSE messages and find the final data chunk. - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.HasPrefix(line, "data: [DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - // The final chunk with the accumulated tool call is the only one with a "usage" field. - if strings.Contains(jsonBody, `"usage"`) { - err := json.Unmarshal([]byte(jsonBody), &finalToolCallChunk) - require.NoError(t, err, "Failed to unmarshal final tool call chunk") - break - } - } - - require.NotEmpty(t, finalToolCallChunk.Choices, "Final chunk should have choices") - require.NotNil(t, finalToolCallChunk.Choices[0].Delta.ToolCalls, "Final chunk should have tool calls") - - finalToolCall := finalToolCallChunk.Choices[0].Delta.ToolCalls[0] - require.Equal(t, "tool_abc", *finalToolCall.ID) - require.Equal(t, "get_weather", finalToolCall.Function.Name) - require.JSONEq(t, `{"location": "SF"}`, finalToolCall.Function.Arguments) - }) - t.Run("handles thinking and tool use stream", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_123", "type": "message", "role": "assistant", "usage": {"input_tokens": 50, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking", "name": "web_searcher"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Searching for information..."}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: content_block_start -data: {"type": "content_block_start", "index": 1, "content_block": {"type": "tool_use", "id": "toolu_abc123", "name": "get_weather", "input": {"location": "San Francisco, CA"}}} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "tool_use"}, "usage": {"output_tokens": 35}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var contentDeltas []string - var reasoningTexts []string - var foundToolCallWithArgs bool - var finalFinishReason openai.ChatCompletionChoicesFinishReason - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil { - if choice.Delta.Content != nil { - contentDeltas = append(contentDeltas, *choice.Delta.Content) - } - if choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - } - if len(choice.Delta.ToolCalls) > 0 { - toolCall := choice.Delta.ToolCalls[0] - // Check if this is the tool chunk that contains the arguments. - if toolCall.Function.Arguments != "" { - expectedArgs := `{"location":"San Francisco, CA"}` - assert.JSONEq(t, expectedArgs, toolCall.Function.Arguments, "Tool call arguments do not match") - assert.Equal(t, "get_weather", toolCall.Function.Name) - assert.Equal(t, "toolu_abc123", *toolCall.ID) - foundToolCallWithArgs = true - } else { - // This should be the initial tool call chunk with empty arguments since input is provided upfront - assert.Equal(t, "get_weather", toolCall.Function.Name) - assert.Equal(t, "toolu_abc123", *toolCall.ID) - } - } - } - if choice.FinishReason != "" { - finalFinishReason = choice.FinishReason - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - - assert.Contains(t, fullReasoning, "Searching for information...") - require.True(t, foundToolCallWithArgs, "Did not find a tool call chunk with arguments to assert against") - assert.Equal(t, openai.ChatCompletionChoicesFinishReasonToolCalls, finalFinishReason, "Final finish reason should be 'tool_calls'") - }) - - t.Run("handles thinking delta stream with text only", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_thinking_1", "type": "message", "role": "assistant", "usage": {"input_tokens": 20, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Let me think about this problem step by step."}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " First, I need to understand the requirements."}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 15}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var reasoningTexts []string - var foundFinishReason bool - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil && choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - } - if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { - foundFinishReason = true - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - assert.Contains(t, fullReasoning, "Let me think about this problem step by step.") - assert.Contains(t, fullReasoning, " First, I need to understand the requirements.") - require.True(t, foundFinishReason, "Should find stop finish reason") - }) - - t.Run("handles thinking delta stream with text and signature", func(t *testing.T) { - sseStream := ` -event: message_start -data: {"type": "message_start", "message": {"id": "msg_thinking_2", "type": "message", "role": "assistant", "usage": {"input_tokens": 25, "output_tokens": 1}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "thinking"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": "Processing request...", "signature": "sig_abc123"}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "thinking_delta", "thinking": " Analyzing data...", "signature": "sig_def456"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 20}} - -event: message_stop -data: {"type": "message_stop"} -` - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, _, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), true, nil) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var reasoningTexts []string - var signatures []string - var foundFinishReason bool - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") || strings.Contains(line, "[DONE]") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - - if len(chunk.Choices) == 0 { - continue - } - choice := chunk.Choices[0] - if choice.Delta != nil && choice.Delta.ReasoningContent != nil { - if choice.Delta.ReasoningContent.Text != "" { - reasoningTexts = append(reasoningTexts, choice.Delta.ReasoningContent.Text) - } - if choice.Delta.ReasoningContent.Signature != "" { - signatures = append(signatures, choice.Delta.ReasoningContent.Signature) - } - } - if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonStop { - foundFinishReason = true - } - } - - fullReasoning := strings.Join(reasoningTexts, "") - assert.Contains(t, fullReasoning, "Processing request...") - assert.Contains(t, fullReasoning, " Analyzing data...") - - allSignatures := strings.Join(signatures, ",") - assert.Contains(t, allSignatures, "sig_abc123") - assert.Contains(t, allSignatures, "sig_def456") - - require.True(t, foundFinishReason, "Should find stop finish reason") - }) -} - -func TestAnthropicStreamParser_EventTypes(t *testing.T) { - runStreamTest := func(t *testing.T, sseStream string, endOfStream bool) ([]byte, metrics.TokenUsage, error) { - openAIReq := &openai.ChatCompletionRequest{Stream: true, Model: "test-model", MaxTokens: new(int64)} - translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) - _, _, err := translator.RequestBody(nil, openAIReq, false) - require.NoError(t, err) - - _, bm, tokenUsage, _, err := translator.ResponseBody(map[string]string{}, strings.NewReader(sseStream), endOfStream, nil) - return bm, tokenUsage, err - } - - t.Run("handles message_start event", func(t *testing.T) { - sseStream := `event: message_start -data: {"type": "message_start", "message": {"id": "msg_123", "usage": {"input_tokens": 15}}} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - assert.Empty(t, string(bm), "message_start should produce an empty chunk") - }) - - t.Run("handles content_block events for tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "tool_use", "id": "tool_abc", "name": "get_weather", "input":{}}} - -event: content_block_delta -data: {"type": "content_block_delta", "index": 0, "delta": {"type": "input_json_delta", "partial_json": "{\"location\": \"SF\"}"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // 1. Split the stream into individual data chunks - // and remove the "data: " prefix. - var chunks []openai.ChatCompletionResponseChunk - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err = json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - } - - // 2. Inspect the Go structs directly. - require.Len(t, chunks, 2, "Expected two data chunks for this tool call stream") - - // Check the first chunk (the tool call initiation). - firstChunk := chunks[0] - require.NotNil(t, firstChunk.Choices[0].Delta.ToolCalls) - require.Equal(t, "tool_abc", *firstChunk.Choices[0].Delta.ToolCalls[0].ID) - require.Equal(t, "get_weather", firstChunk.Choices[0].Delta.ToolCalls[0].Function.Name) - // With empty input, arguments should be empty string, not "{}" - require.Empty(t, firstChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) - - // Check the second chunk (the arguments delta). - secondChunk := chunks[1] - require.NotNil(t, secondChunk.Choices[0].Delta.ToolCalls) - argumentsJSON := secondChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments - - // 3. Unmarshal the arguments string to verify its contents. - var args map[string]string - err = json.Unmarshal([]byte(argumentsJSON), &args) - require.NoError(t, err) - require.Equal(t, "SF", args["location"]) - }) - - t.Run("handles ping event", func(t *testing.T) { - sseStream := `event: ping -data: {"type": "ping"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "ping should produce an empty chunk") - }) - - t.Run("handles error event", func(t *testing.T) { - sseStream := `event: error -data: {"type": "error", "error": {"type": "overloaded_error", "message": "Overloaded"}} - -` - _, _, err := runStreamTest(t, sseStream, false) - require.Error(t, err) - require.Contains(t, err.Error(), "anthropic stream error: overloaded_error - Overloaded") - }) - - t.Run("gracefully handles unknown event types", func(t *testing.T) { - sseStream := `event: future_event_type -data: {"some_new_data": "value"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "unknown events should be ignored and produce an empty chunk") - }) - - t.Run("handles message_stop event", func(t *testing.T) { - sseStream := `event: message_delta -data: {"type": "message_delta", "delta": {"stop_reason": "max_tokens"}, "usage": {"output_tokens": 1}} - -event: message_stop -data: {"type": "message_stop"} - -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - require.Contains(t, string(bm), `"finish_reason":"length"`) - }) - - t.Run("handles chunked input_json_delta for tool use", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_start -data: {"type": "content_block_start", "index": 0, "content_block": {"type": "tool_use", "id": "tool_123", "name": "get_weather"}} - -event: content_block_delta -data: {"type": "content_block_delta","index": 0,"delta": {"type": "input_json_delta","partial_json": "{\"location\": \"San Fra"}} - -event: content_block_delta -data: {"type": "content_block_delta","index": 0,"delta": {"type": "input_json_delta","partial_json": "ncisco\"}"}} - -event: content_block_stop -data: {"type": "content_block_stop", "index": 0} -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // 1. Unmarshal all the chunks from the stream response. - var chunks []openai.ChatCompletionResponseChunk - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - jsonBody := strings.TrimPrefix(line, "data: ") - - var chunk openai.ChatCompletionResponseChunk - err := json.Unmarshal([]byte(jsonBody), &chunk) - require.NoError(t, err, "Failed to unmarshal chunk: %s", jsonBody) - chunks = append(chunks, chunk) - } - - // 2. We expect 3 chunks: start, delta part 1, delta part 2. - require.Len(t, chunks, 3, "Expected three data chunks for this stream") - - // 3. Verify the contents of each relevant chunk. - - // Chunk 1: Tool call start. - chunk1ToolCalls := chunks[0].Choices[0].Delta.ToolCalls - require.NotNil(t, chunk1ToolCalls) - require.Equal(t, "get_weather", chunk1ToolCalls[0].Function.Name) - - // Chunk 2: First part of the arguments. - chunk2Args := chunks[1].Choices[0].Delta.ToolCalls[0].Function.Arguments - require.Equal(t, `{"location": "San Fra`, chunk2Args) //nolint:testifylint - - // Chunk 3: Second part of the arguments. - chunk3Args := chunks[2].Choices[0].Delta.ToolCalls[0].Function.Arguments - require.Equal(t, `ncisco"}`, chunk3Args) - }) - t.Run("sends role on first chunk", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} -` - // Set endOfStream to true to ensure all events in the buffer are processed. - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - var contentChunk openai.ChatCompletionResponseChunk - foundChunk := false - - lines := strings.SplitSeq(strings.TrimSpace(bodyStr), "\n\n") - for line := range lines { - if after, ok := strings.CutPrefix(line, "data: "); ok { - jsonBody := after - // We only care about the chunk that has the text content. - if strings.Contains(jsonBody, `"content"`) { - err := json.Unmarshal([]byte(jsonBody), &contentChunk) - require.NoError(t, err, "Failed to unmarshal content chunk") - foundChunk = true - break - } - } - } - - require.True(t, foundChunk, "Did not find a data chunk with content in the output") - - require.NotNil(t, contentChunk.Choices[0].Delta.Role, "Role should be present on the first chunk") - require.Equal(t, openai.ChatMessageRoleAssistant, contentChunk.Choices[0].Delta.Role) - }) - - t.Run("accumulates output tokens", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":20}}} - -event: message_delta -data: {"type":"message_delta","delta":{},"usage":{"output_tokens":10}} - -event: message_delta -data: {"type":"message_delta","delta":{},"usage":{"output_tokens":5}} - -event: message_stop -data: {"type":"message_stop"} -` - // Run with endOfStream:true to get the final usage chunk. - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - // The final usage chunk should sum the tokens from all message_delta events. - require.Contains(t, bodyStr, `"completion_tokens":15`) - require.Contains(t, bodyStr, `"prompt_tokens":20`) - require.Contains(t, bodyStr, `"total_tokens":35`) - }) - - t.Run("ignores SSE comments", func(t *testing.T) { - sseStream := `event: message_start -data: {"type":"message_start","message":{"id":"msg_123","usage":{"input_tokens":10}}} - -: this is a comment and should be ignored - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} -` - bm, _, err := runStreamTest(t, sseStream, true) - require.NoError(t, err) - require.NotNil(t, bm) - bodyStr := string(bm) - - require.Contains(t, bodyStr, `"content":"Hello"`) - require.NotContains(t, bodyStr, "this is a comment") - }) - t.Run("handles data-only event as a message event", func(t *testing.T) { - sseStream := `data: some text - -data: another message with two lines -` - bm, _, err := runStreamTest(t, sseStream, false) - require.NoError(t, err) - require.Empty(t, bm, "data-only events should be treated as no-op 'message' events and produce an empty chunk") - }) -} diff --git a/internal/translator/openai_gcpanthropic_test.go b/internal/translator/openai_gcpanthropic_test.go index b631edfe23..9f968b2f25 100644 --- a/internal/translator/openai_gcpanthropic_test.go +++ b/internal/translator/openai_gcpanthropic_test.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "fmt" "io" + "log/slog" "strconv" "testing" "time" @@ -1212,872 +1213,6 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseError(t *testing } } -// New test function for helper coverage. -func TestHelperFunctions(t *testing.T) { - t.Run("anthropicToOpenAIFinishReason invalid reason", func(t *testing.T) { - _, err := anthropicToOpenAIFinishReason("unknown_reason") - require.Error(t, err) - require.Contains(t, err.Error(), "received invalid stop reason") - }) - - t.Run("anthropicRoleToOpenAIRole invalid role", func(t *testing.T) { - _, err := anthropicRoleToOpenAIRole("unknown_role") - require.Error(t, err) - require.Contains(t, err.Error(), "invalid anthropic role") - }) -} - -func TestTranslateOpenAItoAnthropicTools(t *testing.T) { - anthropicTestTool := []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - } - openaiTestTool := []openai.Tool{ - {Type: "function", Function: &openai.FunctionDefinition{Name: "get_weather"}}, - } - tests := []struct { - name string - openAIReq *openai.ChatCompletionRequest - expectedTools []anthropic.ToolUnionParam - expectedToolChoice anthropic.ToolChoiceUnionParam - expectErr bool - }{ - { - name: "auto tool choice", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - DisableParallelToolUse: anthropic.Bool(false), - }, - }, - }, - { - name: "any tool choice", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "any"}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{}, - }, - }, - { - name: "specific tool choice by name", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: openai.ChatCompletionNamedToolChoice{Type: "function", Function: openai.ChatCompletionNamedToolChoiceFunction{Name: "my_func"}}}, - Tools: openaiTestTool, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{Type: "tool", Name: "my_func"}, - }, - }, - { - name: "tool definition", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - }, - { - name: "tool_definition_with_required_field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather with a required location", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - "unit": map[string]any{"type": "string"}, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather with a required location"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - "unit": map[string]any{"type": "string"}, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool definition with no parameters", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_time", - Description: "Get the current time", - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_time", - Description: anthropic.String("Get the current time"), - }, - }, - }, - }, - { - name: "disable parallel tool calls", - openAIReq: &openai.ChatCompletionRequest{ - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - Tools: openaiTestTool, - ParallelToolCalls: ptr.To(false), - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - DisableParallelToolUse: anthropic.Bool(true), - }, - }, - }, - { - name: "explicitly enable parallel tool calls", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - ParallelToolCalls: ptr.To(true), - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, - }, - }, - { - name: "default disable parallel tool calls to false (nil)", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "auto"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{DisableParallelToolUse: anthropic.Bool(false)}, - }, - }, - { - name: "none tool choice", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "none"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfNone: &anthropic.ToolChoiceNoneParam{}, - }, - }, - { - name: "function tool choice", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "function"}, - }, - expectedTools: anthropicTestTool, - expectedToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{Name: "function"}, - }, - }, - { - name: "invalid tool choice string", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: "invalid_choice"}, - }, - expectErr: true, - }, - { - name: "skips function tool with nil function definition", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: nil, // This tool has the correct type but a nil definition and should be skipped. - }, - { - Type: "function", - Function: &openai.FunctionDefinition{Name: "get_weather"}, // This is a valid tool. - }, - }, - }, - // We expect only the valid function tool to be translated. - expectedTools: []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - }, - expectErr: false, - }, - { - name: "skips non-function tools", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "retrieval", - }, - { - Type: "function", - Function: &openai.FunctionDefinition{Name: "get_weather"}, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - {OfTool: &anthropic.ToolParam{Name: "get_weather", Description: anthropic.String("")}}, - }, - expectErr: false, - }, - { - name: "tool definition without type field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather without type", - Parameters: map[string]any{ - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather without type"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "", - Properties: map[string]any{ - "location": map[string]any{"type": "string"}, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool definition without properties field", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "get_weather", - Description: "Get the weather without properties", - Parameters: map[string]any{ - "type": "object", - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "get_weather", - Description: anthropic.String("Get the weather without properties"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "unsupported tool_choice type", - openAIReq: &openai.ChatCompletionRequest{ - Tools: openaiTestTool, - ToolChoice: &openai.ChatCompletionToolChoiceUnion{Value: 123}, // Use an integer to trigger the default case. - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - if tt.openAIReq.ToolChoice != nil { - require.NotNil(t, toolChoice) - require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) - if tt.expectedToolChoice.GetName() != nil { - require.Equal(t, *tt.expectedToolChoice.GetName(), *toolChoice.GetName()) - } - if tt.expectedToolChoice.OfTool != nil { - require.Equal(t, tt.expectedToolChoice.OfTool.Name, toolChoice.OfTool.Name) - } - if tt.expectedToolChoice.OfAuto != nil { - require.Equal(t, tt.expectedToolChoice.OfAuto.DisableParallelToolUse, toolChoice.OfAuto.DisableParallelToolUse) - } - } - if tt.openAIReq.Tools != nil { - require.NotNil(t, tools) - require.Len(t, tools, len(tt.expectedTools)) - require.Equal(t, tt.expectedTools[0].GetName(), tools[0].GetName()) - require.Equal(t, tt.expectedTools[0].GetType(), tools[0].GetType()) - require.Equal(t, tt.expectedTools[0].GetDescription(), tools[0].GetDescription()) - if tt.expectedTools[0].GetInputSchema().Properties != nil { - require.Equal(t, tt.expectedTools[0].GetInputSchema().Properties, tools[0].GetInputSchema().Properties) - } - } - } - }) - } -} - -// TestFinishReasonTranslation covers specific cases for the anthropicToOpenAIFinishReason function. -func TestFinishReasonTranslation(t *testing.T) { - tests := []struct { - name string - input anthropic.StopReason - expectedFinishReason openai.ChatCompletionChoicesFinishReason - expectErr bool - }{ - { - name: "max tokens stop reason", - input: anthropic.StopReasonMaxTokens, - expectedFinishReason: openai.ChatCompletionChoicesFinishReasonLength, - }, - { - name: "refusal stop reason", - input: anthropic.StopReasonRefusal, - expectedFinishReason: openai.ChatCompletionChoicesFinishReasonContentFilter, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reason, err := anthropicToOpenAIFinishReason(tt.input) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectedFinishReason, reason) - } - }) - } -} - -// TestToolParameterDereferencing tests the JSON schema dereferencing functionality -// for tool parameters when translating from OpenAI to GCP Anthropic. -func TestToolParameterDereferencing(t *testing.T) { - tests := []struct { - name string - openAIReq *openai.ChatCompletionRequest - expectedTools []anthropic.ToolUnionParam - expectedToolChoice anthropic.ToolChoiceUnionParam - expectErr bool - expectUserFacingErr bool - }{ - { - name: "tool with complex nested $ref - successful dereferencing", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "complex_tool", - Description: "Tool with complex nested references", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "BaseType": map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - "NestedType": map[string]any{ - "allOf": []any{ - map[string]any{"$ref": "#/$defs/BaseType"}, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - "properties": map[string]any{ - "nested": map[string]any{ - "$ref": "#/$defs/NestedType", - }, - }, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "complex_tool", - Description: anthropic.String("Tool with complex nested references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "nested": map[string]any{ - "allOf": []any{ - map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{ - "type": "string", - }, - "required": []any{"id"}, - }, - }, - map[string]any{ - "properties": map[string]any{ - "name": map[string]any{ - "type": "string", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - name: "tool with invalid $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "invalid_ref_tool", - Description: "Tool with invalid reference", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "$ref": "#/$defs/NonExistent", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - { - name: "tool with circular $ref - dereferencing error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "circular_ref_tool", - Description: "Tool with circular reference", - Parameters: map[string]any{ - "type": "object", - "$defs": map[string]any{ - "A": map[string]any{ - "type": "object", - "properties": map[string]any{ - "b": map[string]any{ - "$ref": "#/$defs/B", - }, - }, - }, - "B": map[string]any{ - "type": "object", - "properties": map[string]any{ - "a": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - "properties": map[string]any{ - "circular": map[string]any{ - "$ref": "#/$defs/A", - }, - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - { - name: "tool without $ref - no dereferencing needed", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "simple_tool", - Description: "Simple tool without references", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - "required": []any{"location"}, - }, - }, - }, - }, - }, - expectedTools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "simple_tool", - Description: anthropic.String("Simple tool without references"), - InputSchema: anthropic.ToolInputSchemaParam{ - Type: "object", - Properties: map[string]any{ - "location": map[string]any{ - "type": "string", - }, - }, - Required: []string{"location"}, - }, - }, - }, - }, - }, - { - name: "tool parameter dereferencing returns non-map type - casting error", - openAIReq: &openai.ChatCompletionRequest{ - Tools: []openai.Tool{ - { - Type: "function", - Function: &openai.FunctionDefinition{ - Name: "problematic_tool", - Description: "Tool with parameters that can't be properly dereferenced to map", - // This creates a scenario where jsonSchemaDereference might return a non-map type - // though this is a contrived example since normally the function should return map[string]any - Parameters: map[string]any{ - "$ref": "#/$defs/StringType", // This would resolve to a string, not a map - "$defs": map[string]any{ - "StringType": "not-a-map", // This would cause the casting to fail - }, - }, - }, - }, - }, - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tools, toolChoice, err := translateOpenAItoAnthropicTools(tt.openAIReq.Tools, tt.openAIReq.ToolChoice, tt.openAIReq.ParallelToolCalls) - - if tt.expectErr { - require.Error(t, err) - if tt.expectUserFacingErr { - require.ErrorIs(t, err, internalapi.ErrInvalidRequestBody) - } - return - } - - require.NoError(t, err) - - if tt.openAIReq.Tools != nil { - require.NotNil(t, tools) - require.Len(t, tools, len(tt.expectedTools)) - - for i, expectedTool := range tt.expectedTools { - actualTool := tools[i] - require.Equal(t, expectedTool.GetName(), actualTool.GetName()) - require.Equal(t, expectedTool.GetType(), actualTool.GetType()) - require.Equal(t, expectedTool.GetDescription(), actualTool.GetDescription()) - - expectedSchema := expectedTool.GetInputSchema() - actualSchema := actualTool.GetInputSchema() - - require.Equal(t, expectedSchema.Type, actualSchema.Type) - require.Equal(t, expectedSchema.Required, actualSchema.Required) - - // For properties, we'll do a deep comparison to verify dereferencing worked - if expectedSchema.Properties != nil { - require.NotNil(t, actualSchema.Properties) - require.Equal(t, expectedSchema.Properties, actualSchema.Properties) - } - } - } - - if tt.openAIReq.ToolChoice != nil { - require.NotNil(t, toolChoice) - require.Equal(t, *tt.expectedToolChoice.GetType(), *toolChoice.GetType()) - } - }) - } -} - -// TestContentTranslationCoverage adds specific coverage for the openAIToAnthropicContent helper. -func TestContentTranslationCoverage(t *testing.T) { - tests := []struct { - name string - inputContent any - expectedContent []anthropic.ContentBlockParamUnion - expectErr bool - }{ - { - name: "nil content", - inputContent: nil, - }, - { - name: "empty string content", - inputContent: "", - }, - { - name: "pdf data uri", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "data:application/pdf;base64,dGVzdA=="}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfDocument: &anthropic.DocumentBlockParam{ - Source: anthropic.DocumentBlockParamSourceUnion{ - OfBase64: &anthropic.Base64PDFSourceParam{ - Type: constant.ValueOf[constant.Base64](), - MediaType: constant.ValueOf[constant.ApplicationPDF](), - Data: "dGVzdA==", - }, - }, - }, - }, - }, - }, - { - name: "pdf url", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/doc.pdf"}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfDocument: &anthropic.DocumentBlockParam{ - Source: anthropic.DocumentBlockParamSourceUnion{ - OfURL: &anthropic.URLPDFSourceParam{ - Type: constant.ValueOf[constant.URL](), - URL: "https://example.com/doc.pdf", - }, - }, - }, - }, - }, - }, - { - name: "image url", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{ - {OfImageURL: &openai.ChatCompletionContentPartImageParam{ImageURL: openai.ChatCompletionContentPartImageImageURLParam{URL: "https://example.com/image.png"}}}, - }, - expectedContent: []anthropic.ContentBlockParamUnion{ - { - OfImage: &anthropic.ImageBlockParam{ - Source: anthropic.ImageBlockParamSourceUnion{ - OfURL: &anthropic.URLImageSourceParam{ - Type: constant.ValueOf[constant.URL](), - URL: "https://example.com/image.png", - }, - }, - }, - }, - }, - }, - { - name: "audio content error", - inputContent: []openai.ChatCompletionContentPartUserUnionParam{{OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{}}}, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - content, err := openAIToAnthropicContent(tt.inputContent) - if tt.expectErr { - require.Error(t, err) - return - } - require.NoError(t, err) - - // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. - require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") - - // Use direct assertions instead of cmp.Diff to avoid panics on unexported fields. - require.Len(t, content, len(tt.expectedContent), "Number of content blocks should match") - for i, expectedBlock := range tt.expectedContent { - actualBlock := content[i] - require.Equal(t, expectedBlock.GetType(), actualBlock.GetType(), "Content block types should match") - if expectedBlock.OfDocument != nil { - require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") - require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") - - if expectedBlock.OfDocument.Source.OfBase64 != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") - require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) - } - if expectedBlock.OfDocument.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") - require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) - } - } - if expectedBlock.OfImage != nil { - require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") - require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") - - if expectedBlock.OfImage.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") - require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) - } - } - } - - for i, expectedBlock := range tt.expectedContent { - actualBlock := content[i] - if expectedBlock.OfDocument != nil { - require.NotNil(t, actualBlock.OfDocument, "Expected a document block, but got nil") - require.NotNil(t, actualBlock.OfDocument.Source, "Document source should not be nil") - - if expectedBlock.OfDocument.Source.OfBase64 != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfBase64, "Expected a base64 source") - require.Equal(t, expectedBlock.OfDocument.Source.OfBase64.Data, actualBlock.OfDocument.Source.OfBase64.Data) - } - if expectedBlock.OfDocument.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfDocument.Source.OfURL, "Expected a URL source") - require.Equal(t, expectedBlock.OfDocument.Source.OfURL.URL, actualBlock.OfDocument.Source.OfURL.URL) - } - } - if expectedBlock.OfImage != nil { - require.NotNil(t, actualBlock.OfImage, "Expected an image block, but got nil") - require.NotNil(t, actualBlock.OfImage.Source, "Image source should not be nil") - - if expectedBlock.OfImage.Source.OfURL != nil { - require.NotNil(t, actualBlock.OfImage.Source.OfURL, "Expected a URL image source") - require.Equal(t, expectedBlock.OfImage.Source.OfURL.URL, actualBlock.OfImage.Source.OfURL.URL) - } - } - } - }) - } -} - -// TestSystemPromptExtractionCoverage adds specific coverage for the extractSystemPromptFromDeveloperMsg helper. -func TestSystemPromptExtractionCoverage(t *testing.T) { - tests := []struct { - name string - inputMsg openai.ChatCompletionDeveloperMessageParam - expectedPrompt string - }{ - { - name: "developer message with content parts", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ - {Type: "text", Text: "part 1"}, - {Type: "text", Text: " part 2"}, - }}, - }, - expectedPrompt: "part 1 part 2", - }, - { - name: "developer message with nil content", - inputMsg: openai.ChatCompletionDeveloperMessageParam{Content: openai.ContentUnion{Value: nil}}, - expectedPrompt: "", - }, - { - name: "developer message with string content", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: "simple string"}, - }, - expectedPrompt: "simple string", - }, - { - name: "developer message with text parts array", - inputMsg: openai.ChatCompletionDeveloperMessageParam{ - Content: openai.ContentUnion{Value: []openai.ChatCompletionContentPartTextParam{ - {Type: "text", Text: "text part"}, - }}, - }, - expectedPrompt: "text part", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prompt, _ := extractSystemPromptFromDeveloperMsg(tt.inputMsg) - require.Equal(t, tt.expectedPrompt, prompt) - }) - } -} - func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_Cache(t *testing.T) { t.Run("full request with mixed caching", func(t *testing.T) { openAIReq := &openai.ChatCompletionRequest{ @@ -2605,3 +1740,306 @@ func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_Cache(t *testing.T) { require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), result.Get("messages.0.content.2.cache_control.type").String(), "tool 3 (with cache) should be cached") }) } + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_SetRedactionConfig(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + translator.SetRedactionConfig(true, true, logger) + + require.True(t, translator.debugLogEnabled) + require.True(t, translator.enableRedaction) + require.NotNil(t, translator.logger) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RedactBody(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + t.Run("nil response returns nil", func(t *testing.T) { + result := translator.RedactBody(nil) + require.Nil(t, result) + }) + + t.Run("redacts message content", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{Role: "assistant", Content: ptr.To("sensitive content")}, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Equal(t, "test-id", result.ID) + require.Len(t, result.Choices, 1) + // Content should be redacted (not the original value) + require.NotNil(t, result.Choices[0].Message.Content) + require.NotEqual(t, "sensitive content", *result.Choices[0].Message.Content) + }) + + t.Run("redacts tool calls", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ToolCalls: []openai.ChatCompletionMessageToolCallParam{ + { + ID: ptr.To("tool-1"), + Type: openai.ChatCompletionMessageToolCallTypeFunction, + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: "get_secret", + Arguments: `{"password": "secret123"}`, + }, + }, + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.Len(t, result.Choices[0].Message.ToolCalls, 1) + // Tool call name and arguments should be redacted + require.NotEqual(t, "get_secret", result.Choices[0].Message.ToolCalls[0].Function.Name) + require.NotEqual(t, `{"password": "secret123"}`, result.Choices[0].Message.ToolCalls[0].Function.Arguments) + }) + + t.Run("redacts audio data", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + Audio: &openai.ChatCompletionResponseChoiceMessageAudio{ + Data: "base64-audio-data", + Transcript: "sensitive transcript", + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.NotNil(t, result.Choices[0].Message.Audio) + // Audio data and transcript should be redacted + require.NotEqual(t, "base64-audio-data", result.Choices[0].Message.Audio.Data) + require.NotEqual(t, "sensitive transcript", result.Choices[0].Message.Audio.Transcript) + }) + + t.Run("redacts reasoning content", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{ + { + Index: 0, + Message: openai.ChatCompletionResponseChoiceMessage{ + Role: "assistant", + ReasoningContent: &openai.ReasoningContentUnion{ + Value: &openai.ReasoningContent{ + ReasoningContent: &awsbedrock.ReasoningContentBlock{ + ReasoningText: &awsbedrock.ReasoningTextBlock{ + Text: "sensitive reasoning", + Signature: "sig123", + }, + }, + }, + }, + }, + }, + }, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Len(t, result.Choices, 1) + require.NotNil(t, result.Choices[0].Message.ReasoningContent) + }) + + t.Run("empty choices returns empty choices", func(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "test-id", + Model: "test-model", + Choices: []openai.ChatCompletionResponseChoice{}, + } + + result := translator.RedactBody(resp) + require.NotNil(t, result) + require.Empty(t, result.Choices) + }) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + t.Run("returns event-stream content type for streaming", func(t *testing.T) { + openAIReq := &openai.ChatCompletionRequest{ + Stream: true, + Model: "test-model", + MaxTokens: ptr.To(int64(100)), + } + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize the stream parser by calling RequestBody with streaming request + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // Now ResponseHeaders should return the streaming content type + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Len(t, headers, 1) + require.Equal(t, contentTypeHeaderName, headers[0].Key()) + require.Equal(t, eventStreamContentType, headers[0].Value()) + }) + + t.Run("returns no headers for non-streaming", func(t *testing.T) { + openAIReq := &openai.ChatCompletionRequest{ + Stream: false, + Model: "test-model", + MaxTokens: ptr.To(int64(100)), + } + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize without streaming + _, _, err := translator.RequestBody(nil, openAIReq, false) + require.NoError(t, err) + + // ResponseHeaders should return nil for non-streaming + headers, err := translator.ResponseHeaders(nil) + require.NoError(t, err) + require.Nil(t, headers) + }) +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_WithDebugLogging(t *testing.T) { + // Create a buffer to capture log output + var logBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + translator.SetRedactionConfig(true, true, logger) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-3", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // Create a response + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello! How can I help you?", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + _, _, _, _, err = translator.ResponseBody(nil, bytes.NewReader(body), true, nil) + require.NoError(t, err) + + // Verify that debug logging occurred + logOutput := logBuf.String() + require.Contains(t, logOutput, "response body processing") +} + +// mockSpan implements tracingapi.ChatCompletionSpan for testing +type mockSpan struct { + recordedResponse *openai.ChatCompletionResponse +} + +func (m *mockSpan) RecordResponseChunk(_ *openai.ChatCompletionResponseChunk) {} +func (m *mockSpan) RecordResponse(resp *openai.ChatCompletionResponse) { + m.recordedResponse = resp +} +func (m *mockSpan) EndSpanOnError(_ int, _ []byte) {} +func (m *mockSpan) EndSpan() {} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody_WithSpanRecording(t *testing.T) { + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator("", "").(*openAIToGCPAnthropicTranslatorV1ChatCompletion) + + // Initialize translator with the model + req := &openai.ChatCompletionRequest{ + Model: "claude-3", + MaxTokens: ptr.To(int64(100)), + Messages: []openai.ChatCompletionMessageParamUnion{ + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{Value: "Hello"}, + Role: openai.ChatMessageRoleUser, + }, + }, + }, + } + reqBody, _ := json.Marshal(req) + _, _, err := translator.RequestBody(reqBody, req, false) + require.NoError(t, err) + + // Create a response + anthropicResponse := anthropic.Message{ + ID: "msg_01XYZ", + Type: constant.ValueOf[constant.Message](), + Role: constant.ValueOf[constant.Assistant](), + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Hello!", + }, + }, + StopReason: anthropic.StopReasonEndTurn, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + body, err := json.Marshal(anthropicResponse) + require.NoError(t, err) + + // Create a mock span + span := &mockSpan{} + + _, _, _, _, err = translator.ResponseBody(nil, bytes.NewReader(body), true, span) + require.NoError(t, err) + + // Verify the span recorded the response + require.NotNil(t, span.recordedResponse) + require.Equal(t, "msg_01XYZ", span.recordedResponse.ID) + require.Len(t, span.recordedResponse.Choices, 1) + require.Equal(t, "Hello!", *span.recordedResponse.Choices[0].Message.Content) +} diff --git a/site/docs/api/api.mdx b/site/docs/api/api.mdx index 47ceb67c58..240eaa93b5 100644 --- a/site/docs/api/api.mdx +++ b/site/docs/api/api.mdx @@ -982,7 +982,7 @@ APISchema defines the API schema. name="AWSAnthropic" type="enum" required="false" - description="APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock.
Uses the native Anthropic Messages API format for requests and responses.
https://aws.amazon.com/bedrock/anthropic/
https://docs.claude.com/en/api/claude-on-amazon-bedrock
" + description="APISchemaAWSAnthropic is the schema for Anthropic models hosted on AWS Bedrock.
Uses the native Anthropic Messages API format for requests and responses.
When used with /v1/chat/completions endpoint, translates OpenAI format to Anthropic.
When used with /v1/messages endpoint, passes through native Anthropic format.
https://aws.amazon.com/bedrock/anthropic/
https://docs.claude.com/en/api/claude-on-amazon-bedrock
" /> #### AWSCredentialsFile