diff --git a/core/bifrost.go b/core/bifrost.go index 21441b517d..ffc31d3feb 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1026,9 +1026,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document type not provided for ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1039,9 +1040,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document_url not provided for document_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1052,9 +1054,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "image_url not provided for image_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 0c020ba04e..0dd46a62b8 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -605,8 +605,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) { if err.ExtraFields.RequestType != schemas.OCRRequest { t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType) } - if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" { - t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested) + if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" { + t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested) } } diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 451f9883ca..82da1326d7 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -2699,14 +2699,6 @@ func (provider *AzureProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -2743,7 +2735,7 @@ func (provider *AzureProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -2759,9 +2751,6 @@ func (provider *AzureProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2779,14 +2768,6 @@ func (provider *AzureProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -2833,9 +2814,9 @@ func (provider *AzureProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2843,21 +2824,13 @@ func (provider *AzureProvider) PassthroughStream( rawBodyStream := resp.BodyStream() if rawBodyStream == nil { providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError( - "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + return nil, providerUtils.NewBifrostOperationError("provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body")) } bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx)) stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -2867,9 +2840,9 @@ func (provider *AzureProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2918,7 +2891,7 @@ func (provider *AzureProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index 0d8e283f51..9015230544 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -25,7 +25,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { resp.SetStatusCode(http.StatusBadRequest) resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`) - bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest") + bifrostErr := ParseMistralError(resp) require.NotNil(t, bifrostErr) require.NotNil(t, bifrostErr.Error) @@ -33,8 +33,6 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type) assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested) } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { @@ -156,5 +154,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) require.NotNil(t, response) assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider) - assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested) + assert.Equal(t, "codestral-embed", response.ExtraFields.ResolvedModelUsed) } diff --git a/core/providers/mistral/errors.go b/core/providers/mistral/errors.go index 2ae260eb9a..cbbd40a560 100644 --- a/core/providers/mistral/errors.go +++ b/core/providers/mistral/errors.go @@ -19,7 +19,7 @@ type MistralErrorResponse struct { } // ParseMistralError parses Mistral-specific error responses. -func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp MistralErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if bifrostErr == nil { @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, } } - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 9befe35fee..a7a5ab4a29 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -101,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseMistralError(resp) return nil, bifrostErr } @@ -259,25 +259,23 @@ func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas // OCR performs an OCR request to the Mistral API. // It sends a JSON request to Mistral's OCR endpoint and returns the extracted content. func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralOCRRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil) } // Marshal request body requestBody, err := sonic.Marshal(mistralReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge extra params into JSON payload if len(mistralReq.ExtraParams) > 0 { requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -313,13 +311,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -347,20 +344,17 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostOCRResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.OCRRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -382,8 +376,6 @@ func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postH // It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint. // Returns the transcribed text and metadata, or an error if the request fails. func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { @@ -391,7 +383,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key } // Create multipart form body - body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, providerName) + body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr } @@ -423,8 +415,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) @@ -555,8 +546,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client diff --git a/core/providers/mistral/ocr_test.go b/core/providers/mistral/ocr_test.go index c88bca5f48..ccf7223e4c 100644 --- a/core/providers/mistral/ocr_test.go +++ b/core/providers/mistral/ocr_test.go @@ -436,9 +436,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, 1, resp.Pages[1].Index) require.NotNil(t, resp.UsageInfo) assert.Equal(t, 2, resp.UsageInfo.PagesProcessed) - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) - assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested) }, }, { @@ -503,9 +500,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, "server_error", *err.Error.Type) require.NotNil(t, err.Error.Code) assert.Equal(t, "internal_error", *err.Error.Code) - assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested) }, }, { @@ -757,7 +751,5 @@ func TestMistralOCRIntegration(t *testing.T) { require.NotEmpty(t, resp.Pages, "Expected at least one page") assert.Equal(t, 0, resp.Pages[0].Index) assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0") - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) assert.Greater(t, resp.ExtraFields.Latency, int64(0)) } diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index b8e1590da2..6a107e3902 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -80,7 +80,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -113,7 +114,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -149,7 +151,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -166,13 +169,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 811bf471dd..e9a2b811d4 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -77,7 +77,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto // Log the active interval and the scheduler's actual check frequency so operators // are not surprised that setting interval=1h does not mean checks happen every second. // Actual syncs occur when: (1) the 1-hour ticker fires AND (2) time.Since(lastSync) >= pricingSyncInterval. - logger.Info("pricing sync interval set to %v (scheduler checks every %v)", pricingSyncInterval, syncWorkerTickerPeriod) + logger.Info("pricing sync interval set to %v (scheduler checks every %v)", syncInterval, syncWorkerTickerPeriod) mc := &ModelCatalog{ pricingURL: pricingURL, diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index 9f4d7827b9..e1a961e713 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -295,8 +295,6 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope // Route to the appropriate compute function switch requestType { - case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest, schemas.RealtimeRequest: - return computeTextCost(pricing, input.usage) case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest: return computeTextCost(pricing, input.usage, input.tier) case schemas.EmbeddingRequest: diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index e9cdc5507b..fe832f025a 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -190,10 +190,10 @@ func TestComputeTextCost_Below200kUsesBaseRate(t *testing.T) { func TestComputeTextCost_Tiered272k(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove200kTokens = new(0.000006) + p.OutputCostPerTokenAbove200kTokens = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -210,10 +210,10 @@ func TestComputeTextCost_Tiered272k(t *testing.T) { func TestComputeTextCost_Between200kAnd272kUses200kRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove200kTokens = new(0.000006) + p.OutputCostPerTokenAbove200kTokens = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 200000, @@ -230,10 +230,10 @@ func TestComputeTextCost_Between200kAnd272kUses200kRate(t *testing.T) { func TestComputeTextCost_272kTierWithCacheRead(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1429,15 +1429,15 @@ func TestCalculateCost_272kTier_EndToEnd(t *testing.T) { Model: "claude-3-7-sonnet", Provider: "anthropic", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - InputCostPerTokenAbove200kTokens: ptr(0.000006), - OutputCostPerTokenAbove200kTokens: ptr(0.00003), - InputCostPerTokenAbove272kTokens: ptr(0.000009), - OutputCostPerTokenAbove272kTokens: ptr(0.000045), - CacheReadInputTokenCost: ptr(0.0000003), - CacheReadInputTokenCostAbove200kTokens: ptr(0.0000006), - CacheReadInputTokenCostAbove272kTokens: ptr(0.0000009), + InputCostPerToken: new(0.000003), + OutputCostPerToken: new(0.000015), + InputCostPerTokenAbove200kTokens: new(0.000006), + OutputCostPerTokenAbove200kTokens: new(0.00003), + InputCostPerTokenAbove272kTokens: new(0.000009), + OutputCostPerTokenAbove272kTokens: new(0.000045), + CacheReadInputTokenCost: new(0.0000003), + CacheReadInputTokenCostAbove200kTokens: new(0.0000006), + CacheReadInputTokenCostAbove272kTokens: new(0.0000009), }, }) @@ -1447,7 +1447,7 @@ func TestCalculateCost_272kTier_EndToEnd(t *testing.T) { TotalTokens: 280000, // Above 272k }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Tiered rate: input=0.000009, output=0.000045 // 250000*0.000009 + 30000*0.000045 = 2.25 + 1.35 = 3.60 assert.InDelta(t, 3.60, cost, 1e-9) @@ -1460,13 +1460,13 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { Model: "claude-3-7-sonnet", Provider: "anthropic", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - InputCostPerTokenAbove272kTokens: ptr(0.000009), - OutputCostPerTokenAbove272kTokens: ptr(0.000045), - CacheReadInputTokenCost: ptr(0.0000003), - CacheReadInputTokenCostAbove200kTokens: ptr(0.0000006), - CacheReadInputTokenCostAbove272kTokens: ptr(0.0000009), + InputCostPerToken: new(0.000003), + OutputCostPerToken: new(0.000015), + InputCostPerTokenAbove272kTokens: new(0.000009), + OutputCostPerTokenAbove272kTokens: new(0.000045), + CacheReadInputTokenCost: new(0.0000003), + CacheReadInputTokenCostAbove200kTokens: new(0.0000006), + CacheReadInputTokenCostAbove272kTokens: new(0.0000009), }, }) @@ -1479,7 +1479,7 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { }, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Non-cached input: (250000-50000) * 0.000009 = 200000 * 0.000009 = 1.80 // Cached read (272k rate): 50000 * 0.0000009 = 0.045 // Output: 30000 * 0.000045 = 1.35 @@ -1493,8 +1493,8 @@ func TestCalculateCost_272kTier_CacheReadFallbackChain(t *testing.T) { func TestComputeTextCost_PriorityUsesInputOutputPriorityRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1510,8 +1510,8 @@ func TestComputeTextCost_PriorityUsesInputOutputPriorityRate(t *testing.T) { func TestComputeTextCost_NonPriorityIgnoresPriorityRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1527,12 +1527,12 @@ func TestComputeTextCost_NonPriorityIgnoresPriorityRate(t *testing.T) { func TestComputeTextCost_Priority272kTier(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.InputCostPerTokenAbove272kTokensPriority = ptr(0.000012) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) - p.OutputCostPerTokenAbove272kTokensPriority = ptr(0.00006) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.InputCostPerTokenAbove272kTokensPriority = new(0.000012) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) + p.OutputCostPerTokenAbove272kTokensPriority = new(0.00006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1549,8 +1549,8 @@ func TestComputeTextCost_Priority272kTier(t *testing.T) { func TestComputeTextCost_Priority272kTierFallsBackToNonPriority272k(t *testing.T) { // Priority flag set but no priority-specific 272k rate — fall back to non-priority 272k p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1566,10 +1566,10 @@ func TestComputeTextCost_Priority272kTierFallsBackToNonPriority272k(t *testing.T func TestComputeTextCost_PriorityCacheReadRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenPriority = ptr(0.000006) - p.OutputCostPerTokenPriority = ptr(0.00003) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) + p.InputCostPerTokenPriority = new(0.000006) + p.OutputCostPerTokenPriority = new(0.00003) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1596,10 +1596,10 @@ func TestCalculateCost_PriorityTier_EndToEnd(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenPriority: ptr(0.000010), - OutputCostPerTokenPriority: ptr(0.000030), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenPriority: new(0.000010), + OutputCostPerTokenPriority: new(0.000030), }, }) @@ -1612,14 +1612,15 @@ func TestCalculateCost_PriorityTier_EndToEnd(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ResolvedModelUsed: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Priority rates: 1000*0.000010 + 500*0.000030 = 0.010 + 0.015 = 0.025 assert.InDelta(t, 0.025, cost, 1e-12) } @@ -1631,10 +1632,10 @@ func TestCalculateCost_NonPriorityServiceTier_UsesBaseRate(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenPriority: ptr(0.000010), - OutputCostPerTokenPriority: ptr(0.000030), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenPriority: new(0.000010), + OutputCostPerTokenPriority: new(0.000030), }, }) @@ -1647,14 +1648,15 @@ func TestCalculateCost_NonPriorityServiceTier_UsesBaseRate(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ResolvedModelUsed: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Base rates (not priority): 1000*0.000005 + 500*0.000015 = 0.005 + 0.0075 = 0.0125 assert.InDelta(t, 0.0125, cost, 1e-12) } @@ -1663,23 +1665,23 @@ func TestTieredCacheReadRate_FallbackOrder(t *testing.T) { // 272k rate takes precedence over 200k, 200k over base, base over input rate t.Run("uses_272k_when_above_272k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000009, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{})) }) t.Run("uses_200k_when_between_200k_and_272k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000006, tieredCacheReadInputTokenRate(&p, 230000, serviceTier{})) }) t.Run("uses_base_cache_rate_when_below_200k", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostAbove200kTokens = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostAbove200kTokens = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000003, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{})) }) t.Run("falls_back_to_input_rate_when_no_cache_rate_set", func(t *testing.T) { @@ -1689,49 +1691,49 @@ func TestTieredCacheReadRate_FallbackOrder(t *testing.T) { }) t.Run("priority_uses_272k_priority_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) - p.CacheReadInputTokenCostAbove272kTokensPriority = ptr(0.0000012) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) + p.CacheReadInputTokenCostAbove272kTokensPriority = new(0.0000012) assert.Equal(t, 0.0000012, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isPriority: true})) }) t.Run("priority_falls_back_to_272k_non_priority_when_priority_rate_missing", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) assert.Equal(t, 0.0000009, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isPriority: true})) }) t.Run("priority_uses_priority_base_cache_rate_below_tiers", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostPriority = ptr(0.0000006) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostPriority = new(0.0000006) assert.Equal(t, 0.0000006, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isPriority: true})) }) t.Run("flex_uses_flex_cache_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000005) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000005) assert.Equal(t, 0.0000005, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isFlex: true})) }) t.Run("flex_uses_flex_cache_rate_regardless_of_token_count", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000005) - p.CacheReadInputTokenCostAbove272kTokens = ptr(0.0000009) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000005) + p.CacheReadInputTokenCostAbove272kTokens = new(0.0000009) // Even above 272k, flex flat rate takes precedence assert.Equal(t, 0.0000005, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isFlex: true})) }) t.Run("flex_falls_back_to_base_cache_rate_when_no_flex_cache_rate", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) + p.CacheReadInputTokenCost = new(0.0000003) // No flex cache rate — falls back to base cache rate assert.Equal(t, 0.0000003, tieredCacheReadInputTokenRate(&p, 1500, serviceTier{isFlex: true})) }) t.Run("flex_wins_over_272k_priority_and_priority_base_when_all_present", func(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCostAbove272kTokens = ptr(5e-7) - p.CacheReadInputTokenCostFlex = ptr(1.3e-7) - p.CacheReadInputTokenCostPriority = ptr(5e-7) - p.CacheReadInputTokenCostAbove272kTokensPriority = ptr(0.000001) + p.CacheReadInputTokenCostAbove272kTokens = new(5e-7) + p.CacheReadInputTokenCostFlex = new(1.3e-7) + p.CacheReadInputTokenCostPriority = new(5e-7) + p.CacheReadInputTokenCostAbove272kTokensPriority = new(0.000001) // token count exceeds 272k — but flex flat rate should still win assert.Equal(t, 1.3e-7, tieredCacheReadInputTokenRate(&p, 280000, serviceTier{isFlex: true})) }) @@ -1775,8 +1777,8 @@ func TestTierFromString_Nil(t *testing.T) { func TestComputeTextCost_FlexUsesFlexRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1792,8 +1794,8 @@ func TestComputeTextCost_FlexUsesFlexRate(t *testing.T) { func TestComputeTextCost_NonFlexIgnoresFlexRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1810,10 +1812,10 @@ func TestComputeTextCost_NonFlexIgnoresFlexRate(t *testing.T) { func TestComputeTextCost_FlexIgnoresTokenTiers(t *testing.T) { // Flex is a flat rate — token-count tiers (272k, 200k, 128k) do not apply. p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) - p.InputCostPerTokenAbove272kTokens = ptr(0.000009) - p.OutputCostPerTokenAbove272kTokens = ptr(0.000045) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) + p.InputCostPerTokenAbove272kTokens = new(0.000009) + p.OutputCostPerTokenAbove272kTokens = new(0.000045) usage := &schemas.BifrostLLMUsage{ PromptTokens: 250000, @@ -1829,10 +1831,10 @@ func TestComputeTextCost_FlexIgnoresTokenTiers(t *testing.T) { func TestComputeTextCost_FlexCacheReadRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenFlex = ptr(0.0000015) - p.OutputCostPerTokenFlex = ptr(0.0000075) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheReadInputTokenCostFlex = ptr(0.0000006) + p.InputCostPerTokenFlex = new(0.0000015) + p.OutputCostPerTokenFlex = new(0.0000075) + p.CacheReadInputTokenCost = new(0.0000003) + p.CacheReadInputTokenCostFlex = new(0.0000006) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1875,10 +1877,10 @@ func TestCalculateCost_FlexTier_EndToEnd(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerTokenFlex: ptr(0.0000025), - OutputCostPerTokenFlex: ptr(0.0000075), + InputCostPerToken: new(0.000005), + OutputCostPerToken: new(0.000015), + InputCostPerTokenFlex: new(0.0000025), + OutputCostPerTokenFlex: new(0.0000075), }, }) @@ -1891,14 +1893,15 @@ func TestCalculateCost_FlexTier_EndToEnd(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ResolvedModelUsed: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Flex rates: 1000*0.0000025 + 500*0.0000075 = 0.0025 + 0.00375 = 0.00625 assert.InDelta(t, 0.00625, cost, 1e-12) } @@ -1918,14 +1921,15 @@ func TestCalculateCost_FlexTier_FallsBackToBaseWhenNoFlexRate(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", + ResolvedModelUsed: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // No flex rates configured — falls back to base: 1000*0.000005 + 500*0.000015 = 0.005 + 0.0075 = 0.0125 assert.InDelta(t, 0.0125, cost, 1e-12) } diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index 0916c908d3..fcf1ab8f49 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -905,6 +905,7 @@ func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string redirectURI = templateConfig.RedirectURI } tokenResponse, err := p.exchangeCodeForTokensWithPKCE( + ctx, templateConfig.TokenURL, code, templateConfig.ClientID, @@ -1055,6 +1056,7 @@ func (p *OAuth2Provider) RefreshUserAccessToken(ctx context.Context, sessionToke // Exchange refresh token newTokenResponse, err := p.exchangeRefreshToken( + ctx, templateConfig.TokenURL, templateConfig.ClientID, templateConfig.ClientSecret, diff --git a/plugins/prompts/go.mod b/plugins/prompts/go.mod index 73eaf7ad3c..5ceaa9ab29 100644 --- a/plugins/prompts/go.mod +++ b/plugins/prompts/go.mod @@ -45,6 +45,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jackc/pgx/v5 v5.9.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/plugins/prompts/go.sum b/plugins/prompts/go.sum index d7829db75b..2339673b68 100644 --- a/plugins/prompts/go.sum +++ b/plugins/prompts/go.sum @@ -90,8 +90,7 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index b7c7f1a371..da2d27594c 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -472,7 +472,7 @@ func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 719967a1a9..384e0e3656 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -494,7 +494,7 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { if frameworkConfig.PricingSyncInterval != nil { syncSeconds = *frameworkConfig.PricingSyncInterval } else { - syncSeconds = int64(modelcatalog.DefaultPricingSyncInterval.Seconds()) + syncSeconds = int64(modelcatalog.DefaultSyncInterval.Seconds()) } h.store.FrameworkConfig = &framework.FrameworkConfig{ Pricing: &modelcatalog.Config{ diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index cec3b6fbbd..57d3d7a706 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -156,20 +156,20 @@ var rerankParamsKnownFields = map[string]bool{ } var ocrParamsKnownFields = map[string]bool{ - "model": true, - "id": true, - "document": true, - "fallbacks": true, - "include_image_base64": true, - "pages": true, - "image_limit": true, - "image_min_size": true, - "table_format": true, - "extract_header": true, - "extract_footer": true, - "bbox_annotation_format": true, - "document_annotation_format": true, - "document_annotation_prompt": true, + "model": true, + "id": true, + "document": true, + "fallbacks": true, + "include_image_base64": true, + "pages": true, + "image_limit": true, + "image_min_size": true, + "table_format": true, + "extract_header": true, + "extract_footer": true, + "bbox_annotation_format": true, + "document_annotation_format": true, + "document_annotation_prompt": true, } var speechParamsKnownFields = map[string]bool{ @@ -459,7 +459,7 @@ type RerankRequest struct { // OCRHandlerRequest is a bifrost OCR request type OCRHandlerRequest struct { - ID *string `json:"id,omitempty"` + ID *string `json:"id,omitempty"` Document schemas.OCRDocument `json:"document"` BifrostParams *schemas.OCRParameters @@ -1317,7 +1317,7 @@ func (h *CompletionHandler) ocr(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/lib/pricing_integration_test.go b/transports/bifrost-http/lib/pricing_integration_test.go index 42c893abcc..9e21c083b4 100644 --- a/transports/bifrost-http/lib/pricing_integration_test.go +++ b/transports/bifrost-http/lib/pricing_integration_test.go @@ -48,12 +48,12 @@ func (l *capturingLogger) append(level, msg string, args ...any) { l.mu.Unlock() } -func (l *capturingLogger) Debug(msg string, args ...any) { l.append("DEBUG", msg, args...) } -func (l *capturingLogger) Info(msg string, args ...any) { l.append("INFO", msg, args...) } -func (l *capturingLogger) Warn(msg string, args ...any) { l.append("WARN", msg, args...) } -func (l *capturingLogger) Error(msg string, args ...any) { l.append("ERROR", msg, args...) } -func (l *capturingLogger) Fatal(msg string, args ...any) { l.append("FATAL", msg, args...) } -func (l *capturingLogger) SetLevel(_ schemas.LogLevel) {} +func (l *capturingLogger) Debug(msg string, args ...any) { l.append("DEBUG", msg, args...) } +func (l *capturingLogger) Info(msg string, args ...any) { l.append("INFO", msg, args...) } +func (l *capturingLogger) Warn(msg string, args ...any) { l.append("WARN", msg, args...) } +func (l *capturingLogger) Error(msg string, args ...any) { l.append("ERROR", msg, args...) } +func (l *capturingLogger) Fatal(msg string, args ...any) { l.append("FATAL", msg, args...) } +func (l *capturingLogger) SetLevel(_ schemas.LogLevel) {} func (l *capturingLogger) SetOutputType(_ schemas.LoggerOutputType) {} func (l *capturingLogger) LogHTTPRequest(_ schemas.LogLevel, _ string) schemas.LogEventBuilder { return schemas.NoopLogEvent @@ -141,12 +141,12 @@ func initWithCapture( func getLogger() schemas.Logger { return logger } // pt is a generic pointer helper. -func ptStr(s string) *string { return &s } -func ptI64(n int64) *int64 { return &n } +func ptStr(s string) *string { return &s } +func ptI64(n int64) *int64 { return &n } func ptF64(f float64) *float64 { return &f } // defaultSyncSecs is the production default converted to seconds. -var defaultSyncSecs = int64(modelcatalog.DefaultPricingSyncInterval.Seconds()) +var defaultSyncSecs = int64(modelcatalog.DefaultSyncInterval.Seconds()) // ============================================================================= // STEP 2 — Baseline: no config.json, no DB → built-in defaults @@ -511,8 +511,9 @@ func TestPricingE2E_Step7_RuntimeInterval_StoredCorrectly(t *testing.T) { PricingURL: ptStr("https://example.com/pricing.json"), PricingSyncInterval: &syncSeconds, } - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // The startup Info log must reflect the correct duration. @@ -538,8 +539,9 @@ func TestPricingE2E_Step7_RuntimeInterval_24h_Default(t *testing.T) { // Nil PricingURL: defaults apply. noSyncFunc prevents real HTTP requests. cfg := &modelcatalog.Config{} - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // Must show 24h default. @@ -739,9 +741,9 @@ func TestPricingE2E_Step9B_MissingEnvURL_NotReplacedWithDefault(t *testing.T) { func TestPricingE2E_Step10_NoNilPointers_AllInputCombinations(t *testing.T) { type tc struct { - name string - db *configstoreTables.TableFrameworkConfig - file *framework.FrameworkConfig + name string + db *configstoreTables.TableFrameworkConfig + file *framework.FrameworkConfig } cases := []tc{ {"nil/nil", nil, nil}, @@ -905,8 +907,9 @@ func TestPricingE2E_Step10_SecondsToDurationConversion(t *testing.T) { syncSeconds := int64(3600) cfg := &modelcatalog.Config{PricingSyncInterval: &syncSeconds} // noSyncFunc prevents real HTTP requests to the pricing URL during this unit test. - mc, err := modelcatalog.Init(ctx, cfg, store, noSyncFunc, clg) + mc, err := modelcatalog.Init(ctx, cfg, store, clg) require.NoError(t, err) + mc.SetShouldSyncGate(noSyncFunc) defer mc.Cleanup() // The critical assertion: if the old *time.Duration bug were present, @@ -919,8 +922,9 @@ func TestPricingE2E_Step10_SecondsToDurationConversion(t *testing.T) { SetLogger(clg2) syncSeconds2 := int64(7200) cfg2 := &modelcatalog.Config{PricingSyncInterval: &syncSeconds2} - mc2, err := modelcatalog.Init(ctx, cfg2, store, noSyncFunc, clg2) + mc2, err := modelcatalog.Init(ctx, cfg2, store, clg2) require.NoError(t, err) + mc2.SetShouldSyncGate(noSyncFunc) defer mc2.Cleanup() SetLogger(prev)