Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/v1alpha1/ai_gateway_route.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ type AIGatewayRouteSpec struct {
// type: TotalToken
// - metadataKey: llm_cached_input_token
// type: CachedInputToken
// - metadataKey: llm_cache_creation_input_token
// type: CacheCreationInputToken
// ```
// Then, with the following BackendTrafficPolicy of Envoy Gateway, you can have three
// rate limit buckets for each unique x-user-id header value. One bucket is for the input token,
Expand Down
11 changes: 7 additions & 4 deletions api/v1alpha1/shared_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ type LLMRequestCost struct {
MetadataKey string `json:"metadataKey"`
// Type specifies the type of the request cost. The default is "OutputToken",
// and it uses "output token" as the cost. The other types are "InputToken", "TotalToken",
// and "CEL".
// "CachedInputToken", "CacheCreationInputToken", and "CEL".
//
// +kubebuilder:validation:Enum=OutputToken;InputToken;CachedInputToken;TotalToken;CEL
// +kubebuilder:validation:Enum=OutputToken;InputToken;CachedInputToken;CacheCreationInputToken;TotalToken;CEL
Type LLMRequestCostType `json:"type"`
// CEL is the CEL expression to calculate the cost of the request.
// The CEL expression must return a signed or unsigned integer. If the
Expand All @@ -113,15 +113,16 @@ type LLMRequestCost struct {
// * model: the model name extracted from the request content. Type: string.
// * backend: the backend name in the form of "name.namespace". Type: string.
// * input_tokens: the number of input tokens. Type: unsigned integer.
// * cached_input_tokens: the number of cached input tokens. Type: unsigned integer.
// * cached_input_tokens: the number of cached read input tokens. Type: unsigned integer.
// * cache_creation_input_tokens: the number of cache creation input tokens. Type: unsigned integer.
// * output_tokens: the number of output tokens. Type: unsigned integer.
// * total_tokens: the total number of tokens. Type: unsigned integer.
//
// For example, the following expressions are valid:
//
// * "model == 'llama' ? input_tokens + output_token * 0.5 : total_tokens"
// * "backend == 'foo.default' ? input_tokens + output_tokens : total_tokens"
// * "backend == 'bar.default' ? (input_tokens - cached_input_tokens) + cached_input_tokens * 0.1 + output_tokens : total_tokens"
// * "backend == 'bar.default' ? (input_tokens - cached_input_tokens) + cached_input_tokens * 0.1 + cache_creation_input_tokens * 1.25 + output_tokens : total_tokens"
// * "input_tokens + output_tokens + total_tokens"
// * "input_tokens * output_tokens"
//
Expand All @@ -137,6 +138,8 @@ const (
LLMRequestCostTypeInputToken LLMRequestCostType = "InputToken"
// LLMRequestCostTypeCachedInputToken is the cost type of the cached input token.
LLMRequestCostTypeCachedInputToken LLMRequestCostType = "CachedInputToken"
// LLMRequestCostTypeCacheCreationInputToken is the cost type of the cached input token.
LLMRequestCostTypeCacheCreationInputToken LLMRequestCostType = "CacheCreationInputToken"
// LLMRequestCostTypeOutputToken is the cost type of the output token.
LLMRequestCostTypeOutputToken LLMRequestCostType = "OutputToken"
// LLMRequestCostTypeTotalToken is the cost type of the total token.
Expand Down
2 changes: 2 additions & 0 deletions examples/token_ratelimit/token_ratelimit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ spec:
type: InputToken
- metadataKey: llm_cached_input_token
type: CachedInputToken
- metadataKey: llm_cache_creation_input_token
type: CacheCreationInputToken
- metadataKey: llm_output_token
type: OutputToken
- metadataKey: llm_total_token
Expand Down
8 changes: 8 additions & 0 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,8 @@ type PromptTokensDetails struct {
AudioTokens int `json:"audio_tokens,omitzero"`
// Cached tokens present in the prompt.
CachedTokens int `json:"cached_tokens,omitzero"`
// Tokens written to the cache.
CacheCreationTokens int `json:"cache_creation_input_tokens,omitzero"`
}

// ChatCompletionResponseChunk is described in the OpenAI API documentation:
Expand Down Expand Up @@ -2535,6 +2537,9 @@ type ResponseUsageInputTokensDetails struct {
// The number of tokens that were retrieved from the cache.
// [More on prompt caching](https://platform.openai.com/docs/guides/prompt-caching).
CachedTokens int64 `json:"cached_tokens"`

// The number of tokens that were written to the cache.
CacheCreationTokens int64 `json:"cache_creation_input_tokens"`
}

// A detailed breakdown of the output tokens.
Expand All @@ -2548,6 +2553,9 @@ type ResponseTokensDetails struct {
// CachedTokens: Number of cached tokens.
CachedTokens int `json:"cached_tokens,omitempty"` //nolint:tagliatelle //follow openai api

// CacheCreationTokens: number of tokens that were written to the cache.
CacheCreationTokens int64 `json:"cache_creation_input_tokens"` //nolint:tagliatelle

// ReasoningTokens: Number of reasoning tokens (for reasoning models).
ReasoningTokens int `json:"reasoning_tokens,omitempty"` //nolint:tagliatelle //follow openai api

Expand Down
38 changes: 23 additions & 15 deletions internal/apischema/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1742,26 +1742,30 @@ func TestPromptTokensDetails(t *testing.T) {
{
name: "with text tokens",
details: PromptTokensDetails{
TextTokens: 15,
AudioTokens: 8,
CachedTokens: 384,
TextTokens: 15,
AudioTokens: 8,
CachedTokens: 384,
CacheCreationTokens: 10,
},
expected: `{
"text_tokens": 15,
"audio_tokens": 8,
"cached_tokens": 384
"cached_tokens": 384,
"cache_creation_input_tokens": 10
}`,
},
{
name: "with zero text tokens omitted",
details: PromptTokensDetails{
TextTokens: 0,
AudioTokens: 8,
CachedTokens: 384,
TextTokens: 0,
AudioTokens: 8,
CachedTokens: 384,
CacheCreationTokens: 10,
},
expected: `{
"audio_tokens": 8,
"cached_tokens": 384
"cached_tokens": 384,
"cache_creation_input_tokens": 10
}`,
},
}
Expand Down Expand Up @@ -1818,8 +1822,9 @@ func TestChatCompletionResponseUsage(t *testing.T) {
RejectedPredictionTokens: 0,
},
PromptTokensDetails: &PromptTokensDetails{
AudioTokens: 8,
CachedTokens: 384,
AudioTokens: 8,
CachedTokens: 384,
CacheCreationTokens: 13,
},
},
expected: `{
Expand All @@ -1832,7 +1837,8 @@ func TestChatCompletionResponseUsage(t *testing.T) {
},
"prompt_tokens_details": {
"audio_tokens": 8,
"cached_tokens": 384
"cached_tokens": 384,
"cache_creation_input_tokens": 13
}
}`,
},
Expand All @@ -1850,9 +1856,10 @@ func TestChatCompletionResponseUsage(t *testing.T) {
RejectedPredictionTokens: 0,
},
PromptTokensDetails: &PromptTokensDetails{
TextTokens: 15,
AudioTokens: 8,
CachedTokens: 384,
TextTokens: 15,
AudioTokens: 8,
CachedTokens: 384,
CacheCreationTokens: 21,
},
},
expected: `{
Expand All @@ -1867,7 +1874,8 @@ func TestChatCompletionResponseUsage(t *testing.T) {
"prompt_tokens_details": {
"text_tokens": 15,
"audio_tokens": 8,
"cached_tokens": 384
"cached_tokens": 384,
"cache_creation_input_tokens": 21
}
}`,
},
Expand Down
2 changes: 2 additions & 0 deletions internal/controller/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ func (c *GatewayController) reconcileFilterConfigSecret(
fc.Type = filterapi.LLMRequestCostTypeInputToken
case aigv1a1.LLMRequestCostTypeCachedInputToken:
fc.Type = filterapi.LLMRequestCostTypeCachedInputToken
case aigv1a1.LLMRequestCostTypeCacheCreationInputToken:
fc.Type = filterapi.LLMRequestCostTypeCacheCreationInputToken
case aigv1a1.LLMRequestCostTypeOutputToken:
fc.Type = filterapi.LLMRequestCostTypeOutputToken
case aigv1a1.LLMRequestCostTypeTotalToken:
Expand Down
8 changes: 5 additions & 3 deletions internal/controller/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ func TestGatewayController_reconcileFilterConfigSecret(t *testing.T) {
{MetadataKey: "bar", Type: aigv1a1.LLMRequestCostTypeOutputToken},
{MetadataKey: "baz", Type: aigv1a1.LLMRequestCostTypeTotalToken},
{MetadataKey: "qux", Type: aigv1a1.LLMRequestCostTypeCachedInputToken},
{MetadataKey: "zoo", Type: aigv1a1.LLMRequestCostTypeCacheCreationInputToken},
},
},
},
Expand Down Expand Up @@ -274,13 +275,14 @@ func TestGatewayController_reconcileFilterConfigSecret(t *testing.T) {
var fc filterapi.Config
require.NoError(t, yaml.Unmarshal([]byte(configStr), &fc))
require.Equal(t, "dev", fc.Version)
require.Len(t, fc.LLMRequestCosts, 5)
require.Len(t, fc.LLMRequestCosts, 6)
require.Equal(t, filterapi.LLMRequestCostTypeInputToken, fc.LLMRequestCosts[0].Type)
require.Equal(t, filterapi.LLMRequestCostTypeOutputToken, fc.LLMRequestCosts[1].Type)
require.Equal(t, filterapi.LLMRequestCostTypeTotalToken, fc.LLMRequestCosts[2].Type)
require.Equal(t, filterapi.LLMRequestCostTypeCachedInputToken, fc.LLMRequestCosts[3].Type)
require.Equal(t, filterapi.LLMRequestCostTypeCEL, fc.LLMRequestCosts[4].Type)
require.Equal(t, `backend == 'foo.default' ? input_tokens + output_tokens : total_tokens`, fc.LLMRequestCosts[4].CEL)
require.Equal(t, filterapi.LLMRequestCostTypeCacheCreationInputToken, fc.LLMRequestCosts[4].Type)
require.Equal(t, filterapi.LLMRequestCostTypeCEL, fc.LLMRequestCosts[5].Type)
require.Equal(t, `backend == 'foo.default' ? input_tokens + output_tokens : total_tokens`, fc.LLMRequestCosts[5].CEL)
require.Len(t, fc.Models, 1)
require.Equal(t, "mymodel", fc.Models[0].Name)

Expand Down
27 changes: 16 additions & 11 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,17 @@ func (m *mockMetricsFactory) NewMetrics() metrics.Metrics {

// mockMetrics implements [metrics.Metrics] for testing.
type mockMetrics struct {
requestStart time.Time
originalModel string
requestModel string
responseModel string
backend string
requestSuccessCount int
requestErrorCount int
inputTokenCount int
cachedInputTokenCount int
outputTokenCount int
requestStart time.Time
originalModel string
requestModel string
responseModel string
backend string
requestSuccessCount int
requestErrorCount int
inputTokenCount int
cachedInputTokenCount int
cacheCreationInputTokenCount int
outputTokenCount int
// streamingOutputTokens tracks the cumulative output tokens recorded via RecordTokenLatency.
streamingOutputTokens int
timeToFirstToken float64
Expand Down Expand Up @@ -218,6 +219,9 @@ func (m *mockMetrics) RecordTokenUsage(_ context.Context, usage metrics.TokenUsa
if cachedInput, ok := usage.CachedInputTokens(); ok {
m.cachedInputTokenCount += int(cachedInput)
}
if cacheCreationInput, ok := usage.CacheCreationInputTokens(); ok {
m.cacheCreationInputTokenCount += int(cacheCreationInput)
}
if output, ok := usage.OutputTokens(); ok {
m.outputTokenCount += int(output)
}
Expand Down Expand Up @@ -278,9 +282,10 @@ func (m *mockMetrics) RequireRequestFailure(t *testing.T) {
require.Equal(t, 1, m.requestErrorCount)
}

func (m *mockMetrics) RequireTokensRecorded(t *testing.T, expectedInput, expectedCachedInput, expectedOutput int) {
func (m *mockMetrics) RequireTokensRecorded(t *testing.T, expectedInput, expectedCachedInput, expectedWriteCachedInput, expectedOutput int) {
require.Equal(t, expectedInput, m.inputTokenCount)
require.Equal(t, expectedCachedInput, m.cachedInputTokenCount)
require.Equal(t, expectedWriteCachedInput, m.cacheCreationInputTokenCount)
require.Equal(t, expectedOutput, m.outputTokenCount)
}

Expand Down
4 changes: 4 additions & 0 deletions internal/extproc/processor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,16 @@ func buildDynamicMetadata(config *filterapi.RuntimeConfig, costs *metrics.TokenU
cost, _ = costs.InputTokens()
case filterapi.LLMRequestCostTypeCachedInputToken:
cost, _ = costs.CachedInputTokens()
case filterapi.LLMRequestCostTypeCacheCreationInputToken:
cost, _ = costs.CacheCreationInputTokens()
case filterapi.LLMRequestCostTypeOutputToken:
cost, _ = costs.OutputTokens()
case filterapi.LLMRequestCostTypeTotalToken:
cost, _ = costs.TotalTokens()
case filterapi.LLMRequestCostTypeCEL:
in, _ := costs.InputTokens()
cachedIn, _ := costs.CachedInputTokens()
cacheCreation, _ := costs.CacheCreationInputTokens()
out, _ := costs.OutputTokens()
total, _ := costs.TotalTokens()
costU64, err := llmcostcel.EvaluateProgram(
Expand All @@ -548,6 +551,7 @@ func buildDynamicMetadata(config *filterapi.RuntimeConfig, costs *metrics.TokenU
backendName,
in,
cachedIn,
cacheCreation,
out,
total,
)
Expand Down
7 changes: 7 additions & 0 deletions internal/extproc/processor_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
mt.retUsedToken.SetOutputTokens(123)
mt.retUsedToken.SetInputTokens(1)
mt.retUsedToken.SetCachedInputTokens(1)
mt.retUsedToken.SetCacheCreationInputTokens(3)

celProgInt, err := llmcostcel.NewProgram("54321")
require.NoError(t, err)
Expand All @@ -274,6 +275,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}},
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCachedInputToken, MetadataKey: "cached_input_token_usage"}},
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCacheCreationInputToken, MetadataKey: "cache_creation_input_token_usage"}},
{
CELProg: celProgInt,
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"},
Expand Down Expand Up @@ -309,6 +311,8 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
GetStructValue().Fields["input_token_usage"].GetNumberValue())
require.Equal(t, float64(1), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
GetStructValue().Fields["cached_input_token_usage"].GetNumberValue())
require.Equal(t, float64(3), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
GetStructValue().Fields["cache_creation_input_token_usage"].GetNumberValue())
require.Equal(t, float64(54321), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
GetStructValue().Fields["cel_int"].GetNumberValue())
require.Equal(t, float64(9999), md.Fields[internalapi.AIGatewayFilterMetadataNamespace].
Expand Down Expand Up @@ -371,6 +375,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
mt.expResponseBody = final
mt.retUsedToken.SetInputTokens(5)
mt.retUsedToken.SetCachedInputTokens(3)
mt.retUsedToken.SetCacheCreationInputTokens(21)
mt.retUsedToken.SetOutputTokens(138)
mt.retUsedToken.SetTotalTokens(143)
_, err = p.ProcessResponseBody(t.Context(), final)
Expand All @@ -379,6 +384,8 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
require.Equal(t, 5, mm.inputTokenCount)
require.Equal(t, 138, mm.outputTokenCount)
require.Equal(t, 138, mm.streamingOutputTokens) // accumulated output tokens from stream
require.Equal(t, 3, mm.cachedInputTokenCount)
require.Equal(t, 21, mm.cacheCreationInputTokenCount)
})
}

