Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -4971,6 +4971,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch
if bifrostError != nil {
return nil, bifrostError
}
speechResponse.BackfillParams(req.BifrostRequest.SpeechRequest)
response.SpeechResponse = speechResponse
case schemas.TranscriptionRequest:
transcriptionResponse, bifrostError := provider.Transcription(req.Context, key, req.BifrostRequest.TranscriptionRequest)
Expand All @@ -4983,18 +4984,21 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch
if bifrostError != nil {
return nil, bifrostError
}
imageResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageResponse
case schemas.ImageEditRequest:
imageEditResponse, bifrostError := provider.ImageEdit(req.Context, key, req.BifrostRequest.ImageEditRequest)
if bifrostError != nil {
return nil, bifrostError
}
imageEditResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageEditResponse
case schemas.ImageVariationRequest:
imageVariationResponse, bifrostError := provider.ImageVariation(req.Context, key, req.BifrostRequest.ImageVariationRequest)
if bifrostError != nil {
return nil, bifrostError
}
imageVariationResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageVariationResponse
case schemas.VideoGenerationRequest:
videoGenerationResponse, bifrostError := provider.VideoGeneration(req.Context, key, req.BifrostRequest.VideoGenerationRequest)
Expand Down
12 changes: 6 additions & 6 deletions core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ func HandleAnthropicResponsesStream(
}
}

}
}
}()

return responseChan, nil
Expand Down Expand Up @@ -1820,11 +1820,6 @@ func (provider *AnthropicProvider) Speech(ctx *schemas.BifrostContext, key schem
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}

// Rerank is not supported by the Anthropic provider.
func (provider *AnthropicProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}

// SpeechStream is not supported by the Anthropic provider.
func (provider *AnthropicProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
Expand Down Expand Up @@ -1865,6 +1860,11 @@ func (provider *AnthropicProvider) ImageVariation(ctx *schemas.BifrostContext, k
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}

// Rerank is not supported by the Anthropic provider.
func (provider *AnthropicProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}

// FileUpload uploads a file to Anthropic's Files API.
func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileUploadRequest); err != nil {
Expand Down
1 change: 1 addition & 0 deletions core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody)
}

finalResponse.BackfillParams(request)
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &finalResponse, nil, nil), responseChan)
}
Expand Down
1 change: 1 addition & 0 deletions core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo
Latency: time.Since(startTime).Milliseconds(),
},
}
response.BackfillParams(request)
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
Expand Down
77 changes: 24 additions & 53 deletions core/providers/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -1058,46 +1058,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC
provider.GetProviderKey(),
)
}

var authHeader map[string]string

if value := key.Value.GetValue(); value != "" {
authHeader = map[string]string{"Authorization": "Bearer " + value}
}

// Build streaming URL - append /stream to the fal-ai route, honoring path overrides
defaultPath := fmt.Sprintf("/fal-ai/%s/stream", modelName)
streamURL := provider.buildRequestURL(ctx, defaultPath, schemas.ImageGenerationStreamRequest)

return HandleHuggingFaceImageGenerationStreaming(
ctx,
provider.client,
streamURL,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
provider.logger,
)
}

