diff --git a/core/bifrost.go b/core/bifrost.go index 21441b517d..ffc31d3feb 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1026,9 +1026,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document type not provided for ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1039,9 +1040,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document_url not provided for document_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1052,9 +1054,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "image_url not provided for image_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 0c020ba04e..0dd46a62b8 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -605,8 +605,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) { if err.ExtraFields.RequestType != schemas.OCRRequest { t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType) } - if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" { - t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested) + if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" { + t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested) } } diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 451f9883ca..cecc825887 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -2699,12 +2699,8 @@ func (provider *AzureProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) @@ -2743,7 +2739,7 @@ func (provider *AzureProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -2759,9 +2755,6 @@ func (provider *AzureProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2779,12 +2772,8 @@ func (provider *AzureProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) @@ -2833,9 +2822,9 @@ func (provider *AzureProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2846,18 +2835,13 @@ func (provider *AzureProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx)) stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -2867,9 +2851,9 @@ func (provider *AzureProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2918,7 +2902,7 @@ func (provider *AzureProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index 0d8e283f51..d83639e917 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -25,7 +25,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { resp.SetStatusCode(http.StatusBadRequest) resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`) - bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest") + bifrostErr := ParseMistralError(resp) require.NotNil(t, bifrostErr) require.NotNil(t, bifrostErr.Error) @@ -34,7 +34,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider) assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested) + assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.OriginalModelRequested) } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { @@ -156,5 +156,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) require.NotNil(t, response) assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider) - assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested) + assert.Equal(t, "codestral-embed", response.ExtraFields.OriginalModelRequested) } diff --git a/core/providers/mistral/errors.go b/core/providers/mistral/errors.go index 2ae260eb9a..cbbd40a560 100644 --- a/core/providers/mistral/errors.go +++ b/core/providers/mistral/errors.go @@ -19,7 +19,7 @@ type MistralErrorResponse struct { } // ParseMistralError parses Mistral-specific error responses. -func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp MistralErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if bifrostErr == nil { @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, } } - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 9befe35fee..bcf1359552 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -101,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseMistralError(resp) return nil, bifrostErr } @@ -264,20 +264,20 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke // Convert Bifrost request to Mistral format mistralReq := ToMistralOCRRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil) } // Marshal request body requestBody, err := sonic.Marshal(mistralReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge extra params into JSON payload if len(mistralReq.ExtraParams) > 0 { requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -314,12 +314,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -347,20 +347,19 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostOCRResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.RequestType = schemas.OCRRequest response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -424,7 +423,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) @@ -556,7 +555,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client diff --git a/core/providers/mistral/ocr_test.go b/core/providers/mistral/ocr_test.go index c88bca5f48..92978f2191 100644 --- a/core/providers/mistral/ocr_test.go +++ b/core/providers/mistral/ocr_test.go @@ -438,7 +438,7 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, 2, resp.UsageInfo.PagesProcessed) assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) - assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested) + assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.OriginalModelRequested) }, }, { @@ -505,7 +505,7 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, "internal_error", *err.Error.Code) assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider) assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested) + assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.OriginalModelRequested) }, }, { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index b5d1d242dc..bed848fb21 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -212,7 +212,7 @@ const ( BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) - BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming - accepts transport logs so it never reads ctx - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) @@ -233,7 +233,7 @@ const ( BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) - BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) + BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // *atomic.Value holding func() ([]PluginLogEntry, error) — set by streaming handler, populated by transport interceptor middleware BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index b8e1590da2..6a107e3902 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -80,7 +80,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -113,7 +114,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -149,7 +151,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -166,13 +169,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } diff --git a/framework/modelcatalog/config.go b/framework/modelcatalog/config.go index f7445186e7..235053bced 100644 --- a/framework/modelcatalog/config.go +++ b/framework/modelcatalog/config.go @@ -5,8 +5,8 @@ import ( ) const ( - DefaultPricingSyncInterval = 24 * time.Hour - MinimumPricingSyncIntervalSec = int64(3600) + DefaultSyncInterval = 24 * time.Hour + MinimumSyncIntervalSec = int64(3600) // syncWorkerTickerPeriod is the fixed interval at which the background sync worker // wakes up to check whether a sync is due. This is independent of pricingSyncInterval — diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 5d9577dca8..b42f50397b 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -9,20 +9,12 @@ import ( "sync" "time" - "github.com/bytedance/sonic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) -// Default sync interval and config key -const ( - TokenTierAbove272K = 272000 - TokenTierAbove200K = 200000 - TokenTierAbove128K = 128000 -) - type ModelCatalog struct { configStore configstore.ConfigStore distributedLockManager *configstore.DistributedLockManager @@ -70,145 +62,6 @@ type ModelCatalog struct { syncCancel context.CancelFunc } -// PricingEntry represents a single model's pricing information. -// Field names and JSON tags match the datasheet schema exactly. -type PricingEntry struct { - BaseModel string `json:"base_model,omitempty"` - Provider string `json:"provider"` - Mode string `json:"mode"` - - ContextLength *int `json:"context_length,omitempty"` - MaxInputTokens *int `json:"max_input_tokens,omitempty"` - MaxOutputTokens *int `json:"max_output_tokens,omitempty"` - Architecture *schemas.Architecture `json:"architecture,omitempty"` - - // Costs - Text - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` - InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority,omitempty"` - OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority,omitempty"` - InputCostPerTokenFlex *float64 `json:"input_cost_per_token_flex,omitempty"` - OutputCostPerTokenFlex *float64 `json:"output_cost_per_token_flex,omitempty"` - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - // Costs - 128k Tier - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - // Costs - 200k Tier - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - InputCostPerTokenAbove200kTokensPriority *float64 `json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` - OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` - OutputCostPerTokenAbove200kTokensPriority *float64 `json:"output_cost_per_token_above_200k_tokens_priority,omitempty"` - // Costs - 272k Tier - InputCostPerTokenAbove272kTokens *float64 `json:"input_cost_per_token_above_272k_tokens,omitempty"` - InputCostPerTokenAbove272kTokensPriority *float64 `json:"input_cost_per_token_above_272k_tokens_priority,omitempty"` - OutputCostPerTokenAbove272kTokens *float64 `json:"output_cost_per_token_above_272k_tokens,omitempty"` - OutputCostPerTokenAbove272kTokensPriority *float64 `json:"output_cost_per_token_above_272k_tokens_priority,omitempty"` - - // Costs - Cache - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokensPriority *float64 `json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"` - CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr,omitempty"` - CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"` - CacheCreationInputAudioTokenCost *float64 `json:"cache_creation_input_audio_token_cost,omitempty"` - CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority,omitempty"` - CacheReadInputTokenCostFlex *float64 `json:"cache_read_input_token_cost_flex,omitempty"` - CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` - CacheReadInputTokenCostAbove272kTokens *float64 `json:"cache_read_input_token_cost_above_272k_tokens,omitempty"` - CacheReadInputTokenCostAbove272kTokensPriority *float64 `json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"` - - // Costs - Image - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - InputCostPerPixel *float64 `json:"input_cost_per_pixel,omitempty"` - OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` - OutputCostPerPixel *float64 `json:"output_cost_per_pixel,omitempty"` - OutputCostPerImagePremiumImage *float64 `json:"output_cost_per_image_premium_image,omitempty"` - OutputCostPerImageAbove512x512Pixels *float64 `json:"output_cost_per_image_above_512_and_512_pixels,omitempty"` - OutputCostPerImageAbove512x512PixelsPremium *float64 `json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` - OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` - OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` - OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` - OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` - OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` - OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` - InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` - OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` - - // Costs - Audio/Video - InputCostPerAudioToken *float64 `json:"input_cost_per_audio_token,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - InputCostPerSecond *float64 `json:"input_cost_per_second,omitempty"` - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - OutputCostPerAudioToken *float64 `json:"output_cost_per_audio_token,omitempty"` - OutputCostPerVideoPerSecond *float64 `json:"output_cost_per_video_per_second,omitempty"` - OutputCostPerSecond *float64 `json:"output_cost_per_second,omitempty"` - - // Costs - Other - // - // SearchContextCostPerQuery is stored as a single float64, but the pricing datasheet - // represents it as a tiered object with three keys: search_context_size_low, - // search_context_size_medium, and search_context_size_high. For every provider except - // Perplexity the three tier values are identical, so we collapse the object to its - // medium tier value (falling back to low then high). Perplexity always returns a - // pre-computed total_cost in its usage response, so the per-query rate is never - // consumed for that provider; the collapsed value is therefore correct in all cases. - // See UnmarshalJSON below for the custom decoding logic. - SearchContextCostPerQuery *float64 `json:"search_context_cost_per_query,omitempty"` - CodeInterpreterCostPerSession *float64 `json:"code_interpreter_cost_per_session,omitempty"` -} - -// UnmarshalJSON implements json.Unmarshaler for PricingEntry. -// It handles the special case where search_context_cost_per_query may arrive as either -// a plain float64 or a tiered object {"search_context_size_low":…, -// "search_context_size_medium":…, "search_context_size_high":…}. -func (p *PricingEntry) UnmarshalJSON(data []byte) error { - // Type alias breaks the UnmarshalJSON recursion while keeping all other fields. - type PricingEntryAlias PricingEntry - var raw struct { - PricingEntryAlias - SearchContextCostPerQuery *struct { - Low *float64 `json:"search_context_size_low"` - Medium *float64 `json:"search_context_size_medium"` - High *float64 `json:"search_context_size_high"` - } `json:"search_context_cost_per_query,omitempty"` - } - if err := sonic.Unmarshal(data, &raw); err != nil { - return err - } - *p = PricingEntry(raw.PricingEntryAlias) - - // search_context_cost_per_query arrives as a tiered object – all three values are - // equal for non-Perplexity providers; we prefer medium, then low, then high. - // Perplexity always returns a pre-computed total_cost so the per-query rate is - // never consumed for that provider. - if q := raw.SearchContextCostPerQuery; q != nil { - switch { - case q.Medium != nil: - p.SearchContextCostPerQuery = q.Medium - case q.Low != nil: - p.SearchContextCostPerQuery = q.Low - case q.High != nil: - p.SearchContextCostPerQuery = q.High - } - } - return nil -} - -// ShouldSyncPricingFunc is a function that determines if pricing data should be synced -// It returns a boolean indicating if syncing is needed -// It is completely optional and can be nil if not needed -// syncPricing function will be called if this function returns true -type ShouldSyncPricingFunc func(ctx context.Context) bool - // Init initializes the model catalog func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) { // Initialize pricing URL and sync interval @@ -218,12 +71,12 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto } syncInterval := DefaultSyncInterval if config.PricingSyncInterval != nil { - pricingSyncInterval = time.Duration(*config.PricingSyncInterval) * time.Second + syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second } // Log the active interval and the scheduler's actual check frequency so operators // are not surprised that setting interval=1h does not mean checks happen every second. // Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval. - logger.Info("pricing sync interval set to %v (scheduler checks every %v)", pricingSyncInterval, syncWorkerTickerPeriod) + logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod) mc := &ModelCatalog{ pricingURL: pricingURL, @@ -413,7 +266,7 @@ func (mc *ModelCatalog) UpdateSyncConfig(ctx context.Context, config *Config) er mc.syncInterval = DefaultSyncInterval if config.PricingSyncInterval != nil { - mc.pricingSyncInterval = time.Duration(*config.PricingSyncInterval) * time.Second + mc.syncInterval = time.Duration(*config.PricingSyncInterval) * time.Second } // Create new sync worker with updated configuration diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index ccb3813e1c..def8b047e9 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -4,10 +4,155 @@ import ( "strconv" "strings" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) +// Default sync interval and config key +const ( + TokenTierAbove272K = 272000 + TokenTierAbove200K = 200000 + TokenTierAbove128K = 128000 +) + +// PricingEntry represents a single model's pricing information. +// Field names and JSON tags match the datasheet schema exactly. +type PricingEntry struct { + BaseModel string `json:"base_model,omitempty"` + Provider string `json:"provider"` + Mode string `json:"mode"` + + ContextLength *int `json:"context_length,omitempty"` + MaxInputTokens *int `json:"max_input_tokens,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Architecture *schemas.Architecture `json:"architecture,omitempty"` + + PricingOptions +} + +// UnmarshalJSON implements json.Unmarshaler for PricingEntry. +// It handles the special case where search_context_cost_per_query may arrive as either +// a plain float64 or a tiered object {"search_context_size_low":…, +// "search_context_size_medium":…, "search_context_size_high":…}. +func (p *PricingEntry) UnmarshalJSON(data []byte) error { + // Type alias breaks the UnmarshalJSON recursion while keeping all other fields. + type PricingEntryAlias PricingEntry + var raw struct { + PricingEntryAlias + SearchContextCostPerQuery *struct { + Low *float64 `json:"search_context_size_low"` + Medium *float64 `json:"search_context_size_medium"` + High *float64 `json:"search_context_size_high"` + } `json:"search_context_cost_per_query,omitempty"` + } + if err := sonic.Unmarshal(data, &raw); err != nil { + return err + } + *p = PricingEntry(raw.PricingEntryAlias) + + // search_context_cost_per_query arrives as a tiered object – all three values are + // equal for non-Perplexity providers; we prefer medium, then low, then high. + // Perplexity always returns a pre-computed total_cost so the per-query rate is + // never consumed for that provider. + if q := raw.SearchContextCostPerQuery; q != nil { + switch { + case q.Medium != nil: + p.SearchContextCostPerQuery = q.Medium + case q.Low != nil: + p.SearchContextCostPerQuery = q.Low + case q.High != nil: + p.SearchContextCostPerQuery = q.High + } + } + return nil +} + +type PricingOptions struct { + // Costs - Text + InputCostPerToken *float64 `json:"input_cost_per_token"` + OutputCostPerToken *float64 `json:"output_cost_per_token"` + InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` + OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` + InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority,omitempty"` + OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority,omitempty"` + InputCostPerTokenFlex *float64 `json:"input_cost_per_token_flex,omitempty"` + OutputCostPerTokenFlex *float64 `json:"output_cost_per_token_flex,omitempty"` + InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` + // Costs - 128k Tier + InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` + InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` + InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` + InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` + OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` + // Costs - 200k Tier + InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` + InputCostPerTokenAbove200kTokensPriority *float64 `json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` + OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` + OutputCostPerTokenAbove200kTokensPriority *float64 `json:"output_cost_per_token_above_200k_tokens_priority,omitempty"` + // Costs - 272k Tier + InputCostPerTokenAbove272kTokens *float64 `json:"input_cost_per_token_above_272k_tokens,omitempty"` + InputCostPerTokenAbove272kTokensPriority *float64 `json:"input_cost_per_token_above_272k_tokens_priority,omitempty"` + OutputCostPerTokenAbove272kTokens *float64 `json:"output_cost_per_token_above_272k_tokens,omitempty"` + OutputCostPerTokenAbove272kTokensPriority *float64 `json:"output_cost_per_token_above_272k_tokens_priority,omitempty"` + + // Costs - Cache + CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` + CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` + CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` + CacheReadInputTokenCostAbove200kTokensPriority *float64 `json:"cache_read_input_token_cost_above_200k_tokens_priority,omitempty"` + CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr,omitempty"` + CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"` + CacheCreationInputAudioTokenCost *float64 `json:"cache_creation_input_audio_token_cost,omitempty"` + CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority,omitempty"` + CacheReadInputTokenCostFlex *float64 `json:"cache_read_input_token_cost_flex,omitempty"` + CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` + CacheReadInputTokenCostAbove272kTokens *float64 `json:"cache_read_input_token_cost_above_272k_tokens,omitempty"` + CacheReadInputTokenCostAbove272kTokensPriority *float64 `json:"cache_read_input_token_cost_above_272k_tokens_priority,omitempty"` + + // Costs - Image + InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` + InputCostPerPixel *float64 `json:"input_cost_per_pixel,omitempty"` + OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` + OutputCostPerPixel *float64 `json:"output_cost_per_pixel,omitempty"` + OutputCostPerImagePremiumImage *float64 `json:"output_cost_per_image_premium_image,omitempty"` + OutputCostPerImageAbove512x512Pixels *float64 `json:"output_cost_per_image_above_512_and_512_pixels,omitempty"` + OutputCostPerImageAbove512x512PixelsPremium *float64 `json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"` + OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` + OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` + OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` + OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` + OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` + OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` + OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` + OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` + InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` + OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` + + // Costs - Audio/Video + InputCostPerAudioToken *float64 `json:"input_cost_per_audio_token,omitempty"` + InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` + InputCostPerSecond *float64 `json:"input_cost_per_second,omitempty"` + InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` + OutputCostPerAudioToken *float64 `json:"output_cost_per_audio_token,omitempty"` + OutputCostPerVideoPerSecond *float64 `json:"output_cost_per_video_per_second,omitempty"` + OutputCostPerSecond *float64 `json:"output_cost_per_second,omitempty"` + + // Costs - Other + // + // SearchContextCostPerQuery is stored as a single float64, but the pricing datasheet + // represents it as a tiered object with three keys: search_context_size_low, + // search_context_size_medium, and search_context_size_high. For every provider except + // Perplexity the three tier values are identical, so we collapse the object to its + // medium tier value (falling back to low then high). Perplexity always returns a + // pre-computed total_cost in its usage response, so the per-query rate is never + // consumed for that provider; the collapsed value is therefore correct in all cases. + // See UnmarshalJSON below for the custom decoding logic. + SearchContextCostPerQuery *float64 `json:"search_context_cost_per_query,omitempty"` + CodeInterpreterCostPerSession *float64 `json:"code_interpreter_cost_per_session,omitempty"` +} + // serviceTier captures the OpenAI service_tier value from a response. // Add new tier flags here as OpenAI introduces them. type serviceTier struct { @@ -681,7 +826,10 @@ func tieredInputRate(pricing *configstoreTables.TableModelPricing, totalTokens i if tier.isPriority && pricing.InputCostPerTokenPriority != nil { return *pricing.InputCostPerTokenPriority } - return pricing.InputCostPerToken + if pricing.InputCostPerToken != nil { + return *pricing.InputCostPerToken + } + return 0 } // tieredOutputRate returns the effective per-token output rate based on total token count. @@ -712,7 +860,10 @@ func tieredOutputRate(pricing *configstoreTables.TableModelPricing, totalTokens if tier.isPriority && pricing.OutputCostPerTokenPriority != nil { return *pricing.OutputCostPerTokenPriority } - return pricing.OutputCostPerToken + if pricing.OutputCostPerToken != nil { + return *pricing.OutputCostPerToken + } + return 0 } // tieredImageInputRate returns the effective rate for image tokens on the input side. diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index 3694506eea..ebc9d57379 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -190,10 +190,10 @@ func TestComputeTextCost_Below200kUsesBaseRate(t *testing.T) { func TestComputeTextCost_Tiered272k(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove200kTokens = new(0.000006) + p.OutputCostPerTokenAbove200kTokens = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -210,10 +210,10 @@ func TestComputeTextCost_Tiered272k(t *testing.T) { func TestComputeTextCost_Between200kAnd272kUses200kRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove200kTokens = new(0.000006) + p.OutputCostPerTokenAbove200kTokens = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 200000, @@ -230,10 +230,10 @@ func TestComputeTextCost_Between200kAnd272kUses200kRate(t *testing.T) { func TestComputeTextCost_272kTierWithCacheRead(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -315,7 +315,7 @@ func TestComputeEmbeddingCost_Basic(t *testing.T) { } func TestComputeEmbeddingCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{InputCostPerToken: 0.0000001} + p := configstoreTables.TableModelPricing{InputCostPerToken: new(0.0000001)} assert.Equal(t, 0.0, computeEmbeddingCost(&p, nil, serviceTier{})) } @@ -355,7 +355,7 @@ func TestComputeRerankCost_WithSearchCost(t *testing.T) { } func TestComputeRerankCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{InputCostPerToken: 0.001} + p := configstoreTables.TableModelPricing{InputCostPerToken: new(0.001)} assert.Equal(t, 0.0, computeRerankCost(&p, nil, serviceTier{})) } @@ -600,7 +600,7 @@ func TestComputeImageCost_TokenBasedWithDetails(t *testing.T) { } func TestComputeImageCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{OutputCostPerImage: ptr(0.05)} + p := configstoreTables.TableModelPricing{OutputCostPerImage: new(0.05)} assert.Equal(t, 0.0, computeImageCost(&p, nil, "", "", serviceTier{})) } @@ -1429,15 +1429,15 @@ func TestCalculateCost_272kTier_EndToEnd(t *testing.T) { Model: "claude-3-7-sonnet", Provider: "anthropic", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - InputCostPerTokenAbove200kTokens: ptr(0.000006), - OutputCostPerTokenAbove200kTokens: ptr(0.00003), - InputCostPerTokenAbove272kTokens: ptr(0.000009), - OutputCostPerTokenAbove272kTokens: ptr(0.000045), - CacheReadInputTokenCost: ptr(0.0000003), - CacheReadInputTokenCostAbove200kTokens: ptr(0.0000006), - CacheReadInputTokenCostAbove272kTokens: ptr(0.0000009), + InputCostPerToken: new(0.000003), + OutputCostPerToken: new(0.000015), + InputCostPerTokenAbove200kTokens: new(0.000006), + OutputCostPerTokenAbove200kTokens: new(0.00003), + InputCostPerTokenAbove272kTokens: new(0.000009), + OutputCostPerTokenAbove272kTokens: new(0.000045), + CacheReadInputTokenCost: new(0.0000003), + CacheReadInputTokenCostAbove200kTokens: new(0.0000006), + CacheReadInputTokenCostAbove272kTokens: new(0.0000009), }, }) @@ -1447,7 +1447,7 @@ func TestCalculateCost_272kTier_EndToEnd(t *testing.T) { TotalTokens: 280000, // Above 272k }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Tiered rate: input=0.000009, output=0.000045 // 250000*0.000009 + 30000*0.000045 = 2.25 + 1.35 = 3.60 assert.InDelta(t, 3.60, cost, 1e-9) @@ -1460,13 +1460,13 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { Model: "claude-3-7-sonnet", Provider: "anthropic", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - InputCostPerTokenAbove272kTokens: ptr(0.000009), - OutputCostPerTokenAbove272kTokens: ptr(0.000045), - CacheReadInputTokenCost: ptr(0.0000003), - CacheReadInputTokenCostAbove200kTokens: ptr(0.0000006), - CacheReadInputTokenCostAbove272kTokens: ptr(0.0000009), + InputCostPerToken: new(0.000003), + OutputCostPerToken: new(0.000015), + InputCostPerTokenAbove272kTokens: new(0.000009), + OutputCostPerTokenAbove272kTokens: new(0.000045), + CacheReadInputTokenCost: new(0.0000003), + CacheReadInputTokenCostAbove200kTokens: new(0.0000006), + CacheReadInputTokenCostAbove272kTokens: new(0.0000009), }, }) @@ -1479,7 +1479,7 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { }, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Non-cached input: (250000-50000) * 0.000009 = 200000 * 0.000009 = 1.80 // Cached read (272k rate): 50000 * 0.0000009 = 0.045 // Output: 30000 * 0.000045 = 1.35 @@ -1493,8 +1493,8 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { func TestComputeTextCost_PriorityUsesInputOutputPriorityRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1510,8 +1510,8 @@ func TestComputeTextCost_PriorityUsesInputOutputPriorityRate(t *testing.T) { func TestComputeTextCost_NonPriorityIgnoresPriorityRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1527,12 +1527,12 @@ func TestComputeTextCost_NonPriorityIgnoresPriorityRate(t *testing.T) { func TestComputeTextCost_Priority272kTier(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.InputCostPerTokenAbove272kTokensPriority = ptr(0.000012) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) - p.OutputCostPerTokenAbove272kTokensPriority = ptr(0.00006) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.InputCostPerTokenAbove272kTokensPriority = new(0.000012) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) + p.OutputCostPerTokenAbove272kTokensPriority = new(0.00006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1549,8 +1549,8 @@ func TestComputeTextCost_Priority272kTier(t *testing.T) { func TestComputeTextCost_Priority272kTierFallsBackToNonPriority272k(t *testing.T) { // Priority flag set but no priority-specific 272k rate — fall back to non-priority 272k p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1566,10 +1566,10 @@ func TestComputeTextCost_Priority272kTierFallsBackToNonPriority272k(t *testing.T func TestComputeTextCost_PriorityCacheReadRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1596,10 +1596,10 @@ func TestCalculateCost_PriorityTier_EndToEnd(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenPriority: ptr(0.000010), - OutputCostPerTokenPriority: ptr(0.000030), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenPriority: new(0.000010), + OutputCostPerTokenPriority: new(0.000030), }, }) @@ -1612,14 +1612,14 @@ func TestCalculateCost_PriorityTier_EndToEnd(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Priority rates: 1000*0.000010 + 500*0.000030 = 0.010 + 0.015 = 0.025 assert.InDelta(t, 0.025, cost, 1e-12) } @@ -1631,10 +1631,10 @@ func TestCalculateCost_NonPriorityServiceTier_UsesBaseRate(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenPriority: ptr(0.000010), - OutputCostPerTokenPriority: ptr(0.000030), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenPriority: new(0.000010), + OutputCostPerTokenPriority: new(0.000030), }, }) @@ -1647,14 +1647,14 @@ func TestCalculateCost_NonPriorityServiceTier_UsesBaseRate(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Base rates (not priority): 1000*0.000005 + 500*0.000015 = 0.005 + 0.0075 = 0.0125 assert.InDelta(t, 0.0125, cost, 1e-12) } @@ -1663,23 +1663,23 @@ func TestTieredCacheReadRate_FallbackOrder(t *testing.T) { // 272k rate takes precedence over 200k, 200k over base, base over input rate t.Run("uses_272k_when_above_272k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000009, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{})) }) t.Run("uses_200k_when_between_200k_and_272k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000006, tieredCacheReadInputTokenRate(&p, 230000, serviceTier{})) }) t.Run("uses_base_cache_rate_when_below_200k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000003, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{})) }) t.Run("falls_back_to_input_rate_when_no_cache_rate_set", func(t *testing.T) { @@ -1689,49 +1689,49 @@ func TestTieredCacheReadRate_FallbackOrder(t *testing.T) { }) t.Run("priority_uses_272k_priority_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) - p.CacheReadInputTokenCostAbove272kTokensPriority = ptr(0.0000012) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) + p.CacheReadInputTokenCostAbove272kTokensPriority = new(0.0000012) assert.Equal(t, 0.0000012, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isPriority: true})) }) t.Run("priority_falls_back_to_272k_non_priority_when_priority_rate_missing", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000009, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isPriority: true})) }) t.Run("priority_uses_priority_base_cache_rate_below_tiers", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) assert.Equal(t, 0.0000006, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isPriority: true})) }) t.Run("flex_uses_flex_cache_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000005) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000005) assert.Equal(t, 0.0000005, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isFlex: true})) }) t.Run("flex_uses_flex_cache_rate_regardless_of_token_count", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000005) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000005) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) // Even above 272k, flex flat rate takes precedence assert.Equal(t, 0.0000005, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isFlex: true})) }) t.Run("flex_falls_back_to_base_cache_rate_when_no_flex_cache_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) + p.CacheReadInputTokenCost = new(0.0000003) // No flex cache rate — falls back to base cache rate assert.Equal(t, 0.0000003, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isFlex: true})) }) t.Run("flex_wins_over_272k_priority_and_priority_base_when_all_present", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCostAbove272kTokens = ptr(5e-7) - p.CacheReadInputTokenCostFlex = ptr(1.3e-7) - p.CacheReadInputTokenCostPriority = ptr(5e-7) - p.CacheReadInputTokenCostAbove272kTokensPriority = ptr(0.000001) + p.CacheReadInputTokenCostAbove272kTokens = new(5e-7) + p.CacheReadInputTokenCostFlex = new(1.3e-7) + p.CacheReadInputTokenCostPriority = new(5e-7) + p.CacheReadInputTokenCostAbove272kTokensPriority = new(0.000001) // token count exceeds 272k — but flex flat rate should still win assert.Equal(t, 1.3e-7, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isFlex: true})) }) @@ -1775,8 +1775,8 @@ func TestTierFromString_Nil(t *testing.T) { func TestComputeTextCost_FlexUsesFlexRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1792,8 +1792,8 @@ func TestComputeTextCost_FlexUsesFlexRate(t *testing.T) { func TestComputeTextCost_NonFlexIgnoresFlexRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1810,10 +1810,10 @@ func TestComputeTextCost_NonFlexIgnoresFlexRate(t *testing.T) { func TestComputeTextCost_FlexIgnoresTokenTiers(t *testing.T) { // Flex is a flat rate — token-count tiers (272k, 200k, 128k) do not apply. p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1829,10 +1829,10 @@ func TestComputeTextCost_FlexIgnoresTokenTiers(t *testing.T) { func TestComputeTextCost_FlexCacheReadRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000006) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1875,10 +1875,10 @@ func TestCalculateCost_FlexTier_EndToEnd(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenFlex: ptr(0.0000025), - OutputCostPerTokenFlex: ptr(0.0000075), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenFlex: new(0.0000025), + OutputCostPerTokenFlex: new(0.0000075), }, }) @@ -1891,14 +1891,14 @@ func TestCalculateCost_FlexTier_EndToEnd(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Flex rates: 1000*0.0000025 + 500*0.0000075 = 0.0025 + 0.00375 = 0.00625 assert.InDelta(t, 0.00625, cost, 1e-12) } @@ -1918,14 +1918,14 @@ func TestCalculateCost_FlexTier_FallsBackToBaseWhenNoFlexRate(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // No flex rates configured — falls back to base: 1000*0.000005 + 500*0.000015 = 0.005 + 0.0075 = 0.0125 assert.InDelta(t, 0.0125, cost, 1e-12) } @@ -1952,7 +1952,7 @@ func TestCalculateCost_ProviderCostZeroTotalStillCalculates(t *testing.T) { func TestCalculateCost_AllCachedTokens(t *testing.T) { // All prompt tokens are from cache p := chatPricing(0.000005, 0.000015) - p.CacheReadInputTokenCost = bifrost.Ptr(0.0000005) + p.CacheReadInputTokenCost = new(0.0000005) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1978,8 +1978,8 @@ func TestCalculateCost_ImageGeneration_NilUsage_PerImagePricing(t *testing.T) { Model: "dall-e-3", Provider: "openai", Mode: "image_generation", - InputCostPerToken: bifrost.Ptr(0.0), - OutputCostPerImage: bifrost.Ptr(0.04), + InputCostPerToken: new(0.0), + OutputCostPerImage: new(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1998,8 +1998,8 @@ func TestCalculateCost_ImageGeneration_NilUsage_InputAndOutputPerImage(t *testin Model: "test-image-model", Provider: "test", Mode: "image_generation", - InputCostPerImage: bifrost.Ptr(0.01), - OutputCostPerImage: bifrost.Ptr(0.04), + InputCostPerImage: new(0.01), + OutputCostPerImage: new(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -2019,8 +2019,8 @@ func TestCalculateCost_ImageGeneration_WithInputImages(t *testing.T) { Model: "gpt-image-1", Provider: "openai", Mode: "image_generation", - InputCostPerImage: bifrost.Ptr(0.01), - OutputCostPerImage: bifrost.Ptr(0.04), + InputCostPerImage: new(0.01), + OutputCostPerImage: new(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index cf18c5b919..8f539bdfd9 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -152,16 +152,16 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) Architecture: entry.Architecture, // Costs - Text - InputCostPerToken: entry.InputCostPerToken, - OutputCostPerToken: entry.OutputCostPerToken, - InputCostPerTokenBatches: entry.InputCostPerTokenBatches, - OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches, - InputCostPerTokenPriority: entry.InputCostPerTokenPriority, - OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority, - InputCostPerTokenFlex: entry.InputCostPerTokenFlex, - OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex, - InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens, - InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority, + InputCostPerToken: entry.InputCostPerToken, + OutputCostPerToken: entry.OutputCostPerToken, + InputCostPerTokenBatches: entry.InputCostPerTokenBatches, + OutputCostPerTokenBatches: entry.OutputCostPerTokenBatches, + InputCostPerTokenPriority: entry.InputCostPerTokenPriority, + OutputCostPerTokenPriority: entry.OutputCostPerTokenPriority, + InputCostPerTokenFlex: entry.InputCostPerTokenFlex, + OutputCostPerTokenFlex: entry.OutputCostPerTokenFlex, + InputCostPerTokenAbove200kTokens: entry.InputCostPerTokenAbove200kTokens, + InputCostPerTokenAbove200kTokensPriority: entry.InputCostPerTokenAbove200kTokensPriority, OutputCostPerTokenAbove200kTokens: entry.OutputCostPerTokenAbove200kTokens, OutputCostPerTokenAbove200kTokensPriority: entry.OutputCostPerTokenAbove200kTokensPriority, // Costs - 272k Tier @@ -232,16 +232,16 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry { options := PricingOptions{ // Costs - Text - InputCostPerToken: pricing.InputCostPerToken, - OutputCostPerToken: pricing.OutputCostPerToken, - InputCostPerTokenBatches: pricing.InputCostPerTokenBatches, - OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches, - InputCostPerTokenPriority: pricing.InputCostPerTokenPriority, - OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority, - InputCostPerTokenFlex: pricing.InputCostPerTokenFlex, - OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex, - InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens, - InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority, + InputCostPerToken: pricing.InputCostPerToken, + OutputCostPerToken: pricing.OutputCostPerToken, + InputCostPerTokenBatches: pricing.InputCostPerTokenBatches, + OutputCostPerTokenBatches: pricing.OutputCostPerTokenBatches, + InputCostPerTokenPriority: pricing.InputCostPerTokenPriority, + OutputCostPerTokenPriority: pricing.OutputCostPerTokenPriority, + InputCostPerTokenFlex: pricing.InputCostPerTokenFlex, + OutputCostPerTokenFlex: pricing.OutputCostPerTokenFlex, + InputCostPerTokenAbove200kTokens: pricing.InputCostPerTokenAbove200kTokens, + InputCostPerTokenAbove200kTokensPriority: pricing.InputCostPerTokenAbove200kTokensPriority, OutputCostPerTokenAbove200kTokens: pricing.OutputCostPerTokenAbove200kTokens, OutputCostPerTokenAbove200kTokensPriority: pricing.OutputCostPerTokenAbove200kTokensPriority, // Costs - 272k Tier diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index 0916c908d3..fcf1ab8f49 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -905,6 +905,7 @@ func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string redirectURI = templateConfig.RedirectURI } tokenResponse, err := p.exchangeCodeForTokensWithPKCE( + ctx, templateConfig.TokenURL, code, templateConfig.ClientID, @@ -1055,6 +1056,7 @@ func (p *OAuth2Provider) RefreshUserAccessToken(ctx context.Context, sessionToke // Exchange refresh token newTokenResponse, err := p.exchangeRefreshToken( + ctx, templateConfig.TokenURL, templateConfig.ClientID, templateConfig.ClientSecret, diff --git a/plugins/prompts/go.mod b/plugins/prompts/go.mod index 73eaf7ad3c..5ceaa9ab29 100644 --- a/plugins/prompts/go.mod +++ b/plugins/prompts/go.mod @@ -45,6 +45,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jackc/pgx/v5 v5.9.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/plugins/prompts/go.sum b/plugins/prompts/go.sum index d7829db75b..2339673b68 100644 --- a/plugins/prompts/go.sum +++ b/plugins/prompts/go.sum @@ -90,8 +90,7 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index b7c7f1a371..da2d27594c 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -472,7 +472,7 @@ func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 719967a1a9..384e0e3656 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -494,7 +494,7 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { if frameworkConfig.PricingSyncInterval != nil { syncSeconds = *frameworkConfig.PricingSyncInterval } else { - syncSeconds = int64(modelcatalog.DefaultPricingSyncInterval.Seconds()) + syncSeconds = int64(modelcatalog.DefaultSyncInterval.Seconds()) } h.store.FrameworkConfig = &framework.FrameworkConfig{ Pricing: &modelcatalog.Config{ diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index cec3b6fbbd..08a507134b 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "github.com/bytedance/sonic" "github.com/fasthttp/router" @@ -156,20 +157,20 @@ var rerankParamsKnownFields = map[string]bool{ } var ocrParamsKnownFields = map[string]bool{ - "model": true, - "id": true, - "document": true, - "fallbacks": true, - "include_image_base64": true, - "pages": true, - "image_limit": true, - "image_min_size": true, - "table_format": true, - "extract_header": true, - "extract_footer": true, - "bbox_annotation_format": true, - "document_annotation_format": true, - "document_annotation_prompt": true, + "model": true, + "id": true, + "document": true, + "fallbacks": true, + "include_image_base64": true, + "pages": true, + "image_limit": true, + "image_min_size": true, + "table_format": true, + "extract_header": true, + "extract_footer": true, + "bbox_annotation_format": true, + "document_annotation_format": true, + "document_annotation_prompt": true, } var speechParamsKnownFields = map[string]bool{ @@ -459,7 +460,7 @@ type RerankRequest struct { // OCRHandlerRequest is a bifrost OCR request type OCRHandlerRequest struct { - ID *string `json:"id,omitempty"` + ID *string `json:"id,omitempty"` Document schemas.OCRDocument `json:"document"` BifrostParams *schemas.OCRParameters @@ -1317,7 +1318,7 @@ func (h *CompletionHandler) ocr(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1667,8 +1668,16 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi // The streaming callback will complete the trace after the stream ends ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) - // Get the trace completer function for use in the streaming callback - traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Capture trace completer BEFORE goroutine — ctx may be recycled inside goroutine. + // Signature: func(transportLogs []schemas.PluginLogEntry) + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func([]schemas.PluginLogEntry)) + + // Create atomic slot for the transport post-hook completer. + // TransportInterceptorMiddleware will populate this after next(ctx) returns + // with a closure that uses pre-captured data (no ctx access). + // The goroutine reads from its closure-captured copy of the slot. + var completerSlot atomic.Value + ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, &completerSlot) // Get stream chunk interceptor for plugin hooks interceptor := h.config.GetStreamChunkInterceptor() @@ -1686,12 +1695,15 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi go func() { defer func() { schemas.ReleaseHTTPRequest(httpReq) - // Retrieve and run transport post-hook completer before closing the stream - // so errors can still be communicated to the client as SSE events. - // Must retrieve here (not before goroutine) because TransportInterceptorMiddleware - // sets BifrostContextKeyTransportPostHookCompleter after next(ctx) returns. - if postHookCompleter, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(func() error); ok && postHookCompleter != nil { - if err := postHookCompleter(); err != nil { + // Run transport post-hooks using the pre-captured completer. + // The completerSlot was populated by TransportInterceptorMiddleware + // (which runs after handleStreamingResponse returns from next(ctx)). + // The closure does NOT access ctx — it uses pre-captured request/response data. + var transportLogs []schemas.PluginLogEntry + if fn, ok := completerSlot.Load().(func() ([]schemas.PluginLogEntry, error)); ok && fn != nil { + logs, err := fn() + transportLogs = logs + if err != nil { errorJSON, marshalErr := sonic.Marshal(map[string]string{"error": err.Error()}) if marshalErr == nil { reader.SendError(errorJSON) @@ -1699,10 +1711,11 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi } } reader.Done() - // Complete the trace after streaming finishes - // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + // Complete the trace after streaming finishes, passing transport logs directly. + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL. + // Safe to call after reader.Done() because traceCompleter no longer accesses ctx. if traceCompleter != nil { - traceCompleter() + traceCompleter(transportLogs) } }() diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 830d1135f8..14418b2658 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -357,10 +357,37 @@ func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddl // The streaming handler calls this BEFORE reader.Done() so that errors can // still be sent as SSE events. applyResponse=false because the response is // already on the wire and mutating ctx.Response would corrupt the chunked stream. + // + // IMPORTANT: The callback must NOT access ctx — fasthttp recycles RequestCtx + // after the response body stream completes. All needed data is eagerly captured + // here (while ctx is still valid) and passed through the closure. if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { - ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, func() error { - return runTransportPostHooks(ctx, plugins, bifrostCtx, false) - }) + // Eagerly snapshot request/response from ctx before it can be recycled. + capturedReq := lib.BuildHTTPRequestFromFastHTTP(ctx) + capturedResp := lib.BuildHTTPResponseFromFastHTTP(ctx) + // Snapshot pre-hook transport plugin logs already accumulated on ctx. + var preHookLogs []schemas.PluginLogEntry + if logs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { + preHookLogs = logs + } + + completer := func() ([]schemas.PluginLogEntry, error) { + defer schemas.ReleaseHTTPRequest(capturedReq) + defer schemas.ReleaseHTTPResponse(capturedResp) + postHookLogs, err := runTransportPostHooksCaptured(capturedReq, capturedResp, plugins, bifrostCtx) + allLogs := preHookLogs + if len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, err + } + + // Store the completer in the atomic.Value slot that the streaming handler + // placed on ctx. The goroutine reads from its closure-captured copy of + // the slot, avoiding any ctx access after the handler returns. + if slot, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(*atomic.Value); ok { + slot.Store(completer) + } return } @@ -427,6 +454,58 @@ func runTransportPostHooks(ctx *fasthttp.RequestCtx, plugins []schemas.HTTPTrans return nil } +// runTransportPostHooksCaptured is the goroutine-safe variant of runTransportPostHooks. +// It uses pre-captured HTTPRequest and HTTPResponse snapshots instead of reading from +// a fasthttp RequestCtx, which may have been recycled by the time this runs in a +// streaming goroutine. Returns accumulated plugin logs (instead of writing them to +// ctx.UserValue) so the caller can forward them to the trace completer. +func runTransportPostHooksCaptured(capturedReq *schemas.HTTPRequest, capturedResp *schemas.HTTPResponse, plugins []schemas.HTTPTransportPlugin, bifrostCtx *schemas.BifrostContext) ([]schemas.PluginLogEntry, error) { + // Clone into fresh pooled objects so plugins can mutate without affecting the snapshots. + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = capturedReq.Method + req.Path = capturedReq.Path + for k, v := range capturedReq.Headers { + req.Headers[k] = v + } + for k, v := range capturedReq.Query { + req.Query[k] = v + } + for k, v := range capturedReq.PathParams { + req.PathParams[k] = v + } + + httpResp := schemas.AcquireHTTPResponse() + defer schemas.ReleaseHTTPResponse(httpResp) + httpResp.StatusCode = capturedResp.StatusCode + for k, v := range capturedResp.Headers { + httpResp.Headers[k] = v + } + + var allLogs []schemas.PluginLogEntry + + // Run http post-hooks in reverse order + for i := len(plugins) - 1; i >= 0; i-- { + plugin := plugins[i] + pluginName := plugin.GetName() + pluginCtx := bifrostCtx.WithPluginScope(&pluginName) + err := plugin.HTTPTransportPostHook(pluginCtx, req, httpResp) + pluginCtx.ReleasePluginScope() + if err != nil { + logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", pluginName, err.Error()) + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, fmt.Errorf("HTTPTransportPostHook plugin %s: %w", pluginName, err) + } + } + // Drain post-hook plugin logs + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, nil +} + // getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. func getBifrostContextFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.BifrostContext { return schemas.NewBifrostContext(ctx, schemas.NoDeadline) @@ -936,10 +1015,11 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { ctx.SetUserValue(schemas.BifrostContextKeyParentSpanID, parentSpanID) } - // Store a trace completion callback for streaming handlers to use - ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func() { - // Attach transport plugin logs before completing the trace (streaming path) - if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { + // Store a trace completion callback for streaming handlers to use. + // Accepts transport plugin logs as a parameter so it never reads from + // ctx.UserValue — ctx may be recycled by the time this runs in a goroutine. + ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func(transportLogs []schemas.PluginLogEntry) { + if len(transportLogs) > 0 { tracer.AttachPluginLogs(traceID, transportLogs) } tracer.CompleteAndFlushTrace(traceID) diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 5e028684b9..f3a3ca87cd 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "net/url" - "regexp" "slices" "sort" "strings" @@ -1037,101 +1036,6 @@ func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelPr } } -func validatePricingOverrides(overrides []schemas.ProviderPricingOverride) error { - for i, override := range overrides { - if strings.TrimSpace(override.ModelPattern) == "" { - return fmt.Errorf("override[%d]: model_pattern is required", i) - } - - switch override.MatchType { - case schemas.PricingOverrideMatchExact: - if strings.Contains(override.ModelPattern, "*") { - return fmt.Errorf("override[%d]: exact match_type cannot include '*'", i) - } - case schemas.PricingOverrideMatchWildcard: - if !strings.Contains(override.ModelPattern, "*") { - return fmt.Errorf("override[%d]: wildcard match_type requires '*' in model_pattern", i) - } - case schemas.PricingOverrideMatchRegex: - if _, err := regexp.Compile(override.ModelPattern); err != nil { - return fmt.Errorf("override[%d]: invalid regex pattern: %w", i, err) - } - default: - return fmt.Errorf("override[%d]: unsupported match_type %q", i, override.MatchType) - } - - for _, requestType := range override.RequestTypes { - if !isSupportedOverrideRequestType(requestType) { - return fmt.Errorf("override[%d]: unsupported request_type %q", i, requestType) - } - } - - if err := validatePricingOverrideNonNegativeFields(i, override); err != nil { - return err - } - } - - return nil -} - -func isSupportedOverrideRequestType(requestType schemas.RequestType) bool { - switch requestType { - case schemas.TextCompletionRequest, - schemas.TextCompletionStreamRequest, - schemas.ChatCompletionRequest, - schemas.ChatCompletionStreamRequest, - schemas.ResponsesRequest, - schemas.ResponsesStreamRequest, - schemas.EmbeddingRequest, - schemas.RerankRequest, - schemas.SpeechRequest, - schemas.SpeechStreamRequest, - schemas.TranscriptionRequest, - schemas.TranscriptionStreamRequest, - schemas.ImageGenerationRequest, - schemas.ImageGenerationStreamRequest: - return true - default: - return false - } -} - -func validatePricingOverrideNonNegativeFields(index int, override schemas.ProviderPricingOverride) error { - optionalValues := map[string]*float64{ - "input_cost_per_token": override.InputCostPerToken, - "output_cost_per_token": override.OutputCostPerToken, - "input_cost_per_video_per_second": override.InputCostPerVideoPerSecond, - "input_cost_per_audio_per_second": override.InputCostPerAudioPerSecond, - "input_cost_per_character": override.InputCostPerCharacter, - "input_cost_per_token_above_128k_tokens": override.InputCostPerTokenAbove128kTokens, - "input_cost_per_image_above_128k_tokens": override.InputCostPerImageAbove128kTokens, - "input_cost_per_video_per_second_above_128k_tokens": override.InputCostPerVideoPerSecondAbove128kTokens, - "input_cost_per_audio_per_second_above_128k_tokens": override.InputCostPerAudioPerSecondAbove128kTokens, - "output_cost_per_token_above_128k_tokens": override.OutputCostPerTokenAbove128kTokens, - "input_cost_per_token_above_200k_tokens": override.InputCostPerTokenAbove200kTokens, - "output_cost_per_token_above_200k_tokens": override.OutputCostPerTokenAbove200kTokens, - "cache_creation_input_token_cost_above_200k_tokens": override.CacheCreationInputTokenCostAbove200kTokens, - "cache_read_input_token_cost_above_200k_tokens": override.CacheReadInputTokenCostAbove200kTokens, - "cache_read_input_token_cost": override.CacheReadInputTokenCost, - "cache_creation_input_token_cost": override.CacheCreationInputTokenCost, - "input_cost_per_token_batches": override.InputCostPerTokenBatches, - "output_cost_per_token_batches": override.OutputCostPerTokenBatches, - "input_cost_per_image_token": override.InputCostPerImageToken, - "output_cost_per_image_token": override.OutputCostPerImageToken, - "input_cost_per_image": override.InputCostPerImage, - "output_cost_per_image": override.OutputCostPerImage, - "cache_read_input_image_token_cost": override.CacheReadInputImageTokenCost, - } - - for fieldName, value := range optionalValues { - if value != nil && *value < 0 { - return fmt.Errorf("override[%d]: %s must be non-negative", index, fieldName) - } - } - - return nil -} - func getProviderFromCtx(ctx *fasthttp.RequestCtx) (schemas.ModelProvider, error) { providerValue := ctx.UserValue("provider") if providerValue == nil { diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 88109fd448..3050e526fd 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -2314,8 +2314,9 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *sc // The streaming callback will complete the trace after the stream ends ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) - // Get the trace completer function for use in the streaming callback - traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Capture trace completer BEFORE goroutine — ctx may be recycled inside goroutine. + // Signature: func(transportLogs []schemas.PluginLogEntry) + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func([]schemas.PluginLogEntry)) // Get stream chunk interceptor for plugin hooks interceptor := g.handlerStore.GetStreamChunkInterceptor() @@ -2336,9 +2337,10 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *sc defer schemas.ReleaseHTTPRequest(httpReq) defer func() { // Complete the trace after streaming finishes - // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL. + // Router path has no transport post-hooks, so pass nil for transport logs. if traceCompleter != nil { - traceCompleter() + traceCompleter(nil) } }() diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 66ffbd7725..bc39738d2d 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -127,9 +127,9 @@ type ConfigData struct { // from config.json. Omitting this field or setting it to 2 uses v1.5.0+ semantics: // empty = deny all, ["*"] = allow all. Setting it to 1 restores v1.4.x semantics: // empty = allow all (equivalent to ["*"]). - Version int `json:"version,omitempty"` - Client *configstore.ClientConfig `json:"client"` - EncryptionKey *schemas.EnvVar `json:"encryption_key"` + Version int `json:"version,omitempty"` + Client *configstore.ClientConfig `json:"client"` + EncryptionKey *schemas.EnvVar `json:"encryption_key"` // Deprecated: Use GovernanceConfig.AuthConfig instead AuthConfig *configstore.AuthConfig `json:"auth_config,omitempty"` Providers map[string]configstore.ProviderConfig `json:"providers"` @@ -2159,9 +2159,9 @@ func ResolveFrameworkPricingConfig( case val <= 0: // Zero or negative values are meaningless for a sync eligibility threshold. logger.Warn("pricing_sync_interval in config.json is invalid (%d seconds), ignoring — using default (%d seconds)", val, defaultSyncSeconds) - case val < modelcatalog.MinimumPricingSyncIntervalSec: + case val < modelcatalog.MinimumSyncIntervalSec: // Accept but clamp to the schema-declared minimum of 3600 s (1 hour). - clamped := modelcatalog.MinimumPricingSyncIntervalSec + clamped := modelcatalog.MinimumSyncIntervalSec logger.Warn("pricing_sync_interval in config.json is below minimum (%d seconds), clamping to %d seconds", val, clamped) fileSyncSeconds = &clamped default: @@ -2213,11 +2213,11 @@ func ResolveFrameworkPricingConfig( // Ignore and backfill the DB with the correctly resolved value. logger.Warn("pricing_sync_interval in DB is corrupted (%d seconds), ignoring — backfilling with %d seconds", val, *resolvedSyncSeconds) needsDBUpdate = true - } else if val < modelcatalog.MinimumPricingSyncIntervalSec { + } else if val < modelcatalog.MinimumSyncIntervalSec { // DB has a positive value below the minimum — clamp and backfill, // consistent with the file-path validation in Phase 1. - logger.Warn("pricing_sync_interval in DB is below minimum (%d seconds), clamping to %d seconds — backfilling", val, modelcatalog.MinimumPricingSyncIntervalSec) - clamped := modelcatalog.MinimumPricingSyncIntervalSec + logger.Warn("pricing_sync_interval in DB is below minimum (%d seconds), clamping to %d seconds — backfilling", val, modelcatalog.MinimumSyncIntervalSec) + clamped := modelcatalog.MinimumSyncIntervalSec resolvedSyncSeconds = &clamped intervalSource = "db" needsDBUpdate = true @@ -4262,4 +4262,4 @@ func DeepCopy[T any](in T) (T, error) { } err = sonic.Unmarshal(b, &out) return out, err -} \ No newline at end of file +} diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 0b0756127b..cf6b3237b5 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -15884,8 +15884,8 @@ func TestResolveFrameworkPricingConfig(t *testing.T) { normalizedTable, normalizedModelCatalog, needsDBUpdate := ResolveFrameworkPricingConfig(nil, fileConfig) require.False(t, needsDBUpdate) - require.Equal(t, modelcatalog.MinimumPricingSyncIntervalSec, *normalizedTable.PricingSyncInterval) - require.Equal(t, modelcatalog.MinimumPricingSyncIntervalSec, *normalizedModelCatalog.PricingSyncInterval) + require.Equal(t, modelcatalog.MinimumSyncIntervalSec, *normalizedTable.PricingSyncInterval) + require.Equal(t, modelcatalog.MinimumSyncIntervalSec, *normalizedModelCatalog.PricingSyncInterval) }) t.Run("file interval of zero is ignored and defaults apply", func(t *testing.T) { diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 7f6127406d..a2cfec6ee2 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -610,4 +610,17 @@ func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest // Note: Body not copied - for streaming, body was already consumed return req +} + +// BuildHTTPResponseFromFastHTTP creates an HTTPResponse snapshot from fasthttp context. +// Only captures status code and headers — body is skipped because for streaming +// responses it is an active io.Reader that cannot be materialized. +// The returned response should be released with schemas.ReleaseHTTPResponse when done. +func BuildHTTPResponseFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPResponse { + resp := schemas.AcquireHTTPResponse() + resp.StatusCode = ctx.Response.StatusCode() + for key, value := range ctx.Response.Header.All() { + resp.Headers[string(key)] = string(value) + } + return resp } \ No newline at end of file diff --git a/transports/bifrost-http/lib/pricing_integration_test.go b/transports/bifrost-http/lib/pricing_integration_test.go index 42c893abcc..d6c1865178 100644 --- a/transports/bifrost-http/lib/pricing_integration_test.go +++ b/transports/bifrost-http/lib/pricing_integration_test.go @@ -48,12 +48,12 @@ func (l *capturingLogger) append(level, msg string, args ...any) { l.mu.Unlock() } -func (l *capturingLogger) Debug(msg string, args ...any) { l.append("DEBUG", msg, args...) } -func (l *capturingLogger) Info(msg string, args ...any) { l.append("INFO", msg, args...) } -func (l *capturingLogger) Warn(msg string, args ...any) { l.append("WARN", msg, args...) } -func (l *capturingLogger) Error(msg string, args ...any) { l.append("ERROR", msg, args...) } -func (l *capturingLogger) Fatal(msg string, args ...any) { l.append("FATAL", msg, args...) } -func (l *capturingLogger) SetLevel(_ schemas.LogLevel) {} +func (l *capturingLogger) Debug(msg string, args ...any) { l.append("DEBUG", msg, args...) } +func (l *capturingLogger) Info(msg string, args ...any) { l.append("INFO", msg, args...) } +func (l *capturingLogger) Warn(msg string, args ...any) { l.append("WARN", msg, args...) } +func (l *capturingLogger) Error(msg string, args ...any) { l.append("ERROR", msg, args...) } +func (l *capturingLogger) Fatal(msg string, args ...any) { l.append("FATAL", msg, args...) } +func (l *capturingLogger) SetLevel(_ schemas.LogLevel) {} func (l *capturingLogger) SetOutputType(_ schemas.LoggerOutputType) {} func (l *capturingLogger) LogHTTPRequest(_ schemas.LogLevel, _ string) schemas.LogEventBuilder { return schemas.NoopLogEvent @@ -141,12 +141,12 @@ func initWithCapture( func getLogger() schemas.Logger { return logger } // pt is a generic pointer helper. -func ptStr(s string) *string { return &s } -func ptI64(n int64) *int64 { return &n } +func ptStr(s string) *string { return &s } +func ptI64(n int64) *int64 { return &n } func ptF64(f float64) *float64 { return &f } // defaultSyncSecs is the production default converted to seconds. -var defaultSyncSecs = int64(modelcatalog.DefaultPricingSyncInterval.Seconds()) +var defaultSyncSecs = int64(modelcatalog.DefaultSyncInterval.Seconds()) // ============================================================================= // STEP 2 — Baseline: no config.json, no DB → built-in defaults @@ -228,7 +228,7 @@ func TestPricingE2E_Step4A_TooLow_ClampsTo3600(t *testing.T) { tableOut, catalogOut, _ := resolveWithCapture(t, log, nil, fileConfig) - require.Equal(t, modelcatalog.MinimumPricingSyncIntervalSec, *tableOut.PricingSyncInterval, + require.Equal(t, modelcatalog.MinimumSyncIntervalSec, *tableOut.PricingSyncInterval, "too-low interval must be clamped to minimum 3600 s") require.Equal(t, *tableOut.PricingSyncInterval, *catalogOut.PricingSyncInterval) @@ -511,8 +511,9 @@ func TestPricingE2E_Step7_RuntimeInterval_StoredCorrectly(t *testing.T) { PricingURL: ptStr("https://example.com/pricing.json"), PricingSyncInterval: &syncSeconds, } - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // The startup Info log must reflect the correct duration. @@ -538,8 +539,9 @@ func TestPricingE2E_Step7_RuntimeInterval_24h_Default(t *testing.T) { // Nil PricingURL: defaults apply. noSyncFunc prevents real HTTP requests. cfg := &modelcatalog.Config{} - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // Must show 24h default. @@ -739,9 +741,9 @@ func TestPricingE2E_Step9B_MissingEnvURL_NotReplacedWithDefault(t *testing.T) { func TestPricingE2E_Step10_NoNilPointers_AllInputCombinations(t *testing.T) { type tc struct { - name string - db *configstoreTables.TableFrameworkConfig - file *framework.FrameworkConfig + name string + db *configstoreTables.TableFrameworkConfig + file *framework.FrameworkConfig } cases := []tc{ {"nil/nil", nil, nil}, @@ -905,8 +907,9 @@ func TestPricingE2E_Step10_SecondsToDurationConversion(t *testing.T) { syncSeconds := int64(3600) cfg := &modelcatalog.Config{PricingSyncInterval: &syncSeconds} // noSyncFunc prevents real HTTP requests to the pricing URL during this unit test. - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // The critical assertion: if the old *time.Duration bug were present, @@ -919,8 +922,9 @@ func TestPricingE2E_Step10_SecondsToDurationConversion(t *testing.T) { SetLogger(clg2) syncSeconds2 := int64(7200) cfg2 := &modelcatalog.Config{PricingSyncInterval: &syncSeconds2} - mc2, err := modelcatalog.Init(ctx, cfg2, store, noSyncFunc, clg2) + mc2, err := modelcatalog.Init(ctx, cfg2, store, clg2) require.NoError(t, err) + mc2.SetShouldSyncGate(noSyncFunc) defer mc2.Cleanup() SetLogger(prev) diff --git a/ui/app/workspace/dashboard/page.tsx b/ui/app/workspace/dashboard/page.tsx index 6cc995fa24..cc4f78295a 100644 --- a/ui/app/workspace/dashboard/page.tsx +++ b/ui/app/workspace/dashboard/page.tsx @@ -288,13 +288,7 @@ export default function DashboardPage() { const fetchFilters = { filters }; - const [ - histogramResult, - tokenResult, - costResult, - modelResult, - latencyResult, - ] = await Promise.all([ + const [histogramResult, tokenResult, costResult, modelResult, latencyResult] = await Promise.all([ triggerHistogram(fetchFilters, false), triggerTokens(fetchFilters, false), triggerCost(fetchFilters, false), @@ -312,14 +306,7 @@ export default function DashboardPage() { setLoadingModels(false); setLatencyData(latencyResult.data ?? null); setLoadingLatency(false); - }, [ - filters, - triggerHistogram, - triggerTokens, - triggerCost, - triggerModels, - triggerLatency, - ]); + }, [filters, triggerHistogram, triggerTokens, triggerCost, triggerModels, triggerLatency]); // Fetch Provider Usage tab data (3 calls) const fetchProviderData = useCallback(async () => { @@ -329,11 +316,7 @@ export default function DashboardPage() { const fetchFilters = { filters }; - const [ - providerCostResult, - providerTokenResult, - providerLatencyResult, - ] = await Promise.all([ + const [providerCostResult, providerTokenResult, providerLatencyResult] = await Promise.all([ triggerProviderCost(fetchFilters, false), triggerProviderTokens(fetchFilters, false), triggerProviderLatency(fetchFilters, false), @@ -345,12 +328,7 @@ export default function DashboardPage() { setLoadingProviderTokens(false); setProviderLatencyData(providerLatencyResult.data ?? null); setLoadingProviderLatency(false); - }, [ - filters, - triggerProviderCost, - triggerProviderTokens, - triggerProviderLatency, - ]); + }, [filters, triggerProviderCost, triggerProviderTokens, triggerProviderLatency]); // Fetch MCP data const fetchMcpData = useCallback(async () => { @@ -382,7 +360,6 @@ export default function DashboardPage() { setLoadingRankings(false); }, [filters, triggerRankings]); - // --- Lazy-load refs: each tab fetches only once per filter change --- const overviewFetchedRef = useRef(false); const overviewLoadingRef = useRef(false); const overviewGenRef = useRef(0); @@ -408,14 +385,16 @@ export default function DashboardPage() { if (overviewLoadingRef.current) return overviewPromiseRef.current ?? undefined; const gen = overviewGenRef.current; overviewLoadingRef.current = true; - const promise = fetchOverviewData().then( - () => { if (gen === overviewGenRef.current) overviewFetchedRef.current = true; }, - ).finally(() => { - if (gen === overviewGenRef.current) { - overviewLoadingRef.current = false; - overviewPromiseRef.current = null; - } - }); + const promise = fetchOverviewData() + .then(() => { + if (gen === overviewGenRef.current) overviewFetchedRef.current = true; + }) + .finally(() => { + if (gen === overviewGenRef.current) { + overviewLoadingRef.current = false; + overviewPromiseRef.current = null; + } + }); overviewPromiseRef.current = promise; return promise; }, [fetchOverviewData]); @@ -425,14 +404,16 @@ export default function DashboardPage() { if (providerLoadingRef.current) return providerPromiseRef.current ?? undefined; const gen = providerGenRef.current; providerLoadingRef.current = true; - const promise = fetchProviderData().then( - () => { if (gen === providerGenRef.current) providerFetchedRef.current = true; }, - ).finally(() => { - if (gen === providerGenRef.current) { - providerLoadingRef.current = false; - providerPromiseRef.current = null; - } - }); + const promise = fetchProviderData() + .then(() => { + if (gen === providerGenRef.current) providerFetchedRef.current = true; + }) + .finally(() => { + if (gen === providerGenRef.current) { + providerLoadingRef.current = false; + providerPromiseRef.current = null; + } + }); providerPromiseRef.current = promise; return promise; }, [fetchProviderData]); @@ -442,14 +423,16 @@ export default function DashboardPage() { if (mcpLoadingRef.current) return mcpPromiseRef.current ?? undefined; const gen = mcpGenRef.current; mcpLoadingRef.current = true; - const promise = fetchMcpData().then( - () => { if (gen === mcpGenRef.current) mcpFetchedRef.current = true; }, - ).finally(() => { - if (gen === mcpGenRef.current) { - mcpLoadingRef.current = false; - mcpPromiseRef.current = null; - } - }); + const promise = fetchMcpData() + .then(() => { + if (gen === mcpGenRef.current) mcpFetchedRef.current = true; + }) + .finally(() => { + if (gen === mcpGenRef.current) { + mcpLoadingRef.current = false; + mcpPromiseRef.current = null; + } + }); mcpPromiseRef.current = promise; return promise; }, [fetchMcpData]); @@ -459,14 +442,16 @@ export default function DashboardPage() { if (rankingsLoadingRef.current) return rankingsPromiseRef.current ?? undefined; const gen = rankingsGenRef.current; rankingsLoadingRef.current = true; - const promise = fetchRankingsData().then( - () => { if (gen === rankingsGenRef.current) rankingsFetchedRef.current = true; }, - ).finally(() => { - if (gen === rankingsGenRef.current) { - rankingsLoadingRef.current = false; - rankingsPromiseRef.current = null; - } - }); + const promise = fetchRankingsData() + .then(() => { + if (gen === rankingsGenRef.current) rankingsFetchedRef.current = true; + }) + .finally(() => { + if (gen === rankingsGenRef.current) { + rankingsLoadingRef.current = false; + rankingsPromiseRef.current = null; + } + }); rankingsPromiseRef.current = promise; return promise; }, [fetchRankingsData]); @@ -646,12 +631,7 @@ export default function DashboardPage() { // Preload all tab data (used by CSV and PDF export) const handlePreloadData = useCallback(async () => { - await Promise.all([ - ensureOverviewDataLoaded(), - ensureProviderDataLoaded(), - ensureRankingsDataLoaded(), - ensureMcpDataLoaded(), - ]); + await Promise.all([ensureOverviewDataLoaded(), ensureProviderDataLoaded(), ensureRankingsDataLoaded(), ensureMcpDataLoaded()]); }, [ensureOverviewDataLoaded, ensureProviderDataLoaded, ensureRankingsDataLoaded, ensureMcpDataLoaded]); // PDF export mode — when true, all TabsContent are force-mounted so @@ -679,9 +659,7 @@ export default function DashboardPage() { // Radix sets `hidden` on inactive force-mounted TabsContent. // Temporarily remove it so html2canvas can capture them. - const hiddenTabs = document.querySelectorAll( - '[data-slot="tabs-content"][hidden]', - ); + const hiddenTabs = document.querySelectorAll('[data-slot="tabs-content"][hidden]'); hiddenTabsRef.current = Array.from(hiddenTabs); for (const tab of hiddenTabs) { tab.removeAttribute("hidden"); @@ -704,12 +682,7 @@ export default function DashboardPage() { }); }); - const ids = [ - "dashboard-section-overview", - "dashboard-section-provider-usage", - "dashboard-section-rankings", - "dashboard-section-mcp", - ]; + const ids = ["dashboard-section-overview", "dashboard-section-provider-usage", "dashboard-section-rankings", "dashboard-section-mcp"]; return ids.map((id) => document.getElementById(id)).filter(Boolean) as HTMLElement[]; }, [handlePreloadData]); @@ -758,7 +731,12 @@ export default function DashboardPage() {

Dashboard

- + {(urlState.tab === "overview" || urlState.tab === "provider-usage" || urlState.tab === "rankings") && ( )} @@ -831,103 +809,103 @@ export default function DashboardPage() { {/* Overview Tab */}
- +
{/* Provider Usage Tab */}
- +
{/* Model Rankings Tab */}
- +
{/* MCP Tab */}
- +