Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -5167,6 +5167,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch
if bifrostError != nil {
return nil, bifrostError
}
transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest)
response.TranscriptionResponse = transcriptionResponse
case schemas.ImageGenerationRequest:
imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest)
Expand Down
22 changes: 9 additions & 13 deletions core/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,6 @@ func HandleOpenAIResponsesStreaming(
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan)
}
}

}()

return responseChan, nil
Expand Down Expand Up @@ -2373,7 +2372,6 @@ func HandleOpenAISpeechStreamRequest(

providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan)
}

}()

return responseChan, nil
Expand Down Expand Up @@ -2514,7 +2512,12 @@ func HandleOpenAITranscriptionRequest(
// Parse OpenAI's transcription response directly into BifrostTranscribe
response := &schemas.BifrostTranscriptionResponse{}
var rawResponse interface{}
if customResponseHandler != nil {
if request.Params != nil && schemas.IsPlainTextTranscriptionFormat(request.Params.ResponseFormat) {
response.Text = string(copiedResponseBody)
if sendBackRawResponse {
rawResponse = string(copiedResponseBody)
}
} else if customResponseHandler != nil {
_, rawResponse, bifrostErr = customResponseHandler(copiedResponseBody, response, nil, false, sendBackRawResponse)
} else {
if err := sonic.Unmarshal(copiedResponseBody, response); err != nil {
Expand All @@ -2531,7 +2534,7 @@ func HandleOpenAITranscriptionRequest(
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName)
}

//TODO: add HandleProviderResponse here
// TODO: add HandleProviderResponse here

// Parse raw response for RawResponse field
if sendBackRawResponse {
Expand Down Expand Up @@ -2835,7 +2838,6 @@ func HandleOpenAITranscriptionStreamRequest(

providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan)
}

}()

return responseChan, nil
Expand All @@ -2845,8 +2847,8 @@ func HandleOpenAITranscriptionStreamRequest(
// It formats the request, sends it to OpenAI, and processes the response.
// Returns a BifrostResponse containing the bifrost response or an error if the request fails.
func (provider *OpenAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key,
req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {

req *schemas.BifrostImageGenerationRequest,
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil {
return nil, err
}
Expand Down Expand Up @@ -2879,7 +2881,6 @@ func HandleOpenAIImageGenerationRequest(
sendBackRawResponse bool,
logger schemas.Logger,
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {

// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
Expand Down Expand Up @@ -3000,7 +3001,6 @@ func (provider *OpenAIProvider) ImageGenerationStream(
key schemas.Key,
request *schemas.BifrostImageGenerationRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {

if request == nil {
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey())
}
Expand Down Expand Up @@ -3049,7 +3049,6 @@ func HandleOpenAIImageGenerationStreaming(
postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse,
logger schemas.Logger,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {

// Set headers
headers := map[string]string{
"Content-Type": "application/json",
Expand Down Expand Up @@ -3411,7 +3410,6 @@ func HandleOpenAIImageGenerationStreaming(
return
}
}

}()

return responseChan, nil
Expand Down Expand Up @@ -4320,7 +4318,6 @@ func HandleOpenAIImageEditStreamRequest(
postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse,
logger schemas.Logger,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {

reqBody := ToOpenAIImageEditRequest(request)
if reqBody == nil {
return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil, providerName)
Expand Down Expand Up @@ -4666,7 +4663,6 @@ func HandleOpenAIImageEditStreamRequest(
return
}
}

}()

return responseChan, nil
Expand Down
40 changes: 31 additions & 9 deletions core/schemas/transcriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,37 @@ func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte {
}

type BifrostTranscriptionResponse struct {
Duration *float64 `json:"duration,omitempty"` // Duration in seconds
Language *string `json:"language,omitempty"` // e.g., "english"
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
Segments []TranscriptionSegment `json:"segments,omitempty"`
Task *string `json:"task,omitempty"` // e.g., "transcribe"
Text string `json:"text"`
Usage *TranscriptionUsage `json:"usage,omitempty"`
Words []TranscriptionWord `json:"words,omitempty"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
Duration *float64 `json:"duration,omitempty"` // Duration in seconds
Language *string `json:"language,omitempty"` // e.g., "english"
LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"`
Segments []TranscriptionSegment `json:"segments,omitempty"`
Task *string `json:"task,omitempty"` // e.g., "transcribe"
Text string `json:"text"`
Usage *TranscriptionUsage `json:"usage,omitempty"`
Words []TranscriptionWord `json:"words,omitempty"`
ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}

func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) {
if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil {
return
}
r.ResponseFormat = req.Params.ResponseFormat
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// IsPlainTextTranscriptionFormat returns true if the given response format
// produces a plain-text response body (not JSON).
func IsPlainTextTranscriptionFormat(format *string) bool {
if format == nil {
return false
}
switch *format {
case "text", "srt", "vtt":
return true
default:
return false
}
}

type TranscriptionInput struct {
Expand Down
10 changes: 6 additions & 4 deletions transports/bifrost-http/integrations/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore)
return resp, nil
},
TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) {
if schemas.IsPlainTextTranscriptionFormat(resp.ResponseFormat) {
return []byte(resp.Text), nil
}
if resp.ExtraFields.Provider == schemas.OpenAI {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
Expand Down Expand Up @@ -648,7 +651,6 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore)
return &schemas.BifrostRequest{
ResponsesRequest: openaiReq.ToBifrostResponsesRequest(ctx),
}, nil

}
return nil, errors.New("invalid request type")
},
Expand Down Expand Up @@ -854,6 +856,9 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore)
return nil, errors.New("invalid transcription request type")
},
TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) {
if schemas.IsPlainTextTranscriptionFormat(resp.ResponseFormat) {
return []byte(resp.Text), nil
}
if resp.ExtraFields.Provider == schemas.OpenAI {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
Expand Down Expand Up @@ -2406,7 +2411,6 @@ func extractContainerListQueryParams(_ lib.HandlerStore) PreRequestCallback {
// extractContainerIDFromPath extracts container_id from path parameters and provider from query params
func extractContainerIDFromPath(_ lib.HandlerStore) PreRequestCallback {
return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {

containerID := ctx.UserValue("container_id")
if containerID == nil {
return errors.New("container_id is required")
Expand Down Expand Up @@ -2655,7 +2659,6 @@ func extractContainerFileCreateParams(_ lib.HandlerStore) PreRequestCallback {
// extractContainerFileListQueryParams extracts query parameters for container file list requests
func extractContainerFileListQueryParams(_ lib.HandlerStore) PreRequestCallback {
return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {

containerID := ctx.UserValue("container_id")
if containerID == nil {
return errors.New("container_id is required")
Expand Down Expand Up @@ -2702,7 +2705,6 @@ func extractContainerFileListQueryParams(_ lib.HandlerStore) PreRequestCallback
// extractContainerAndFileIDFromPath extracts container_id and file_id from path parameters and provider from query params
func extractContainerAndFileIDFromPath(handlerStore lib.HandlerStore) PreRequestCallback {
return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {

containerID := ctx.UserValue("container_id")
if containerID == nil {
return errors.New("container_id is required")
Expand Down
14 changes: 13 additions & 1 deletion transports/bifrost-http/integrations/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,19 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf
// Convert Bifrost response to integration-specific format and send
response, err = config.TranscriptionResponseConverter(bifrostCtx, transcriptionResponse)
providerResponseHeaders = transcriptionResponse.ExtraFields.ProviderResponseHeaders

// If converter returns raw bytes, write directly with provider headers.
// Used for plain-text transcription formats (text, srt, vtt).
if err == nil {
if rawBytes, ok := response.([]byte); ok {
for key, value := range providerResponseHeaders {
ctx.Response.Header.Set(key, value)
}
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBody(rawBytes)
return
}
}
case bifrostReq.ImageGenerationRequest != nil:
imageGenerationResponse, bifrostErr := g.client.ImageGenerationRequest(bifrostCtx, bifrostReq.ImageGenerationRequest)
if bifrostErr != nil {
Expand Down Expand Up @@ -1709,7 +1722,6 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout

// handleFileRequest handles file API requests (upload, list, retrieve, delete, content)
func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, fileReq *FileRequest, bifrostCtx *schemas.BifrostContext) {

var response interface{}
var err error

Expand Down