// HandleHuggingFaceImageGenerationStreaming handles image generation streaming for fal-ai through HuggingFace router.
func HandleHuggingFaceImageGenerationStreaming(
ctx *schemas.BifrostContext,
client *fasthttp.Client,
url string,
request *schemas.BifrostImageGenerationRequest,
authHeader map[string]string,
extraHeaders map[string]string,
sendBackRawRequest bool,
sendBackRawResponse bool,
providerName schemas.ModelProvider,
postHookRunner schemas.PostHookRunner,
logger schemas.Logger,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
providerName := provider.GetProviderKey()

// Set headers
headers := map[string]string{
Expand All @@ -1106,8 +1067,8 @@ func HandleHuggingFaceImageGenerationStreaming(
"Cache-Control": "no-cache",
}

if authHeader != nil {
maps.Copy(headers, authHeader)
if value := key.Value.GetValue(); value != "" {
headers["Authorization"] = "Bearer " + value
}

jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
Expand All @@ -1127,13 +1088,17 @@ func HandleHuggingFaceImageGenerationStreaming(
resp.StreamBody = true
defer fasthttp.ReleaseRequest(req)

// Build streaming URL - append /stream to the fal-ai route, honoring path overrides
defaultPath := fmt.Sprintf("/fal-ai/%s/stream", modelName)
url := provider.buildRequestURL(ctx, defaultPath, schemas.ImageGenerationStreamRequest)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// Setup request
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(url)
req.Header.SetContentType("application/json")

// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)

// Set headers
for key, value := range headers {
Expand All @@ -1148,7 +1113,7 @@ func HandleHuggingFaceImageGenerationStreaming(
startTime := time.Now()

// Make the request
err := client.Do(req, resp)
err := provider.client.Do(req, resp)
if err != nil {
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
defer providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
Expand Down Expand Up @@ -1177,7 +1142,7 @@ func HandleHuggingFaceImageGenerationStreaming(
Provider: providerName,
Model: request.Model,
RequestType: schemas.ImageGenerationStreamRequest,
}), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
}

// Large payload streaming passthrough — pipe raw upstream SSE to client
Expand All @@ -1202,7 +1167,7 @@ func HandleHuggingFaceImageGenerationStreaming(
providerName,
)
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger)
return
}

Expand All @@ -1212,7 +1177,7 @@ func HandleHuggingFaceImageGenerationStreaming(

// Setup cancellation handler to close the raw network stream on ctx cancellation,
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()

sseReader := providerUtils.GetSSEDataReader(ctx, reader)
Expand Down Expand Up @@ -1244,7 +1209,7 @@ func HandleHuggingFaceImageGenerationStreaming(
RequestType: schemas.ImageGenerationStreamRequest,
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger)
return
}
break
Expand Down Expand Up @@ -1275,7 +1240,7 @@ func HandleHuggingFaceImageGenerationStreaming(
bifrostErr.Error.Message = errorResp.Error
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger)
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger)
return
}
}
Expand All @@ -1284,7 +1249,7 @@ func HandleHuggingFaceImageGenerationStreaming(
// Parse fal-ai response
var response HuggingFaceFalAIImageStreamResponse
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
logger.Warn(fmt.Sprintf("Failed to parse fal-ai stream response: %v", err))
provider.logger.Warn(fmt.Sprintf("Failed to parse fal-ai stream response: %v", err))
continue
}
// Extract images from response (handles both Data.Images and top-level Images)
Expand Down Expand Up @@ -1314,7 +1279,7 @@ func HandleHuggingFaceImageGenerationStreaming(
chunk.CreatedAt = time.Now().Unix()
}
// Set raw response if enabled
if sendBackRawResponse {
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
chunk.ExtraFields.RawResponse = jsonData
}

Expand Down Expand Up @@ -1346,6 +1311,9 @@ func HandleHuggingFaceImageGenerationStreaming(
Latency: time.Since(startTime).Milliseconds(),
},
}
finalChunk.BackfillParams(&schemas.BifrostRequest{
ImageGenerationRequest: request,
})
if lastURLData != "" {
finalChunk.URL = lastURLData
} else if lastB64Data != "" {
Expand All @@ -1354,10 +1322,10 @@ func HandleHuggingFaceImageGenerationStreaming(
if finalChunk.CreatedAt == 0 {
finalChunk.CreatedAt = time.Now().Unix()
}
if sendBackRawRequest {
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&finalChunk.ExtraFields, jsonBody)
}
if sendBackRawResponse {
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
finalChunk.ExtraFields.RawResponse = lastJsonData
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
Expand Down Expand Up @@ -1756,6 +1724,9 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext
Latency: time.Since(startTime).Milliseconds(),
},
}
finalChunk.BackfillParams(&schemas.BifrostRequest{
ImageEditRequest: request,
})
if lastURLData != "" {
finalChunk.URL = lastURLData
} else if lastB64Data != "" {
Expand Down
30 changes: 19 additions & 11 deletions core/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1976,11 +1976,6 @@ func (provider *OpenAIProvider) Speech(ctx *schemas.BifrostContext, key schemas.
)
}

