diff --git a/internal/apischema/awsbedrock/awsbedrock.go b/internal/apischema/awsbedrock/awsbedrock.go index 503bb94012..894c40b425 100644 --- a/internal/apischema/awsbedrock/awsbedrock.go +++ b/internal/apischema/awsbedrock/awsbedrock.go @@ -404,11 +404,11 @@ type ConverseOutput struct { // TokenUsage is defined in the AWS Bedrock API: // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_TokenUsage.html type TokenUsage struct { - InputTokens int `json:"inputTokens"` - OutputTokens int `json:"outputTokens"` - TotalTokens int `json:"totalTokens"` - CacheReadInputTokens *int `json:"cacheReadInputTokens,omitempty"` - CacheWriteInputTokens *int `json:"cacheWriteInputTokens,omitempty"` + InputTokens int64 `json:"inputTokens"` + OutputTokens int64 `json:"outputTokens"` + TotalTokens int64 `json:"totalTokens"` + CacheReadInputTokens *int64 `json:"cacheReadInputTokens,omitempty"` + CacheWriteInputTokens *int64 `json:"cacheWriteInputTokens,omitempty"` } // ConverseStreamEventType represents a distinct event type received from the Bedrock ConverseStream API. diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index e9929df812..61058ff9b0 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -259,21 +259,26 @@ func (u *TokenUsage) Override(other TokenUsage) { } } -// ExtractTokenUsageFromAnthropic extracts the correct token usage from Anthropic API response. -// According to Claude API documentation, total input tokens is the summation of: +// ExtractTokenUsageFromExplicitCaching extracts the correct token usage from upstream Anthropic or AWS Bedrock token usage response. +// The total input tokens is the summation of: // input_tokens + cache_creation_input_tokens + cache_read_input_tokens +// This is to unify the usage response returned by envoy ai gateway for both explicit and implicit caching. // // This function works for both streaming and non-streaming responses by accepting -// the common usage fields that exist in all Anthropic usage structures. -func ExtractTokenUsageFromAnthropic(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) TokenUsage { - // Calculate total input tokens as per Anthropic API documentation - totalInputTokens := inputTokens + cacheCreationTokens + cacheReadTokens - +// the common usage fields that exist from anthropic or AWS bedrock usage structures. +func ExtractTokenUsageFromExplicitCaching(inputTokens, outputTokens int64, cacheReadTokens, cacheCreationTokens *int64) TokenUsage { var usage TokenUsage - usage.SetInputTokens(uint32(totalInputTokens)) //nolint:gosec - usage.SetOutputTokens(uint32(outputTokens)) //nolint:gosec - usage.SetTotalTokens(uint32(totalInputTokens + outputTokens)) //nolint:gosec - usage.SetCachedInputTokens(uint32(cacheReadTokens)) //nolint:gosec - usage.SetCacheCreationInputTokens(uint32(cacheCreationTokens)) //nolint:gosec + totalInputTokens := inputTokens + if cacheCreationTokens != nil { + totalInputTokens += *cacheCreationTokens + usage.SetCacheCreationInputTokens(uint32(*cacheCreationTokens)) //nolint:gosec + } + if cacheReadTokens != nil { + totalInputTokens += *cacheReadTokens + usage.SetCachedInputTokens(uint32(*cacheReadTokens)) //nolint:gosec + } + usage.SetInputTokens(uint32(totalInputTokens)) //nolint:gosec + usage.SetOutputTokens(uint32(outputTokens)) //nolint:gosec + usage.SetTotalTokens(uint32(totalInputTokens + outputTokens)) //nolint:gosec return usage } diff --git a/internal/tracing/openinference/anthropic/messages.go b/internal/tracing/openinference/anthropic/messages.go index c513611b46..27da0c8bb1 100644 --- a/internal/tracing/openinference/anthropic/messages.go +++ b/internal/tracing/openinference/anthropic/messages.go @@ -207,11 +207,13 @@ func buildResponseAttributes(resp *anthropic.MessagesResponse, config *openinfer // Token counts are considered metadata and are still included even when output content is hidden. u := resp.Usage - cost := metrics.ExtractTokenUsageFromAnthropic( + cacheReadTokens := int64(u.CacheReadInputTokens) + cacheCreationTokens := int64(u.CacheCreationInputTokens) + cost := metrics.ExtractTokenUsageFromExplicitCaching( int64(u.InputTokens), int64(u.OutputTokens), - int64(u.CacheReadInputTokens), - int64(u.CacheCreationInputTokens), + &cacheReadTokens, + &cacheCreationTokens, ) input, _ := cost.InputTokens() cacheRead, _ := cost.CachedInputTokens() diff --git a/internal/translator/anthropic_anthropic.go b/internal/translator/anthropic_anthropic.go index 0a5294a955..43b8f223f0 100644 --- a/internal/translator/anthropic_anthropic.go +++ b/internal/translator/anthropic_anthropic.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/tidwall/sjson" + "k8s.io/utils/ptr" "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" "github.com/envoyproxy/ai-gateway/internal/internalapi" @@ -99,11 +100,11 @@ func (a *anthropicToAnthropicTranslator) ResponseBody(_ map[string]string, body return nil, nil, tokenUsage, responseModel, fmt.Errorf("failed to unmarshal body: %w", err) } usage := anthropicResp.Usage - tokenUsage = metrics.ExtractTokenUsageFromAnthropic( + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( int64(usage.InputTokens), int64(usage.OutputTokens), - int64(usage.CacheReadInputTokens), - int64(usage.CacheCreationInputTokens), + ptr.To(int64(usage.CacheReadInputTokens)), + ptr.To(int64(usage.CacheCreationInputTokens)), ) if span != nil { span.RecordResponse(anthropicResp) @@ -144,11 +145,11 @@ func (a *anthropicToAnthropicTranslator) extractUsageFromBufferEvent(s tracing.M } // Extract usage from message_start event - this sets the baseline input tokens if u := message.Usage; u != nil { - messageStartUsage := metrics.ExtractTokenUsageFromAnthropic( + messageStartUsage := metrics.ExtractTokenUsageFromExplicitCaching( int64(u.InputTokens), int64(u.OutputTokens), - int64(u.CacheReadInputTokens), - int64(u.CacheCreationInputTokens), + ptr.To(int64(u.CacheReadInputTokens)), + ptr.To(int64(u.CacheCreationInputTokens)), ) // Override with message_start usage (contains input tokens and initial state) a.streamingTokenUsage.Override(messageStartUsage) diff --git a/internal/translator/anthropic_usage_test.go b/internal/translator/anthropic_usage_test.go index cb75e8a7d0..dd35db7f74 100644 --- a/internal/translator/anthropic_usage_test.go +++ b/internal/translator/anthropic_usage_test.go @@ -10,6 +10,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/stretchr/testify/assert" + "k8s.io/utils/ptr" "github.com/envoyproxy/ai-gateway/internal/metrics" ) @@ -103,11 +104,11 @@ func TestExtractLLMTokenUsage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := metrics.ExtractTokenUsageFromAnthropic( + result := metrics.ExtractTokenUsageFromExplicitCaching( tt.inputTokens, tt.outputTokens, - tt.cacheReadTokens, - tt.cacheCreationTokens, + &tt.cacheReadTokens, + &tt.cacheCreationTokens, ) expected := tokenUsageFrom( @@ -178,10 +179,10 @@ func TestExtractLLMTokenUsageFromUsage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := metrics.ExtractTokenUsageFromAnthropic(tt.usage.InputTokens, + result := metrics.ExtractTokenUsageFromExplicitCaching(tt.usage.InputTokens, tt.usage.OutputTokens, - tt.usage.CacheReadInputTokens, - tt.usage.CacheCreationInputTokens, + &tt.usage.CacheReadInputTokens, + &tt.usage.CacheCreationInputTokens, ) expected := tokenUsageFrom(tt.expectedInputTokens, int32(tt.expectedCachedTokens), int32(tt.expectedCacheCreationTokens), tt.expectedOutputTokens, tt.expectedTotalTokens) // nolint:gosec assert.Equal(t, expected, result) @@ -245,10 +246,10 @@ func TestExtractLLMTokenUsageFromDeltaUsage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := metrics.ExtractTokenUsageFromAnthropic(tt.usage.InputTokens, + result := metrics.ExtractTokenUsageFromExplicitCaching(tt.usage.InputTokens, tt.usage.OutputTokens, - tt.usage.CacheReadInputTokens, - tt.usage.CacheCreationInputTokens, + &tt.usage.CacheReadInputTokens, + &tt.usage.CacheCreationInputTokens, ) expected := tokenUsageFrom(tt.expectedInputTokens, int32(tt.expectedCachedTokens), int32(tt.expectedCacheCreationTokens), tt.expectedOutputTokens, tt.expectedTotalTokens) // nolint:gosec assert.Equal(t, expected, result) @@ -261,7 +262,8 @@ func TestExtractLLMTokenUsage_EdgeCases(t *testing.T) { t.Run("negative values should be handled", func(t *testing.T) { // Note: In practice, the Anthropic API shouldn't return negative values, // but our function should handle them gracefully by casting to uint32. - result := metrics.ExtractTokenUsageFromAnthropic(-10, -5, -2, -1) + result := metrics.ExtractTokenUsageFromExplicitCaching(-10, -5, ptr.To[int64](-2), + ptr.To[int64](-1)) // Negative int64 values will wrap around when cast to uint32. // This test documents current behavior rather than prescribing it. @@ -272,7 +274,8 @@ func TestExtractLLMTokenUsage_EdgeCases(t *testing.T) { t.Run("maximum int64 values", func(t *testing.T) { // Test with very large values to ensure no overflow issues. // Note: This will result in truncation when casting to uint32. - result := metrics.ExtractTokenUsageFromAnthropic(9223372036854775807, 1000, 500, 100) + result := metrics.ExtractTokenUsageFromExplicitCaching(9223372036854775807, 1000, + ptr.To[int64](500), ptr.To[int64](100)) assert.NotNil(t, result) }) } @@ -285,14 +288,14 @@ func TestExtractLLMTokenUsage_ClaudeAPIDocumentationCompliance(t *testing.T) { // cache_creation_input_tokens, and cache_read_input_tokens". inputTokens := int64(100) - cachedWriteTokens := int64(20) + cacheCreationTokens := int64(20) cacheReadTokens := int64(30) outputTokens := int64(50) - result := metrics.ExtractTokenUsageFromAnthropic(inputTokens, outputTokens, cacheReadTokens, cachedWriteTokens) + result := metrics.ExtractTokenUsageFromExplicitCaching(inputTokens, outputTokens, &cacheReadTokens, &cacheCreationTokens) // Total input should be sum of all input token types. - expectedTotalInputInt := inputTokens + cachedWriteTokens + cacheReadTokens + expectedTotalInputInt := inputTokens + cacheCreationTokens + cacheReadTokens expectedTotalInput := uint32(expectedTotalInputInt) // #nosec G115 - test values are small and safe inputTokensVal, ok := result.InputTokens() assert.True(t, ok) @@ -301,16 +304,16 @@ func TestExtractLLMTokenUsage_ClaudeAPIDocumentationCompliance(t *testing.T) { cachedTokens, ok := result.CachedInputTokens() assert.True(t, ok) - assert.Equal(t, uint32(cacheReadTokens), cachedTokens, + assert.Equal(t, uint32(cacheReadTokens), cachedTokens, // #nosec G115 - test values are small and safe "CachedInputTokens should be cache_read_input_tokens") - cacheCreationTokens, ok := result.CacheCreationInputTokens() + cacheCreationResult, ok := result.CacheCreationInputTokens() assert.True(t, ok) - assert.Equal(t, uint32(cachedWriteTokens), cacheCreationTokens, + assert.Equal(t, uint32(cacheCreationTokens), cacheCreationResult, // #nosec G115 - test values are small and safe "CacheCreationInputTokens should be cache_creation_input_tokens") // Total tokens should be input + output. - expectedTotal := expectedTotalInput + uint32(outputTokens) + expectedTotal := expectedTotalInput + uint32(outputTokens) // #nosec G115 - test values are small and safe totalTokens, ok := result.TotalTokens() assert.True(t, ok) assert.Equal(t, expectedTotal, totalTokens, diff --git a/internal/translator/openai_awsbedrock.go b/internal/translator/openai_awsbedrock.go index f0ecf6a69e..79c3a12235 100644 --- a/internal/translator/openai_awsbedrock.go +++ b/internal/translator/openai_awsbedrock.go @@ -701,15 +701,8 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string for i := range o.events { event := &o.events[i] if usage := event.Usage; usage != nil { - tokenUsage.SetInputTokens(uint32(usage.InputTokens)) //nolint:gosec - tokenUsage.SetOutputTokens(uint32(usage.OutputTokens)) //nolint:gosec - tokenUsage.SetTotalTokens(uint32(usage.TotalTokens)) //nolint:gosec - if usage.CacheReadInputTokens != nil { - tokenUsage.SetCachedInputTokens(uint32(*usage.CacheReadInputTokens)) //nolint:gosec - } - if usage.CacheWriteInputTokens != nil { - tokenUsage.SetCacheCreationInputTokens(uint32(*usage.CacheWriteInputTokens)) //nolint:gosec - } + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(usage.InputTokens, usage.OutputTokens, + usage.CacheReadInputTokens, usage.CacheWriteInputTokens) } oaiEvent, ok := o.convertEvent(event) if !ok { @@ -744,24 +737,26 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string } // Convert token usage. if bedrockResp.Usage != nil { - tokenUsage.SetInputTokens(uint32(bedrockResp.Usage.InputTokens)) //nolint:gosec - tokenUsage.SetOutputTokens(uint32(bedrockResp.Usage.OutputTokens)) //nolint:gosec - tokenUsage.SetTotalTokens(uint32(bedrockResp.Usage.TotalTokens)) //nolint:gosec + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(bedrockResp.Usage.InputTokens, bedrockResp.Usage.OutputTokens, + bedrockResp.Usage.CacheReadInputTokens, bedrockResp.Usage.CacheWriteInputTokens) + totalTokens, _ := tokenUsage.TotalTokens() + inputTokens, _ := tokenUsage.InputTokens() + outputTokens, _ := tokenUsage.OutputTokens() openAIResp.Usage = openai.Usage{ - TotalTokens: bedrockResp.Usage.TotalTokens, - PromptTokens: bedrockResp.Usage.InputTokens, - CompletionTokens: bedrockResp.Usage.OutputTokens, + TotalTokens: int(totalTokens), + PromptTokens: int(inputTokens), + CompletionTokens: int(outputTokens), } if bedrockResp.Usage.CacheReadInputTokens != nil || bedrockResp.Usage.CacheWriteInputTokens != nil { openAIResp.Usage.PromptTokensDetails = &openai.PromptTokensDetails{} } if bedrockResp.Usage.CacheReadInputTokens != nil { tokenUsage.SetCachedInputTokens(uint32(*bedrockResp.Usage.CacheReadInputTokens)) //nolint:gosec - openAIResp.Usage.PromptTokensDetails.CachedTokens = *bedrockResp.Usage.CacheReadInputTokens + openAIResp.Usage.PromptTokensDetails.CachedTokens = int(*bedrockResp.Usage.CacheReadInputTokens) } if bedrockResp.Usage.CacheWriteInputTokens != nil { tokenUsage.SetCacheCreationInputTokens(uint32(*bedrockResp.Usage.CacheWriteInputTokens)) //nolint:gosec - openAIResp.Usage.PromptTokensDetails.CacheCreationTokens = *bedrockResp.Usage.CacheWriteInputTokens + openAIResp.Usage.PromptTokensDetails.CacheCreationTokens = int(*bedrockResp.Usage.CacheWriteInputTokens) } } @@ -852,19 +847,24 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe if event.Usage == nil { return chunk, false } + tokenUsage := metrics.ExtractTokenUsageFromExplicitCaching(event.Usage.InputTokens, event.Usage.OutputTokens, + event.Usage.CacheReadInputTokens, event.Usage.CacheWriteInputTokens) + totalTokens, _ := tokenUsage.TotalTokens() + inputTokens, _ := tokenUsage.InputTokens() + outputTokens, _ := tokenUsage.OutputTokens() chunk.Usage = &openai.Usage{ - TotalTokens: event.Usage.TotalTokens, - PromptTokens: event.Usage.InputTokens, - CompletionTokens: event.Usage.OutputTokens, + TotalTokens: int(totalTokens), + PromptTokens: int(inputTokens), + CompletionTokens: int(outputTokens), } if event.Usage.CacheReadInputTokens != nil || event.Usage.CacheWriteInputTokens != nil { chunk.Usage.PromptTokensDetails = &openai.PromptTokensDetails{} } if event.Usage.CacheReadInputTokens != nil { - chunk.Usage.PromptTokensDetails.CachedTokens = *event.Usage.CacheReadInputTokens + chunk.Usage.PromptTokensDetails.CachedTokens = int(*event.Usage.CacheReadInputTokens) } if event.Usage.CacheWriteInputTokens != nil { - chunk.Usage.PromptTokensDetails.CacheCreationTokens = *event.Usage.CacheWriteInputTokens + chunk.Usage.PromptTokensDetails.CacheCreationTokens = int(*event.Usage.CacheWriteInputTokens) } // messageStart event. case awsbedrock.ConverseStreamEventTypeMessageStart.String(): diff --git a/internal/translator/openai_awsbedrock_test.go b/internal/translator/openai_awsbedrock_test.go index fd9ef5c6f8..865632e8d6 100644 --- a/internal/translator/openai_awsbedrock_test.go +++ b/internal/translator/openai_awsbedrock_test.go @@ -1450,8 +1450,8 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) InputTokens: 10, OutputTokens: 20, TotalTokens: 30, - CacheReadInputTokens: ptr.To(5), - CacheWriteInputTokens: ptr.To(7), + CacheWriteInputTokens: ptr.To[int64](7), + CacheReadInputTokens: ptr.To[int64](5), }, Output: &awsbedrock.ConverseOutput{ Message: awsbedrock.Message{ @@ -1470,8 +1470,8 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)), Object: "chat.completion", Usage: openai.Usage{ - TotalTokens: 30, - PromptTokens: 10, + TotalTokens: 42, + PromptTokens: 22, CompletionTokens: 20, PromptTokensDetails: &openai.PromptTokensDetails{ CachedTokens: 5, @@ -1902,7 +1902,7 @@ func TestOpenAIToAWSBedrockTranslatorExtractAmazonEventStreamEvents(t *testing.T strings.Join(texts, ""), ) require.NotNil(t, usage) - require.Equal(t, 461, usage.TotalTokens) + require.Equal(t, int64(461), usage.TotalTokens) }) } @@ -1921,7 +1921,7 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) { InputTokens: 10, OutputTokens: 20, TotalTokens: 30, - CacheReadInputTokens: ptr.To(5), + CacheReadInputTokens: ptr.To[int64](5), }, }, out: &openai.ChatCompletionResponseChunk{ @@ -1930,8 +1930,8 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) { Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)), // 0 nanoseconds Object: "chat.completion.chunk", Usage: &openai.Usage{ - TotalTokens: 30, - PromptTokens: 10, + TotalTokens: 35, + PromptTokens: 15, CompletionTokens: 20, PromptTokensDetails: &openai.PromptTokensDetails{ CachedTokens: 5, @@ -2016,7 +2016,8 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) { require.False(t, ok) } else { // Use require.True and cmp.Equal with the options - require.True(t, cmp.Equal(*tc.out, *chunk, ignoreDynamicFields), "The ChatCompletionResponseChunk structs should be equal ignoring the ID and Created field") + require.True(t, cmp.Equal(*tc.out, *chunk, ignoreDynamicFields), + "The ChatCompletionResponseChunk structs should be equal ignoring the ID and Created field") } }) } diff --git a/internal/translator/openai_gcpanthropic.go b/internal/translator/openai_gcpanthropic.go index 716053cf4b..9c5493a250 100644 --- a/internal/translator/openai_gcpanthropic.go +++ b/internal/translator/openai_gcpanthropic.go @@ -829,11 +829,11 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri Created: openai.JSONUNIXTime(time.Now()), } usage := anthropicResp.Usage - tokenUsage = metrics.ExtractTokenUsageFromAnthropic( + tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching( usage.InputTokens, usage.OutputTokens, - usage.CacheReadInputTokens, - usage.CacheCreationInputTokens, + &usage.CacheReadInputTokens, + &usage.CacheCreationInputTokens, ) inputTokens, _ := tokenUsage.InputTokens() outputTokens, _ := tokenUsage.OutputTokens() diff --git a/internal/translator/openai_gcpanthropic_stream.go b/internal/translator/openai_gcpanthropic_stream.go index 1846ec7358..d43808ec30 100644 --- a/internal/translator/openai_gcpanthropic_stream.go +++ b/internal/translator/openai_gcpanthropic_stream.go @@ -199,11 +199,11 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat p.activeMessageID = event.Message.ID p.created = openai.JSONUNIXTime(time.Now()) u := event.Message.Usage - usage := metrics.ExtractTokenUsageFromAnthropic( + usage := metrics.ExtractTokenUsageFromExplicitCaching( u.InputTokens, u.OutputTokens, - u.CacheReadInputTokens, - u.CacheCreationInputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, ) // For message_start, we store the initial usage but don't add to the accumulated // The message_delta event will contain the final totals @@ -281,11 +281,11 @@ func (p *anthropicStreamParser) handleAnthropicStreamEvent(eventType []byte, dat return nil, fmt.Errorf("unmarshal message_delta: %w", err) } u := event.Usage - usage := metrics.ExtractTokenUsageFromAnthropic( + usage := metrics.ExtractTokenUsageFromExplicitCaching( u.InputTokens, u.OutputTokens, - u.CacheReadInputTokens, - u.CacheCreationInputTokens, + &u.CacheReadInputTokens, + &u.CacheCreationInputTokens, ) // For message_delta, accumulate the incremental output tokens if output, ok := usage.OutputTokens(); ok {