Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions internal/apischema/awsbedrock/awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,11 @@ type ConverseOutput struct {
// TokenUsage is defined in the AWS Bedrock API:
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_TokenUsage.html
type TokenUsage struct {
InputTokens int `json:"inputTokens"`
OutputTokens int `json:"outputTokens"`
TotalTokens int `json:"totalTokens"`
CacheReadInputTokens *int `json:"cacheReadInputTokens,omitempty"`
CacheWriteInputTokens *int `json:"cacheWriteInputTokens,omitempty"`
InputTokens int64 `json:"inputTokens"`
OutputTokens int64 `json:"outputTokens"`
TotalTokens int64 `json:"totalTokens"`
CacheReadInputTokens *int64 `json:"cacheReadInputTokens,omitempty"`
CacheWriteInputTokens *int64 `json:"cacheWriteInputTokens,omitempty"`
}

// ConverseStreamEventType represents a distinct event type received from the Bedrock ConverseStream API.
Expand Down
29 changes: 17 additions & 12 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,26 @@ func (u *TokenUsage) Override(other TokenUsage) {
}
}

// ExtractTokenUsageFromAnthropic extracts the correct token usage from Anthropic API response.
// According to Claude API documentation, total input tokens is the summation of:
// ExtractTokenUsageFromExplicitCaching extracts the correct token usage from upstream Anthropic or AWS Bedrock token usage response.
// The total input tokens is the summation of:
// input_tokens + cache_creation_input_tokens + cache_read_input_tokens
// This is to unify the usage response returned by envoy ai gateway for both explicit and implicit caching.
//
// This function works for both streaming and non-streaming responses by accepting
// the common usage fields that exist in all Anthropic usage structures.
func ExtractTokenUsageFromAnthropic(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) TokenUsage {
// Calculate total input tokens as per Anthropic API documentation
totalInputTokens := inputTokens + cacheCreationTokens + cacheReadTokens

// the common usage fields that exist from anthropic or AWS bedrock usage structures.
func ExtractTokenUsageFromExplicitCaching(inputTokens, outputTokens int64, cacheReadTokens, cacheCreationTokens *int64) TokenUsage {
var usage TokenUsage
usage.SetInputTokens(uint32(totalInputTokens)) //nolint:gosec
usage.SetOutputTokens(uint32(outputTokens)) //nolint:gosec
usage.SetTotalTokens(uint32(totalInputTokens + outputTokens)) //nolint:gosec
usage.SetCachedInputTokens(uint32(cacheReadTokens)) //nolint:gosec
usage.SetCacheCreationInputTokens(uint32(cacheCreationTokens)) //nolint:gosec
totalInputTokens := inputTokens
if cacheCreationTokens != nil {
totalInputTokens += *cacheCreationTokens
usage.SetCacheCreationInputTokens(uint32(*cacheCreationTokens)) //nolint:gosec
}
if cacheReadTokens != nil {
totalInputTokens += *cacheReadTokens
usage.SetCachedInputTokens(uint32(*cacheReadTokens)) //nolint:gosec
}
usage.SetInputTokens(uint32(totalInputTokens)) //nolint:gosec
usage.SetOutputTokens(uint32(outputTokens)) //nolint:gosec
usage.SetTotalTokens(uint32(totalInputTokens + outputTokens)) //nolint:gosec
return usage
}
8 changes: 5 additions & 3 deletions internal/tracing/openinference/anthropic/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@ func buildResponseAttributes(resp *anthropic.MessagesResponse, config *openinfer

// Token counts are considered metadata and are still included even when output content is hidden.
u := resp.Usage
cost := metrics.ExtractTokenUsageFromAnthropic(
cacheReadTokens := int64(u.CacheReadInputTokens)
cacheCreationTokens := int64(u.CacheCreationInputTokens)
cost := metrics.ExtractTokenUsageFromExplicitCaching(
int64(u.InputTokens),
int64(u.OutputTokens),
int64(u.CacheReadInputTokens),
int64(u.CacheCreationInputTokens),
&cacheReadTokens,
&cacheCreationTokens,
)
input, _ := cost.InputTokens()
cacheRead, _ := cost.CachedInputTokens()
Expand Down
13 changes: 7 additions & 6 deletions internal/translator/anthropic_anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"

"github.com/tidwall/sjson"
"k8s.io/utils/ptr"

"github.com/envoyproxy/ai-gateway/internal/apischema/anthropic"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
Expand Down Expand Up @@ -99,11 +100,11 @@ func (a *anthropicToAnthropicTranslator) ResponseBody(_ map[string]string, body
return nil, nil, tokenUsage, responseModel, fmt.Errorf("failed to unmarshal body: %w", err)
}
usage := anthropicResp.Usage
tokenUsage = metrics.ExtractTokenUsageFromAnthropic(
tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(
int64(usage.InputTokens),
int64(usage.OutputTokens),
int64(usage.CacheReadInputTokens),
int64(usage.CacheCreationInputTokens),
ptr.To(int64(usage.CacheReadInputTokens)),
ptr.To(int64(usage.CacheCreationInputTokens)),
)
if span != nil {
span.RecordResponse(anthropicResp)
Expand Down Expand Up @@ -144,11 +145,11 @@ func (a *anthropicToAnthropicTranslator) extractUsageFromBufferEvent(s tracing.M
}
// Extract usage from message_start event - this sets the baseline input tokens
if u := message.Usage; u != nil {
messageStartUsage := metrics.ExtractTokenUsageFromAnthropic(
messageStartUsage := metrics.ExtractTokenUsageFromExplicitCaching(
int64(u.InputTokens),
int64(u.OutputTokens),
int64(u.CacheReadInputTokens),
int64(u.CacheCreationInputTokens),
ptr.To(int64(u.CacheReadInputTokens)),
ptr.To(int64(u.CacheCreationInputTokens)),
)
// Override with message_start usage (contains input tokens and initial state)
a.streamingTokenUsage.Override(messageStartUsage)
Expand Down
39 changes: 21 additions & 18 deletions internal/translator/anthropic_usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/anthropics/anthropic-sdk-go"
"github.com/stretchr/testify/assert"
"k8s.io/utils/ptr"

"github.com/envoyproxy/ai-gateway/internal/metrics"
)
Expand Down Expand Up @@ -103,11 +104,11 @@ func TestExtractLLMTokenUsage(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := metrics.ExtractTokenUsageFromAnthropic(
result := metrics.ExtractTokenUsageFromExplicitCaching(
tt.inputTokens,
tt.outputTokens,
tt.cacheReadTokens,
tt.cacheCreationTokens,
&tt.cacheReadTokens,
&tt.cacheCreationTokens,
)

expected := tokenUsageFrom(
Expand Down Expand Up @@ -178,10 +179,10 @@ func TestExtractLLMTokenUsageFromUsage(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := metrics.ExtractTokenUsageFromAnthropic(tt.usage.InputTokens,
result := metrics.ExtractTokenUsageFromExplicitCaching(tt.usage.InputTokens,
tt.usage.OutputTokens,
tt.usage.CacheReadInputTokens,
tt.usage.CacheCreationInputTokens,
&tt.usage.CacheReadInputTokens,
&tt.usage.CacheCreationInputTokens,
)
expected := tokenUsageFrom(tt.expectedInputTokens, int32(tt.expectedCachedTokens), int32(tt.expectedCacheCreationTokens), tt.expectedOutputTokens, tt.expectedTotalTokens) // nolint:gosec
assert.Equal(t, expected, result)
Expand Down Expand Up @@ -245,10 +246,10 @@ func TestExtractLLMTokenUsageFromDeltaUsage(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := metrics.ExtractTokenUsageFromAnthropic(tt.usage.InputTokens,
result := metrics.ExtractTokenUsageFromExplicitCaching(tt.usage.InputTokens,
tt.usage.OutputTokens,
tt.usage.CacheReadInputTokens,
tt.usage.CacheCreationInputTokens,
&tt.usage.CacheReadInputTokens,
&tt.usage.CacheCreationInputTokens,
)
expected := tokenUsageFrom(tt.expectedInputTokens, int32(tt.expectedCachedTokens), int32(tt.expectedCacheCreationTokens), tt.expectedOutputTokens, tt.expectedTotalTokens) // nolint:gosec
assert.Equal(t, expected, result)
Expand All @@ -261,7 +262,8 @@ func TestExtractLLMTokenUsage_EdgeCases(t *testing.T) {
t.Run("negative values should be handled", func(t *testing.T) {
// Note: In practice, the Anthropic API shouldn't return negative values,
// but our function should handle them gracefully by casting to uint32.
result := metrics.ExtractTokenUsageFromAnthropic(-10, -5, -2, -1)
result := metrics.ExtractTokenUsageFromExplicitCaching(-10, -5, ptr.To[int64](-2),
ptr.To[int64](-1))

// Negative int64 values will wrap around when cast to uint32.
// This test documents current behavior rather than prescribing it.
Expand All @@ -272,7 +274,8 @@ func TestExtractLLMTokenUsage_EdgeCases(t *testing.T) {
t.Run("maximum int64 values", func(t *testing.T) {
// Test with very large values to ensure no overflow issues.
// Note: This will result in truncation when casting to uint32.
result := metrics.ExtractTokenUsageFromAnthropic(9223372036854775807, 1000, 500, 100)
result := metrics.ExtractTokenUsageFromExplicitCaching(9223372036854775807, 1000,
ptr.To[int64](500), ptr.To[int64](100))
assert.NotNil(t, result)
})
}
Expand All @@ -285,14 +288,14 @@ func TestExtractLLMTokenUsage_ClaudeAPIDocumentationCompliance(t *testing.T) {
// cache_creation_input_tokens, and cache_read_input_tokens".

inputTokens := int64(100)
cachedWriteTokens := int64(20)
cacheCreationTokens := int64(20)
cacheReadTokens := int64(30)
outputTokens := int64(50)

result := metrics.ExtractTokenUsageFromAnthropic(inputTokens, outputTokens, cacheReadTokens, cachedWriteTokens)
result := metrics.ExtractTokenUsageFromExplicitCaching(inputTokens, outputTokens, &cacheReadTokens, &cacheCreationTokens)

// Total input should be sum of all input token types.
expectedTotalInputInt := inputTokens + cachedWriteTokens + cacheReadTokens
expectedTotalInputInt := inputTokens + cacheCreationTokens + cacheReadTokens
expectedTotalInput := uint32(expectedTotalInputInt) // #nosec G115 - test values are small and safe
inputTokensVal, ok := result.InputTokens()
assert.True(t, ok)
Expand All @@ -301,16 +304,16 @@ func TestExtractLLMTokenUsage_ClaudeAPIDocumentationCompliance(t *testing.T) {

cachedTokens, ok := result.CachedInputTokens()
assert.True(t, ok)
assert.Equal(t, uint32(cacheReadTokens), cachedTokens,
assert.Equal(t, uint32(cacheReadTokens), cachedTokens, // #nosec G115 - test values are small and safe
"CachedInputTokens should be cache_read_input_tokens")

cacheCreationTokens, ok := result.CacheCreationInputTokens()
cacheCreationResult, ok := result.CacheCreationInputTokens()
assert.True(t, ok)
assert.Equal(t, uint32(cachedWriteTokens), cacheCreationTokens,
assert.Equal(t, uint32(cacheCreationTokens), cacheCreationResult, // #nosec G115 - test values are small and safe
"CacheCreationInputTokens should be cache_creation_input_tokens")

// Total tokens should be input + output.
expectedTotal := expectedTotalInput + uint32(outputTokens)
expectedTotal := expectedTotalInput + uint32(outputTokens) // #nosec G115 - test values are small and safe
totalTokens, ok := result.TotalTokens()
assert.True(t, ok)
assert.Equal(t, expectedTotal, totalTokens,
Expand Down
44 changes: 22 additions & 22 deletions internal/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,15 +701,8 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
for i := range o.events {
event := &o.events[i]
if usage := event.Usage; usage != nil {
tokenUsage.SetInputTokens(uint32(usage.InputTokens)) //nolint:gosec
tokenUsage.SetOutputTokens(uint32(usage.OutputTokens)) //nolint:gosec
tokenUsage.SetTotalTokens(uint32(usage.TotalTokens)) //nolint:gosec
if usage.CacheReadInputTokens != nil {
tokenUsage.SetCachedInputTokens(uint32(*usage.CacheReadInputTokens)) //nolint:gosec
}
if usage.CacheWriteInputTokens != nil {
tokenUsage.SetCacheCreationInputTokens(uint32(*usage.CacheWriteInputTokens)) //nolint:gosec
}
tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(usage.InputTokens, usage.OutputTokens,
usage.CacheReadInputTokens, usage.CacheWriteInputTokens)
}
oaiEvent, ok := o.convertEvent(event)
if !ok {
Expand Down Expand Up @@ -744,24 +737,26 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
}
// Convert token usage.
if bedrockResp.Usage != nil {
tokenUsage.SetInputTokens(uint32(bedrockResp.Usage.InputTokens)) //nolint:gosec
tokenUsage.SetOutputTokens(uint32(bedrockResp.Usage.OutputTokens)) //nolint:gosec
tokenUsage.SetTotalTokens(uint32(bedrockResp.Usage.TotalTokens)) //nolint:gosec
tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(bedrockResp.Usage.InputTokens, bedrockResp.Usage.OutputTokens,
bedrockResp.Usage.CacheReadInputTokens, bedrockResp.Usage.CacheWriteInputTokens)
totalTokens, _ := tokenUsage.TotalTokens()
inputTokens, _ := tokenUsage.InputTokens()
outputTokens, _ := tokenUsage.OutputTokens()
openAIResp.Usage = openai.Usage{
TotalTokens: bedrockResp.Usage.TotalTokens,
PromptTokens: bedrockResp.Usage.InputTokens,
CompletionTokens: bedrockResp.Usage.OutputTokens,
TotalTokens: int(totalTokens),
PromptTokens: int(inputTokens),
CompletionTokens: int(outputTokens),
}
if bedrockResp.Usage.CacheReadInputTokens != nil || bedrockResp.Usage.CacheWriteInputTokens != nil {
openAIResp.Usage.PromptTokensDetails = &openai.PromptTokensDetails{}
}
if bedrockResp.Usage.CacheReadInputTokens != nil {
tokenUsage.SetCachedInputTokens(uint32(*bedrockResp.Usage.CacheReadInputTokens)) //nolint:gosec
openAIResp.Usage.PromptTokensDetails.CachedTokens = *bedrockResp.Usage.CacheReadInputTokens
openAIResp.Usage.PromptTokensDetails.CachedTokens = int(*bedrockResp.Usage.CacheReadInputTokens)
}
if bedrockResp.Usage.CacheWriteInputTokens != nil {
tokenUsage.SetCacheCreationInputTokens(uint32(*bedrockResp.Usage.CacheWriteInputTokens)) //nolint:gosec
openAIResp.Usage.PromptTokensDetails.CacheCreationTokens = *bedrockResp.Usage.CacheWriteInputTokens
openAIResp.Usage.PromptTokensDetails.CacheCreationTokens = int(*bedrockResp.Usage.CacheWriteInputTokens)
}
}

Expand Down Expand Up @@ -852,19 +847,24 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
if event.Usage == nil {
return chunk, false
}
tokenUsage := metrics.ExtractTokenUsageFromExplicitCaching(event.Usage.InputTokens, event.Usage.OutputTokens,
event.Usage.CacheReadInputTokens, event.Usage.CacheWriteInputTokens)
totalTokens, _ := tokenUsage.TotalTokens()
inputTokens, _ := tokenUsage.InputTokens()
outputTokens, _ := tokenUsage.OutputTokens()
chunk.Usage = &openai.Usage{
TotalTokens: event.Usage.TotalTokens,
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: int(totalTokens),
PromptTokens: int(inputTokens),
CompletionTokens: int(outputTokens),
}
if event.Usage.CacheReadInputTokens != nil || event.Usage.CacheWriteInputTokens != nil {
chunk.Usage.PromptTokensDetails = &openai.PromptTokensDetails{}
}
if event.Usage.CacheReadInputTokens != nil {
chunk.Usage.PromptTokensDetails.CachedTokens = *event.Usage.CacheReadInputTokens
chunk.Usage.PromptTokensDetails.CachedTokens = int(*event.Usage.CacheReadInputTokens)
}
if event.Usage.CacheWriteInputTokens != nil {
chunk.Usage.PromptTokensDetails.CacheCreationTokens = *event.Usage.CacheWriteInputTokens
chunk.Usage.PromptTokensDetails.CacheCreationTokens = int(*event.Usage.CacheWriteInputTokens)
}
// messageStart event.
case awsbedrock.ConverseStreamEventTypeMessageStart.String():
Expand Down
19 changes: 10 additions & 9 deletions internal/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1450,8 +1450,8 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
CacheReadInputTokens: ptr.To(5),
CacheWriteInputTokens: ptr.To(7),
CacheWriteInputTokens: ptr.To[int64](7),
CacheReadInputTokens: ptr.To[int64](5),
},
Output: &awsbedrock.ConverseOutput{
Message: awsbedrock.Message{
Expand All @@ -1470,8 +1470,8 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)),
Object: "chat.completion",
Usage: openai.Usage{
TotalTokens: 30,
PromptTokens: 10,
TotalTokens: 42,
PromptTokens: 22,
CompletionTokens: 20,
PromptTokensDetails: &openai.PromptTokensDetails{
CachedTokens: 5,
Expand Down Expand Up @@ -1902,7 +1902,7 @@ func TestOpenAIToAWSBedrockTranslatorExtractAmazonEventStreamEvents(t *testing.T
strings.Join(texts, ""),
)
require.NotNil(t, usage)
require.Equal(t, 461, usage.TotalTokens)
require.Equal(t, int64(461), usage.TotalTokens)
})
}

Expand All @@ -1921,7 +1921,7 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
CacheReadInputTokens: ptr.To(5),
CacheReadInputTokens: ptr.To[int64](5),
},
},
out: &openai.ChatCompletionResponseChunk{
Expand All @@ -1930,8 +1930,8 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
Created: openai.JSONUNIXTime(time.Unix(ReleaseDateUnix, 0)), // 0 nanoseconds
Object: "chat.completion.chunk",
Usage: &openai.Usage{
TotalTokens: 30,
PromptTokens: 10,
TotalTokens: 35,
PromptTokens: 15,
CompletionTokens: 20,
PromptTokensDetails: &openai.PromptTokensDetails{
CachedTokens: 5,
Expand Down Expand Up @@ -2016,7 +2016,8 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
require.False(t, ok)
} else {
// Use require.True and cmp.Equal with the options
require.True(t, cmp.Equal(*tc.out, *chunk, ignoreDynamicFields), "The ChatCompletionResponseChunk structs should be equal ignoring the ID and Created field")
require.True(t, cmp.Equal(*tc.out, *chunk, ignoreDynamicFields),
"The ChatCompletionResponseChunk structs should be equal ignoring the ID and Created field")
}
})
}
Expand Down
6 changes: 3 additions & 3 deletions internal/translator/openai_gcpanthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,11 @@ func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(_ map[stri
Created: openai.JSONUNIXTime(time.Now()),
}
usage := anthropicResp.Usage
tokenUsage = metrics.ExtractTokenUsageFromAnthropic(
tokenUsage = metrics.ExtractTokenUsageFromExplicitCaching(
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
&usage.CacheReadInputTokens,
&usage.CacheCreationInputTokens,
)
inputTokens, _ := tokenUsage.InputTokens()
outputTokens, _ := tokenUsage.OutputTokens()
Expand Down
Loading
Loading