Expand Down
4 changes: 3 additions & 1 deletion internal/filterapi/filterconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ const (
LLMRequestCostTypeOutputToken LLMRequestCostType = "OutputToken"
// LLMRequestCostTypeInputToken specifies that the request cost is calculated from the input token.
LLMRequestCostTypeInputToken LLMRequestCostType = "InputToken"
// LLMRequestCostTypeCachedInputToken specifies that the request cost is calculated from the cached input token.
// LLMRequestCostTypeCachedInputToken specifies that the request cost is calculated from the cached read input token.
LLMRequestCostTypeCachedInputToken LLMRequestCostType = "CachedInputToken"
// LLMRequestCostTypeCacheCreationInputToken specifies that the request cost is calculated from the cache creation input token.
LLMRequestCostTypeCacheCreationInputToken LLMRequestCostType = "CacheCreationInputToken"
// LLMRequestCostTypeTotalToken specifies that the request cost is calculated from the total token.
LLMRequestCostTypeTotalToken LLMRequestCostType = "TotalToken"
// LLMRequestCostTypeCEL specifies that the request cost is calculated from the CEL expression.
Expand Down
2 changes: 1 addition & 1 deletion internal/filterapi/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestServer_LoadConfig(t *testing.T) {
require.Equal(t, "1 + 1", rc.RequestCosts[1].CEL)
prog := rc.RequestCosts[1].CELProg
require.NotNil(t, prog)
val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1, 1)
val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1, 1, 1)
require.NoError(t, err)
require.Equal(t, uint64(2), val)
require.Equal(t, config.Models, rc.DeclaredModels)
Expand Down
Loading