// Rerank is not supported by the OpenAI provider.
func (provider *OpenAIProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}

// HandleOpenAISpeechRequest handles speech requests for OpenAI-compatible APIs.
// This shared function reduces code duplication between providers that use the same speech request format.
func HandleOpenAISpeechRequest(
Expand Down Expand Up @@ -2345,6 +2340,7 @@ func HandleOpenAISpeechStreamRequest(
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
}
response.BackfillParams(request)
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan)
return
Expand Down Expand Up @@ -3081,9 +3077,6 @@ func HandleOpenAIImageGenerationStreaming(
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)

// Capture start time before making the HTTP request for latency calculation
startTime := time.Now()

// Make the request
err := activeClient.Do(req, resp)
if err != nil {
Expand Down Expand Up @@ -3154,6 +3147,7 @@ func HandleOpenAIImageGenerationStreaming(

sseReader := providerUtils.GetSSEDataReader(ctx, reader)

startTime := time.Now()
lastChunkTime := startTime
var collectedUsage *schemas.ImageUsage
// Track chunk indices per image - similar to how speech/transcription track chunkIndex
Expand Down Expand Up @@ -3364,6 +3358,9 @@ func HandleOpenAIImageGenerationStreaming(
}
// For completed chunk, use total latency from start
chunk.ExtraFields.Latency = time.Since(startTime).Milliseconds()
chunk.BackfillParams(&schemas.BifrostRequest{
ImageGenerationRequest: request,
})
// Set raw request only on final chunk if enabled
if sendBackRawRequest {
providerUtils.ParseAndSetRawRequest(&chunk.ExtraFields, jsonBody)
Expand All @@ -3385,6 +3382,11 @@ func HandleOpenAIImageGenerationStreaming(
return responseChan, nil
}

// Rerank is not supported by the OpenAI provider.
func (provider *OpenAIProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}

// VideoGeneration performs a video generation request via the OpenAI API.
func (provider *OpenAIProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoGenerationRequest); err != nil {
Expand Down Expand Up @@ -4393,7 +4395,8 @@ func HandleOpenAIImageEditStreamRequest(

sseReader := providerUtils.GetSSEDataReader(ctx, reader)

lastChunkTime := time.Now()
startTime := time.Now()
lastChunkTime := startTime
var collectedUsage *schemas.ImageUsage
// Track chunk indices per image - similar to how speech/transcription track chunkIndex
imageChunkIndices := make(map[int]int) // image index -> chunk index
Expand Down Expand Up @@ -4602,7 +4605,10 @@ func HandleOpenAIImageEditStreamRequest(
chunk.Usage = collectedUsage
}
// For completed chunk, use total latency from start
chunk.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds()
chunk.ExtraFields.Latency = time.Since(startTime).Milliseconds()
chunk.BackfillParams(&schemas.BifrostRequest{
ImageEditRequest: request,
})
Comment thread
coderabbitai[bot] marked this conversation as resolved.
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
}

Expand Down Expand Up @@ -7000,7 +7006,6 @@ func (provider *OpenAIProvider) PassthroughStream(
if req.RawQuery != "" {
url += "?" + req.RawQuery
}
startTime := time.Now()

fasthttpReq := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
Expand All @@ -7025,6 +7030,9 @@ func (provider *OpenAIProvider) PassthroughStream(
fasthttpReq.SetBody(req.Body)

activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)

startTime := time.Now()

Comment thread
Pratham-Mishra04 marked this conversation as resolved.
if err := activeClient.Do(fasthttpReq, resp); err != nil {
providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
Expand Down
Loading
Loading