Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,9 +1026,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "document type not provided for ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand All @@ -1039,9 +1040,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "document_url not provided for document_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand All @@ -1052,9 +1054,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "image_url not provided for image_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) {
if err.ExtraFields.RequestType != schemas.OCRRequest {
t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType)
}
if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" {
t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested)
if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" {
t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested)
}
}

Expand Down
34 changes: 9 additions & 25 deletions core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -2699,12 +2699,8 @@ func (provider *AzureProvider) Passthrough(
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey())
}

if key.AzureKeyConfig.Endpoint.GetValue() == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey())
return nil, providerUtils.NewConfigurationError("endpoint not set")
}

url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)
Expand Down Expand Up @@ -2743,7 +2739,7 @@ func (provider *AzureProvider) Passthrough(

body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
}

// Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body.
Expand All @@ -2759,9 +2755,6 @@ func (provider *AzureProvider) Passthrough(
Body: body,
}

bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
bifrostResponse.ExtraFields.ModelRequested = req.Model
bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
Expand All @@ -2779,12 +2772,8 @@ func (provider *AzureProvider) PassthroughStream(
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey())
}

if key.AzureKeyConfig.Endpoint.GetValue() == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey())
return nil, providerUtils.NewConfigurationError("endpoint not set")
}

url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)
Expand Down Expand Up @@ -2833,9 +2822,9 @@ func (provider *AzureProvider) PassthroughStream(
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}

headers := providerUtils.ExtractProviderResponseHeaders(resp)
Expand All @@ -2846,18 +2835,13 @@ func (provider *AzureProvider) PassthroughStream(
return nil, providerUtils.NewBifrostOperationError(
"provider returned an empty stream body",
fmt.Errorf("provider returned an empty stream body"),
provider.GetProviderKey(),
)
}

bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)

extraFields := schemas.BifrostResponseExtraFields{
Provider: provider.GetProviderKey(),
ModelRequested: req.Model,
RequestType: schemas.PassthroughStreamRequest,
}
extraFields := schemas.BifrostResponseExtraFields{}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
}
Expand All @@ -2867,9 +2851,9 @@ func (provider *AzureProvider) PassthroughStream(
go func() {
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger)
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger)
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger)
}
close(ch)
}()
Expand Down Expand Up @@ -2918,7 +2902,7 @@ func (provider *AzureProvider) PassthroughStream(
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(startTime).Milliseconds()
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger)
return
}
}
Expand Down
6 changes: 3 additions & 3 deletions core/providers/mistral/custom_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) {
resp.SetStatusCode(http.StatusBadRequest)
resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`)

bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest")
bifrostErr := ParseMistralError(resp)
require.NotNil(t, bifrostErr)
require.NotNil(t, bifrostErr.Error)

Expand All @@ -34,7 +34,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) {
assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code)
assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider)
assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType)
assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested)
assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.OriginalModelRequested)
}

func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) {
Expand Down Expand Up @@ -156,5 +156,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T)
require.NotNil(t, response)

assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider)
assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested)
assert.Equal(t, "codestral-embed", response.ExtraFields.OriginalModelRequested)
}
6 changes: 1 addition & 5 deletions core/providers/mistral/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type MistralErrorResponse struct {
}

// ParseMistralError parses Mistral-specific error responses.
func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError {
func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp MistralErrorResponse
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
if bifrostErr == nil {
Expand Down Expand Up @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType,
}
}

bifrostErr.ExtraFields.Provider = providerName
bifrostErr.ExtraFields.ModelRequested = model
bifrostErr.ExtraFields.RequestType = requestType

return bifrostErr
}
21 changes: 10 additions & 11 deletions core/providers/mistral/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "")
bifrostErr := ParseMistralError(resp)
return nil, bifrostErr
}

Expand Down Expand Up @@ -264,20 +264,20 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke
// Convert Bifrost request to Mistral format
mistralReq := ToMistralOCRRequest(request)
if mistralReq == nil {
return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName)
return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil)
}

// Marshal request body
requestBody, err := sonic.Marshal(mistralReq)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}

// Merge extra params into JSON payload
if len(mistralReq.ExtraParams) > 0 {
requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}

Expand Down Expand Up @@ -314,12 +314,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

responseBody, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}

// Check for empty response
Expand Down Expand Up @@ -347,20 +347,19 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke
},
}
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
}

// Convert to Bifrost format
response := mistralResponse.ToBifrostOCRResponse()
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName)
return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil)
}

// Set extra fields
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.RequestType = schemas.OCRRequest
response.ExtraFields.Provider = providerName
response.ExtraFields.ModelRequested = request.Model

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
Expand Down Expand Up @@ -424,7 +423,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

responseBody, err := providerUtils.CheckAndDecodeBody(resp)
Expand Down Expand Up @@ -556,7 +555,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

// Large payload streaming passthrough — pipe raw upstream SSE to client
Expand Down
4 changes: 2 additions & 2 deletions core/providers/mistral/ocr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ func TestOCRWithMockServer(t *testing.T) {
assert.Equal(t, 2, resp.UsageInfo.PagesProcessed)
assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType)
assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider)
assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested)
assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.OriginalModelRequested)
},
},
{
Expand Down Expand Up @@ -505,7 +505,7 @@ func TestOCRWithMockServer(t *testing.T) {
assert.Equal(t, "internal_error", *err.Error.Code)
assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider)
assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType)
assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested)
assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.OriginalModelRequested)
},
},
{
Expand Down
4 changes: 2 additions & 2 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ const (
BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost)
BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost)
BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers)
BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware)
BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming - accepts transport logs so it never reads ctx - set by tracing middleware)
BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost)
BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations)
BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost)
Expand All @@ -233,7 +233,7 @@ const (
BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream)
BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries
BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks)
BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware)
BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // *atomic.Value holding func() ([]PluginLogEntry, error) — set by streaming handler, populated by transport interceptor middleware
BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request
BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns)
BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string
Expand Down
Loading
Loading