Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,9 +1026,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "document type not provided for ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand All @@ -1039,9 +1040,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "document_url not provided for document_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand All @@ -1052,9 +1054,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif
Message: "image_url not provided for image_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
ModelRequested: req.Model,
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) {
if err.ExtraFields.RequestType != schemas.OCRRequest {
t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType)
}
if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" {
t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested)
if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" {
t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested)
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

Expand Down
43 changes: 8 additions & 35 deletions core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -2699,14 +2699,6 @@ func (provider *AzureProvider) Passthrough(
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey())
}

if key.AzureKeyConfig.Endpoint.GetValue() == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey())
}

url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)

fasthttpReq := fasthttp.AcquireRequest()
Expand Down Expand Up @@ -2743,7 +2735,7 @@ func (provider *AzureProvider) Passthrough(

body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
}

// Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body.
Expand All @@ -2759,9 +2751,6 @@ func (provider *AzureProvider) Passthrough(
Body: body,
}

bifrostResponse.ExtraFields.Provider = provider.GetProviderKey()
bifrostResponse.ExtraFields.ModelRequested = req.Model
bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
Expand All @@ -2779,14 +2768,6 @@ func (provider *AzureProvider) PassthroughStream(
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey())
}

if key.AzureKeyConfig.Endpoint.GetValue() == "" {
return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey())
}

url := provider.buildPassthroughURL(key, req.Path, req.RawQuery)

fasthttpReq := fasthttp.AcquireRequest()
Expand Down Expand Up @@ -2833,31 +2814,23 @@ func (provider *AzureProvider) PassthroughStream(
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey())
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}

headers := providerUtils.ExtractProviderResponseHeaders(resp)

rawBodyStream := resp.BodyStream()
if rawBodyStream == nil {
providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.NewBifrostOperationError(
"provider returned an empty stream body",
fmt.Errorf("provider returned an empty stream body"),
provider.GetProviderKey(),
)
return nil, providerUtils.NewBifrostOperationError("provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"))
}

bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)

extraFields := schemas.BifrostResponseExtraFields{
Provider: provider.GetProviderKey(),
ModelRequested: req.Model,
RequestType: schemas.PassthroughStreamRequest,
}
extraFields := schemas.BifrostResponseExtraFields{}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
}
Expand All @@ -2867,9 +2840,9 @@ func (provider *AzureProvider) PassthroughStream(
go func() {
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger)
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger)
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger)
}
close(ch)
}()
Expand Down Expand Up @@ -2918,7 +2891,7 @@ func (provider *AzureProvider) PassthroughStream(
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(startTime).Milliseconds()
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger)
return
}
}
Expand Down
6 changes: 2 additions & 4 deletions core/providers/mistral/custom_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) {
resp.SetStatusCode(http.StatusBadRequest)
resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`)

bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest")
bifrostErr := ParseMistralError(resp)
require.NotNil(t, bifrostErr)
require.NotNil(t, bifrostErr.Error)

assert.Equal(t, "invalid request", bifrostErr.Error.Message)
assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type)
assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code)
assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider)
assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType)
assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested)
}

func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) {
Expand Down Expand Up @@ -156,5 +154,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T)
require.NotNil(t, response)

assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider)
assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested)
assert.Equal(t, "codestral-embed", response.ExtraFields.ResolvedModelUsed)
}
6 changes: 1 addition & 5 deletions core/providers/mistral/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type MistralErrorResponse struct {
}

// ParseMistralError parses Mistral-specific error responses.
func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError {
func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp MistralErrorResponse
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if bifrostErr == nil {
Expand Down Expand Up @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType,
}
}

bifrostErr.ExtraFields.Provider = providerName
bifrostErr.ExtraFields.ModelRequested = model
bifrostErr.ExtraFields.RequestType = requestType

return bifrostErr
}
32 changes: 11 additions & 21 deletions core/providers/mistral/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "")
bifrostErr := ParseMistralError(resp)
return nil, bifrostErr
}

Expand Down Expand Up @@ -259,25 +259,23 @@ func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas
// OCR performs an OCR request to the Mistral API.
// It sends a JSON request to Mistral's OCR endpoint and returns the extracted content.
func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()

// Convert Bifrost request to Mistral format
mistralReq := ToMistralOCRRequest(request)
if mistralReq == nil {
return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName)
return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil)
}

// Marshal request body
requestBody, err := sonic.Marshal(mistralReq)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}

// Merge extra params into JSON payload
if len(mistralReq.ExtraParams) > 0 {
requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}

Expand Down Expand Up @@ -313,13 +311,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

responseBody, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}

// Check for empty response
Expand Down Expand Up @@ -347,20 +344,17 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke
},
}
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName)
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
}

// Convert to Bifrost format
response := mistralResponse.ToBifrostOCRResponse()
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName)
return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil)
}

// Set extra fields
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.RequestType = schemas.OCRRequest
response.ExtraFields.Provider = providerName
response.ExtraFields.ModelRequested = request.Model

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
Expand All @@ -382,16 +376,14 @@ func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postH
// It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint.
// Returns the transcribed text and metadata, or an error if the request fails.
func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()

// Convert Bifrost request to Mistral format
mistralReq := ToMistralTranscriptionRequest(request)
if mistralReq == nil {
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
}

// Create multipart form body
body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, providerName)
body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, provider.GetProviderKey())
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -423,8 +415,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

responseBody, err := providerUtils.CheckAndDecodeBody(resp)
Expand Down Expand Up @@ -555,8 +546,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model)
return nil, ParseMistralError(resp)
}

// Large payload streaming passthrough — pipe raw upstream SSE to client
Expand Down
8 changes: 0 additions & 8 deletions core/providers/mistral/ocr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,6 @@ func TestOCRWithMockServer(t *testing.T) {
assert.Equal(t, 1, resp.Pages[1].Index)
require.NotNil(t, resp.UsageInfo)
assert.Equal(t, 2, resp.UsageInfo.PagesProcessed)
assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType)
assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider)
assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested)
},
},
{
Expand Down Expand Up @@ -503,9 +500,6 @@ func TestOCRWithMockServer(t *testing.T) {
assert.Equal(t, "server_error", *err.Error.Type)
require.NotNil(t, err.Error.Code)
assert.Equal(t, "internal_error", *err.Error.Code)
assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider)
assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType)
assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested)
},
},
{
Expand Down Expand Up @@ -757,7 +751,5 @@ func TestMistralOCRIntegration(t *testing.T) {
require.NotEmpty(t, resp.Pages, "Expected at least one page")
assert.Equal(t, 0, resp.Pages[0].Index)
assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0")
assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType)
assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider)
assert.Greater(t, resp.ExtraFields.Latency, int64(0))
}
25 changes: 15 additions & 10 deletions core/schemas/chatcompletions.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion
RequestType: TextCompletionRequest,
ChunkIndex: cr.ExtraFields.ChunkIndex,
Provider: cr.ExtraFields.Provider,
ModelRequested: cr.ExtraFields.ModelRequested,
OriginalModelRequested: cr.ExtraFields.OriginalModelRequested,
ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed,
Latency: cr.ExtraFields.Latency,
RawResponse: cr.ExtraFields.RawResponse,
CacheDebug: cr.ExtraFields.CacheDebug,
Expand Down Expand Up @@ -113,7 +114,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion
RequestType: TextCompletionRequest,
ChunkIndex: cr.ExtraFields.ChunkIndex,
Provider: cr.ExtraFields.Provider,
ModelRequested: cr.ExtraFields.ModelRequested,
OriginalModelRequested: cr.ExtraFields.OriginalModelRequested,
ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed,
Latency: cr.ExtraFields.Latency,
RawResponse: cr.ExtraFields.RawResponse,
CacheDebug: cr.ExtraFields.CacheDebug,
Expand Down Expand Up @@ -149,7 +151,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion
RequestType: TextCompletionRequest,
ChunkIndex: cr.ExtraFields.ChunkIndex,
Provider: cr.ExtraFields.Provider,
ModelRequested: cr.ExtraFields.ModelRequested,
OriginalModelRequested: cr.ExtraFields.OriginalModelRequested,
ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed,
Latency: cr.ExtraFields.Latency,
RawResponse: cr.ExtraFields.RawResponse,
CacheDebug: cr.ExtraFields.CacheDebug,
Expand All @@ -166,13 +169,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion
SystemFingerprint: cr.SystemFingerprint,
Usage: cr.Usage,
ExtraFields: BifrostResponseExtraFields{
RequestType: TextCompletionRequest,
ChunkIndex: cr.ExtraFields.ChunkIndex,
Provider: cr.ExtraFields.Provider,
ModelRequested: cr.ExtraFields.ModelRequested,
Latency: cr.ExtraFields.Latency,
RawResponse: cr.ExtraFields.RawResponse,
CacheDebug: cr.ExtraFields.CacheDebug,
RequestType: TextCompletionRequest,
ChunkIndex: cr.ExtraFields.ChunkIndex,
Provider: cr.ExtraFields.Provider,
OriginalModelRequested: cr.ExtraFields.OriginalModelRequested,
ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed,
Latency: cr.ExtraFields.Latency,
RawResponse: cr.ExtraFields.RawResponse,
CacheDebug: cr.ExtraFields.CacheDebug,
ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders,
},
}
}
Expand Down
Loading
Loading