diff --git a/core/bifrost.go b/core/bifrost.go index d089828efc..abaddfc46b 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -594,9 +594,10 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * Message: "prompt not provided for text completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -633,9 +634,10 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "text not provided for text completion stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -664,9 +666,10 @@ func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, r Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -729,9 +732,10 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -765,9 +769,10 @@ func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *s Message: "responses not provided for responses request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -833,9 +838,10 @@ func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req Message: "responses not provided for responses stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -868,9 +874,10 @@ func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *sch Message: "input not provided for count tokens request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.CountTokensRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.CountTokensRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -907,9 +914,10 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem Message: "embedding input not provided for embedding request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.EmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -946,9 +954,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "query not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -959,9 +968,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "documents not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -973,9 +983,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: fmt.Sprintf("document text is empty at index %d", i), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1011,9 +1022,10 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. Message: "speech input not provided for speech request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1050,9 +1062,10 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *sc Message: "speech input not provided for speech stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1084,9 +1097,10 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s Message: "transcription input not provided for transcription request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1123,9 +1137,10 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, Message: "transcription input not provided for transcription stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1158,9 +1173,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for image generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1180,9 +1196,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1211,9 +1228,10 @@ func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext Message: "prompt not provided for image generation stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1245,9 +1263,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "images not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1264,9 +1283,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "prompt not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1287,9 +1307,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1317,9 +1338,10 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "images not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1336,9 +1358,10 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "prompt not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1370,9 +1393,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "image not provided for image variation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1393,9 +1417,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1424,9 +1449,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for video generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1446,9 +1472,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -3203,9 +3230,10 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error Message: "request failed during provider concurrency update", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: m.RequestType, - Provider: provider, - ModelRequested: model, + RequestType: m.RequestType, + Provider: provider, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, }: case <-time.After(1 * time.Second): @@ -3845,7 +3873,16 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } } + wsProvider, wsModel, _ := preReq.GetRequestFields() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) @@ -4055,11 +4092,7 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) return nil, err } @@ -4092,16 +4125,6 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4144,29 +4167,10 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4181,11 +4185,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest) return nil, err } @@ -4207,16 +4207,6 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4257,29 +4247,10 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4291,11 +4262,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4306,7 +4273,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling requestHandler, so streaming goroutines @@ -4325,6 +4294,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4334,6 +4304,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4341,11 +4312,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4357,11 +4324,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4372,36 +4335,26 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4410,15 +4363,13 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4466,7 +4417,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-ctx.Done(): bifrost.releaseChannelMessage(msg) provider, model, _ := req.GetRequestFields() - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "waiting for provider response") + bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4477,11 +4430,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4492,7 +4441,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines @@ -4527,6 +4478,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4600,6 +4552,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4607,11 +4560,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4623,11 +4572,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4638,36 +4583,26 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4676,15 +4611,13 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4980,9 +4913,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -5009,9 +4943,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -5036,9 +4971,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -5053,18 +4989,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } } + + originalModelRequested := model + resolvedModel := key.Aliases.Resolve(model) + + // Note: This mutates only the worker's local copy (ChannelMessage.BifrostRequest). + // Key selection already used the original alias. We also record both original and + // resolved values in ExtraFields. + req.SetModel(resolvedModel) + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks var postHookRunner schemas.PostHookRunner var pipeline *PluginPipeline if IsStreamRequestType(req.RequestType) { pipeline = bifrost.getPluginPipeline() postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) } if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) } return resp, nil } @@ -5097,14 +5053,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } if bifrostError != nil { - bifrostError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, - RawRequest: bifrostError.ExtraFields.RawRequest, - RawResponse: bifrostError.ExtraFields.RawResponse, - KeyStatuses: bifrostError.ExtraFields.KeyStatuses, - } + bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) // Send error with context awareness to prevent deadlock select { @@ -5118,6 +5067,9 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { @@ -5402,9 +5354,10 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5438,9 +5391,10 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -6277,62 +6231,53 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest if skipModelCheck { // When skipping model check: just verify keys are enabled and have values - for _, k := range keys { + for _, key := range keys { // Skip disabled keys - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + isKeyValid := validateKey(providerKey, &key) + if !isKeyValid { + bifrost.logger.Warn("key %s is not valid for provider: %s", key.ID, providerKey) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } } else { - // When NOT skipping model check: do full model/deployment filtering + // When NOT skipping model check: do full model filtering for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } + isKeyValid := validateKey(providerKey, &key) + if !isKeyValid { + bifrost.logger.Warn("key %s is not valid for provider: %s", key.ID, providerKey) + continue + } hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) // ["*"] = allow all models; [] = deny all; specific list = allow only listed + // NOTE: Model filtering uses the original requested model (which may be an alias). + // key.Models and key.BlacklistedModels must therefore be expressed in alias keys. + // The provider-specific identifier is resolved later in requestWorker via key.Aliases.Resolve(model). modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) - // Additional deployment checks for Azure, Bedrock and Vertex - deploymentSupported := true - if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { - // For Azure, check if deployment exists for this model - if len(key.AzureKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.AzureKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { - // For Bedrock, check if deployment exists for this model - if len(key.BedrockKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Vertex && key.VertexKeyConfig != nil { - // For Vertex, check if deployment exists for this model - if len(key.VertexKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.VertexKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Replicate && key.ReplicateKeyConfig != nil { - // For Replicate, check if deployment exists for this model - if len(key.ReplicateKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.ReplicateKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { + if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { // For VLLM, check if model name matches the key's configured model if key.VLLMKeyConfig.ModelName != "" { - deploymentSupported = (key.VLLMKeyConfig.ModelName == model) + modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model) } } - if modelSupported && deploymentSupported { + if modelSupported { supportedKeys = append(supportedKeys, key) } } } if len(supportedKeys) == 0 { if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock || baseProviderType == schemas.Vertex || baseProviderType == schemas.Replicate || baseProviderType == schemas.VLLM { - return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) + return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } diff --git a/core/internal/llmtests/account.go b/core/internal/llmtests/account.go index 7d014cb730..c136df95d3 100644 --- a/core/internal/llmtests/account.go +++ b/core/internal/llmtests/account.go @@ -207,36 +207,36 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, { Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_ARN"), - Deployments: map[string]string{ - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, }, { Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_BEDROCK_ARN"), - Deployments: map[string]string{ - "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -266,18 +266,18 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o", + "gpt-4o-backup": "gpt-4o-3", + "claude-opus-4-5": "claude-opus-4-5", + "o1": "o1", + "gpt-image-1": "gpt-image-1", + "text-embedding-ada-002": "text-embedding-ada-002", + "sora-2": "sora-2", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), - APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o", - "gpt-4o-backup": "gpt-4o-3", - "claude-opus-4-5": "claude-opus-4-5", - "o1": "o1", - "gpt-image-1": "gpt-image-1", - "text-embedding-ada-002": "text-embedding-ada-002", - "sora-2": "sora-2", - }, + Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), + APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), ClientID: schemas.NewEnvVar("env.AZURE_CLIENT_ID"), ClientSecret: schemas.NewEnvVar("env.AZURE_CLIENT_SECRET"), TenantID: schemas.NewEnvVar("env.AZURE_TENANT_ID"), @@ -288,14 +288,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "whisper": "whisper", + "whisper-1": "whisper", + "gpt-4o-mini-tts": "gpt-4o-mini-tts", + "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "whisper": "whisper", - "gpt-4o-mini-tts": "gpt-4o-mini-tts", - "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", - }, }, }, }, nil @@ -329,15 +330,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), Models: []string{"claude-sonnet-4-5", "claude-4.5-haiku", "claude-opus-4-5"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-4.5-haiku": "claude-haiku-4-5@20251001", + "claude-opus-4-5": "claude-opus-4-5", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION_ANTHROPIC", "us-east5")), AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), - Deployments: map[string]string{ - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-4.5-haiku": "claude-haiku-4-5@20251001", - "claude-opus-4-5": "claude-opus-4-5", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, diff --git a/core/internal/llmtests/image_edit.go b/core/internal/llmtests/image_edit.go index 56ad66d502..deed0bd820 100644 --- a/core/internal/llmtests/image_edit.go +++ b/core/internal/llmtests/image_edit.go @@ -364,8 +364,8 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context t.Error("❌ ExtraFields.Provider is empty") } - if imageEditResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageEditResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageEditRequest @@ -374,7 +374,7 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.ModelRequested, len(imageEditResponse.Data)) + imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data)) }) } diff --git a/core/internal/llmtests/image_generation.go b/core/internal/llmtests/image_generation.go index 81a0626978..1516ff0088 100644 --- a/core/internal/llmtests/image_generation.go +++ b/core/internal/llmtests/image_generation.go @@ -145,12 +145,12 @@ func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Error("❌ ExtraFields.Provider is empty") } - if imageGenerationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.ModelRequested, len(imageGenerationResponse.Data)) + imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data)) }) } diff --git a/core/internal/llmtests/image_variation.go b/core/internal/llmtests/image_variation.go index 0aca33a63f..d0c4d18e78 100644 --- a/core/internal/llmtests/image_variation.go +++ b/core/internal/llmtests/image_variation.go @@ -162,8 +162,8 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co t.Error("❌ ExtraFields.Provider is empty") } - if imageVariationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageVariationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageVariationRequest @@ -172,7 +172,7 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Logf("✅ Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.ModelRequested, len(imageVariationResponse.Data)) + imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data)) }) } diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index 628cfcc783..367b1eb02a 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -847,7 +847,7 @@ func validateResponsesBasicStructure(response *schemas.BifrostResponsesResponse, } provider := response.ExtraFields.Provider - model := response.ExtraFields.ModelDeployment + model := response.ExtraFields.ResolvedModelUsed // Verify top level status is present for OpenAI and Azure with non-Claude models if provider != "" && (provider == schemas.OpenAI || provider == schemas.Azure) && !strings.Contains(strings.ToLower(model), "claude") { @@ -976,8 +976,7 @@ func validateResponsesTechnicalFields(t *testing.T, response *schemas.BifrostRes // Check model field if expectations.ShouldHaveModel { - if strings.TrimSpace(response.Model) == "" && - strings.TrimSpace(response.ExtraFields.ModelDeployment) == "" { + if strings.TrimSpace(response.Model) == "" { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Expected model field but not present or empty (provider: %s)", response.ExtraFields.Provider)) } diff --git a/core/internal/llmtests/speech_synthesis.go b/core/internal/llmtests/speech_synthesis.go index 4e08d6e2c8..aae66423a3 100644 --- a/core/internal/llmtests/speech_synthesis.go +++ b/core/internal/llmtests/speech_synthesis.go @@ -239,8 +239,8 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize) } - if speechResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.ModelRequested) + if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested) } t.Logf("✅ HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio)) @@ -344,8 +344,8 @@ func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpee t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes) } - if expectedModel != "" && response.ExtraFields.ModelRequested != expectedModel { - t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.ModelRequested) + if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel { + t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested) } t.Logf("✅ Audio validation passed: %d bytes generated", audioSize) diff --git a/core/internal/llmtests/speech_synthesis_stream.go b/core/internal/llmtests/speech_synthesis_stream.go index 87268f3c17..8b7bdc8efb 100644 --- a/core/internal/llmtests/speech_synthesis_stream.go +++ b/core/internal/llmtests/speech_synthesis_stream.go @@ -184,8 +184,8 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) } - if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } @@ -348,8 +348,8 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Logf("✅ HD chunk %d: %d bytes", chunkCount, chunkSize) } - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } diff --git a/core/internal/llmtests/transcription_stream.go b/core/internal/llmtests/transcription_stream.go index dfc80fc533..a28239c00f 100644 --- a/core/internal/llmtests/transcription_stream.go +++ b/core/internal/llmtests/transcription_stream.go @@ -242,8 +242,12 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) } - if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) + gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested + if gotModel == "" { + t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested") + } + if gotModel != testConfig.TranscriptionModel { + t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel) } lastResponse = DeepCopyBifrostStreamChunk(response) diff --git a/core/internal/llmtests/video.go b/core/internal/llmtests/video.go index c622edf6b4..8ac2d6e396 100644 --- a/core/internal/llmtests/video.go +++ b/core/internal/llmtests/video.go @@ -48,8 +48,8 @@ func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C if resp.ExtraFields.Provider == "" { t.Fatal("❌ Video generation extra_fields.provider is empty") } - if resp.ExtraFields.ModelRequested == "" { - t.Fatal("❌ Video generation extra_fields.model_requested is empty") + if resp.ExtraFields.OriginalModelRequested == "" { + t.Fatal("❌ Video generation extra_fields.original_model_requested is empty") } t.Logf("✅ Video generation created job: id=%s status=%s", resp.ID, resp.Status) diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 7f683929ef..ca683f26e0 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -173,7 +173,7 @@ func extractAnthropicResponsesUsageFromPrefetch(data []byte) *schemas.ResponsesR // Returns the response body or an error if the request fails. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, latency, nil) — callers must check the context flag. -func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, requestType schemas.RequestType) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -208,7 +208,7 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, requestClient := provider.client responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) - isCountTokens := meta != nil && meta.RequestType == schemas.CountTokensRequest + isCountTokens := requestType == schemas.CountTokensRequest // CountTokens responses are always tiny — skip streaming client so the response // is buffered normally (same approach as OpenAI and Gemini count_tokens handlers). if responseThreshold > 0 && !isCountTokens { @@ -233,20 +233,20 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) - return nil, latency, providerResponseHeaders, parseAnthropicError(resp, meta) + return nil, latency, providerResponseHeaders, parseAnthropicError(resp) } // CountTokens uses buffered response (streaming skipped above) — decode directly. if isCountTokens { body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return body, latency, providerResponseHeaders, nil } // Delegate large response detection + normal buffered path to shared utility - body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { return nil, latency, providerResponseHeaders, respErr } @@ -290,10 +290,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseAnthropicError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseAnthropicError(resp) } // Parse Anthropic's response @@ -304,7 +301,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -355,18 +352,13 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToAnthropicTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use struct directly for JSON marshaling (no beta headers for text completion) - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TextCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), schemas.TextCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -379,9 +371,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -400,9 +389,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostTextCompletionResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -444,18 +430,13 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k } AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use struct directly for JSON marshaling - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), schemas.ChatCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -468,9 +449,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -489,9 +467,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostChatResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -528,8 +503,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont anthropicReq.Stream = schemas.Ptr(true) AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -563,11 +537,6 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } @@ -587,7 +556,6 @@ func HandleAnthropicChatCompletionStreaming( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -634,9 +602,9 @@ func HandleAnthropicChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -645,7 +613,7 @@ func HandleAnthropicChatCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -661,14 +629,10 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { defer func() { - model := "unknown" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -678,7 +642,6 @@ func HandleAnthropicChatCompletionStreaming( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -732,7 +695,7 @@ func HandleAnthropicChatCompletionStreaming( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -791,7 +754,6 @@ func HandleAnthropicChatCompletionStreaming( } } if event.Message != nil { - // Handle different event types modelName = event.Message.Model } @@ -840,11 +802,8 @@ func HandleAnthropicChatCompletionStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -868,22 +827,14 @@ func HandleAnthropicChatCompletionStreaming( response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream(ctx, structuredOutputToolName, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -910,7 +861,7 @@ func HandleAnthropicChatCompletionStreaming( usage.PromptTokens = usage.PromptTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens usage.TotalTokens = usage.TotalTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName, 0) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, modelName, 0) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -939,16 +890,12 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, nil) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), schemas.ResponsesRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -966,9 +913,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc Model: request.Model, Usage: extractAnthropicResponsesUsageFromPrefetch([]byte(preview)), ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -988,9 +932,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc bifrostResponse := response.ToBifrostResponsesResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1014,7 +955,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, } // Convert to Anthropic format using the centralized converter - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), true, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, true, nil) if err != nil { return nil, err } @@ -1047,11 +988,6 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } @@ -1071,7 +1007,6 @@ func HandleAnthropicResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1120,9 +1055,9 @@ func HandleAnthropicResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1131,7 +1066,7 @@ func HandleAnthropicResponsesStream( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1147,14 +1082,10 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { defer func() { - model := "" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1164,7 +1095,6 @@ func HandleAnthropicResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1216,7 +1146,7 @@ func HandleAnthropicResponsesStream( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1286,11 +1216,6 @@ func HandleAnthropicResponsesStream( ctx.SetValue(schemas.BifrostContextKeyHasEmittedMessageDelta, true) } if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { return @@ -1307,12 +1232,9 @@ func HandleAnthropicResponsesStream( Type: schemas.ResponsesStreamResponseType(eventType), SequenceNumber: chunkIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - RawResponse: eventData, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + RawResponse: eventData, }, } lastChunkTime = time.Now() @@ -1326,11 +1248,8 @@ func HandleAnthropicResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -1384,7 +1303,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if len(request.Requests) == 0 { - return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil) } // Create request @@ -1422,7 +1341,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key jsonData, err := providerUtils.MarshalSorted(anthropicReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } usedLargePayloadBody := setAnthropicRequestBody(ctx, req, jsonData) @@ -1442,12 +1361,12 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var anthropicResp AnthropicBatchResponse @@ -1456,7 +1375,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return anthropicResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -1472,7 +1391,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Initialize serial pagination helper (Anthropic uses AfterID for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.AfterID, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -1483,10 +1402,6 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -1535,12 +1450,12 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicBatchListResponse @@ -1553,7 +1468,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(anthropicResp.Data)) var lastBatchID string for _, batch := range anthropicResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, false, false, nil, nil)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, false, false, nil, nil)) lastBatchID = batch.ID } @@ -1567,9 +1482,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -1587,7 +1500,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1628,7 +1541,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1640,7 +1553,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1658,8 +1571,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := anthropicResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := anthropicResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -1674,7 +1586,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1711,7 +1623,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1723,7 +1635,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1746,9 +1658,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys Object: anthropicResp.Type, Status: ToBifrostBatchStatus(anthropicResp.ProcessingStatus), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1791,7 +1701,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1825,7 +1735,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1837,7 +1747,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1879,9 +1789,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1959,7 +1867,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -1973,14 +1881,14 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2012,12 +1920,12 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileUploadRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var anthropicResp AnthropicFileResponse @@ -2028,7 +1936,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - return anthropicResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided keys and aggregates results. @@ -2046,7 +1954,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2057,10 +1965,6 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2106,12 +2010,12 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicFileListResponse @@ -2146,9 +2050,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2167,7 +2069,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2208,7 +2110,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2220,7 +2122,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2238,7 +2140,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - return anthropicResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } return nil, lastErr @@ -2253,7 +2155,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2290,7 +2192,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2307,9 +2209,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2319,7 +2219,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2342,9 +2242,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: anthropicResp.Type == "file_deleted", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2372,7 +2270,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -2404,7 +2302,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileContentRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2416,7 +2314,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2436,9 +2334,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2451,16 +2347,12 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, []string{"max_tokens", "temperature"}) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, []string{"max_tokens", "temperature"}) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }) + responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), schemas.CountTokensRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2484,9 +2376,6 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) response.Model = request.Model - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2621,7 +2510,7 @@ func (provider *AnthropicProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { @@ -2636,9 +2525,6 @@ func (provider *AnthropicProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2702,9 +2588,9 @@ func (provider *AnthropicProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2715,7 +2601,6 @@ func (provider *AnthropicProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -2727,11 +2612,7 @@ func (provider *AnthropicProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2742,9 +2623,9 @@ func (provider *AnthropicProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2793,7 +2674,7 @@ func (provider *AnthropicProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/anthropic/batch.go b/core/providers/anthropic/batch.go index 405738330c..ac4b0940c4 100644 --- a/core/providers/anthropic/batch.go +++ b/core/providers/anthropic/batch.go @@ -3,9 +3,7 @@ package anthropic import ( "time" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) // Anthropic Batch API Types @@ -129,7 +127,7 @@ func ToBifrostObjectType(anthropicType string) string { } // ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response. -func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { expiresAt := parseAnthropicTimestamp(r.ExpiresAt) resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, @@ -140,9 +138,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExpiresAt: &expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -170,7 +166,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem } // ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response. -func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: ToBifrostObjectType(r.Type), @@ -179,9 +175,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch ResultsURL: r.ResultsURL, CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -228,26 +222,6 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch return resp } -// ParseAnthropicError parses Anthropic error responses for batch operations. -func ParseAnthropicError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { - var errorResp AnthropicError - bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) - if errorResp.Error != nil { - if errorResp.Error.Type != "" { - bifrostErr.Error.Type = &errorResp.Error.Type - } - if errorResp.Error.Message != "" { - bifrostErr.Error.Message = errorResp.Error.Message - } - } - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - } - return bifrostErr -} - // ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format. func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse { result := &AnthropicBatchResponse{ diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index 14f16c5004..38d986f330 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -418,12 +418,8 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // Initialize Bifrost response bifrostResponse := &schemas.BifrostChatResponse{ - ID: response.ID, - Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Anthropic, - }, + ID: response.ID, + Model: response.Model, Created: int(time.Now().Unix()), } diff --git a/core/providers/anthropic/errors.go b/core/providers/anthropic/errors.go index dd1dfaf698..81bbd49d0c 100644 --- a/core/providers/anthropic/errors.go +++ b/core/providers/anthropic/errors.go @@ -54,7 +54,7 @@ func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) } -func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp AnthropicError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Error != nil { @@ -64,10 +64,5 @@ func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMet bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index bed34e5eee..3815a0244b 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -1,12 +1,14 @@ package anthropic import ( + "strings" "time" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -18,61 +20,51 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide HasMore: schemas.Ptr(response.HasMore), } - // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination - // If there are more results, set next_page_token to last_id so it can be used in the next request + // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination. + // If there are more results, set next_page_token to last_id for the next request. if response.HasMore && response.LastID != nil { bifrostResponse.NextPageToken = *response.LastID } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) + included := make(map[string]bool) + for _, model := range response.Data { - modelID := model.ID - if !unfiltered && allowedModels.IsRestricted() { - allowed := false - for _, allowedModel := range allowedModels { - if schemas.SameBaseModel(model.ID, allowedModel) { - modelID = allowedModel - allowed = true - break - } - } - if !allowed { + for _, result := range pipeline.FilterModel(model.ID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { continue } - } - if !unfiltered && blacklistedModels.IsBlocked(modelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Created: schemas.Ptr(model.CreatedAt.Unix()), - MaxInputTokens: model.MaxInputTokens, - MaxOutputTokens: model.MaxTokens, - ProviderExtra: model.Capabilities, - }) - includedModels[modelID] = true - } - - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Created: schemas.Ptr(model.CreatedAt.Unix()), + MaxInputTokens: model.MaxInputTokens, + MaxOutputTokens: model.MaxTokens, + ProviderExtra: model.Capabilities, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[resolvedKey] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 57794fe784..bd8e494063 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1429,13 +1429,13 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp if bifrostResp.Response.ID != nil { streamMessage.ID = *bifrostResp.Response.ID } - // Preserve model from Response if available, otherwise use ExtraFields - if bifrostResp.ExtraFields.ModelRequested != "" { - if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { - streamMessage.Model = bifrostResp.Response.Model - } else { - streamMessage.Model = bifrostResp.ExtraFields.ModelRequested - } + // Prefer Response.Model, then ResolvedModelUsed, then OriginalModelRequested + if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { + streamMessage.Model = bifrostResp.Response.Model + } else if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + streamMessage.Model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if bifrostResp.ExtraFields.OriginalModelRequested != "" { + streamMessage.Model = bifrostResp.ExtraFields.OriginalModelRequested } streamResp.Message = streamMessage } diff --git a/core/providers/anthropic/text.go b/core/providers/anthropic/text.go index 3228ad49f6..39a700499b 100644 --- a/core/providers/anthropic/text.go +++ b/core/providers/anthropic/text.go @@ -103,10 +103,6 @@ func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schema TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, }, Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Anthropic, - }, } } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index 18c4b646aa..36dfe7485b 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -108,7 +108,7 @@ var ProviderFeatures = map[schemas.ModelProvider]ProviderFeatureSupport{ RedactThinking: true, }, schemas.Vertex: { - WebSearch: true, // only web_search_20250305 (basic), NOT dynamic filtering + WebSearch: true, // only web_search_20250305 (basic), NOT dynamic filtering ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, Compaction: true, ContextEditing: true, InterleavedThinking: true, Context1M: true, @@ -160,7 +160,7 @@ func (req *AnthropicTextRequest) IsStreamingRequested() bool { // and the effort parameter (output_config.effort) for controlling token spending. type AnthropicOutputConfig struct { Format json.RawMessage `json:"format,omitempty"` - Effort *string `json:"effort,omitempty"` // "low", "medium", "high", "max" (Opus 4.5+) + Effort *string `json:"effort,omitempty"` // "low", "medium", "high", "max" (Opus 4.5+) } // AnthropicMessageRequest represents an Anthropic messages API request @@ -865,24 +865,24 @@ const ( AnthropicToolTypeMemory20250818 AnthropicToolType = "memory_20250818" // Tool search (client-side, for defer_loading) - AnthropicToolTypeToolSearchBM25 AnthropicToolType = "tool_search_tool_bm25" - AnthropicToolTypeToolSearchBM2520251119 AnthropicToolType = "tool_search_tool_bm25_20251119" - AnthropicToolTypeToolSearchRegex AnthropicToolType = "tool_search_tool_regex" - AnthropicToolTypeToolSearchRegex20251119 AnthropicToolType = "tool_search_tool_regex_20251119" + AnthropicToolTypeToolSearchBM25 AnthropicToolType = "tool_search_tool_bm25" + AnthropicToolTypeToolSearchBM2520251119 AnthropicToolType = "tool_search_tool_bm25_20251119" + AnthropicToolTypeToolSearchRegex AnthropicToolType = "tool_search_tool_regex" + AnthropicToolTypeToolSearchRegex20251119 AnthropicToolType = "tool_search_tool_regex_20251119" ) type AnthropicToolName string const ( - AnthropicToolNameComputer AnthropicToolName = "computer" - AnthropicToolNameWebSearch AnthropicToolName = "web_search" - AnthropicToolNameWebFetch AnthropicToolName = "web_fetch" - AnthropicToolNameBash AnthropicToolName = "bash" - AnthropicToolNameTextEditor AnthropicToolName = "str_replace_based_edit_tool" - AnthropicToolNameCodeExecution AnthropicToolName = "code_execution" - AnthropicToolNameMemory AnthropicToolName = "memory" - AnthropicToolNameToolSearchBM25 AnthropicToolName = "tool_search_tool_bm25" - AnthropicToolNameToolSearchRegex AnthropicToolName = "tool_search_tool_regex" + AnthropicToolNameComputer AnthropicToolName = "computer" + AnthropicToolNameWebSearch AnthropicToolName = "web_search" + AnthropicToolNameWebFetch AnthropicToolName = "web_fetch" + AnthropicToolNameBash AnthropicToolName = "bash" + AnthropicToolNameTextEditor AnthropicToolName = "str_replace_based_edit_tool" + AnthropicToolNameCodeExecution AnthropicToolName = "code_execution" + AnthropicToolNameMemory AnthropicToolName = "memory" + AnthropicToolNameToolSearchBM25 AnthropicToolName = "tool_search_tool_bm25" + AnthropicToolNameToolSearchRegex AnthropicToolName = "tool_search_tool_regex" ) type AnthropicToolComputerUse struct { @@ -917,7 +917,7 @@ type AnthropicToolWebFetch struct { // AnthropicToolInputExample represents an input example for a tool (beta feature) type AnthropicToolInputExample struct { Input json.RawMessage `json:"input"` - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` } // AnthropicTool represents a tool in Anthropic format @@ -1248,7 +1248,7 @@ type AnthropicFileDeleteResponse struct { } // ToBifrostFileUploadResponse converts an Anthropic file response to Bifrost file upload response. -func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Type, @@ -1259,9 +1259,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1277,7 +1275,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas } // ToBifrostFileRetrieveResponse converts an Anthropic file response to Bifrost file retrieve response. -func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { +func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { resp := &schemas.BifrostFileRetrieveResponse{ ID: r.ID, Object: r.Type, @@ -1288,9 +1286,7 @@ func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index a0ff905568..c04453644e 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -136,7 +136,7 @@ func setEffortOnOutputConfig(req *AnthropicMessageRequest, effort string) { req.OutputConfig.Effort = &effort } -func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader in completeRequest/ // setAnthropicRequestBody — skip all body building here (matches CheckContextAndGetRequestBody). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -156,7 +156,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi _, model := schemas.ParseModelString(modelStr, schemas.Anthropic) jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", model) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -168,36 +168,36 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi } jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", defaultMaxTokens) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Strip auto-injectable server-side tools to prevent conflicts with API auto-injection jsonBody, err = StripAutoInjectableTools(jsonBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remove excluded fields for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Convert request to Anthropic format reqBody, convErr := ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Anthropic) if isStreaming { @@ -206,7 +206,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -215,14 +215,14 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Use MergeExtraParamsIntoJSON which preserves key order jsonBody, err = providerUtils.MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Remove excluded fields after merging (using sjson to preserve order) for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else if len(excludeFields) > 0 { @@ -230,7 +230,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index ec98dcd88d..16d0a3c301 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -100,7 +100,7 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -109,11 +109,11 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, Scopes: scopes, }) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -138,16 +138,16 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -206,10 +206,8 @@ func (provider *AzureProvider) completeRequest( jsonData []byte, path string, key schemas.Key, - deployment string, model string, - requestType schemas.RequestType, -) ([]byte, string, time.Duration, map[string]string, *schemas.BifrostError) { +) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -222,7 +220,7 @@ func (provider *AzureProvider) completeRequest( }() var url string - isAnthropicModel := schemas.IsAnthropicModel(deployment) + isAnthropicModel := schemas.IsAnthropicModel(model) // Set any extra headers from network config. // For Anthropic models, exclude anthropic-beta — it is merged and filtered explicitly below. @@ -237,7 +235,7 @@ func (provider *AzureProvider) completeRequest( // Get authentication headers authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel) if bifrostErr != nil { - return nil, deployment, 0, nil, bifrostErr + return nil, 0, nil, bifrostErr } // Apply headers to request @@ -247,7 +245,7 @@ func (provider *AzureProvider) completeRequest( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, deployment, 0, nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, 0, nil, providerUtils.NewConfigurationError("endpoint not set") } if isAnthropicModel { @@ -282,7 +280,7 @@ func (provider *AzureProvider) completeRequest( latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp) defer wait() if bifrostErr != nil { - return nil, deployment, latency, nil, bifrostErr + return nil, latency, nil, bifrostErr } // Extract provider response headers before body is copied — do this before status check @@ -292,19 +290,20 @@ func (provider *AzureProvider) completeRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, deployment, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, requestType, provider.GetProviderKey(), model) + rawErrBody := append([]byte(nil), resp.Body()...) + return rawErrBody, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { - return nil, deployment, latency, providerResponseHeaders, decodeErr + return nil, latency, providerResponseHeaders, decodeErr } if isLargeResp { respOwned = false - return nil, deployment, latency, providerResponseHeaders, nil + return nil, latency, providerResponseHeaders, nil } - return body, deployment, latency, providerResponseHeaders, nil + return body, latency, providerResponseHeaders, nil } // listModelsByKey performs a list models request for a single key. @@ -312,11 +311,11 @@ func (provider *AzureProvider) completeRequest( func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { // Validate Azure key configuration if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", schemas.Azure) + return nil, providerUtils.NewConfigurationError("azure key config not set") } if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", schemas.Azure) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Get API version @@ -359,12 +358,12 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -379,9 +378,9 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.AzureKeyConfig.Deployments, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil) } response.ExtraFields.Latency = latency.Milliseconds() @@ -415,35 +414,23 @@ func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []sc // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized OpenAI text converter (Azure is OpenAI-compatible) jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAITextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/completions", deployment), + fmt.Sprintf("openai/deployments/%s/completions", request.Model), key, - deployment, request.Model, - schemas.TextCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -457,10 +444,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -474,10 +457,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -498,21 +477,12 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -520,11 +490,6 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, return nil, err } - customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -538,7 +503,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, nil, postHookRunner, nil, - customPostResponseConverter, + nil, provider.logger, ) } @@ -547,26 +512,16 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request) if err != nil { return nil, err } if reqBody != nil { - reqBody.Model = deployment // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } @@ -574,27 +529,24 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } else { return openai.ToOpenAIChatRequest(ctx, request), nil } - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { - path = fmt.Sprintf("openai/deployments/%s/chat/completions", deployment) + path = fmt.Sprintf("openai/deployments/%s/chat/completions", request.Model) } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ChatCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -608,10 +560,6 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -622,7 +570,7 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -637,12 +585,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ChatCompletionRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -662,22 +606,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -694,14 +624,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, return nil, err } if reqBody != nil { - reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -719,13 +647,8 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -736,7 +659,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Use shared streaming logic from OpenAI return openai.HandleOpenAIChatCompletionStreaming( @@ -754,7 +677,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, nil, - postResponseConverter, + nil, provider.logger, ) } @@ -764,51 +687,36 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - var jsonData []byte var bifrostErr *schemas.BifrostError - if schemas.IsAnthropicModel(deployment) { - jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), false) + if schemas.IsAnthropicModel(request.Model) { + jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false) } else { jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { reqBody := openai.ToOpenAIResponsesRequest(request) - if reqBody != nil { - reqBody.Model = deployment - } return reqBody, nil - }, - provider.GetProviderKey()) + }) } if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { path = "openai/v1/responses" } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ResponsesRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -822,10 +730,6 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -836,7 +740,7 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -851,12 +755,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ResponsesRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -873,22 +773,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema // ResponsesStream performs a streaming responses request to Azure's API. func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -896,7 +782,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue()) - jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), true) + jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true) if bifrostErr != nil { return nil, bifrostErr } @@ -914,13 +800,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -929,11 +810,6 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post } url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint.GetValue()) - postRequestConverter := func(req *openai.OpenAIResponsesRequest) *openai.OpenAIResponsesRequest { - req.Model = deployment - return req - } - // Use shared streaming logic from OpenAI return openai.HandleOpenAIResponsesStreaming( ctx, @@ -948,8 +824,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -959,35 +835,23 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAIEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/embeddings", deployment), + fmt.Sprintf("openai/deployments/%s/embeddings", request.Model), key, - deployment, request.Model, - schemas.EmbeddingRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1001,10 +865,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1019,12 +879,8 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.EmbeddingRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1041,15 +897,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema // Speech is not supported by the Azure provider. func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1057,10 +904,10 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAISpeechRequest( ctx, @@ -1080,9 +927,6 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1094,15 +938,6 @@ func (provider *AzureProvider) Rerank(ctx *schemas.BifrostContext, key schemas.K // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1113,7 +948,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1153,11 +988,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo reqBody := openai.ToOpenAISpeechRequest(request) if reqBody != nil { reqBody.StreamFormat = schemas.Ptr("sse") - reqBody.Model = deployment // Replace model with deployment } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1181,9 +1014,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(requestErr, fasthttp.ErrTimeout) || errors.Is(requestErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1192,7 +1025,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Create response channel @@ -1204,9 +1037,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1307,11 +1140,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo var bifrostErr schemas.BifrostError if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -1333,12 +1161,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Set extra fields for the response response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -1367,7 +1191,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // a fake "done" response with truncated audio. ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1380,12 +1204,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo finalResponse := schemas.BifrostSpeechStreamResponse{ Type: schemas.SpeechStreamResponseTypeDone, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1408,21 +1228,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Transcription is not supported by the Azure provider. func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAITranscriptionRequest( ctx, @@ -1441,9 +1252,6 @@ func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key sc return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1457,16 +1265,6 @@ func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1474,13 +1272,13 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } response, err := openai.HandleOpenAIImageGenerationRequest( ctx, provider.client, - fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()), + fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()), request, key, provider.networkConfig.ExtraHeaders, @@ -1493,9 +1291,6 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1508,18 +1303,6 @@ func (provider *AzureProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - // - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1527,17 +1310,10 @@ func (provider *AzureProvider) ImageGenerationStream( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1558,7 +1334,7 @@ func (provider *AzureProvider) ImageGenerationStream( postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1566,16 +1342,6 @@ func (provider *AzureProvider) ImageGenerationStream( // ImageEdit performs an image edit request to Azure's API. func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1583,10 +1349,10 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAIImageEditRequest( ctx, provider.client, @@ -1603,24 +1369,11 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } // ImageEditStream performs a streaming image edit request to Azure's API. func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1628,17 +1381,10 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1659,7 +1405,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1673,30 +1419,19 @@ func (provider *AzureProvider) ImageVariation(ctx *schemas.BifrostContext, key s // VideoGeneration creates a video using Azure's OpenAI-compatible Sora API. // This delegates to the OpenAI handler with Azure-specific URL and authentication. func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, bifrostErr := provider.getModelDeployment(key, request.Model) - if bifrostErr != nil { - return nil, bifrostErr - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL for OpenAI-compatible video generation endpoint url := fmt.Sprintf("%s/openai/v1/videos", endpoint) - requestCopy := *request - requestCopy.Model = deployment response, bifrostErr := openai.HandleOpenAIVideoGenerationRequest( ctx, provider.client, url, - &requestCopy, + request, key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), @@ -1708,27 +1443,20 @@ func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, nil } // VideoRetrieve retrieves the status of a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false) @@ -1754,20 +1482,16 @@ func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key sc // VideoDownload downloads video content from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Create request @@ -1803,13 +1527,12 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -1825,9 +1548,7 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc Content: append([]byte(nil), body...), ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1836,20 +1557,16 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // VideoDelete deletes a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1876,13 +1593,9 @@ func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key sche // VideoList lists videos from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1912,64 +1625,14 @@ func (provider *AzureProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey()) } -// validateKeyConfig validates the key configuration. -// It checks if the key config is set, the endpoint is set, and the deployments are set. -// Returns an error if any of the checks fail. -func (provider *AzureProvider) validateKeyConfig(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments == nil { - return providerUtils.NewConfigurationError("deployments not set", provider.GetProviderKey()) - } - - return nil -} - -// validateKeyConfigForFiles validates key config for file/batch APIs, which only -// require a configured Azure endpoint (no per-model deployments needed). -func (provider *AzureProvider) validateKeyConfigForFiles(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - return nil -} - -func (provider *AzureProvider) getModelDeployment(key schemas.Key, model string) (string, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return "", providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments != nil { - if deployment, ok := key.AzureKeyConfig.Deployments[model]; ok { - return deployment, nil - } - } - return "", providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", model), provider.GetProviderKey()) -} - // FileUpload uploads a file to Azure OpenAI. func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Get API version @@ -1984,7 +1647,7 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add file field @@ -1994,14 +1657,14 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2038,13 +1701,12 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp openai.OpenAIFileResponse @@ -2055,17 +1717,15 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem return nil, bifrostErr } - return openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided Azure keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2074,7 +1734,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2085,18 +1745,9 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2144,13 +1795,12 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIFileListResponse @@ -2185,9 +1835,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2202,7 +1850,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2210,11 +1858,6 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2257,8 +1900,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2270,7 +1912,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2296,14 +1938,12 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // FileDelete deletes a file from Azure OpenAI by trying each key until successful. func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2311,11 +1951,6 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2358,8 +1993,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2375,9 +2009,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2387,7 +2019,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2410,9 +2042,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2432,24 +2062,17 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // FileContent downloads file content from Azure OpenAI by trying each key until found. func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation") } var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2491,8 +2114,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2504,7 +2126,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2524,9 +2146,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2537,12 +2157,6 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // BatchCreate creates a new batch job on Azure OpenAI. // Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns. func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -2550,12 +2164,11 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Convert inline requests to JSONL format jsonlData, err := openai.ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{ - Provider: schemas.Azure, File: jsonlData, Filename: "batch_requests.jsonl", Purpose: "batch", @@ -2569,7 +2182,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil) } // Get API version @@ -2616,7 +2229,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -2629,13 +2242,12 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } var openAIResp openai.OpenAIBatchResponse @@ -2646,25 +2258,24 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs from all provided Azure keys and aggregates results. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation") } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2675,18 +2286,9 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2732,13 +2334,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIBatchListResponse @@ -2751,7 +2352,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -2764,9 +2365,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2778,14 +2377,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found. func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2793,11 +2390,6 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2840,8 +2432,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2853,7 +2444,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2871,8 +2462,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -2881,14 +2471,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful. func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2896,11 +2484,6 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2943,8 +2526,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2956,7 +2538,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2981,9 +2563,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3017,8 +2597,6 @@ func (provider *AzureProvider) BatchDelete(ctx *schemas.BifrostContext, keys []s // BatchResults retrieves batch results from Azure OpenAI by trying each key until successful. // For Azure (like OpenAI), batch results are obtained by downloading the output_file_id. func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3029,7 +2607,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file content (using all keys) @@ -3058,9 +2636,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go index 4c7ce174f8..d008b146de 100644 --- a/core/providers/azure/files.go +++ b/core/providers/azure/files.go @@ -24,7 +24,7 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -33,11 +33,11 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R Scopes: scopes, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -68,16 +68,16 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -110,9 +110,7 @@ func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mod StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index 0875cf781b..5daca3836d 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -3,79 +3,11 @@ package azure import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// findMatchingAllowedModel finds a matching item in a whitelist, considering both -// exact match and base model matches (ignoring version suffixes). -// Returns the matched item from the whitelist if found, empty string otherwise. -// If matched via base model, returns the item from whitelist (not the value parameter). -func findMatchingAllowedModel(wl schemas.WhiteList, value string) string { - // First check exact match (case-insensitive) - if wl.Contains(value) { - return value - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Return the item from whitelist (not value) to use the actual name from allowedModels - for _, item := range wl { - if schemas.SameBaseModel(item, value) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and base model matches (ignoring version suffixes). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check exact match first (by alias/key) - if deployment, ok := deployments[modelID]; ok { - return deployment, modelID - } - - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == modelID { - return depValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - for aliasKey, deploymentValue := range deployments { - // Check if modelID's base matches deploymentValue's base - if schemas.SameBaseModel(deploymentValue, modelID) { - return deploymentValue, aliasKey - } - // Also check if modelID's base matches alias's base (for cases where alias is used as deployment) - if schemas.SameBaseModel(aliasKey, modelID) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -// matchesBlacklist reports whether modelID matches any entry in the blacklist, -// using the same matching logic as findMatchingAllowedModel (exact and base-model). -func matchesBlacklist(bl schemas.BlackList, modelID string) bool { - if bl.IsEmpty() { - return false - } - if bl.Contains(modelID) { - return true - } - for _, item := range bl { - if schemas.SameBaseModel(item, modelID) { - return true - } - } - return false -} - -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -84,117 +16,36 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Azure, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - restrictAllowed := !unfiltered && allowedModels.IsRestricted() + included := make(map[string]bool) - includedModels := make(map[string]bool) for _, model := range response.Data { - modelID := model.ID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering base model matches (ignoring version suffixes) - shouldFilter := false - if restrictAllowed && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if restrictAllowed { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - shouldFilter = deploymentValue == "" - } - // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false - - if shouldFilter { - continue - } - if !unfiltered && (matchesBlacklist(blacklistedModels, model.ID) || - (deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - modelEntry := schemas.Model{ - ID: string(schemas.Azure) + "/" + modelID, - Created: schemas.Ptr(model.CreatedAt), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[strings.ToLower(deploymentAlias)] = true - } else { - includedModels[strings.ToLower(modelID)] = true - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - } - - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[strings.ToLower(alias)] { - continue + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Azure) + "/" + result.ResolvedID, + Created: schemas.Ptr(model.CreatedAt), } - // If allowedModels is restricted, only include if alias is in the list - if restrictAllowed && !allowedModels.Contains(alias) { - continue + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } - if !unfiltered && matchesBlacklist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[strings.ToLower(alias)] = true + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } - // Backfill allowed models that were not in the response - if restrictAllowed { - for _, allowedModel := range allowedModels { - if matchesBlacklist(blacklistedModels, allowedModel) { - continue - } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index 49d1db8de3..20f216c19e 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -27,24 +27,24 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Replace model with deployment jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Delete fallbacks field jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { @@ -52,10 +52,10 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s request.Model = deployment reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } if isStreaming { @@ -68,7 +68,7 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes, preserving field order jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index bf1a567ad7..147704ff7c 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -136,22 +136,6 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { return providerUtils.GetProviderName(schemas.Bedrock, provider.customProviderConfig) } -// ensureBedrockKeyConfig ensures key.BedrockKeyConfig is non-nil. When the key -// uses API key authentication (key.Value is set) but has no Bedrock-specific -// config, a minimal default is created so the request URL can be constructed -// (region defaults to us-east-1). Returns false only when there is truly no -// way to authenticate (no API key AND no bedrock config). -func ensureBedrockKeyConfig(key *schemas.Key) bool { - if key.BedrockKeyConfig != nil { - return true - } - if key.Value.GetValue() != "" { - key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} - return true - } - return false -} - // completeRequest sends a request to Bedrock's API and handles the response. // It constructs the API URL, sets up AWS authentication, and processes the response. // Returns the response body, request latency, or an error if the request fails. @@ -189,7 +173,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock"); err != nil { return nil, 0, nil, err } } @@ -212,10 +196,10 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -316,7 +300,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros if key.Value.GetValue() != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock-agent-runtime", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock-agent-runtime"); err != nil { return nil, 0, nil, err } } @@ -337,10 +321,10 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } var opErr *net.OpError var dnsErr *net.DNSError @@ -387,15 +371,9 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros // makeStreamingRequest creates a streaming request to Bedrock's API. // It formats the request, sends it to Bedrock, and returns the response. // Returns the response body and an error if the request fails. -func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, "", providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - +func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, *schemas.BifrostError) { // Format the path with proper model identifier for streaming - path, deployment := provider.getModelPath(action, model, key) + path := provider.getModelPath(action, model, key) region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { @@ -405,7 +383,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Create HTTP request for streaming req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonData)) if reqErr != nil { - return nil, deployment, providerUtils.NewBifrostOperationError("error creating request", reqErr, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", reqErr) } // Set any extra headers from network config @@ -424,8 +402,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } else { req.Header.Set("Accept", "application/vnd.amazon.eventstream") // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock", providerName); err != nil { - return nil, deployment, err + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock"); err != nil { + return nil, err } } @@ -433,7 +411,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex resp, respErr := provider.client.Do(req) if respErr != nil { if errors.Is(respErr, context.Canceled) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Type: schemas.Ptr(schemas.RequestCancelled), @@ -445,18 +423,18 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(respErr, &netErr) && netErr.Timeout() { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.DeadlineExceeded) { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError var dnsErr *net.DNSError if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) { - return nil, deployment, providerUtils.NewBifrostOperationError(schemas.ErrProviderNetworkError, respErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderNetworkError, respErr) } - return nil, deployment, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, respErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, respErr) } // Extract provider response headers before status check so error responses also forward them @@ -466,10 +444,10 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, deployment, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) + return nil, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) } - return resp, deployment, nil + return resp, nil } // signAWSRequest signs an HTTP request using AWS Signature Version 4. @@ -486,7 +464,6 @@ func signAWSRequest( externalID *schemas.EnvVar, sessionName *schemas.EnvVar, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Set required headers before signing (only if not already set) if req.Header.Get("Content-Type") == "" { @@ -501,7 +478,7 @@ func signAWSRequest( if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return providerUtils.NewBifrostOperationError("error reading request body", err, providerName) + return providerUtils.NewBifrostOperationError("error reading request body", err) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -543,11 +520,10 @@ func signAWSRequest( ) } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } if roleARN != nil && roleARN.GetValue() != "" { - extID := "" if externalID != nil { extID = externalID.GetValue() @@ -602,12 +578,12 @@ func signAWSRequest( // Get credentials creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(ctx, creds, req, bodyHash, service, region, time.Now()); err != nil { - return providerUtils.NewBifrostOperationError("failed to sign request", err, providerName) + return providerUtils.NewBifrostOperationError("failed to sign request", err) } return nil @@ -617,13 +593,7 @@ func signAWSRequest( // It retrieves all foundation models available in Amazon Bedrock for a specific key. func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - config := key.BedrockKeyConfig - region := DefaultBedrockRegion if config.Region != nil && config.Region.GetValue() != "" { region = config.Region.GetValue() @@ -670,7 +640,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock", providerName); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, "bedrock"); err != nil { return nil, err } } @@ -693,10 +663,10 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -744,9 +714,9 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, config.Deployments, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() @@ -787,24 +757,17 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) body, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -816,29 +779,25 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key // Handle model-specific response conversion var bifrostResponse *schemas.BifrostTextCompletionResponse switch { - case schemas.IsAnthropicModel(deployment): + case schemas.IsAnthropicModel(request.Model): var response BedrockAnthropicTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() - case schemas.IsMistralModel(deployment): + case schemas.IsMistralModel(request.Model): var response BedrockMistralTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() default: - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model)) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -851,7 +810,7 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err) } bifrostResponse.ExtraFields.RawResponse = rawResponse } @@ -869,22 +828,17 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -900,9 +854,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -939,7 +893,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("error decoding %s EventStream message: %v", providerName, err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -959,7 +913,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex errMsg = bedrockErr.Message } err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } } @@ -970,18 +924,14 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } if err := sonic.Unmarshal(message.Payload, &chunkPayload); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } // Create BifrostStreamChunk response containing the raw model-specific JSON chunk textResponse := &schemas.BifrostTextCompletionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - Latency: time.Since(startTime).Milliseconds(), + Latency: time.Since(startTime).Milliseconds(), // Pass the raw JSON string from the chunk bytes RawResponse: string(chunkPayload.Bytes), }, @@ -1003,26 +953,19 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1039,13 +982,13 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostChatResponse(ctx, request.Model) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Override finish reason for structured output @@ -1059,10 +1002,6 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1090,20 +1029,17 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex return nil, err } - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1119,9 +1055,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1177,8 +1113,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error decoding EventStream message: %v", err) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1193,8 +1129,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + err := fmt.Errorf("stream %s: %s", excType, errMsg) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } } @@ -1203,7 +1139,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1282,12 +1218,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1304,11 +1236,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response, bifrostErr, _ := streamEvent.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1317,12 +1244,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ID = id response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1341,8 +1264,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } // Send final response - response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model, 0) - response.ExtraFields.ModelDeployment = deployment + response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, request.Model, 0) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData) @@ -1363,26 +1285,19 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1399,22 +1314,18 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(ctx) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = deployment + bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1442,20 +1353,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1471,9 +1379,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1494,7 +1402,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Create stream state for stateful conversions streamState := acquireBedrockResponsesStreamState() - streamState.Model = &deployment + streamState.Model = &request.Model defer releaseBedrockResponsesStreamState(streamState) // Check for structured output mode - if set, we need to intercept tool calls @@ -1528,12 +1436,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) for i, finalResponse := range finalResponses { finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1556,8 +1460,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error decoding EventStream message: %v", err) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1572,8 +1476,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + err := fmt.Errorf("stream %s: %s", excType, errMsg) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } } @@ -1582,7 +1486,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1638,12 +1542,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po SequenceNumber: chunkIndex, Delta: &content, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1660,11 +1560,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1672,12 +1567,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po for _, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1703,15 +1594,10 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Determine model type modelType, err := DetermineEmbeddingModelType(request.Model) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } // Convert request and execute based on model type @@ -1720,7 +1606,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string var jsonData []byte switch modelType { @@ -1730,12 +1615,11 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTitanEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) case "cohere": @@ -1744,16 +1628,15 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockCohereEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) default: - return nil, providerUtils.NewConfigurationError("unsupported embedding model type", providerName) + return nil, providerUtils.NewConfigurationError("unsupported embedding model type") } if providerResponseHeaders != nil { @@ -1769,7 +1652,7 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche case "titan": var titanResp BedrockTitanEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = titanResp.ToBifrostEmbeddingResponse() bifrostResponse.Model = request.Model @@ -1777,17 +1660,13 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche case "cohere": var cohereResp cohere.CohereEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = cohereResp.ToBifrostEmbeddingResponse() bifrostResponse.Model = request.Model } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1813,26 +1692,16 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - - deployment := strings.TrimSpace(resolveBedrockDeployment(request.Model, key)) - if deployment == "" { - return nil, providerUtils.NewConfigurationError("bedrock rerank model is empty", providerName) - } - if !strings.HasPrefix(deployment, "arn:") { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", deployment), providerName) + if !strings.HasPrefix(request.Model, "arn:") { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", request.Model)) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToBedrockRerankRequest(request, deployment) + return ToBedrockRerankRequest(request, request.Model) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -1856,10 +1725,6 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas bifrostResponse := response.ToBifrostRerankResponse(request.Documents, returnDocuments) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1896,31 +1761,24 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var rawResponse []byte var jsonData []byte var bifrostError *schemas.BifrostError var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if isStabilityAIModel(deployment) { + if isStabilityAIModel(request.Model) { return ToStabilityAIImageGenerationRequest(request) } return ToBedrockImageGenerationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } @@ -1936,19 +1794,15 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke var bifrostResponse *schemas.BifrostImageGenerationResponse var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1983,28 +1837,21 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError - // Resolve deployment alias before building the request body so that // Stability AI routing and task-type inference use the actual model ID. - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if isStabilityAIModel(deployment) { - return ToStabilityAIImageEditRequest(request, deployment) + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageEditRequest(request, request.Model) } return ToBedrockImageEditRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } @@ -2021,20 +1868,16 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche // Parse response (reuse BedrockImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2066,11 +1909,6 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError @@ -2079,14 +1917,13 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageVariationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2098,20 +1935,16 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key // Parse response (reuse BedrockImageGenerationResponse and ToBifrostImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageVariationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2170,13 +2003,6 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is is missing in file upload request") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2198,7 +2024,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch if s3Bucket == "" { provider.logger.Error("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)") - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } // Parse bucket name and optional prefix from s3Bucket (could be "bucket-name" or "s3://bucket-name/prefix/") @@ -2232,14 +2058,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch httpReq, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(request.File)) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } httpReq.Header.Set("Content-Type", "application/octet-stream") httpReq.ContentLength = int64(len(request.File)) // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("error signing request: %s", err.Error.Message) return nil, err } @@ -2259,14 +2085,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) provider.logger.Error("s3 upload failed: %d", resp.StatusCode) - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Return S3 URI as the file ID @@ -2283,9 +2109,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch StorageBackend: schemas.FileStorageS3, StorageURI: s3URI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2298,8 +2122,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2321,7 +2143,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc } if s3Bucket == "" { - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } bucketName, bucketPrefix := parseS3URI(s3Bucket) @@ -2332,7 +2154,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2343,10 +2165,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2373,14 +2191,11 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request for S3 - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); bifrostErr != nil { return nil, bifrostErr } @@ -2399,23 +2214,23 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Parse S3 ListObjectsV2 XML response var listResp S3ListObjectsResponse if err := parseS3ListResponse(body, &listResp); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err) } // Convert files to Bifrost format @@ -2447,9 +2262,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2465,25 +2278,18 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2495,12 +2301,12 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodHead, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2520,13 +2326,13 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, nil, nil) continue } @@ -2556,9 +2362,7 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys StorageBackend: schemas.FileStorageS3, StorageURI: request.FileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2572,25 +2376,18 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2602,12 +2399,12 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2627,7 +2424,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } @@ -2635,7 +2432,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } @@ -2646,9 +2443,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2662,25 +2457,18 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2692,12 +2480,12 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2717,21 +2505,21 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err) continue } @@ -2745,9 +2533,7 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ Content: body, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2762,13 +2548,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is not provided") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Require RoleArn in extra params roleArn := "" // First we will honor the role_arn coming from the client side if present @@ -2779,14 +2558,14 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // If its empty then we will honor the role_arn from the key config if roleArn == "" { - if key.BedrockKeyConfig.ARN != nil { - roleArn = key.BedrockKeyConfig.ARN.GetValue() + if key.BedrockKeyConfig.RoleARN != nil { + roleArn = key.BedrockKeyConfig.RoleARN.GetValue() } } // And if still we don't get role ARN if roleArn == "" { provider.logger.Error("role_arn is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil) } // Get output S3 URI from extra params outputS3Uri := "" @@ -2797,24 +2576,12 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } if outputS3Uri == "" { provider.logger.Error("output_s3_uri is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil) } if request.Model == nil { provider.logger.Error("model is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil, providerName) - } - - // Get model ID - - var modelID *string - if key.BedrockKeyConfig.Deployments != nil && request.Model != nil { - if deployment, ok := key.BedrockKeyConfig.Deployments[*request.Model]; ok { - modelID = schemas.Ptr(deployment) - } - } - if modelID == nil { - modelID = request.Model + return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil) } // Generate job name @@ -2842,9 +2609,9 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // Convert inline requests to Bedrock JSONL format - jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, modelID) + jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, request.Model) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Generate S3 key for the input file @@ -2864,7 +2631,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc bucket, s3Key, jsonlData, - providerName, ); bifrostErr != nil { return nil, bifrostErr } @@ -2875,13 +2641,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc // Validate that we have an input file ID (either provided or uploaded) if inputFileID == "" { provider.logger.Error("either input_file_id (S3 URI) or requests array is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil) } // Build request bedrockReq := &BedrockBatchJobRequest{ JobName: jobName, - ModelID: modelID, + ModelID: request.Model, RoleArn: roleArn, InputDataConfig: BedrockInputDataConfig{ S3InputDataConfig: BedrockS3InputDataConfig{ @@ -2906,7 +2672,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(bedrockReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } sendBackRawRequest := provider.sendBackRawRequest @@ -2921,11 +2687,11 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc reqURL := fmt.Sprintf("https://bedrock.%s.amazonaws.com/model-invocation-job", region) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewBuffer(jsonData)) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock"); err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } @@ -2944,13 +2710,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc }, }, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { @@ -2959,7 +2725,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, sendBackRawRequest, sendBackRawResponse) } // AWS CreateModelInvocationJob only returns jobArn, not status or other details. @@ -2976,9 +2742,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc InputFileID: inputFileID, Status: schemas.BatchStatusValidating, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2991,9 +2755,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc Status: retrieveResp.Status, CreatedAt: retrieveResp.CreatedAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3011,12 +2773,10 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - // Initialize serial pagination helper (Bedrock uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3027,17 +2787,9 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3060,11 +2812,11 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock", providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock"); bifrostErr != nil { return nil, bifrostErr } @@ -3083,13 +2835,13 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { @@ -3098,7 +2850,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s var bedrockResp BedrockBatchJobListResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert batches to Bifrost format @@ -3143,9 +2895,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3185,7 +2935,7 @@ func (provider *BedrockProvider) fetchBatchManifest(ctx *schemas.BifrostContext, } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("failed to sign manifest request: %v", err) return nil } @@ -3223,19 +2973,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3247,12 +2990,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock"); err != nil { lastErr = err continue } @@ -3272,14 +3015,14 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3290,7 +3033,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) continue } @@ -3309,9 +3052,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys Status: ToBifrostBatchStatus(bedrockResp.Status), Metadata: metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3367,19 +3108,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3391,12 +3125,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "bedrock"); err != nil { lastErr = err continue } @@ -3416,14 +3150,14 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3446,9 +3180,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: schemas.BatchStatusCancelling, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: totalLatency.Milliseconds(), + Latency: totalLatency.Milliseconds(), }, }, nil } @@ -3458,9 +3190,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: retrieveResp.Status, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3481,8 +3211,6 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output S3 URI prefix (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3493,7 +3221,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil) } outputS3URI := *batchResp.OutputFileID @@ -3535,7 +3263,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys if directErr != nil { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to access batch results at %s: listing failed and direct access failed", outputS3URI), - nil, providerName) + nil) } // Direct download succeeded, parse the content @@ -3544,9 +3272,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } if len(parseErrors) > 0 { @@ -3579,9 +3305,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: allResults, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: totalLatency, + Latency: totalLatency, }, } @@ -3592,26 +3316,14 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return batchResultsResp, nil } -func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) (string, string) { - deployment := resolveBedrockDeployment(model, key) - // Default: use model/deployment directly - path := fmt.Sprintf("%s/%s", deployment, basePath) +func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string { + path := fmt.Sprintf("%s/%s", model, basePath) // If ARN is present, Bedrock expects the ARN-scoped identifier if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.ARN != nil && key.BedrockKeyConfig.ARN.GetValue() != "" { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), deployment)) + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), model)) path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath) } - return path, deployment -} - -func resolveBedrockDeployment(model string, key schemas.Key) string { - deployment := model - if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.Deployments != nil { - if mapped, ok := key.BedrockKeyConfig.Deployments[model]; ok && mapped != "" { - deployment = mapped - } - } - return deployment + return path } func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { @@ -3619,16 +3331,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Convert to Bedrock Converse format using the existing responses converter converseReq, convErr := ToBedrockResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr) } // Wrap in the CountTokens request envelope @@ -3637,11 +3343,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(countTokensReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Format the path with proper model identifier - path, deployment := provider.getModelPath("count-tokens", request.Model, key) + path := provider.getModelPath("count-tokens", request.Model, key) // Send the request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -3652,15 +3358,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc if isCountTokensUnsupported(bifrostErr) { estimated := estimateTokenCount(jsonData) return &schemas.BifrostCountTokensResponse{ - Model: deployment, + Model: request.Model, InputTokens: estimated, TotalTokens: &estimated, Object: "response.input_tokens", ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - ModelRequested: request.Model, - ModelDeployment: deployment, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3683,15 +3385,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc } // Convert to Bifrost format - response := bedrockResponse.ToBifrostCountTokensResponse(deployment) + response := bedrockResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequest } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 1949051c44..9e43a176e3 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -3626,17 +3626,17 @@ func TestToBedrockInvokeMessagesStreamResponse_NoDuplicateContentBlockStop(t *te { Type: schemas.ResponsesStreamResponseTypeOutputTextDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeContentPartDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeOutputItemDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, } diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 71e7890935..6459df377b 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -247,8 +247,6 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte Usage: usage, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Bedrock, }, } diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index c21520edc4..29ff971858 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -448,9 +448,16 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return nil, fmt.Errorf("bifrost response is nil") } - model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + model := "" + if resp.Model != "" { + model = resp.Model + } else { + extraFields := resp.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } } // Nova models: delegate to existing ToBedrockConverseResponse (Nova InvokeModel matches Converse format) @@ -623,12 +630,17 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc // final Completed event). Without checking resp.ExtraFields, early chunks would // have model="" and Nova streams would be mis-routed through the Anthropic path. model := "" - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.ExtraFields.ModelRequested != "" { - model = resp.Response.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.Model != "" { - model = resp.Response.Model + if resp.Response != nil { + if resp.Response.Model != "" { + model = resp.Response.Model + } else { + extraFields := resp.Response.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } + } } // Nova models: delegate to existing converse stream response (same format) @@ -666,8 +678,11 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) switch resp.Type { case schemas.ResponsesStreamResponseTypeCreated: - // message_start — use ExtraFields.ModelRequested as fallback for early chunks - model := resp.ExtraFields.ModelRequested + // message_start — prefer resolved model for accurate family detection on early chunks + model := resp.ExtraFields.ResolvedModelUsed + if model == "" { + model = resp.ExtraFields.OriginalModelRequested + } msgStart := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ @@ -777,7 +792,7 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) "type": "content_block_delta", "index": idx, "delta": map[string]interface{}{ - "type": "input_json_delta", + "type": "input_json_delta", "partial_json": *resp.Delta, }, } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index 005998aa4a..6d2f9006f2 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -3,6 +3,7 @@ package bedrock import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -80,175 +81,8 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } -// regionPrefixes is a list of region prefixes used in Bedrock deployments -// Based on AWS region naming patterns and Bedrock deployment configurations -var regionPrefixes = []string{ - "us.", // US regions (us-east-1, us-west-2, etc.) - "eu.", // Europe regions (eu-west-1, eu-central-1, etc.) - "ap.", // Asia Pacific regions (ap-southeast-1, ap-northeast-1, etc.) - "ca.", // Canada regions (ca-central-1, etc.) - "sa.", // South America regions (sa-east-1, etc.) - "af.", // Africa regions (af-south-1, etc.) - "global.", // Global deployment prefix -} - -// extractPrefix extracts the region prefix ending with '.' from a string -// Only recognizes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the prefix (including the dot) if found, empty string otherwise -func extractPrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return prefix - } - } - return "" -} - -// removePrefix removes any region prefix ending with '.' from a string -// Only removes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the string without the prefix -func removePrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return s[len(prefix):] - } - } - return s -} - -// findMatchingAllowedModel finds a matching item in a whitelist, considering both -// exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// Returns the matched item from the whitelist if found, empty string otherwise. -// If matched via base model, returns the item from whitelist (not the value parameter). -func findMatchingAllowedModel(wl schemas.WhiteList, value string) string { - // First check exact matches (case-insensitive) - if wl.Contains(value) { - return value - } - - // Check with region prefix added/removed - valuePrefix := extractPrefix(value) - if valuePrefix != "" { - // value has a prefix, check if whitelist contains version without prefix - withoutPrefix := removePrefix(value) - if wl.Contains(withoutPrefix) { - return withoutPrefix - } - } - - // Check if any item in whitelist has a prefix that matches value without prefix - for _, item := range wl { - itemPrefix := extractPrefix(item) - if itemPrefix != "" { - // item has prefix, check if value matches without the prefix - itemWithoutPrefix := removePrefix(item) - if itemWithoutPrefix == value { - return item - } - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize value by removing any region prefix for base model comparison - valueNormalized := removePrefix(value) - - for _, item := range wl { - // Normalize item by removing any region prefix for base model comparison - itemNormalized := removePrefix(item) - - // Check base model match with normalized values (prefix removed from both) - // Return the item from whitelist (not value) to use the actual name from allowedModels - if schemas.SameBaseModel(itemNormalized, valueNormalized) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// The modelID from the API response should match a deployment value (not the alias/key). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check if any deployment value matches the modelID (with or without prefix) - for aliasKey, deploymentValue := range deployments { - // Exact match - if deploymentValue == modelID || aliasKey == modelID { - return deploymentValue, aliasKey - } - - // Check prefix variations - deploymentPrefix := extractPrefix(deploymentValue) - modelIDPrefix := extractPrefix(modelID) - aliasKeyPrefix := extractPrefix(aliasKey) - - // Case 1: deploymentValue or aliasKey has prefix, modelID doesn't - if (deploymentPrefix != "" && modelIDPrefix == "") || (aliasKeyPrefix != "" && modelIDPrefix == "") { - if removePrefix(deploymentValue) == modelID || removePrefix(aliasKey) == modelID { - return deploymentValue, aliasKey - } - } - - // Case 2: modelID or aliasKey has prefix, deploymentValue doesn't - if (modelIDPrefix != "" && deploymentPrefix == "") || (aliasKeyPrefix != "" && deploymentPrefix == "") { - if removePrefix(modelID) == deploymentValue || removePrefix(modelID) == aliasKey { - return deploymentValue, aliasKey - } - } - - // Case 3: Both have prefixes but different prefixes - if (deploymentPrefix != "" && modelIDPrefix != "" && deploymentPrefix != modelIDPrefix) || (aliasKeyPrefix != "" && modelIDPrefix != "" && aliasKeyPrefix != modelIDPrefix) { - if removePrefix(deploymentValue) == removePrefix(modelID) || removePrefix(aliasKey) == removePrefix(modelID) { - return deploymentValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize both values by removing any region prefix for base model comparison - deploymentNormalized := removePrefix(deploymentValue) - modelIDNormalized := removePrefix(modelID) - // Check base model match with normalized values (prefix removed from both) - if schemas.SameBaseModel(deploymentNormalized, modelIDNormalized) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -// matchesBlacklist reports whether modelID matches any entry in the blacklist, -// using the same matching logic as findMatchingAllowedModel (exact, prefix-normalized, base-model). -func matchesBlacklist(bl schemas.BlackList, modelID string) bool { - if bl.IsEmpty() { - return false - } - if bl.Contains(modelID) { - return true - } - if extractPrefix(modelID) != "" { - if bl.Contains(removePrefix(modelID)) { - return true - } - } - for _, item := range bl { - if extractPrefix(item) != "" && removePrefix(item) == modelID { - return true - } - } - valueNormalized := removePrefix(modelID) - for _, item := range bl { - if schemas.SameBaseModel(removePrefix(item), valueNormalized) { - return true - } - } - return false -} - -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -257,127 +91,41 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } - if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { - return bifrostResponse + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } - - deploymentValues := make([]string, 0, len(deployments)) - for _, deployment := range deployments { - deploymentValues = append(deploymentValues, deployment) + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - restrictAllowed := !unfiltered && allowedModels.IsRestricted() + included := make(map[string]bool) - includedModels := make(map[string]bool) for _, model := range response.ModelSummaries { - modelID := model.ModelID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering global prefix variations - shouldFilter := false - if restrictAllowed && len(deploymentValues) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) + for _, result := range pipeline.FilterModel(model.ModelID) { + modelEntry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.ModelName), + OwnedBy: schemas.Ptr(model.ProviderName), + Architecture: &schemas.Architecture{ + InputModalities: model.InputModalities, + OutputModalities: model.OutputModalities, + }, } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if restrictAllowed { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deploymentValues) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false - - if shouldFilter { - continue - } - if !unfiltered && (matchesBlacklist(blacklistedModels, model.ModelID) || - (deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ModelID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - modelEntry := schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.ModelName), - OwnedBy: schemas.Ptr(model.ProviderName), - Architecture: &schemas.Architecture{ - InputModalities: model.InputModalities, - OutputModalities: model.OutputModalities, - }, - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(providerKey) + "/" + deploymentAlias - // Use the actual deployment value (which might have global prefix) - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[strings.ToLower(deploymentAlias)] = true - } else { - includedModels[strings.ToLower(modelID)] = true - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - } - - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[strings.ToLower(alias)] { - continue - } - // If allowedModels is restricted, only include if alias is in the list - if restrictAllowed && !allowedModels.Contains(alias) { - continue - } - if !unfiltered && matchesBlacklist(blacklistedModels, alias) { - continue + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[strings.ToLower(alias)] = true + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } } - // Backfill allowed models that were not in the response - if restrictAllowed { - for _, allowedModel := range allowedModels { - if matchesBlacklist(blacklistedModels, allowedModel) { - continue - } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } diff --git a/core/providers/bedrock/rerank_test.go b/core/providers/bedrock/rerank_test.go index 0dff5c3ee2..c1b7bb5480 100644 --- a/core/providers/bedrock/rerank_test.go +++ b/core/providers/bedrock/rerank_test.go @@ -195,27 +195,23 @@ func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) { func TestResolveBedrockDeployment(t *testing.T) { key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", }, } - deployment := resolveBedrockDeployment("cohere-rerank", key) + deployment := key.Aliases.Resolve("cohere-rerank") assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment) - assert.Equal(t, "cohere.rerank-v3-5:0", resolveBedrockDeployment("cohere.rerank-v3-5:0", key)) - assert.Equal(t, "", resolveBedrockDeployment("", key)) + assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0")) + assert.Equal(t, "", key.Aliases.Resolve("")) } func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) { provider := &BedrockProvider{} ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "cohere.rerank-v3-5:0", }, } diff --git a/core/providers/bedrock/s3.go b/core/providers/bedrock/s3.go index da06e5e820..be2d0afb32 100644 --- a/core/providers/bedrock/s3.go +++ b/core/providers/bedrock/s3.go @@ -22,7 +22,6 @@ func uploadToS3( region string, bucket, key string, content []byte, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Create AWS config with credentials var cfg aws.Config @@ -47,7 +46,7 @@ func uploadToS3( } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load AWS config for S3", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err) } // Create S3 client @@ -62,7 +61,7 @@ func uploadToS3( }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to S3: %s/%s", bucket, key), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err) } return nil diff --git a/core/providers/bedrock/signer.go b/core/providers/bedrock/signer.go index 9f12e3bbaf..b7e87ae8d2 100644 --- a/core/providers/bedrock/signer.go +++ b/core/providers/bedrock/signer.go @@ -280,17 +280,16 @@ func signAWSRequestFastHTTP( accessKey, secretKey string, sessionToken *string, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Get AWS credentials if not provided if accessKey == "" && secretKey == "" { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } accessKey = creds.AccessKeyID secretKey = creds.SecretAccessKey diff --git a/core/providers/bedrock/text.go b/core/providers/bedrock/text.go index 6ad24ee1c8..d31d716ded 100644 --- a/core/providers/bedrock/text.go +++ b/core/providers/bedrock/text.go @@ -127,8 +127,6 @@ func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -154,8 +152,6 @@ func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *s Object: "text_completion", Choices: choices, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -167,11 +163,14 @@ func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionR return nil } - // Determine response format based on model - // Use ModelRequested from ExtraFields if available, otherwise use Model + // Determine response format based on resolved model identity. + // Use ResolvedModelUsed (actual provider ID) for accurate family detection, + // falling back to bifrostResp.Model, then OriginalModelRequested as a last resort. model := bifrostResp.Model - if bifrostResp.ExtraFields.ModelRequested != "" { - model = bifrostResp.ExtraFields.ModelRequested + if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" { + model = bifrostResp.ExtraFields.OriginalModelRequested } if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") { diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index e880087e8b..3a000a76af 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -178,9 +178,6 @@ func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/cohere/chat.go b/core/providers/cohere/chat.go index 33807a22bd..366102831e 100644 --- a/core/providers/cohere/chat.go +++ b/core/providers/cohere/chat.go @@ -367,8 +367,6 @@ func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas }, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Cohere, }, } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 2986760329..b67ef8668e 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -155,7 +155,7 @@ func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, def // completeRequest sends a request to Cohere's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -199,10 +199,10 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, latency, providerResponseHeaders, parseCohereError(resp, meta) + return nil, latency, providerResponseHeaders, parseCohereError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, latency, providerResponseHeaders, decodeErr } @@ -217,8 +217,6 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -234,7 +232,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Parse and add query parameters u, err := url.Parse(baseURL) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse request URL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse request url", err) } q := u.Query() @@ -269,15 +267,12 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseCohereError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse Cohere list models response @@ -288,7 +283,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -352,17 +347,12 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -375,9 +365,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -396,9 +383,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -424,7 +408,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, err } - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -435,8 +418,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } reqBody.Stream = schemas.Ptr(true) return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -486,9 +468,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -497,11 +479,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -520,9 +498,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -560,7 +538,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -582,11 +560,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -594,11 +567,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext if response != nil { response.ID = responseID response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -638,18 +608,13 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereResponsesRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Convert to Cohere v2 request - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -662,9 +627,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -685,9 +647,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -711,7 +670,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, err } - providerName := provider.GetProviderKey() // Convert to Cohere v2 request and add streaming jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -725,8 +683,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos reqBody.Stream = schemas.Ptr(true) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -774,9 +731,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -785,11 +742,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -808,9 +761,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -852,8 +805,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error reading stream: %v", readErr) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -873,11 +826,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -886,11 +834,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() chunkIndex++ @@ -934,18 +879,13 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Create Bifrost request for conversion - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -958,9 +898,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -979,9 +916,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := response.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1010,17 +944,12 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereRerankRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1033,9 +962,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1056,9 +982,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1211,16 +1134,12 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereCountTokensRequest(request) - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1230,11 +1149,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonBody, provider.buildRequestURL(ctx, "/v1/tokenize", schemas.CountTokensRequest), key.Value.GetValue(), - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1248,9 +1162,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return &schemas.BifrostCountTokensResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1272,12 +1183,9 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch bifrostResponse := cohereResponse.ToBifrostCountTokensResponse(request.Model) if bifrostResponse == nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil Cohere count tokens response"), providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil cohere count tokens response")), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.CountTokensRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders diff --git a/core/providers/cohere/errors.go b/core/providers/cohere/errors.go index e9183b1b34..e444d86650 100644 --- a/core/providers/cohere/errors.go +++ b/core/providers/cohere/errors.go @@ -6,7 +6,7 @@ import ( "github.com/valyala/fasthttp" ) -func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp CohereError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) bifrostErr.Type = &errorResp.Type @@ -17,10 +17,5 @@ func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetada if errorResp.Code != nil { bifrostErr.Error.Code = errorResp.Code } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index e66aeedb1b..3b285f97b6 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -4,6 +4,7 @@ import ( "encoding/json" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -44,7 +45,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -53,41 +54,39 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - for _, model := range response.Models { - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.Name) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(model.Name) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.Name, - Name: schemas.Ptr(model.Name), - ContextLength: schemas.Ptr(int(model.ContextLength)), - SupportedMethods: model.Endpoints, - }) - includedModels[strings.ToLower(model.Name)] = true - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + for _, model := range response.Models { + // Cohere uses model.Name as the model identifier + for _, result := range pipeline.FilterModel(model.Name) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + ContextLength: schemas.Ptr(int(model.ContextLength)), + SupportedMethods: model.Endpoints, } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index afa095e56d..40ddbeb4ad 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "mime/multipart" "net/http" @@ -74,8 +73,6 @@ func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,10 +100,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseElevenlabsError(resp) } var elevenlabsResponse ElevenlabsListModelsResponse @@ -115,7 +109,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -188,8 +182,6 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -211,7 +203,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche endpoint = "/v1/text-to-speech/" + voice } } else { - return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil) } requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request) @@ -228,8 +220,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -250,26 +241,18 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Get the response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response based on whether timestamps were requested bifrostResponse := &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -282,7 +265,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche if withTimestampsRequest { var timestampResponse ElevenlabsSpeechWithTimestampsResponse if err := sonic.Unmarshal(body, ×tampResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err) } bifrostResponse.AudioBase64 = ×tampResponse.AudioBase64 @@ -321,15 +304,12 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -345,7 +325,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request)) @@ -376,9 +356,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -387,11 +367,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response channel @@ -402,9 +378,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -445,7 +421,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -458,11 +434,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -481,11 +454,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDone, Audio: []byte{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -506,32 +476,30 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return nil, err } - providerName := provider.GetProviderKey() - reqBody := ToElevenlabsTranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil) } hasFile := len(reqBody.File) > 0 hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" if hasFile && hasURL { - return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil) } if !hasFile && !hasURL { - return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil) } var body bytes.Buffer writer := multipart.NewWriter(&body) - if bifrostErr := writeTranscriptionMultipart(writer, reqBody, providerName); bifrostErr != nil { + if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil { return nil, bifrostErr } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err) } req := fasthttp.AcquireRequest() @@ -562,17 +530,12 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + return nil, parseElevenlabsError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -588,18 +551,15 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k chunks, err := parseTranscriptionResponse(responseBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } if len(chunks) == 0 { - return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil) } response := ToBifrostTranscriptionResponse(chunks) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } @@ -607,7 +567,7 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + rawResponse = string(responseBody) } response.ExtraFields.RawResponse = rawResponse } @@ -615,9 +575,9 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return response, nil } -func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError { +func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError { if err := writer.WriteField("model_id", reqBody.ModelID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model_id field", err) } if len(reqBody.File) > 0 { @@ -627,98 +587,98 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create file field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create file field", err) } if _, err := fileWriter.Write(reqBody.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } } if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" { if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil { - return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err) } } if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" { if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language_code field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language_code field", err) } } if reqBody.TagAudioEvents != nil { if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err) } } if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 { if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err) } } if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" { if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err) } } if reqBody.Diarize != nil { if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarize field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarize field", err) } } if reqBody.DiarizationThreshold != nil { if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err) } } if len(reqBody.AdditionalFormats) > 0 { payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err) } if err := writer.WriteField("additional_formats", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err) } } if reqBody.FileFormat != nil && *reqBody.FileFormat != "" { if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file_format field", err) } } if reqBody.Webhook != nil { if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook field", err) } } if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" { if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err) } } if reqBody.Temperature != nil { if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } if reqBody.Seed != nil { if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seed field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seed field", err) } } if reqBody.UseMultiChannel != nil { if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err) } } @@ -727,16 +687,16 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr case string: if strings.TrimSpace(v) != "" { if err := writer.WriteField("webhook_metadata", v); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } default: payload, err := providerUtils.MarshalSorted(v) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err) } if err := writer.WriteField("webhook_metadata", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } } diff --git a/core/providers/elevenlabs/errors.go b/core/providers/elevenlabs/errors.go index 374e251958..f30807efd5 100644 --- a/core/providers/elevenlabs/errors.go +++ b/core/providers/elevenlabs/errors.go @@ -9,7 +9,7 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" ) -func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp ElevenlabsError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Detail != nil { @@ -64,11 +64,6 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe Message: message, }, } - if meta != nil { - result.ExtraFields.Provider = meta.Provider - result.ExtraFields.ModelRequested = meta.Model - result.ExtraFields.RequestType = meta.RequestType - } return result } } @@ -91,10 +86,5 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe bifrostErr.Error.Message = message } } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index 3c4e939fca..f762d97ee8 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -3,10 +3,11 @@ package elevenlabs import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,39 +16,36 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - for _, model := range *response { - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(model.ModelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ModelID, - Name: schemas.Ptr(model.Name), - }) - includedModels[strings.ToLower(model.ModelID)] = true - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + for _, model := range *response { + for _, result := range pipeline.FilterModel(model.ModelID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/gemini/batch.go b/core/providers/gemini/batch.go index e3d92383f6..8f0405e524 100644 --- a/core/providers/gemini/batch.go +++ b/core/providers/gemini/batch.go @@ -249,8 +249,6 @@ func extractBatchIDFromName(name string) string { // downloadBatchResultsFile downloads and parses a batch results file from Gemini. // Returns the parsed result items from the JSONL file and any parse errors encountered. func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request to download the file req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -287,15 +285,12 @@ func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse JSONL content - each line is a separate JSON object diff --git a/core/providers/gemini/errors.go b/core/providers/gemini/errors.go index adf217a141..2d60a7bcd3 100644 --- a/core/providers/gemini/errors.go +++ b/core/providers/gemini/errors.go @@ -36,7 +36,7 @@ func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError { } // parseGeminiError parses Gemini error responses -func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse as []GeminiGenerationError var errorResps []GeminiGenerationError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps) @@ -62,11 +62,6 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada } // Set Message to trimmed concatenated message bifrostErr.Error.Message = message - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } @@ -80,10 +75,5 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code)) bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index e852008e0d..e925cebe27 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -97,9 +97,7 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { // completeRequest handles the common HTTP request pattern for Gemini API calls. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, nil, latency, nil) — callers must check the context flag. -func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - +func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -146,10 +144,10 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp, meta) + return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, nil, latency, providerResponseHeaders, decodeErr } @@ -161,13 +159,13 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Parse Gemini's response var geminiResponse GenerateContentResponse if err := sonic.Unmarshal(body, &geminiResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } } @@ -208,10 +206,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseGeminiError(resp) } // Parse Gemini's response @@ -227,7 +222,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -282,24 +277,17 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiChatCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -312,9 +300,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -323,9 +308,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := geminiResponse.ToBifrostChatResponse() - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -360,8 +342,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion request is not provided or could not be converted to Gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -447,9 +428,9 @@ func HandleGeminiChatCompletionStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -459,11 +440,7 @@ func HandleGeminiChatCompletionStream( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) respBody := append([]byte(nil), resp.Body()...) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -480,9 +457,9 @@ func HandleGeminiChatCompletionStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -492,7 +469,6 @@ func HandleGeminiChatCompletionStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -554,7 +530,7 @@ func HandleGeminiChatCompletionStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } // Process chunk using shared function @@ -569,11 +545,6 @@ func HandleGeminiChatCompletionStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -594,11 +565,6 @@ func HandleGeminiChatCompletionStream( // Convert to Bifrost stream response response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -610,11 +576,8 @@ func HandleGeminiChatCompletionStream( response.Model = modelName } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -686,8 +649,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return nil, fmt.Errorf("responses input is not provided or could not be converted to Gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -699,11 +661,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem } // Use struct directly for JSON marshaling - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -716,9 +674,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -729,9 +684,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -761,13 +713,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( bodyReader io.Reader, // Optional: for large payload request streaming (pass nil for normal path) bodySize int, // Required if bodyReader is non-nil ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - meta := &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - } - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -801,14 +746,14 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Handle error response — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - bifrostErr := parseGeminiError(resp, meta) + bifrostErr := parseGeminiError(resp) wait() fasthttp.ReleaseResponse(resp) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Delegate large response detection + normal buffered path to shared utility - responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -824,9 +769,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( Model: request.Model, Usage: usage, } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // resp owned by reader in context — don't release wait() @@ -838,12 +780,9 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Normal parse-and-convert path var geminiResponse GenerateContentResponse if unmarshalErr := sonic.Unmarshal(responseBody, &geminiResponse); unmarshalErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr) } bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -892,8 +831,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, fmt.Errorf("responses input is not provided or could not be converted to Gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -978,9 +916,9 @@ func HandleGeminiResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -989,11 +927,7 @@ func HandleGeminiResponsesStream( // Check for HTTP errors — use parseGeminiError to preserve upstream error details if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1010,9 +944,9 @@ func HandleGeminiResponsesStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1023,7 +957,6 @@ func HandleGeminiResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError( @@ -1092,7 +1025,7 @@ func HandleGeminiResponsesStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } @@ -1108,11 +1041,6 @@ func HandleGeminiResponsesStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1130,11 +1058,6 @@ func HandleGeminiResponsesStream( // Convert to Bifrost responses stream response responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1143,11 +1066,8 @@ func HandleGeminiResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1200,11 +1120,8 @@ func HandleGeminiResponsesStream( continue } finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1249,8 +1166,7 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiEmbeddingRequest(request), nil - }, - providerName) + }) if err != nil { return nil, err } @@ -1301,17 +1217,13 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) wait() fasthttp.ReleaseResponse(resp) return nil, parsedErr } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -1324,9 +1236,6 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1348,12 +1257,9 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model) if bifrostResponse == nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, - fmt.Errorf("failed to convert Gemini embedding response to Bifrost format"), providerName) + fmt.Errorf("failed to convert Gemini embedding response to Bifrost format")) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1382,18 +1288,13 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.SpeechRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1405,9 +1306,6 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1419,13 +1317,10 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. } response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.SpeechRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1452,16 +1347,13 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using speech-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1507,9 +1399,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1518,11 +1410,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1541,9 +1429,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1583,7 +1471,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1603,11 +1491,6 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1658,11 +1541,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -1679,11 +1559,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDone, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } response.BackfillParams(request) @@ -1711,18 +1588,13 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1734,9 +1606,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostTranscriptionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TranscriptionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1746,9 +1615,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s response := geminiResponse.ToBifrostTranscriptionResponse() // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TranscriptionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1770,16 +1636,13 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using transcription-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1825,9 +1688,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1836,11 +1699,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1859,9 +1718,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1901,7 +1760,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1920,11 +1779,6 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1969,11 +1823,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Type: schemas.TranscriptionStreamResponseTypeDelta, Delta: &deltaText, // Delta text for this chunk ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -1996,11 +1847,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, TotalTokens: usage.TotalTokens, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2033,18 +1881,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageGenerationRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2056,9 +1899,6 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2067,25 +1907,16 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image generation response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - provider.GetProviderKey(), ) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2102,16 +1933,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key // handleImagenImageGeneration handles Imagen model requests using Vertex AI endpoint with API key auth func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Prepare Imagen request body jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2151,16 +1979,11 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Imagen response - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2168,10 +1991,7 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2183,9 +2003,6 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost } // Convert to Bifrost format response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2210,8 +2027,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, err } - providerName := provider.GetProviderKey() - // Handle Imagen models using :predict endpoint if schemas.IsImagenModel(request.Model) { jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2219,8 +2034,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2255,15 +2069,10 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2271,10 +2080,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem imagenRespOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2286,9 +2092,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2307,18 +2110,13 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2330,9 +2128,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2341,25 +2136,16 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image edit response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - providerName, ) } // Set ExtraFields - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2391,7 +2177,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() model := bifrostReq.Model jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2400,7 +2185,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2433,17 +2217,13 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // use handle provider response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2459,12 +2239,9 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = model - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -2482,10 +2259,9 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, err } - providerName := provider.GetProviderKey() operationID := bifrostReq.ID - operationID = providerUtils.StripVideoIDProviderSuffix(operationID, providerName) + operationID = providerUtils.StripVideoIDProviderSuffix(operationID, provider.GetProviderKey()) // Create HTTP request req := fasthttp.AcquireRequest() @@ -2510,10 +2286,8 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + respBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2527,12 +2301,10 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) // Add extra fields bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { bifrostResp.ExtraFields.RawResponse = rawResponse @@ -2546,9 +2318,8 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first so download behavior follows retrieve status. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2563,11 +2334,10 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), nil, - providerName, ) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte contentType := "video/mp4" @@ -2578,7 +2348,7 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded latency = time.Since(startTime) @@ -2609,17 +2379,16 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), nil, - providerName, ) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ VideoID: request.ID, @@ -2628,8 +2397,6 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -2659,18 +2426,16 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - // Validate that either InputFileID or Requests is provided, but not both hasFileInput := request.InputFileID != "" hasInlineRequests := len(request.Requests) > 0 if !hasFileInput && !hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil) } if hasFileInput && hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil) } // Build the batch request with proper nested structure @@ -2703,12 +2468,12 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch if rawMessages, ok := body["messages"]; ok { messagesBytes, err := providerUtils.MarshalSorted(rawMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err) } var chatMessages []schemas.ChatMessage err = sonic.Unmarshal(messagesBytes, &chatMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err) } contents, systemInstruction := convertBifrostMessagesToGemini(chatMessages) @@ -2718,11 +2483,11 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // If no "messages" key, try direct unmarshal (already in Gemini format) requestBytes, err := providerUtils.MarshalSorted(body) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err) } err = sonic.Unmarshal(requestBytes, &geminiReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err) } } @@ -2746,7 +2511,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(batchReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create HTTP request @@ -2784,31 +2549,27 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.BatchCreateRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse the batch job response var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { provider.logger.Error("gemini batch create unmarshal error: " + err.Error()) - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Check for metadata if geminiResp.Metadata == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil) } // Check for batch stats if geminiResp.Metadata.BatchStats == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil) } // Calculate request counts based on response totalRequests := geminiResp.Metadata.BatchStats.RequestCount @@ -2851,9 +2612,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2872,8 +2631,6 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // batchListByKey lists batch jobs for Gemini for a single key. func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2921,26 +2678,21 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format @@ -2952,10 +2704,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Status: ToBifrostBatchStatus(batch.Metadata.State), CreatedAt: parseGeminiTimestamp(batch.Metadata.CreateTime), OperationName: &batch.Name, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, + ExtraFields: schemas.BifrostResponseExtraFields{}, }) } @@ -2971,9 +2720,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key HasMore: hasMore, NextCursor: nextCursor, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } @@ -2987,16 +2734,14 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil) } // Initialize serial pagination helper (Gemini uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3007,10 +2752,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -3042,9 +2783,7 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3056,8 +2795,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // batchRetrieveByKey retrieves a specific batch job for Gemini for a single key. func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3090,20 +2827,17 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var completedCount, failedCount int @@ -3132,9 +2866,7 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3145,14 +2877,12 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil) } // Try each key until we find the batch @@ -3171,8 +2901,6 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // batchCancelByKey cancels a batch job for Gemini for a single key. func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3210,15 +2938,9 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke if resp.StatusCode() == fasthttp.StatusNotFound || resp.StatusCode() == fasthttp.StatusMethodNotAllowed { // 404 could mean batch not found or cancel not supported // Return the error instead of assuming completed - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } now := time.Now().Unix() @@ -3228,9 +2950,7 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke Status: schemas.BatchStatusCancelling, CancellingAt: &now, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3242,14 +2962,12 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil) } // Try each key until cancellation succeeds @@ -3271,8 +2989,6 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // batches.delete indicates the client is no longer interested in the operation result. // It does not cancel the operation. If the server doesn't support this method, it returns UNIMPLEMENTED. func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -3301,10 +3017,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke } if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostBatchDeleteResponse{ @@ -3312,9 +3025,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke Object: "batch", Status: schemas.BatchStatusDeleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3327,14 +3038,12 @@ func (provider *GeminiProvider) BatchDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil) } var lastError *schemas.BifrostError @@ -3482,8 +3191,6 @@ func readNextSSEDataLine(reader *bufio.Reader, skipInlineData bool) ([]byte, err // batchResultsByKey retrieves batch results for Gemini for a single key. func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // We need to get the full batch response with results, so make the API call directly req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3517,20 +3224,17 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Check if batch is still processing @@ -3538,7 +3242,6 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("batch %s is still processing (state: %s), results not yet available", request.BatchID, geminiResp.Metadata.State), nil, - providerName, ) } @@ -3626,9 +3329,7 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3647,14 +3348,12 @@ func (provider *GeminiProvider) BatchResults(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil) } // Try each key until we get results @@ -3678,10 +3377,8 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart request @@ -3691,14 +3388,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add file metadata as JSON metadataField, err := writer.CreateFormField("metadata") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err) } metadataJSON, err := providerUtils.SetJSONField([]byte(`{}`), "file.displayName", request.Filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } if _, err := metadataField.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } // Add file content @@ -3708,14 +3405,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3746,15 +3443,12 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileUploadRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse response - wrapped in "file" object @@ -3762,7 +3456,7 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche File GeminiFileResponse `json:"file"` } if err := sonic.Unmarshal(body, &responseWrapper); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } geminiResp := responseWrapper.File @@ -3798,17 +3492,13 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } // fileListByKey lists files from Gemini for a single key. func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3845,20 +3535,17 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost response @@ -3867,9 +3554,7 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s Data: make([]schemas.FileObject, len(geminiResp.Files)), HasMore: geminiResp.NextPageToken != "", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3922,16 +3607,14 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil) } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3942,10 +3625,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3977,9 +3656,7 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3991,8 +3668,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch // fileRetrieveByKey retrieves file metadata from Gemini for a single key. func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4023,20 +3698,17 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var sizeBytes int64 @@ -4073,9 +3745,7 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4086,14 +3756,12 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil) } // Try each key until we find the file @@ -4113,8 +3781,6 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // fileDeleteByKey deletes a file from Gemini for a single key. func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4145,10 +3811,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key // Handle error response - DELETE returns 200 with empty body on success if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostFileDeleteResponse{ @@ -4156,9 +3819,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4169,14 +3830,12 @@ func (provider *GeminiProvider) FileDelete(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil) } // Try each key until deletion succeeds @@ -4202,14 +3861,11 @@ func (provider *GeminiProvider) FileContent(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - // Gemini doesn't support direct file content download // Files are referenced by their URI in requests return nil, providerUtils.NewBifrostOperationError( "Gemini Files API doesn't support direct content download. Use the file URI in your requests instead.", nil, - providerName, ) } @@ -4240,7 +3896,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiResponsesRequest(request), nil }, - provider.GetProviderKey(), ) if bifrostErr != nil { return nil, bifrostErr @@ -4252,14 +3907,13 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonData, _ = providerUtils.DeleteJSONField(jsonData, "systemInstruction") } - providerName := provider.GetProviderKey() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) if strings.TrimSpace(request.Model) == "" { - return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model"), providerName) + return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model")) } // Determine native model name (e.g., parse any provider prefix) @@ -4292,15 +3946,12 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch } if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } responseBody := append([]byte(nil), body...) @@ -4320,9 +3971,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch response := geminiResponse.ToBifrostCountTokensResponse(request.Model) // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.CountTokensRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -4425,7 +4073,7 @@ func (provider *GeminiProvider) Passthrough( headers := providerUtils.ExtractProviderResponseHeaders(resp) body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -4439,9 +4087,6 @@ func (provider *GeminiProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4505,9 +4150,9 @@ func (provider *GeminiProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -4518,7 +4163,6 @@ func (provider *GeminiProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -4530,11 +4174,7 @@ func (provider *GeminiProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4545,9 +4185,9 @@ func (provider *GeminiProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -4596,7 +4236,7 @@ func (provider *GeminiProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index 34f021a5be..7b9f6410eb 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -3,6 +3,7 @@ package gemini import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -16,7 +17,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -25,49 +26,47 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - for _, model := range response.Models { + included := make(map[string]bool) + for _, model := range response.Models { contextLength := model.InputTokenLimit + model.OutputTokenLimit - // Remove prefix models/ from model.Name + // Gemini returns model names with a "models/" prefix — strip it before filtering + // so that allowedModels entries like "gemini-1.5-pro" match correctly. modelName := strings.TrimPrefix(model.Name, "models/") - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelName) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(modelName) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelName, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - ContextLength: schemas.Ptr(int(contextLength)), - MaxInputTokens: schemas.Ptr(model.InputTokenLimit), - MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), - SupportedMethods: model.SupportedGenerationMethods, - }) - includedModels[strings.ToLower(modelName)] = true - } - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + for _, result := range pipeline.FilterModel(modelName) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + ContextLength: schemas.Ptr(int(contextLength)), + MaxInputTokens: schemas.Ptr(model.InputTokenLimit), + MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), + SupportedMethods: model.SupportedGenerationMethods, } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/gemini/videos.go b/core/providers/gemini/videos.go index 62ce110c26..43ece90be4 100644 --- a/core/providers/gemini/videos.go +++ b/core/providers/gemini/videos.go @@ -217,7 +217,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if operation == nil { - return nil, providerUtils.NewBifrostOperationError("operation is nil", nil, schemas.Gemini) + return nil, providerUtils.NewBifrostOperationError("operation is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index 4bbcfd1395..b3c030b386 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -149,9 +149,6 @@ func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/huggingface/errors.go b/core/providers/huggingface/errors.go index 49ce427df7..d98357e0a8 100644 --- a/core/providers/huggingface/errors.go +++ b/core/providers/huggingface/errors.go @@ -10,7 +10,7 @@ import ( ) // parseHuggingFaceImageError parses HuggingFace error responses -func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp HuggingFaceResponseError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -53,13 +53,5 @@ func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.Req bifrostErr.Error.Message = errorResp.Error } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index f2fa8a4547..32dadecc61 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -254,12 +254,12 @@ func (provider *HuggingFaceProvider) completeRequest(ctx *schemas.BifrostContext // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp, nil) + return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -325,7 +325,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)} + resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)} return } @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency @@ -459,10 +459,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionRequest, - }, } } if inferenceProvider != "" { @@ -483,8 +479,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, reqBody.Stream = schemas.Ptr(false) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -518,9 +513,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse.Object = "chat.completion" } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -550,10 +542,6 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionStreamRequest, - }, } } if inferenceProvider != "" { @@ -610,9 +598,6 @@ func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -644,10 +629,6 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.EmbeddingRequest, - }, } } @@ -657,8 +638,7 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceEmbeddingRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -698,13 +678,10 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key // Unmarshal directly to BifrostEmbeddingResponse with custom logic bifrostResponse, convErr := UnmarshalHuggingFaceEmbeddingResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -735,10 +712,6 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.SpeechRequest, - }, } } @@ -747,8 +720,7 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -784,18 +756,15 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch // Download the audio file from the URL audioData, downloadErr := provider.downloadAudioFromURL(ctx, response.Audio.URL) if downloadErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse, convErr := response.ToBifrostSpeechResponse(request.Model, audioData) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.SpeechRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -833,10 +802,6 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.TranscriptionRequest, - }, } } @@ -846,7 +811,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, isHFInferenceAudioRequest := inferenceProvider == hfInference if inferenceProvider == hfInference { if request.Input == nil || len(request.Input.File) == 0 { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests"), provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests")) } jsonData = request.Input.File } else { @@ -856,8 +821,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceTranscriptionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -905,13 +869,10 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, bifrostResponse, convErr := response.ToBifrostTranscriptionResponse(request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TranscriptionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -945,10 +906,6 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationRequest, - }, } } @@ -958,8 +915,7 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageGenerationRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -999,15 +955,12 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext // Unmarshal response using Nebius converter bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1039,10 +992,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationStreamRequest, - }, } } @@ -1050,11 +999,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image generation streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } - providerName := provider.GetProviderKey() // Set headers headers := map[string]string{ @@ -1072,8 +1018,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageStreamRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1105,9 +1050,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1123,9 +1065,9 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1134,11 +1076,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1161,9 +1099,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1184,6 +1120,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1202,14 +1140,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1232,11 +1163,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1262,11 +1188,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageGenerationEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1306,11 +1229,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Type: schemas.ImageGenerationEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ @@ -1354,10 +1274,6 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditRequest, - }, } } @@ -1372,8 +1288,7 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageEditRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1409,15 +1324,12 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key // Unmarshal response bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1449,10 +1361,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditStreamRequest, - }, } } @@ -1460,9 +1368,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image edit streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } var authHeader map[string]string @@ -1488,15 +1394,13 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageEditRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1524,9 +1428,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1542,9 +1443,9 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1553,11 +1454,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1580,9 +1477,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1603,6 +1498,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1621,14 +1518,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1651,11 +1541,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1681,11 +1566,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageEditEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1725,11 +1607,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Type: schemas.ImageEditEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index bc2314af31..de615ccec2 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -13,7 +14,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,11 +23,20 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) + included := make(map[string]bool) + for _, model := range response.Models { if model.ModelID == "" { continue @@ -37,39 +47,33 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi continue } - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(model.ModelID) { - continue - } - - newModel := schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), - Name: schemas.Ptr(model.ModelID), - SupportedMethods: supported, - HuggingFaceID: schemas.Ptr(model.ID), - } - - bifrostResponse.Data = append(bifrostResponse.Data, newModel) - includedModels[strings.ToLower(model.ModelID)] = true - } - - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + // Aliases apply at the model level (model.ModelID), not at the compound + // "{providerKey}/{inferenceProvider}/{modelID}" level. + for _, result := range pipeline.FilterModel(model.ModelID) { + newModel := schemas.Model{ + // inferenceProvider stays in the compound ID; aliases rename only the model segment + ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID), + Name: schemas.Ptr(model.ModelID), + SupportedMethods: supported, + HuggingFaceID: schemas.Ptr(model.ID), } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + newModel.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, newModel) + included[strings.ToLower(result.ResolvedID)] = true } } + // Backfill: use standard pipeline. Note that backfilled HF entries use a simplified + // compound ID since we don't know which inferenceProvider to assign them to. + for _, m := range pipeline.BackfillModels(included) { + // Re-wrap the backfill ID to include the inferenceProvider segment + rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/") + m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID) + bifrostResponse.Data = append(bifrostResponse.Data, m) + } + return bifrostResponse } diff --git a/core/providers/huggingface/responses.go b/core/providers/huggingface/responses.go index fd68aa76a8..35ad2c336d 100644 --- a/core/providers/huggingface/responses.go +++ b/core/providers/huggingface/responses.go @@ -43,9 +43,6 @@ func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse responsesResp := resp.ToBifrostResponsesResponse() if responsesResp != nil { - responsesResp.ExtraFields.Provider = schemas.HuggingFace - responsesResp.ExtraFields.ModelRequested = requestedModel - responsesResp.ExtraFields.RequestType = schemas.ResponsesRequest } return responsesResp, nil diff --git a/core/providers/huggingface/speech.go b/core/providers/huggingface/speech.go index 65c0ba6e12..f702d1f39f 100644 --- a/core/providers/huggingface/speech.go +++ b/core/providers/huggingface/speech.go @@ -125,10 +125,6 @@ func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedMode // Create the base Bifrost response with the downloaded audio data bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: audioData, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Note: HuggingFace TTS API typically doesn't return usage information diff --git a/core/providers/huggingface/transcription.go b/core/providers/huggingface/transcription.go index 0d892cb07c..f3ff5c293a 100644 --- a/core/providers/huggingface/transcription.go +++ b/core/providers/huggingface/transcription.go @@ -144,10 +144,6 @@ func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse // Create the base Bifrost response bifrostResponse := &schemas.BifrostTranscriptionResponse{ Text: response.Text, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Map chunks to segments if available diff --git a/core/providers/huggingface/utils.go b/core/providers/huggingface/utils.go index ad68e4c6ac..b96210c832 100644 --- a/core/providers/huggingface/utils.go +++ b/core/providers/huggingface/utils.go @@ -221,8 +221,6 @@ func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappin } func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Check cache first if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok { if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok { @@ -259,12 +257,12 @@ func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx contex body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var mappingResp HuggingFaceInferenceProviderMappingResponse if err := sonic.Unmarshal(body, &mappingResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } mappings := convertToInferenceProviderMappings(&mappingResp) diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 244d390437..fb11cfbc88 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -74,8 +74,6 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -101,7 +99,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } @@ -116,7 +114,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, request.Unfiltered) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -212,9 +210,6 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -273,7 +268,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form body @@ -310,12 +305,12 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -343,20 +338,17 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostTranscriptionResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.TranscriptionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -378,7 +370,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } mistralReq.Stream = schemas.Ptr(true) @@ -433,9 +425,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -445,7 +437,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -464,9 +456,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -505,7 +497,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) } break } @@ -553,11 +545,6 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -586,11 +573,8 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( // Set extra fields response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(*lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(*lastChunkTime).Milliseconds(), } *lastChunkTime = time.Now() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index 9b1002e54b..8d5fd7f3d6 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -3,10 +3,11 @@ package mistral import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,44 +16,40 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Mistral, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + model.ID, - Name: schemas.Ptr(model.Name), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.Created), - ContextLength: schemas.Ptr(int(model.MaxContextLength)), - OwnedBy: schemas.Ptr(model.OwnedBy), - }) - includedModels[strings.ToLower(model.ID)] = true - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Mistral) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.Created), + ContextLength: schemas.Ptr(int(model.MaxContextLength)), + OwnedBy: schemas.Ptr(model.OwnedBy), } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - includedModels[strings.ToLower(allowedModel)] = true + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/mistral/transcription.go b/core/providers/mistral/transcription.go index a4a018e5c6..fe9b262126 100644 --- a/core/providers/mistral/transcription.go +++ b/core/providers/mistral/transcription.go @@ -109,58 +109,58 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *Mi } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(req.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } // Add model field (required) if err := writer.WriteField("model", req.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add stream field if streaming if req.Stream != nil && *req.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } // Add optional fields if req.Language != nil { if err := writer.WriteField("language", *req.Language); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language field", err) } } if req.Prompt != nil { if err := writer.WriteField("prompt", *req.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } } if req.ResponseFormat != nil { if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if req.Temperature != nil { if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range req.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } // Close the multipart writer to finalize the form if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/nebius/errors.go b/core/providers/nebius/errors.go index 98d0fb78d8..de8bcf0d84 100644 --- a/core/providers/nebius/errors.go +++ b/core/providers/nebius/errors.go @@ -9,7 +9,7 @@ import ( ) // parseNebiusImageError parses Nebius error responses -func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseNebiusImageError(resp *fasthttp.Response) *schemas.BifrostError { var nebiusErr NebiusError bifrostErr := providerUtils.HandleProviderAPIError(resp, &nebiusErr) @@ -60,13 +60,5 @@ func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestM bifrostErr.Error.Message = message } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index d8bd8a2256..42429b87ab 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -193,9 +193,6 @@ func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -260,16 +257,15 @@ func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Validate request is not nil if request == nil { - return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil) } // Validate input and prompt are not nil/empty if request.Input == nil || strings.TrimSpace(request.Input.Prompt) == "" { - return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil) } path := providerUtils.GetPathFromContext(ctx, "/v1/images/generations") - providerName := schemas.Nebius // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -304,8 +300,7 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return provider.ToNebiusImageGenerationRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -323,16 +318,12 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp), jsonData, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } response := &schemas.BifrostImageGenerationResponse{} @@ -352,9 +343,6 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 04aefd2338..82b9471a41 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -64,40 +64,12 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// getBaseURL resolves the base URL for a request from the per-key ollama_key_config. -// Each Ollama key must have its own URL configured — there is no provider-level fallback. -func (provider *OllamaProvider) getBaseURL(key schemas.Key) string { - if key.OllamaKeyConfig != nil && key.OllamaKeyConfig.URL.GetValue() != "" { - return strings.TrimRight(key.OllamaKeyConfig.URL.GetValue(), "/") - } - return "" -} - -// baseURLOrError returns the resolved base URL or a BifrostError when none is configured. -func (provider *OllamaProvider) baseURLOrError(key schemas.Key) (string, *schemas.BifrostError) { - u := provider.getBaseURL(key) - if u == "" { - return "", providerUtils.NewBifrostOperationError( - "no base URL configured: set ollama_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) - } - return u, nil -} - -// listModelsByKey performs a list models request for a single Ollama key, -// resolving the per-key URL so each backend is queried individually. +// listModelsByKey performs a list models request for a single Ollama key. func (provider *OllamaProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } - url := baseURL + providerUtils.GetPathFromContext(ctx, "/v1/models") return openai.ListModelsByKey( ctx, provider.client, - url, + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), key, request.Unfiltered, provider.networkConfig.ExtraHeaders, @@ -121,14 +93,10 @@ func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []s // TextCompletion performs a text completion request to the Ollama API. func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -145,14 +113,10 @@ func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key // It formats the request, sends it to Ollama, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -169,14 +133,10 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext // ChatCompletion performs a chat completion request to the Ollama API. func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -194,15 +154,11 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -227,9 +183,6 @@ func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -247,14 +200,10 @@ func (provider *OllamaProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Embedding performs an embedding request to the Ollama API. func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/openai/batch.go b/core/providers/openai/batch.go index ae095e5c77..ec8ce468bb 100644 --- a/core/providers/openai/batch.go +++ b/core/providers/openai/batch.go @@ -10,10 +10,10 @@ import ( // OpenAIBatchRequest represents the request body for creating a batch. type OpenAIBatchRequest struct { - InputFileID string `json:"input_file_id"` - Endpoint string `json:"endpoint"` - CompletionWindow string `json:"completion_window"` - Metadata map[string]string `json:"metadata,omitempty"` + InputFileID string `json:"input_file_id"` + Endpoint string `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]string `json:"metadata,omitempty"` OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"` } @@ -82,7 +82,7 @@ func ToBifrostBatchStatus(status string) schemas.BatchStatus { } // ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response. -func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, Object: r.Object, @@ -95,9 +95,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. OutputFileID: r.OutputFileID, ErrorFileID: r.ErrorFileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -125,7 +123,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. } // ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response. -func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: r.Object, @@ -146,9 +144,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema ErrorFileID: r.ErrorFileID, Errors: r.Errors, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -174,35 +170,3 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema return resp } - -// splitJSONL splits JSONL content into individual lines. -func splitJSONL(data []byte) [][]byte { - var lines [][]byte - start := 0 - for i, b := range data { - if b == '\n' { - if i > start { - end := i - // Strip trailing \r if present (handle CRLF) - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - start = i + 1 - } - } - if start < len(data) { - end := len(data) - // Strip trailing \r if present - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - return lines -} diff --git a/core/providers/openai/errors.go b/core/providers/openai/errors.go index 6a5bc1ce08..69d0aff407 100644 --- a/core/providers/openai/errors.go +++ b/core/providers/openai/errors.go @@ -10,10 +10,10 @@ import ( ) // ErrorConverter is a function that converts provider-specific error responses to BifrostError. -type ErrorConverter func(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError +type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError // ParseOpenAIError parses OpenAI error responses. -func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp schemas.BifrostError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -49,11 +49,6 @@ func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, } // Set ExtraFields unconditionally so provider/model/request metadata is always attached - if bifrostErr != nil { - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - } return bifrostErr } diff --git a/core/providers/openai/errors_test.go b/core/providers/openai/errors_test.go index f33008600b..1132a92723 100644 --- a/core/providers/openai/errors_test.go +++ b/core/providers/openai/errors_test.go @@ -12,7 +12,7 @@ func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *tes resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`) - errResp := ParseOpenAIError(&resp, schemas.ResponsesStreamRequest, schemas.Cerebras, "llama3.1-8b") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -29,7 +29,7 @@ func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) { resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -43,7 +43,7 @@ func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBody(nil) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -59,7 +59,7 @@ func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -73,7 +73,7 @@ func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing. // fasthttp defaults zero-value response status code to 200. resp.SetBodyString(`{"error":{"message":""}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } diff --git a/core/providers/openai/files.go b/core/providers/openai/files.go index bbaf2b2f70..133250cac7 100644 --- a/core/providers/openai/files.go +++ b/core/providers/openai/files.go @@ -55,7 +55,7 @@ func ToBifrostFileStatus(status string) schemas.FileStatus { } // ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response. -func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Object, @@ -67,9 +67,7 @@ func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mo StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -97,9 +95,7 @@ func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas. StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/openai/images.go b/core/providers/openai/images.go index f183e17ec5..9176f1e1e7 100644 --- a/core/providers/openai/images.go +++ b/core/providers/openai/images.go @@ -125,18 +125,18 @@ func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *Open func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add prompt field (required) if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add stream field when requesting streaming if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } @@ -168,71 +168,71 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err) } if _, err := part.Write(imageInput.Image); err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err) } } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Quality != nil { if err := writer.WriteField("quality", *openaiReq.Quality); err != nil { - return providerUtils.NewBifrostOperationError("failed to write quality field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write quality field", err) } } if openaiReq.Background != nil { if err := writer.WriteField("background", *openaiReq.Background); err != nil { - return providerUtils.NewBifrostOperationError("failed to write background field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write background field", err) } } if openaiReq.InputFidelity != nil { if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err) } } if openaiReq.PartialImages != nil { if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write partial_images field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write partial_images field", err) } } if openaiReq.OutputFormat != nil { if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_format field", err) } } if openaiReq.OutputCompression != nil { if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_compression field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_compression field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } @@ -260,16 +260,16 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {maskMimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create mask form part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create mask form part", err) } if _, err := maskPart.Write(openaiReq.Mask); err != nil { - return providerUtils.NewBifrostOperationError("failed to write mask data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write mask data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil @@ -299,12 +299,12 @@ func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add image file (required) if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 { - return providerUtils.NewBifrostOperationError("image is required", nil, providerName) + return providerUtils.NewBifrostOperationError("image is required", nil) } // Detect MIME type @@ -320,41 +320,41 @@ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openai "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create image part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create image part", err) } if _, err := part.Write(openaiReq.Input.Image.Image); err != nil { - return providerUtils.NewBifrostOperationError("failed to write image data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write image data", err) } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/large_payload.go b/core/providers/openai/large_payload.go index 461f3417de..fe3aaf1812 100644 --- a/core/providers/openai/large_payload.go +++ b/core/providers/openai/large_payload.go @@ -42,8 +42,6 @@ func handleOpenAILargePayloadPassthrough( key schemas.Key, extraHeaders map[string]string, providerName schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) (*largePayloadResult, *schemas.BifrostError, bool) { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) @@ -91,7 +89,7 @@ func handleOpenAILargePayloadPassthrough( // Error responses are always small — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - parsedErr := ParseOpenAIError(resp, requestType, providerName, model) + parsedErr := ParseOpenAIError(resp) fasthttp.ReleaseResponse(resp) return nil, parsedErr, true } @@ -126,7 +124,7 @@ func finalizeOpenAIResponse( providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, *largePayloadResult, *schemas.BifrostError) { - body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, logger) + body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger) if bifrostErr != nil { fasthttp.ReleaseResponse(resp) return nil, nil, bifrostErr diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index 8268608568..a76d350d28 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -3,11 +3,12 @@ package openai import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,42 +17,39 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ID, - Created: model.Created, - OwnedBy: schemas.Ptr(model.OwnedBy), - ContextLength: model.ContextWindow, - }) - includedModels[strings.ToLower(model.ID)] = true - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Created: model.Created, + OwnedBy: schemas.Ptr(model.OwnedBy), + ContextLength: model.ContextWindow, } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index de6d151519..7bc9ad73d8 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -166,7 +166,7 @@ func ListModelsByKey( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseOpenAIError(resp) return nil, bifrostErr } @@ -181,10 +181,8 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, unfiltered) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -289,22 +287,22 @@ func HandleOpenAITextCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TextCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTextCompletionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -313,8 +311,7 @@ func HandleOpenAITextCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -335,9 +332,9 @@ func HandleOpenAITextCompletionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -349,7 +346,7 @@ func HandleOpenAITextCompletionRequest( return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -367,9 +364,6 @@ func HandleOpenAITextCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -455,8 +449,7 @@ func HandleOpenAITextCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -501,9 +494,9 @@ func HandleOpenAITextCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -514,9 +507,9 @@ func HandleOpenAITextCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -533,9 +526,9 @@ func HandleOpenAITextCompletionStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -557,7 +550,7 @@ func HandleOpenAITextCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -584,7 +577,7 @@ func HandleOpenAITextCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -595,11 +588,6 @@ func HandleOpenAITextCompletionStreaming( rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { // TODO fix this - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -618,11 +606,6 @@ func HandleOpenAITextCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -699,9 +682,6 @@ func HandleOpenAITextCompletionStreaming( if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { chunkIndex++ - response.ExtraFields.RequestType = schemas.TextCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -719,7 +699,7 @@ func HandleOpenAITextCompletionStreaming( } } - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -811,22 +791,22 @@ func HandleOpenAIChatCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ChatCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostChatResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -835,8 +815,7 @@ func HandleOpenAIChatCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -858,9 +837,9 @@ func HandleOpenAIChatCompletionRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -872,7 +851,7 @@ func HandleOpenAIChatCompletionRequest( return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } response := &schemas.BifrostChatResponse{} @@ -890,9 +869,6 @@ func HandleOpenAIChatCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ChatCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1008,8 +984,7 @@ func HandleOpenAIChatCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1054,9 +1029,9 @@ func HandleOpenAIChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1067,9 +1042,9 @@ func HandleOpenAIChatCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1082,19 +1057,13 @@ func HandleOpenAIChatCompletionStreaming( // Create response channel responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) - // Determine request type for cleanup - streamRequestType := schemas.ChatCompletionStreamRequest - if isResponsesToChatCompletionsFallback { - streamRequestType = schemas.ResponsesStreamRequest - } - // Start streaming in a goroutine go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } // Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback) schemas.ReleaseChatToResponsesStreamState(responsesStreamState) @@ -1118,7 +1087,7 @@ func HandleOpenAIChatCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1132,6 +1101,7 @@ func HandleOpenAIChatCompletionStreaming( var finishReason *string var messageID string + var modelName string var created int forwardedTerminalFinishReason := false @@ -1148,7 +1118,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -1161,11 +1131,6 @@ func HandleOpenAIChatCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1179,11 +1144,6 @@ func HandleOpenAIChatCompletionStreaming( if customResponseHandler != nil { rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -1214,11 +1174,6 @@ func HandleOpenAIChatCompletionStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: streamRequestType, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1236,9 +1191,6 @@ func HandleOpenAIChatCompletionStreaming( return } - response.ExtraFields.RequestType = streamRequestType - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber if sendBackRawResponse { @@ -1301,6 +1253,10 @@ func HandleOpenAIChatCompletionStreaming( response.Usage = nil } + if response.Model != "" { + modelName = response.Model + } + // Skip empty responses or responses without choices if len(response.Choices) == 0 { continue @@ -1333,9 +1289,6 @@ func HandleOpenAIChatCompletionStreaming( } chunkIndex++ - response.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -1359,7 +1312,7 @@ func HandleOpenAIChatCompletionStreaming( if forwardedTerminalFinishReason { finalFinishReason = nil } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, streamRequestType, providerName, request.Model, created) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, modelName, created) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1446,21 +1399,21 @@ func HandleOpenAIResponsesRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ResponsesRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostResponsesResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1470,8 +1423,7 @@ func HandleOpenAIResponsesRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1493,9 +1445,9 @@ func HandleOpenAIResponsesRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1506,7 +1458,7 @@ func HandleOpenAIResponsesRequest( if lpResult != nil { return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1524,9 +1476,6 @@ func HandleOpenAIResponsesRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ResponsesRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1623,8 +1572,7 @@ func HandleOpenAIResponsesStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1668,9 +1616,9 @@ func HandleOpenAIResponsesStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1681,9 +1629,9 @@ func HandleOpenAIResponsesStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1700,9 +1648,9 @@ func HandleOpenAIResponsesStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1724,7 +1672,7 @@ func HandleOpenAIResponsesStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1746,7 +1694,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1758,11 +1706,6 @@ func HandleOpenAIResponsesStreaming( if customResponseHandler != nil { rawRequest, rawResponse, bifrostErr := customResponseHandler([]byte(jsonData), &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } if sendBackRawRequest { bifrostErr.ExtraFields.RawRequest = rawRequest } @@ -1796,11 +1739,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1818,9 +1756,6 @@ func HandleOpenAIResponsesStreaming( return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete { @@ -1910,22 +1845,22 @@ func HandleOpenAIEmbeddingRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.EmbeddingRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostEmbeddingResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1935,8 +1870,7 @@ func HandleOpenAIEmbeddingRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIEmbeddingRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1957,7 +1891,7 @@ func HandleOpenAIEmbeddingRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.EmbeddingRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1969,7 +1903,7 @@ func HandleOpenAIEmbeddingRequest( return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1987,9 +1921,6 @@ func HandleOpenAIEmbeddingRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.EmbeddingRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2068,22 +1999,21 @@ func HandleOpenAISpeechRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.SpeechRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } // Speech response is raw audio bytes (MP3/WAV), not JSON return &schemas.BifrostSpeechResponse{ Audio: lpResult.ResponseBody, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }, - providerName) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }) if bifrostErr != nil { return nil, bifrostErr } @@ -2104,7 +2034,7 @@ func HandleOpenAISpeechRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Get the binary audio data from the response body @@ -2115,7 +2045,7 @@ func HandleOpenAISpeechRequest( } if lpResult != nil { return &schemas.BifrostSpeechResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2125,9 +2055,6 @@ func HandleOpenAISpeechRequest( bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: body, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2150,7 +2077,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo for _, model := range providerUtils.UnsupportedSpeechStreamModels { if model == request.Model { - return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil) } } @@ -2235,8 +2162,7 @@ func HandleOpenAISpeechStreamRequest( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2262,9 +2188,9 @@ func HandleOpenAISpeechStreamRequest( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -2274,7 +2200,7 @@ func HandleOpenAISpeechStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2291,9 +2217,9 @@ func HandleOpenAISpeechStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2315,7 +2241,7 @@ func HandleOpenAISpeechStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2339,7 +2265,7 @@ func HandleOpenAISpeechStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2351,11 +2277,6 @@ func HandleOpenAISpeechStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -2381,11 +2302,8 @@ func HandleOpenAISpeechStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2446,7 +2364,7 @@ func HandleOpenAITranscriptionRequest( logger schemas.Logger, ) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TranscriptionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } @@ -2454,13 +2372,13 @@ func HandleOpenAITranscriptionRequest( if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTranscriptionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2489,7 +2407,7 @@ func HandleOpenAITranscriptionRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form @@ -2516,7 +2434,7 @@ func HandleOpenAITranscriptionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -2526,7 +2444,7 @@ func HandleOpenAITranscriptionRequest( } if lpResult != nil { return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2565,7 +2483,7 @@ func HandleOpenAITranscriptionRequest( }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // TODO: add HandleProviderResponse here @@ -2573,7 +2491,7 @@ func HandleOpenAITranscriptionRequest( // Parse raw response for RawResponse field if sendBackRawResponse { if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err) } } } @@ -2583,9 +2501,6 @@ func HandleOpenAITranscriptionRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -2647,7 +2562,7 @@ func HandleOpenAITranscriptionStreamRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) if postRequestConverter != nil { @@ -2708,9 +2623,9 @@ func HandleOpenAITranscriptionStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -2720,7 +2635,7 @@ func HandleOpenAITranscriptionStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2737,9 +2652,9 @@ func HandleOpenAITranscriptionStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2761,7 +2676,7 @@ func HandleOpenAITranscriptionStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2786,7 +2701,7 @@ func HandleOpenAITranscriptionStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2797,11 +2712,6 @@ func HandleOpenAITranscriptionStreamRequest( if customResponseHandler != nil { _, _, bifrostErr = customResponseHandler([]byte(jsonData), response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } if sendBackRawResponse { bifrostErr.ExtraFields.RawResponse = jsonData } @@ -2816,13 +2726,9 @@ func HandleOpenAITranscriptionStreamRequest( var bifrostErrVal schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErrVal); err == nil { if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" { - bifrostErrVal.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, nil, nil, false, sendBackRawResponse), responseChan, logger) + respBody := append([]byte(nil), resp.Body()...) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger) return } } @@ -2846,11 +2752,8 @@ func HandleOpenAITranscriptionStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2940,20 +2843,20 @@ func HandleOpenAIImageGenerationRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageGenerationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2963,8 +2866,7 @@ func HandleOpenAIImageGenerationRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2985,7 +2887,7 @@ func HandleOpenAIImageGenerationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -2995,7 +2897,7 @@ func HandleOpenAIImageGenerationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -3007,9 +2909,6 @@ func HandleOpenAIImageGenerationRequest( return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -3036,7 +2935,7 @@ func (provider *OpenAIProvider) ImageGenerationStream( request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } // Check if image generation stream is allowed for this provider @@ -3110,8 +3009,7 @@ func HandleOpenAIImageGenerationStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3156,9 +3054,9 @@ func HandleOpenAIImageGenerationStreaming( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -3168,7 +3066,7 @@ func HandleOpenAIImageGenerationStreaming( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -3185,9 +3083,9 @@ func HandleOpenAIImageGenerationStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -3209,7 +3107,7 @@ func HandleOpenAIImageGenerationStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -3236,7 +3134,7 @@ func HandleOpenAIImageGenerationStreaming( if readErr != nil { if readErr != io.EOF { logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -3248,11 +3146,6 @@ func HandleOpenAIImageGenerationStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -3272,11 +3165,6 @@ func HandleOpenAIImageGenerationStreaming( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -3377,11 +3265,8 @@ func HandleOpenAIImageGenerationStreaming( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -3482,7 +3367,7 @@ func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3511,7 +3396,7 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3552,12 +3437,12 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -3575,8 +3460,6 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3592,7 +3475,7 @@ func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3662,10 +3545,10 @@ func HandleOpenAIVideoGenerationRequest( // Use centralized converter reqBody, err := ToOpenAIVideoGenerationRequest(request) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil) } // Create multipart form @@ -3691,12 +3574,12 @@ func HandleOpenAIVideoGenerationRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoGenerationRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -3722,9 +3605,6 @@ func HandleOpenAIVideoGenerationRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3789,12 +3669,12 @@ func HandleOpenAIVideoRetrieveRequest( if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRetrieveRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoGenerationResponse{} @@ -3836,8 +3716,6 @@ func HandleOpenAIVideoRetrieveRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3889,12 +3767,12 @@ func HandleOpenAIVideoDeleteRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDeleteRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse OpenAI's video response @@ -3908,8 +3786,6 @@ func HandleOpenAIVideoDeleteRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3982,12 +3858,12 @@ func HandleOpenAIVideoListRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoListResponse{} @@ -4014,8 +3890,6 @@ func HandleOpenAIVideoListRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4085,20 +3959,20 @@ func HandleOpenAICountTokensRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.CountTokensRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostCountTokensResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4107,9 +3981,7 @@ func HandleOpenAICountTokensRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -4130,7 +4002,7 @@ func HandleOpenAICountTokensRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.CountTokensRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4140,7 +4012,7 @@ func HandleOpenAICountTokensRequest( } if lpResult != nil { return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4153,9 +4025,6 @@ func HandleOpenAICountTokensRequest( } response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4203,26 +4072,26 @@ func HandleOpenAIImageEditRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageEditRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageEditRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4269,7 +4138,7 @@ func HandleOpenAIImageEditRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4279,7 +4148,7 @@ func HandleOpenAIImageEditRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4288,9 +4157,6 @@ func HandleOpenAIImageEditRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4354,7 +4220,7 @@ func HandleOpenAIImageEditStreamRequest( ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { reqBody := ToOpenAIImageEditRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -4414,9 +4280,9 @@ func HandleOpenAIImageEditStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -4425,7 +4291,7 @@ func HandleOpenAIImageEditStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditStreamRequest, providerName, request.Model), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -4442,9 +4308,9 @@ func HandleOpenAIImageEditStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -4466,7 +4332,7 @@ func HandleOpenAIImageEditStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -4493,7 +4359,7 @@ func HandleOpenAIImageEditStreamRequest( if readErr != nil { if readErr != io.EOF { logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -4505,11 +4371,6 @@ func HandleOpenAIImageEditStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -4529,11 +4390,6 @@ func HandleOpenAIImageEditStreamRequest( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -4634,11 +4490,8 @@ func HandleOpenAIImageEditStreamRequest( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -4738,26 +4591,26 @@ func HandleOpenAIImageVariationRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageVariationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageVariationRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4803,7 +4656,7 @@ func HandleOpenAIImageVariationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageVariationRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4813,7 +4666,7 @@ func HandleOpenAIImageVariationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4822,9 +4675,6 @@ func HandleOpenAIImageVariationRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageVariationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4841,14 +4691,12 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Create multipart form data @@ -4857,16 +4705,16 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add expires_after fields if provided if request.ExpiresAfter != nil { if err := writer.WriteField("expires_after[anchor]", request.ExpiresAfter.Anchor); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err) } if err := writer.WriteField("expires_after[seconds]", fmt.Sprintf("%d", request.ExpiresAfter.Seconds)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err) } } @@ -4877,14 +4725,14 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -4914,13 +4762,13 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp OpenAIFileResponse @@ -4931,7 +4779,7 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, bifrostErr } - fileResponse := openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) + fileResponse := openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) fileResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) return fileResponse, nil } @@ -4950,7 +4798,7 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -4961,10 +4809,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -5014,12 +4858,12 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIFileListResponse @@ -5055,8 +4899,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -5077,7 +4919,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5112,7 +4954,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5122,7 +4964,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5153,7 +4995,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5188,7 +5030,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5198,7 +5040,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5219,9 +5061,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5248,7 +5088,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -5279,7 +5119,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5289,7 +5129,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5308,9 +5148,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -5327,10 +5165,10 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } if request.Input == nil || request.Input.Prompt == "" { - return nil, providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("prompt is required", nil) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -5338,8 +5176,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIVideoRemixRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -5377,12 +5214,12 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRemixRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Parse OpenAI's video response @@ -5400,9 +5237,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRemixRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } if sendBackRawResponse { @@ -5421,8 +5256,6 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -5430,7 +5263,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Convert inline requests to JSONL format jsonlData, err := ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" @@ -5449,12 +5282,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil) } // Validate that we have an endpoint if request.Endpoint == "" { - return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil) } // Create request @@ -5489,7 +5322,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -5505,12 +5338,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, ""), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var openAIResp OpenAIBatchResponse @@ -5519,7 +5352,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -5529,14 +5362,13 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -5547,10 +5379,6 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -5594,12 +5422,12 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIBatchListResponse @@ -5612,7 +5440,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -5626,9 +5454,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -5645,10 +5471,9 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, request.Provider) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5680,7 +5505,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5690,7 +5515,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5706,8 +5531,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -5721,10 +5545,9 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5756,7 +5579,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5766,7 +5589,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5789,9 +5612,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5831,11 +5652,9 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (this already iterates over keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -5846,7 +5665,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file - try each key @@ -5876,7 +5695,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5886,7 +5705,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5910,9 +5729,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5932,14 +5749,12 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.Name == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil) } // Build request body @@ -5975,7 +5790,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key jsonBody, err := providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create request @@ -6004,7 +5819,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, ParseOpenAIError(resp, schemas.ContainerCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6038,9 +5853,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6057,16 +5870,14 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // ContainerList lists containers via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } @@ -6080,7 +5891,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6091,10 +5902,6 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys Object: "list", Data: []schemas.ContainerObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - }, }, nil } @@ -6141,7 +5948,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.ContainerListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6176,9 +5983,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6199,20 +6004,18 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // ContainerRetrieve retrieves a specific container via OpenAI's API. func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerRetrieveRequest); err != nil { @@ -6247,7 +6050,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6287,9 +6090,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6310,20 +6111,18 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // ContainerDelete deletes a container via OpenAI's API. func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerDeleteRequest); err != nil { @@ -6358,7 +6157,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6386,9 +6185,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6417,14 +6214,12 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } // Create request @@ -6441,7 +6236,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle file upload (multipart only) if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil) } // Multipart file upload @@ -6451,13 +6246,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Add file part, err := writer.CreateFormFile("file", "file") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err) } if _, err = part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.SetBody(body.Bytes()) @@ -6475,13 +6270,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -6510,9 +6305,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6530,21 +6323,19 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // ContainerFileList lists files in a container via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6558,7 +6349,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6569,10 +6360,6 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k Object: "list", Data: []schemas.ContainerFileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - }, }, nil } @@ -6618,13 +6405,13 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k } if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var listResp struct { @@ -6656,9 +6443,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6679,13 +6464,11 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // ContainerFileRetrieve retrieves a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6694,15 +6477,15 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6730,7 +6513,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6739,7 +6522,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6774,9 +6557,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6797,13 +6578,11 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // ContainerFileContent retrieves the content of a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6812,15 +6591,15 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6848,7 +6627,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6865,7 +6644,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } content := append([]byte(nil), body...) @@ -6874,9 +6653,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileContentRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6900,13 +6677,11 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext // ContainerFileDelete deletes a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6915,15 +6690,15 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6951,7 +6726,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6960,7 +6735,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6987,9 +6762,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -7058,7 +6831,7 @@ func (provider *OpenAIProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -7074,9 +6847,6 @@ func (provider *OpenAIProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -7144,9 +6914,9 @@ func (provider *OpenAIProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := make(map[string]string) @@ -7160,9 +6930,7 @@ func (provider *OpenAIProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Wrap reader with idle timeout to detect stalled streams. @@ -7171,11 +6939,7 @@ func (provider *OpenAIProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -7185,9 +6949,9 @@ func (provider *OpenAIProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -7236,7 +7000,7 @@ func (provider *OpenAIProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/openai/transcription.go b/core/providers/openai/transcription.go index 8ab2305b05..8c2bf112a1 100644 --- a/core/providers/openai/transcription.go +++ b/core/providers/openai/transcription.go @@ -54,63 +54,63 @@ func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return utils.NewBifrostOperationError("failed to create form file", err, providerName) + return utils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(openaiReq.File); err != nil { - return utils.NewBifrostOperationError("failed to write file data", err, providerName) + return utils.NewBifrostOperationError("failed to write file data", err) } // Add model field if err := writer.WriteField("model", openaiReq.Model); err != nil { - return utils.NewBifrostOperationError("failed to write model field", err, providerName) + return utils.NewBifrostOperationError("failed to write model field", err) } // Add optional fields if openaiReq.Language != nil { if err := writer.WriteField("language", *openaiReq.Language); err != nil { - return utils.NewBifrostOperationError("failed to write language field", err, providerName) + return utils.NewBifrostOperationError("failed to write language field", err) } } if openaiReq.Prompt != nil { if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil { - return utils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return utils.NewBifrostOperationError("failed to write prompt field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return utils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return utils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Temperature != nil { if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil { - return utils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return utils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range openaiReq.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } for _, include := range openaiReq.Include { if err := writer.WriteField("include[]", include); err != nil { - return utils.NewBifrostOperationError("failed to write include field", err, providerName) + return utils.NewBifrostOperationError("failed to write include field", err) } } if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return utils.NewBifrostOperationError("failed to write stream field", err, providerName) + return utils.NewBifrostOperationError("failed to write stream field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return utils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return utils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/videos.go b/core/providers/openai/videos.go index aa4052e029..512306b7c7 100644 --- a/core/providers/openai/videos.go +++ b/core/providers/openai/videos.go @@ -132,30 +132,30 @@ func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *sc func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add prompt field (required) if openaiReq.Prompt == "" { - return providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return providerUtils.NewBifrostOperationError("prompt is required", nil) } if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add optional model field if openaiReq.Model != "" { if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } } // Add optional seconds field if openaiReq.Seconds != nil { if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seconds field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seconds field", err) } } // Add optional size field if openaiReq.Size != "" { if err := writer.WriteField("size", openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } @@ -196,16 +196,16 @@ func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, opena "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err) } if _, err := part.Write(openaiReq.InputReference); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index eda7353372..a1e746834c 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -94,12 +94,12 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // Check for auth errors (401, 403) statusCode := resp.StatusCode() if statusCode == fasthttp.StatusUnauthorized || statusCode == fasthttp.StatusForbidden { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } // Any 4xx/5xx error indicates the key might be invalid if statusCode >= 400 { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } return nil @@ -108,8 +108,6 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Validate the key first using /v1/auth/key (only during provider add/update). // OpenRouter's /v1/models doesn't require auth, so we need this extra check. shouldValidate := false @@ -157,7 +155,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Continue with empty response; allowed models will be backfilled below. modelsFetched = false } else { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } } @@ -184,51 +182,62 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, } } - // Filter by key.Models - allowedModels := key.Models - blacklistedModels := key.BlacklistedModels + // OpenRouter model IDs in the API response do NOT include the "openrouter/" prefix + // (e.g. the API returns "openai/gpt-4", not "openrouter/openai/gpt-4"). + // Users may supply allowedModels / aliases with or without the prefix, so we + // normalize both by stripping it before feeding into the shared pipeline. providerPrefix := string(schemas.OpenRouter) + "/" + stripPrefix := func(s string) string { + if strings.HasPrefix(strings.ToLower(s), strings.ToLower(providerPrefix)) { + return s[len(providerPrefix):] + } + return s + } + + normalizedAllowed := make(schemas.WhiteList, 0, len(key.Models)) + for _, m := range key.Models { + normalizedAllowed = append(normalizedAllowed, stripPrefix(m)) + } + normalizedBlacklist := make(schemas.BlackList, 0, len(key.BlacklistedModels)) + for _, m := range key.BlacklistedModels { + normalizedBlacklist = append(normalizedBlacklist, stripPrefix(m)) + } + normalizedAliases := make(map[string]string, len(key.Aliases)) + for k, v := range key.Aliases { + normalizedAliases[stripPrefix(k)] = stripPrefix(v) + } + + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: normalizedAllowed, + BlacklistedModels: normalizedBlacklist, + Aliases: normalizedAliases, + Unfiltered: request.Unfiltered, + ProviderKey: schemas.OpenRouter, + MatchFns: providerUtils.DefaultMatchFns(), + } - if !request.Unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + if pipeline.ShouldEarlyExit() { openrouterResponse.Data = make([]schemas.Model, 0) - } else if !request.Unfiltered && allowedModels.IsRestricted() { + } else { + included := make(map[string]bool) filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) - includedModels := make(map[string]bool) for i := range openrouterResponse.Data { + // rawID has no "openrouter/" prefix — e.g. "openai/gpt-4" rawID := openrouterResponse.Data[i].ID - if !(allowedModels.Contains(rawID) || allowedModels.Contains(providerPrefix+rawID)) { - continue - } - if blacklistedModels.IsBlocked(rawID) || blacklistedModels.IsBlocked(providerPrefix+rawID) { - continue - } - openrouterResponse.Data[i].ID = providerPrefix + rawID - filteredData = append(filteredData, openrouterResponse.Data[i]) - includedModels[strings.ToLower(rawID)] = true - } - // Backfill allowed models not in the API response - for _, allowedModel := range allowedModels { - // Strip provider prefix case-insensitively to handle any casing users may supply - rawID := allowedModel - if strings.HasPrefix(strings.ToLower(allowedModel), strings.ToLower(providerPrefix)) { - rawID = allowedModel[len(providerPrefix):] - } - if blacklistedModels.IsBlocked(rawID) || blacklistedModels.IsBlocked(providerPrefix+rawID) { - continue - } - if !includedModels[strings.ToLower(rawID)] { - filteredData = append(filteredData, schemas.Model{ - ID: providerPrefix + rawID, - Name: schemas.Ptr(rawID), - }) - includedModels[strings.ToLower(rawID)] = true // avoid duplicate backfill + for _, result := range pipeline.FilterModel(rawID) { + entry := openrouterResponse.Data[i] + entry.ID = providerPrefix + result.ResolvedID + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } else { + entry.Alias = nil + } + filteredData = append(filteredData, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + filteredData = append(filteredData, pipeline.BackfillModels(included)...) openrouterResponse.Data = filteredData - } else { - for i := range openrouterResponse.Data { - openrouterResponse.Data[i].ID = providerPrefix + openrouterResponse.Data[i].ID - } } openrouterResponse.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index 6dc6b74cca..280bf0b64d 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -145,9 +145,6 @@ func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/chat.go b/core/providers/perplexity/chat.go index dafe1c615b..a832ac0ad7 100644 --- a/core/providers/perplexity/chat.go +++ b/core/providers/perplexity/chat.go @@ -280,8 +280,6 @@ func (response *PerplexityChatResponse) ToBifrostChatResponse(model string) *sch Object: response.Object, Created: response.Created, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Perplexity, }, SearchResults: response.SearchResults, Videos: response.Videos, diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index d0e68a6850..2b0abbefb4 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -101,12 +101,12 @@ func (provider *PerplexityProvider) completeRequest(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) - return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, schemas.ChatCompletionRequest, provider.GetProviderKey(), model) + return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -141,8 +141,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToPerplexityChatCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -161,9 +160,6 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -223,9 +219,6 @@ func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key s } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/replicate/errors.go b/core/providers/replicate/errors.go index e7fc2051d0..1575d9ca77 100644 --- a/core/providers/replicate/errors.go +++ b/core/providers/replicate/errors.go @@ -15,9 +15,6 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: replicateErr.Detail, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -28,8 +25,5 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: string(body), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/replicate/files.go b/core/providers/replicate/files.go index cdd37a65c5..15ca0e13e8 100644 --- a/core/providers/replicate/files.go +++ b/core/providers/replicate/files.go @@ -30,8 +30,6 @@ func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } @@ -67,8 +65,6 @@ func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } diff --git a/core/providers/replicate/images.go b/core/providers/replicate/images.go index 72c031703c..0327b8b5fa 100644 --- a/core/providers/replicate/images.go +++ b/core/providers/replicate/images.go @@ -27,7 +27,6 @@ var modelInputImageFieldMap = map[string]string{ "black-forest-labs/flux-krea-dev": "image", } - // ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest { if bifrostReq == nil || bifrostReq.Input == nil { @@ -125,9 +124,6 @@ func ToBifrostImageGenerationResponse( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 206d0e0ca6..6c0c14dbf7 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -3,62 +3,64 @@ package replicate import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToBifrostListModelsResponse converts Replicate models and deployments to a Bifrost list models response +// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response. +// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion"). func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, + aliases map[string]string, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - includedModels := make(map[string]bool) - // Add deployments from /v1/deployments endpoint + included := make(map[string]bool) + if deploymentsResponse != nil { for _, deployment := range deploymentsResponse.Results { + // Replicate model IDs are composite owner/name deploymentID := deployment.Owner + "/" + deployment.Name - modelName := schemas.Ptr(deployment.Name) var created *int64 - - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(deploymentID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(deploymentID) { - continue - } - - // Extract information from current release if available - if deployment.CurrentRelease != nil { - // Parse created timestamp - if deployment.CurrentRelease.CreatedAt != "" { - createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) - if createdTimestamp > 0 { - created = schemas.Ptr(createdTimestamp) - } + if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" { + createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) + if createdTimestamp > 0 { + created = schemas.Ptr(createdTimestamp) } } - bifrostModel := schemas.Model{ - ID: string(providerKey) + "/" + deploymentID, - Name: modelName, - Deployment: modelName, - OwnedBy: schemas.Ptr(deployment.Owner), - Created: created, + for _, result := range pipeline.FilterModel(deploymentID) { + bifrostModel := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(deployment.Name), + OwnedBy: schemas.Ptr(deployment.Owner), + Created: created, + } + if result.AliasValue != "" { + bifrostModel.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) - includedModels[strings.ToLower(deploymentID)] = true } if deploymentsResponse.Next != nil { @@ -66,58 +68,8 @@ func ToBifrostListModelsResponse( } } - // Backfill allowed models that were not in the response - if !unfiltered && allowedModels.IsRestricted() { - for _, allowedModel := range allowedModels { - if blacklistedModels.IsBlocked(allowedModel) { - continue - } - if !includedModels[strings.ToLower(allowedModel)] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } - -// ToReplicateListModelsResponse converts a Bifrost list models response to a Replicate list models response -// This is mainly used for testing and compatibility -func ToReplicateListModelsResponse(response *schemas.BifrostListModelsResponse) *ReplicateModelListResponse { - if response == nil { - return nil - } - - replicateResponse := &ReplicateModelListResponse{ - Results: make([]ReplicateModelResponse, 0, len(response.Data)), - } - - for _, model := range response.Data { - modelID := strings.TrimPrefix(model.ID, string(schemas.Replicate)+"/") - replicateModel := ReplicateModelResponse{ - URL: "https://replicate.com/" + modelID, - Name: modelID, - } - - if model.Description != nil { - replicateModel.Description = model.Description - } - - if model.OwnedBy != nil { - replicateModel.Owner = *model.OwnedBy - } - - replicateResponse.Results = append(replicateResponse.Results, replicateModel) - } - - // Set next page token if available - if response.NextPageToken != "" { - next := response.NextPageToken - replicateResponse.Next = &next - } - - return replicateResponse -} diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 012eaa845e..3fe52615ec 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -149,7 +149,7 @@ func createPrediction( // Parse response body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, schemas.Replicate) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var prediction ReplicatePredictionResponse @@ -204,7 +204,7 @@ func getPrediction( // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, schemas.Replicate) + return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } prediction := &ReplicatePredictionResponse{} @@ -252,9 +252,7 @@ func pollPrediction( case <-pollCtx.Done(): return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError( schemas.ErrProviderRequestTimedOut, - fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds), - schemas.Replicate, - ) + fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds)) case <-ticker.C: prediction, rawResponse, providerResponseHeaders, err = getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse) if err != nil { @@ -277,6 +275,17 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont client := provider.client extraHeaders := provider.networkConfig.ExtraHeaders + if !key.ReplicateKeyConfig.UseDeploymentsEndpoint { + return ToBifrostListModelsResponse( + &ReplicateDeploymentListResponse{}, + providerName, + key.Models, + key.BlacklistedModels, + key.Aliases, + request.Unfiltered, + ), nil + } + // Build deployments URL deploymentsURL := provider.buildRequestURL(ctx, "/v1/deployments", schemas.ListModelsRequest) @@ -335,9 +344,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont if err := sonic.Unmarshal(bodyCopy, &pageResponse); err != nil { return nil, providerUtils.NewBifrostOperationError( "failed to parse deployments response", - err, - schemas.Replicate, - ) + err) } // Append results from this page @@ -362,6 +369,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont providerName, key.Models, key.BlacklistedModels, + key.Aliases, request.Unfiltered, ) @@ -375,11 +383,10 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys } if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } startTime := time.Now() - providerName := provider.GetProviderKey() response, err := providerUtils.HandleMultipleListModelsRequests( ctx, @@ -393,8 +400,6 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys // Update metadata with total latency latency := time.Since(startTime) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -406,17 +411,11 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -431,7 +430,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.TextCompletionRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -480,10 +479,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostTextCompletionResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -503,11 +499,6 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -519,8 +510,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -532,7 +522,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.TextCompletionStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -556,9 +546,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -589,9 +577,9 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -636,7 +624,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -667,11 +655,8 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -705,14 +690,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -727,14 +705,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -750,10 +721,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont nil, // usage - not available in done event finishReason, chunkIndex, - schemas.TextCompletionStreamRequest, - provider.GetProviderKey(), - request.Model, - ) + schemas.TextCompletionStreamRequest) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -781,17 +749,11 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -806,7 +768,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.ChatCompletionRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -855,10 +817,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostChatResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -878,11 +837,6 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -894,8 +848,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -907,7 +860,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.ChatCompletionStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -931,9 +884,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -964,9 +915,9 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1011,7 +962,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -1049,11 +1000,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1087,14 +1035,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1109,14 +1050,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1142,11 +1076,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1174,17 +1105,11 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -1199,7 +1124,7 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ResponsesRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -1246,9 +1171,6 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc // Convert to Bifrost response response := prediction.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1266,24 +1188,18 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } // Enable streaming (using sjson to set field directly, preserving key order) if updatedData, err := providerUtils.SetJSONField(jsonData, "stream", true); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err) } else { jsonData = updatedData } @@ -1295,7 +1211,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ResponsesStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -1319,9 +1235,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1360,9 +1274,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(streamErr, fasthttp.ErrTimeout) || errors.Is(streamErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1391,9 +1305,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1405,10 +1319,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if reader == nil { bifrostErr := providerUtils.NewBifrostOperationError( - "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - provider.GetProviderKey(), - ) + "provider returned an empty response", + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger) return @@ -1455,7 +1367,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()) + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1497,11 +1409,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } if sendBackRawRequest { @@ -1524,10 +1433,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1556,10 +1462,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1587,10 +1490,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1610,10 +1510,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, Delta: schemas.Ptr(currentEvent.Data), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1639,10 +1536,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ItemID: schemas.Ptr(itemID), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1665,10 +1559,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1702,10 +1593,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1725,11 +1613,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CompletedAt: schemas.Ptr(int(time.Now().Unix())), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } @@ -1762,14 +1647,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream error: %s", errorMsg), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } + fmt.Errorf("stream error: %s", errorMsg)) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1825,19 +1703,13 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageGenerationInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1852,7 +1724,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageGenerationRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -1904,10 +1776,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -1926,15 +1795,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -1943,8 +1806,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon replicateReq := ToReplicateImageGenerationInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1956,7 +1818,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon request.Model, provider.customProviderConfig, schemas.ImageGenerationStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction prediction, _, _, _, err := createPrediction( @@ -1977,10 +1839,16 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2011,9 +1879,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2060,7 +1928,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, provider.logger) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2105,11 +1974,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2143,36 +2009,24 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended with error")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2187,11 +2041,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon OutputFormat: lastOutputFormat, // Include output format CreatedAt: time.Now().Unix(), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2233,17 +2084,13 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon Error: &schemas.ErrorField{ Message: errorMsg, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Include accumulated raw responses in error if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2260,19 +2107,13 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageEditInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2287,7 +2128,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ImageEditRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -2339,10 +2180,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2361,15 +2199,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -2378,8 +2210,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, replicateReq := ToReplicateImageEditInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2391,7 +2222,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageEditStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -2413,10 +2244,16 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2447,9 +2284,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2494,18 +2331,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, if errors.Is(readErr, context.Canceled) { return } - bifrostErr := providerUtils.NewBifrostOperationError( - "stream read error", - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2548,11 +2376,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2586,34 +2411,22 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended with error")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2628,11 +2441,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: lastOutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2660,18 +2470,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, bifrostErr := providerUtils.NewBifrostOperationError( "stream error", - fmt.Errorf("%s", errorData.Detail), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("%s", errorData.Detail)) if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2693,21 +2497,13 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - - providerName := provider.GetProviderKey() - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateVideoGenerationInput(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2719,7 +2515,7 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.VideoGenerationRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -2748,13 +2544,10 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, if err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) + bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, schemas.Replicate) // Set extra fields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2774,7 +2567,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -2816,7 +2609,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -2828,12 +2621,10 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke bifrostResponse, convertErr := ToBifrostVideoGenerationResponse(&prediction) if convertErr != nil { - return nil, providerUtils.EnrichError(ctx, convertErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, convertErr, nil, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if sendBackRawResponse { @@ -2848,9 +2639,8 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve latest status/output first. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2864,19 +2654,17 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if videoResp.Videos[0].URL != nil { videoUrl = *videoResp.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2896,9 +2684,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2906,7 +2692,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -2920,8 +2706,6 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2977,7 +2761,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -3014,22 +2798,22 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s part, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } // Add filename field if provided if filename != "" { if err := writer.WriteField("filename", filename); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err) } } // Add type field (content type) if err := writer.WriteField("type", contentType); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write type field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write type field", err) } // Add metadata field if provided @@ -3038,24 +2822,24 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s if len(metadata) > 0 { metadataJSON, err := providerUtils.MarshalSorted(metadata) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } h := make(textproto.MIMEHeader) h.Set("Content-Disposition", `form-data; name="metadata"`) h.Set("Content-Type", "application/json") metadataPart, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err) } if _, err := metadataPart.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } } } } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3091,7 +2875,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var replicateResp ReplicateFileResponse @@ -3119,7 +2903,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper (Replicate uses cursor-based pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3130,10 +2914,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3182,7 +2962,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var replicateResp ReplicateFileListResponse @@ -3226,8 +3006,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: finalHasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3244,7 +3022,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3289,7 +3067,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3321,7 +3099,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3364,8 +3142,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3386,7 +3162,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3411,8 +3187,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, diff --git a/core/providers/replicate/utils.go b/core/providers/replicate/utils.go index 3279b0a847..1d88337539 100644 --- a/core/providers/replicate/utils.go +++ b/core/providers/replicate/utils.go @@ -31,17 +31,13 @@ func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.Bifro } return providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("%s", errorMsg), - schemas.Replicate, - ) + fmt.Errorf("%s", errorMsg)) } if prediction.Status == ReplicatePredictionStatusCanceled { return providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("prediction was canceled"), - schemas.Replicate, - ) + fmt.Errorf("prediction was canceled")) } return nil @@ -126,9 +122,9 @@ func listenToReplicateStreamURL( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -178,24 +174,12 @@ func isVersionID(s string) bool { return versionIDPattern.MatchString(s) } -// resolveDeploymentModel checks if the model maps to a deployment. -// Returns the resolved model and whether it is a deployment. -func resolveDeploymentModel(model string, key schemas.Key) (string, bool) { - if key.ReplicateKeyConfig == nil || key.ReplicateKeyConfig.Deployments == nil { - return model, false - } - if deployment, ok := key.ReplicateKeyConfig.Deployments[model]; ok && deployment != "" { - return deployment, true - } - return model, false -} - // buildPredictionURL builds the appropriate URL for creating a prediction // Returns the URL for the appropriate prediction endpoint. -func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, isDeployment bool) string { +func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string { var defaultPath string - if isDeployment { + if useDeploymentsEndpoint { defaultPath = "/v1/deployments/" + model + "/predictions" } else if isVersionID(model) { // If model is a version ID, use base predictions endpoint diff --git a/core/providers/replicate/videos.go b/core/providers/replicate/videos.go index b6dadaab55..3a277d067d 100644 --- a/core/providers/replicate/videos.go +++ b/core/providers/replicate/videos.go @@ -87,9 +87,6 @@ func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) ( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/runway/errors.go b/core/providers/runway/errors.go index a64f8ffc60..d9259e825f 100644 --- a/core/providers/runway/errors.go +++ b/core/providers/runway/errors.go @@ -9,7 +9,7 @@ import ( ) // parseRunwayError parses Runway API error responses and converts them to BifrostError. -func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseRunwayError(resp *fasthttp.Response) *schemas.BifrostError { // Parse as RunwayAPIError var errorResp RunwayAPIError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -34,12 +34,5 @@ func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Message = strings.TrimRight(bifrostErr.Error.Message, "\n") } - // Set metadata - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } - return bifrostErr } diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index d512742afd..a4d95c2bf8 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -165,8 +165,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostReq, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToRunwayVideoGenerationRequest(bifrostReq) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -205,17 +204,14 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -232,10 +228,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key Object: "video", Status: schemas.VideoStatusQueued, ExtraFields: schemas.BifrostResponseExtraFields{ - Latency: latency.Milliseconds(), - Provider: providerName, - ModelRequested: model, - RequestType: schemas.VideoGenerationRequest, + Latency: latency.Milliseconds(), }, } @@ -282,16 +275,14 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -309,8 +300,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if sendBackRawRequest { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -324,7 +313,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // VideoDownload retrieves a video from Runway's API. func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() // Retrieve task status to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ Provider: request.Provider, @@ -338,20 +326,21 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if taskDetails.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", taskDetails.Status), - nil, - providerName, - ) + nil) } if len(taskDetails.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if taskDetails.Videos[0].URL != nil { videoUrl = *taskDetails.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + // Download video from Runway's URL req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -367,14 +356,13 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } // Get content and content type body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -389,8 +377,6 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -402,7 +388,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("task_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("task_id is required", nil) } taskID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -434,10 +420,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch // Handle error response - Runway returns 204 No Content on success if resp.StatusCode() != fasthttp.StatusNoContent { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoDeleteRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Build response - Runway returns empty body on 204 @@ -448,8 +431,6 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch } response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.VideoDeleteRequest return response, nil } diff --git a/core/providers/runway/videos.go b/core/providers/runway/videos.go index 49b0cec237..809a8a1038 100644 --- a/core/providers/runway/videos.go +++ b/core/providers/runway/videos.go @@ -121,7 +121,7 @@ func ToRunwayVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Runway task details to Bifrost video generation response format. func ToBifrostVideoGenerationResponse(taskDetails *RunwayTaskDetailsResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if taskDetails == nil { - return nil, providerUtils.NewBifrostOperationError("task details is nil", nil, schemas.Runway) + return nil, providerUtils.NewBifrostOperationError("task details is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index 66db709a7b..25a1375d06 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -64,40 +64,13 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { return schemas.SGL } -// getBaseURL resolves the base URL for a request from the per-key sgl_key_config. -// Each SGL key must have its own URL configured — there is no provider-level fallback. -func (provider *SGLProvider) getBaseURL(key schemas.Key) string { - if key.SGLKeyConfig != nil && key.SGLKeyConfig.URL.GetValue() != "" { - return strings.TrimRight(key.SGLKeyConfig.URL.GetValue(), "/") - } - return "" -} - -// baseURLOrError returns the resolved base URL or a BifrostError when none is configured. -func (provider *SGLProvider) baseURLOrError(key schemas.Key) (string, *schemas.BifrostError) { - u := provider.getBaseURL(key) - if u == "" { - return "", providerUtils.NewBifrostOperationError( - "no base URL configured: set sgl_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) - } - return u, nil -} - // listModelsByKey performs a list models request for a single SGL key, // resolving the per-key URL so each backend is queried individually. func (provider *SGLProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } - url := baseURL + providerUtils.GetPathFromContext(ctx, "/v1/models") return openai.ListModelsByKey( ctx, provider.client, - url, + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), key, request.Unfiltered, provider.networkConfig.ExtraHeaders, @@ -121,14 +94,10 @@ func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []sche // TextCompletion performs a text completion request to the SGL API. func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -145,14 +114,10 @@ func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key sch // It formats the request, sends it to SGL, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -169,14 +134,10 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p // ChatCompletion performs a chat completion request to the SGL API. func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -194,15 +155,11 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -227,9 +184,6 @@ func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas. } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -247,14 +201,10 @@ func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo // Embedding performs an embedding request to the SGL API. func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - baseURL, bifrostErr := provider.baseURLOrError(key) - if bifrostErr != nil { - return nil, bifrostErr - } return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - baseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, @@ -458,4 +408,4 @@ func (provider *SGLProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Ke func (provider *SGLProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) -} \ No newline at end of file +} diff --git a/core/providers/utils/large_response.go b/core/providers/utils/large_response.go index a7e0e7bf36..e62d375c9a 100644 --- a/core/providers/utils/large_response.go +++ b/core/providers/utils/large_response.go @@ -116,7 +116,6 @@ func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Resp func FinalizeResponseWithLargeDetection( ctx *schemas.BifrostContext, resp *fasthttp.Response, - providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, bool, *schemas.BifrostError) { responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) @@ -125,7 +124,7 @@ func FinalizeResponseWithLargeDetection( if responseThreshold <= 0 { body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Copy body before caller releases resp return append([]byte(nil), body...), false, nil @@ -142,14 +141,14 @@ func FinalizeResponseWithLargeDetection( } bodyBytes, readErr := io.ReadAll(reader) if readErr != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } return bodyBytes, false, nil } // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -169,7 +168,7 @@ func FinalizeResponseWithLargeDetection( bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1)) if readErr != nil { releaseGzip() - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } if int64(len(bodyBytes)) <= responseThreshold { releaseGzip() @@ -195,7 +194,7 @@ func FinalizeResponseWithLargeDetection( // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -206,11 +205,11 @@ func FinalizeResponseWithLargeDetection( if bodyStream == nil { // No stream available — fall back to buffered read if logger != nil { - logger.Warn("large-response fallback to buffered path: provider=%s content_length=%d threshold=%d body_stream_nil=true", providerName, contentLength, responseThreshold) + logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold) } body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -232,7 +231,7 @@ func FinalizeResponseWithLargeDetection( if wasGzip { ReleaseGzipReader(gz) } - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } prefetchBuf = prefetchBuf[:n] diff --git a/core/providers/utils/make_request_test.go b/core/providers/utils/make_request_test.go index ce1610d7bb..ec2bf771bc 100644 --- a/core/providers/utils/make_request_test.go +++ b/core/providers/utils/make_request_test.go @@ -295,7 +295,7 @@ func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) } func TestNewBifrostTimeoutError(t *testing.T) { - err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded, "openai") + err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded) if !err.IsBifrostError { t.Fatal("expected IsBifrostError to be true") diff --git a/core/providers/utils/models.go b/core/providers/utils/models.go new file mode 100644 index 0000000000..a494f4fd69 --- /dev/null +++ b/core/providers/utils/models.go @@ -0,0 +1,351 @@ +// Package utils — list_models.go +// Centralised pipeline for filtering and backfilling models in ListModels responses. +// +// Every provider's ToBifrostListModelsResponse follows the same logical steps: +// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID) +// 2. Filter (allowlist + blacklist check on the resolved name) +// 3. Backfill entries that were not returned by the API but should appear in output +// +// Providers plug in custom MatchFns to extend the default matching behaviour. +// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns. +package utils + +import ( + "sort" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// ToDisplayName converts a raw model ID or alias key into a human-readable display name. +// Splits on "-" or "_", title-cases each word, and joins with spaces. +// +// "gemini-pro" → "Gemini Pro" +// "claude_3_opus" → "Claude 3 Opus" +// "gpt-4-turbo" → "Gpt 4 Turbo" +func ToDisplayName(id string) string { + caser := cases.Title(language.English) + parts := strings.FieldsFunc(id, func(r rune) bool { + return r == '-' || r == '_' + }) + if len(parts) == 0 { + return "" + } + for i, part := range parts { + if part != "" { + parts[i] = caser.String(strings.ToLower(part)) + } + } + return strings.Join(parts, " ") +} + +// MatchFn reports whether two model ID strings should be treated as equivalent. +// Functions are applied in order during every comparison — the first one that +// returns true short-circuits the rest. +// +// Example built-in fns (see DefaultMatchFns): +// +// exactMatch("gpt-4", "gpt-4") → true +// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true +type MatchFn func(a, b string) bool + +// DefaultMatchFns returns the standard matching functions used by most providers. +// Currently only performs case-insensitive exact matching. +// +// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet") +// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping. +// It can be appended here if fuzzy base-model matching is ever needed globally. +func DefaultMatchFns() []MatchFn { + return []MatchFn{ + func(a, b string) bool { return strings.EqualFold(a, b) }, + } +} + +// matches reports whether a and b are considered equal by any of the provided fns. +// Returns true on the first fn that returns true. +func matches(a, b string, fns []MatchFn) bool { + for _, fn := range fns { + if fn(a, b) { + return true + } + } + return false +} + +// FilterResult is the outcome of running Pipeline.FilterModel for a single model +// from the provider's API response. Each returned result represents one alias +// entry (or the raw model ID when no alias matched) that passed all filters. +type FilterResult struct { + // ResolvedID is the user-facing model name to use as the ID suffix. + // If the model matched an alias VALUE, this is the alias KEY. + // Otherwise this is the original model ID from the API response. + // + // Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"} + // → ResolvedID = "my-gpt4" + // Example: API returns "gpt-3.5-turbo", no alias match + // → ResolvedID = "gpt-3.5-turbo" + ResolvedID string + + // AliasValue is the provider-specific model ID when the model was matched + // via an alias. Set as the model.Alias field so callers know the underlying ID. + // Empty when the model was matched directly (no alias involved). + // + // Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched + // → AliasValue = "gpt-4-turbo" + AliasValue string +} + +// Pipeline holds all the context needed to filter and backfill models in a +// single ListModels response. Construct one per ToBifrostListModelsResponse call +// and use its methods instead of passing params + matchFns to every function. +// +// pipeline := &providerUtils.ListModelsPipeline{ +// AllowedModels: key.Models, +// BlacklistedModels: key.BlacklistedModels, +// Aliases: key.Aliases, +// Unfiltered: request.Unfiltered, +// ProviderKey: schemas.OpenAI, +// MatchFns: providerUtils.DefaultMatchFns(), +// } +// if pipeline.ShouldEarlyExit() { return empty } +// result := pipeline.FilterModel(model.ID) +// pipeline.BackfillModels(included) +type ListModelsPipeline struct { + AllowedModels schemas.WhiteList + BlacklistedModels schemas.BlackList + // Aliases maps user-facing alias keys to provider-specific model IDs. + // e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"} + Aliases map[string]string + Unfiltered bool + ProviderKey schemas.ModelProvider + // MatchFns is the ordered list of equivalence functions used for every + // model ID comparison. Use DefaultMatchFns() for standard behaviour; + // providers may append additional fns (e.g. Bedrock's region-prefix remover). + MatchFns []MatchFn +} + +// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately +// return an empty response without processing any models. +// +// Returns true when: +// - not unfiltered AND allowlist is empty AND no aliases configured +// (there is nothing to match against — all models would be filtered out anyway) +// - not unfiltered AND blacklist blocks everything +// +// Note: allowlist empty + aliases present → do NOT early exit. +// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels). +func (p *ListModelsPipeline) ShouldEarlyExit() bool { + if p.Unfiltered { + return false + } + if p.BlacklistedModels.IsBlockAll() { + return true + } + if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 { + return true + } + return false +} + +// aliasMatch holds a single alias key/value pair returned by resolveModelID. +type aliasMatch struct { + key string + value string +} + +// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns. +// Results are sorted by alias key (case-insensitive) for deterministic ordering. +// +// If one or more aliases match → returns one aliasMatch per matching alias key. +// +// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"} +// → [{key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}] +// +// If no alias matches → returns a single entry with the original model ID and no alias value. +// +// Example: modelID="gpt-3.5-turbo", no alias match +// → [{key:"gpt-3.5-turbo", value:""}] +func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch { + var candidates []aliasMatch + for aliasKey, providerID := range p.Aliases { + if matches(modelID, providerID, p.MatchFns) { + candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID}) + } + } + if len(candidates) == 0 { + return []aliasMatch{{key: modelID, value: ""}} + } + sort.Slice(candidates, func(i, j int) bool { + return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key) + }) + return candidates +} + +// FilterModel applies the full filter pipeline for a single model from the API response. +// +// Steps: +// 1. Resolve name — check alias VALUES for a match (uses MatchFns). +// If matched: resolvedName = alias KEY, aliasValue = provider ID. +// If not matched: resolvedName = original modelID, aliasValue = "". +// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard): +// Skip if resolvedName is not in AllowedModels. +// 3. Blacklist check (always): +// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything. +// 4. Return one FilterResult per passing candidate. +// +// An empty slice means the model should be skipped entirely. +// When multiple aliases map to the same provider model ID, each alias that passes +// the filters produces its own FilterResult entry. +// +// Examples: +// +// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// FilterModel("gpt-3.5") → [] (not in allowlist) +// +// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"}, +// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// +// allowedModels=["gpt-3.5"], aliases={}, blacklist=[] +// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}] +// FilterModel("gpt-4") → [] +func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult { + // Step 1: resolve name — collect all alias matches (or the raw ID if none match). + candidates := p.resolveModelID(modelID) + + var results []FilterResult + for _, candidate := range candidates { + resolvedName := candidate.key + + // Step 2: allowlist check. + // IsRestricted() is true for both an explicit list AND an empty list (deny-all). + // Only a wildcard allowlist marker bypasses this check (pass-through). + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + allowed := false + for _, entry := range p.AllowedModels { + if matches(resolvedName, entry, p.MatchFns) { + allowed = true + break + } + } + if !allowed { + continue + } + } + + // Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases. + if !p.Unfiltered { + blacklisted := false + for _, entry := range p.BlacklistedModels { + if matches(resolvedName, entry, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + } + + results = append(results, FilterResult{ + ResolvedID: resolvedName, + AliasValue: candidate.value, + }) + } + return results +} + +// BackfillModels adds model entries that were configured by the caller but not +// returned by the provider's API response (or not matched during filtering). +// +// The `included` map tracks model IDs (lowercased) already added during the +// filter pass, used to avoid duplicates. +// +// Two cases depending on whether the allowlist is restricted: +// +// Case A — allowlist restricted (caller specified explicit model names): +// +// Add each allowlist entry that is not yet in `included`, skip if blacklisted. +// If the entry has an alias mapping (aliases[entry] exists), set Alias to the +// provider-specific ID so callers can route to the right model. +// +// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"} +// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"} +// +// Case B — allowlist wildcard (*) only: +// +// We don't know all model names (no explicit list), so we only backfill entries +// that were explicitly configured via aliases and not yet matched from the API. +// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard. +// +// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included +// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// +// Blacklist always wins — nothing blacklisted is added in either case. +func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model { + var result []schemas.Model + + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + // Case A: backfill explicit allowlist entries not yet matched. + for _, entry := range p.AllowedModels { + if included[strings.ToLower(entry)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(entry, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + m := schemas.Model{ + ID: string(p.ProviderKey) + "/" + entry, + Name: schemas.Ptr(ToDisplayName(entry)), + } + // If this allowlist entry has an alias, surface the provider-specific ID. + for aliasKey, providerID := range p.Aliases { + if matches(entry, aliasKey, p.MatchFns) { + m.Alias = schemas.Ptr(providerID) + break + } + } + result = append(result, m) + } + return result + } + + // Case B: wildcard allowlist — backfill only explicitly configured aliases. + if !p.Unfiltered && len(p.Aliases) > 0 { + for aliasKey, providerID := range p.Aliases { + if included[strings.ToLower(aliasKey)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(aliasKey, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + result = append(result, schemas.Model{ + ID: string(p.ProviderKey) + "/" + aliasKey, + Name: schemas.Ptr(ToDisplayName(aliasKey)), + Alias: schemas.Ptr(providerID), + }) + } + } + + return result +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index ba6d16a08a..6cadc5a62c 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -178,12 +178,12 @@ func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f } // Check for timeout errors first before checking net.OpError to avoid misclassification if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check if error implements net.Error and has Timeout() == true var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -1043,7 +1043,7 @@ func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{ } // CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists. -func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter, providerType schemas.ModelProvider) ([]byte, *schemas.BifrostError) { +func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) { if IsLargePayloadPassthroughEnabled(ctx) { return nil, nil } @@ -1052,15 +1052,15 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette if !ok { convertedBody, err := requestConverter() if err != nil { - return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err) } if convertedBody == nil { - return nil, NewBifrostOperationError("request body is not provided", nil, providerType) + return nil, NewBifrostOperationError("request body is not provided", nil) } jsonBody, err := MarshalSortedIndent(convertedBody, "", " ") if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -1070,7 +1070,7 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette // tool schemas and other order-sensitive JSON structures. jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -1367,10 +1367,6 @@ func NewUnsupportedOperationError(requestType schemas.RequestType, providerName Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName), Code: schemas.Ptr("unsupported_operation"), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: requestType, - }, } } @@ -1593,37 +1589,31 @@ func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult // NewConfigurationError creates a standardized error for configuration errors. // This helper reduces code duplication across providers that have configuration errors. -func NewConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewConfigurationError(message string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostOperationError creates a standardized error for bifrost operation errors. // This helper reduces code duplication across providers that have bifrost operation errors. -func NewBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostOperationError(message string, err error) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: message, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostTimeoutError creates a standardized error for provider request timeout errors. // Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut, // consistent with HandleStreamTimeout for streaming requests. -func NewBifrostTimeoutError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError { statusCode := 504 errorType := schemas.RequestTimedOut return &schemas.BifrostError{ @@ -1634,15 +1624,12 @@ func NewBifrostTimeoutError(message string, err error, providerType schemas.Mode Type: &errorType, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewProviderAPIError creates a standardized error for provider API errors. // This helper reduces code duplication across providers that have provider API errors. -func NewProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { +func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -1653,20 +1640,9 @@ func NewProviderAPIError(message string, err error, statusCode int, providerType Error: err, Type: errorType, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } -// RequestMetadata contains metadata about a request for error reporting. -// This struct is used to pass request context to parseError functions. -type RequestMetadata struct { - Provider schemas.ModelProvider - Model string - RequestType schemas.RequestType -} - // ShouldSendBackRawRequest checks if the raw request should be captured. // Context overrides are intentionally restricted to asymmetric behavior: a context value can only // promote false→true and will not override a true config to false, avoiding accidental suppression. @@ -1694,17 +1670,14 @@ func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse b } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 0, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 0, + Latency: time.Since(startTime).Milliseconds(), }, } //TODO add bifrost response pooling here @@ -1715,17 +1688,14 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 1, + Latency: time.Since(startTime).Milliseconds(), }, } //TODO add bifrost response pooling here @@ -2015,9 +1985,6 @@ func HandleStreamCancellation( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2033,11 +2000,6 @@ func HandleStreamCancellation( Message: "Request cancelled: client disconnected", Type: schemas.Ptr(schemas.RequestCancelled), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2056,9 +2018,6 @@ func HandleStreamTimeout( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2074,11 +2033,6 @@ func HandleStreamTimeout( Message: "Request timed out: deadline exceeded", Type: schemas.Ptr(schemas.RequestTimedOut), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2094,9 +2048,6 @@ func ProcessAndSendError( postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStreamChunk, - requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, logger schemas.Logger, ) { // Send scanner error through channel @@ -2107,11 +2058,6 @@ func ProcessAndSendError( Message: fmt.Sprintf("Error reading stream: %v", err), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - }, } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) @@ -2144,8 +2090,6 @@ func CreateBifrostTextCompletionChunkResponse( finishReason *string, currentChunkIndex int, requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, ) *schemas.BifrostTextCompletionResponse { response := &schemas.BifrostTextCompletionResponse{ ID: id, @@ -2158,10 +2102,7 @@ func CreateBifrostTextCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2173,8 +2114,6 @@ func CreateBifrostChatCompletionChunkResponse( usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, - requestType schemas.RequestType, - providerName schemas.ModelProvider, model string, created int, ) *schemas.BifrostChatResponse { @@ -2183,7 +2122,7 @@ func CreateBifrostChatCompletionChunkResponse( Model: model, Created: created, Object: "chat.completion.chunk", - Usage: usage, + Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, @@ -2193,10 +2132,7 @@ func CreateBifrostChatCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2366,10 +2302,7 @@ func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse // extractSuccessfulListModelsResponses extracts successful responses from a results channel // and tracks per-key status information. This utility reduces code duplication across providers // for handling multi-key ListModels requests. -func extractSuccessfulListModelsResponses( - results chan schemas.ListModelsByKeyResult, - providerName schemas.ModelProvider, -) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { +func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { var successfulResponses []*schemas.BifrostListModelsResponse var keyStatuses []schemas.KeyStatus var lastError *schemas.BifrostError @@ -2387,7 +2320,7 @@ func extractSuccessfulListModelsResponses( getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg)) keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusListModelsFailed, Error: result.Err, }) @@ -2397,7 +2330,7 @@ func extractSuccessfulListModelsResponses( keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusSuccess, }) successfulResponses = append(successfulResponses, result.Response) @@ -2412,10 +2345,6 @@ func extractSuccessfulListModelsResponses( Error: &schemas.ErrorField{ Message: "all keys failed to list models", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }, } } @@ -2498,8 +2427,6 @@ func HandleMultipleListModelsRequests( // Set ExtraFields latency := time.Since(startTime) - response.ExtraFields.Provider = request.Provider - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil diff --git a/core/providers/vertex/embedding.go b/core/providers/vertex/embedding.go index 0fc0ad598f..54662f50fe 100644 --- a/core/providers/vertex/embedding.go +++ b/core/providers/vertex/embedding.go @@ -110,8 +110,6 @@ func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B Data: embeddings, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: schemas.Vertex, }, } } diff --git a/core/providers/vertex/errors.go b/core/providers/vertex/errors.go index 6b255835d4..e0ed7f1d3d 100644 --- a/core/providers/vertex/errors.go +++ b/core/providers/vertex/errors.go @@ -10,25 +10,13 @@ import ( "github.com/valyala/fasthttp" ) -func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { - var providerName schemas.ModelProvider - if meta != nil { - providerName = meta.Provider - } - +func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError { var openAIErr schemas.BifrostError var vertexErr []VertexError decodedBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) return bifrostErr } @@ -42,13 +30,6 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseEmpty, }, } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } return bifrostErr } @@ -61,26 +42,20 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(decodedBody)), }, - } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + ExtraFields: schemas.BifrostErrorExtraFields{ + RawResponse: string(decodedBody), + }, } return bifrostErr } createError := func(message string) *schemas.BifrostError { - bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), providerName, nil, nil) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil) + var rawResponse interface{} + if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil { + rawResponse = string(decodedBody) } + bifrostErr.ExtraFields.RawResponse = rawResponse return bifrostErr } @@ -93,14 +68,7 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada // Try VertexValidationError format (validation errors from Mistral endpoint) var validationErr VertexValidationError if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) return bifrostErr } if len(validationErr.Detail) > 0 { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 54ba41ac28..48837563eb 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -3,9 +3,8 @@ package vertex import ( "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "golang.org/x/text/cases" - "golang.org/x/text/language" ) // VertexRankRequest represents the Discovery Engine rank API request. @@ -55,49 +54,6 @@ type vertexRerankOptions struct { UserLabels map[string]string } -// formatDeploymentName converts a deployment alias into a human-readable name. -// It splits the alias by "-" or "_", capitalizes each word, and joins them with spaces. -// Example: "gemini-pro" → "Gemini Pro", "claude_3_opus" → "Claude 3 Opus" -func formatDeploymentName(alias string) string { - caser := cases.Title(language.English) - - // Try splitting by hyphen first, then underscore - var parts []string - if strings.Contains(alias, "-") { - parts = strings.Split(alias, "-") - } else if strings.Contains(alias, "_") { - parts = strings.Split(alias, "_") - } else { - // No delimiter found, just capitalize the whole string - return caser.String(strings.ToLower(alias)) - } - - // Capitalize each part - for i, part := range parts { - if part != "" { - parts[i] = caser.String(strings.ToLower(part)) - } - } - - return strings.Join(parts, " ") -} - -// findDeploymentMatch finds a matching deployment value in the deployments map. -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, customModelID string) (deploymentValue, alias string) { - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == customModelID { - return depValue, aliasKey - } - } - // Check exact match by alias/key - if deployment, ok := deployments[customModelID]; ok { - return deployment, customModelID - } - return "", "" -} - // ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format. // It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels). // @@ -113,7 +69,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -122,14 +78,22 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } - if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + included := make(map[string]bool) - // First pass: Process all models from the Vertex AI API response (custom models) + // Process all models from the Vertex AI API response (custom deployed models). + // The model ID is extracted from the endpoint URL last segment. for _, model := range response.Models { if len(model.DeployedModels) == 0 { continue @@ -145,111 +109,28 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod continue } - // Filter if model is not present in both lists (when both are non-empty) - var deploymentValue, deploymentAlias string - restrictAllowed := !unfiltered && allowedModels.IsRestricted() - shouldFilter := false - if restrictAllowed && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = allowedModels.Contains(deploymentAlias) + for _, result := range pipeline.FilterModel(customModelID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { + continue } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if restrictAllowed { - // Only allowedModels is present: filter if model is not in allowedModels - shouldFilter = !allowedModels.Contains(customModelID) - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty (or allowedModels is unrestricted and no deployments), shouldFilter remains false - - if shouldFilter { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(customModelID) { - continue - } - - modelID := customModelID - - modelEntry := schemas.Model{ - ID: string(schemas.Vertex) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.VersionCreateTime.Unix()), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Vertex) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelEntry.ID] = true - } - } - - restrictAllowed := !unfiltered && allowedModels.IsRestricted() - - // Second pass: Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + alias - if addedModelIDs[modelID] { - continue - } - // If allowedModels is restricted, only include if alias is in the list - if restrictAllowed && !allowedModels.Contains(alias) { - continue - } - if blacklistedModels.IsBlocked(alias) { - continue - } - - modelName := formatDeploymentName(alias) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.VersionCreateTime.Unix()), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[resolvedKey] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true } } - // Third pass: Backfill allowed models that were not in the response or deployments - if restrictAllowed { - for _, allowedModel := range allowedModels { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + allowedModel - if addedModelIDs[modelID] { - continue - } - if blacklistedModels.IsBlocked(allowedModel) { - continue - } - - modelName := formatDeploymentName(allowedModel) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) bifrostResponse.NextPageToken = response.NextPageToken @@ -258,7 +139,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -267,12 +148,19 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } - if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { return bifrostResponse } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + included := make(map[string]bool) for _, model := range response.PublisherModels { // Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro") @@ -281,35 +169,27 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a continue } - // Filter based on allowedModels if specified - if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelID) { - continue - } - if !unfiltered && blacklistedModels.IsBlocked(modelID) { - continue - } - - // Skip if already added (shouldn't happen, but safety check) - fullModelID := string(schemas.Vertex) + "/" + modelID - if addedModelIDs[fullModelID] { - continue - } - - // Extract display name from supported actions if available - displayName := modelID - if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { - displayName = model.SupportedActions.Deploy.ModelDisplayName - } - - modelEntry := schemas.Model{ - ID: fullModelID, - Name: schemas.Ptr(displayName), + for _, result := range pipeline.FilterModel(modelID) { + // Extract display name from supported actions if available + displayName := result.ResolvedID + if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { + displayName = model.SupportedActions.Deploy.ModelDisplayName + } + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(displayName), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[fullModelID] = true } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse diff --git a/core/providers/vertex/rerank.go b/core/providers/vertex/rerank.go index 74372658b2..b06430fcac 100644 --- a/core/providers/vertex/rerank.go +++ b/core/providers/vertex/rerank.go @@ -83,7 +83,7 @@ func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) } // ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format. -func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployment string, options *vertexRerankOptions) (*VertexRankRequest, error) { +func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost rerank request is nil") } @@ -132,7 +132,7 @@ func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployme rankRequest.TopN = &topN } - if trimmedModel := strings.TrimSpace(modelDeployment); trimmedModel != "" { + if trimmedModel := strings.TrimSpace(bifrostReq.Model); trimmedModel != "" { rankRequest.Model = &trimmedModel } diff --git a/core/providers/vertex/rerank_test.go b/core/providers/vertex/rerank_test.go index afd8ed225e..3f2efcec52 100644 --- a/core/providers/vertex/rerank_test.go +++ b/core/providers/vertex/rerank_test.go @@ -42,7 +42,6 @@ func TestToVertexRankRequest(t *testing.T) { TopN: schemas.Ptr(10), }, }, - "semantic-ranker-default@latest", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, @@ -77,7 +76,6 @@ func TestToVertexRankRequestTooManyRecords(t *testing.T) { Query: "q", Documents: docs, }, - "", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 8325d20d88..0dfda6763d 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -26,74 +26,74 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Add max_tokens if not present if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305) jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } // Add anthropic_version if not present if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Validate tools are supported by Vertex if request.Params != nil && request.Params.Tools != nil { if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil { - return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil) } } // Convert request to Anthropic format reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } reqBody.Model = deployment @@ -109,44 +109,44 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add anthropic_version if not present (using sjson to preserve order) if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Remove model field for Vertex API (it's in URL) jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -204,11 +204,11 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels schema continue } - modelName := formatDeploymentName(alias) + modelName := providerUtils.ToDisplayName(alias) modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + ID: modelID, + Name: schemas.Ptr(modelName), + Alias: schemas.Ptr(deploymentValue), } response.Data = append(response.Data, modelEntry) @@ -228,7 +228,7 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels schema continue } - modelName := formatDeploymentName(allowedModel) + modelName := providerUtils.ToDisplayName(allowedModel) modelEntry := schemas.Model{ ID: modelID, Name: schemas.Ptr(modelName), diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 131aac38a8..2c2b7aca7a 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -114,9 +114,6 @@ const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // It uses the JWT config if auth credentials are provided. // It returns an error if the token source creation fails. func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) { - if key.VertexKeyConfig == nil { - return nil, fmt.Errorf("vertex key config is not set") - } authCredentials := key.VertexKeyConfig.AuthCredentials var tokenSource oauth2.TokenSource if authCredentials.GetValue() == "" { @@ -176,18 +173,12 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // 1. If deployments or allowedModels are configured, return those (no API call needed) // 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden) func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - deployments := key.VertexKeyConfig.Deployments + deployments := key.Aliases allowedModels := key.Models if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) { @@ -217,11 +208,11 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err) } // Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models @@ -250,13 +241,14 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { wait() + respBody := append([]byte(nil), resp.Body()...) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) // Non-Google publishers may not be available in all regions; skip on error if publisher != "google" { break } - return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -284,9 +276,9 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key var errorResp VertexError if err := sonic.Unmarshal(respBody, &errorResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, schemas.Vertex, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Vertex's publisher models response @@ -326,7 +318,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests @@ -372,18 +364,6 @@ func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -393,7 +373,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use centralized Anthropic converter reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request) if convErr != nil { @@ -403,7 +383,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order @@ -430,13 +409,12 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to delete model field: %w", err) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiChatCompletionRequest(request) if reqBody == nil { return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes @@ -451,7 +429,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Marshal to JSON bytes rawBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -466,26 +443,26 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Remap unsupported tool versions for Vertex (handles raw passthrough bodies) - if schemas.IsAnthropicModel(deployment) && jsonBody != nil { + if schemas.IsAnthropicModel(request.Model) && jsonBody != nil { remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } jsonBody = remappedBody } @@ -494,43 +471,43 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsAnthropicModel(deployment) { + } else if schemas.IsAnthropicModel(request.Model) { // Claude models use Anthropic publisher if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsMistralModel(deployment) { + } else if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with rawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { // Gemini models support api key if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } else { if region == "global" { @@ -565,11 +542,11 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -598,14 +575,10 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -614,16 +587,13 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Create response object from pool anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) @@ -637,17 +607,9 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostChatResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), - } - - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment + Latency: latency.Milliseconds(), + ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } - response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -660,7 +622,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -669,12 +631,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } response := geminiResponse.ToBifrostChatResponse() - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -696,12 +652,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -724,35 +674,17 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails. func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use Anthropic-style streaming for Claude models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -767,7 +699,6 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) @@ -804,7 +735,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -814,15 +745,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext var remapErr error jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } } var completeURL string if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -835,11 +766,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken @@ -856,15 +787,10 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), providerName, postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { // Use Gemini-style streaming for Gemini models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -874,12 +800,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if reqBody == nil { return nil, fmt.Errorf("chat completion input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -892,12 +817,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { @@ -916,11 +841,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } @@ -938,7 +863,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -947,12 +872,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsMistralModel(deployment) { + if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with streamRawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model) } } else { // Other models use OpenAPI endpoint for gemini models @@ -972,22 +897,17 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } authHeader = map[string]string{ "Authorization": "Bearer " + token.AccessToken, } } - postRequestConverter := func(reqBody *openai.OpenAIChatRequest) *openai.OpenAIChatRequest { - reqBody.Model = deployment - return reqBody - } - // Use shared OpenAI streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -1003,8 +923,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -1012,40 +932,28 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Responses performs a responses request to the Vertex API. func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Claude models use Anthropic publisher var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } // Create HTTP request for streaming @@ -1066,11 +974,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1098,14 +1006,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1113,9 +1017,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1135,13 +1036,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem response := anthropicResponse.ToBifrostResponsesResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1152,12 +1049,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse } - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -1166,24 +1060,23 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -1193,11 +1086,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":generateContent") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1222,11 +1115,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -1255,14 +1148,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1270,9 +1159,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1287,16 +1173,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := geminiResponse.ToResponsesBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse @@ -1314,52 +1193,33 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response, nil } } // ResponsesStream performs a streaming responses request to the Vertex API. func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -1372,22 +1232,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, @@ -1401,23 +1253,18 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } // Use Gemini-style streaming for Gemini models @@ -1429,12 +1276,11 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1447,12 +1293,12 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) @@ -1470,23 +1316,15 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Gemini return gemini.HandleGeminiResponsesStream( ctx, @@ -1500,7 +1338,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -1518,18 +1356,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1538,19 +1372,14 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVertexEmbeddingRequest(request), nil }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } - deployment := provider.getModelDeployment(key, request.Model) - - // Remove google/ prefix from deployment - deployment = strings.TrimPrefix(deployment, "google/") - // Build the native Vertex embedding API endpoint url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", - region, projectID, region, deployment) + region, projectID, region, request.Model) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1573,11 +1402,11 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1613,7 +1442,7 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Try to parse Vertex's error format var vertexError map[string]interface{} if err := sonic.Unmarshal(errBody, &vertexError); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errorObj, exists := vertexError["error"]; exists { @@ -1627,10 +1456,10 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem } } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), schemas.Vertex, nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1638,9 +1467,6 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostEmbeddingResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1650,28 +1476,21 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Parse Vertex's native embedding response using typed response var vertexResponse VertexEmbeddingResponse if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Use centralized Vertex converter bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if bifrostResponse.ExtraFields.ModelRequested != deployment { - bifrostResponse.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponseMap map[string]interface{} if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ExtraFields.RawResponse = rawResponseMap } @@ -1686,30 +1505,23 @@ func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas. // Rerank performs a rerank request using Vertex Discovery Engine ranking API. func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } options, err := getVertexRerankOptions(projectID, request.Params) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } - modelDeployment := provider.getModelDeployment(key, request.Model) jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToVertexRankRequest(request, modelDeployment, options) + return ToVertexRankRequest(request, options) }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1735,11 +1547,11 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1767,11 +1579,7 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. } errorMessage := parseDiscoveryEngineErrorMessage(resp.Body()) - parsedError := parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + parsedError := parseVertexError(resp) if strings.TrimSpace(errorMessage) != "" { shouldOverride := parsedError == nil || @@ -1781,19 +1589,14 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal if shouldOverride { - parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), providerName, nil, nil) - parsedError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, - } + parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil) } } return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1802,9 +1605,6 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1820,16 +1620,9 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - if request.Model != modelDeployment { - bifrostResponse.ExtraFields.ModelDeployment = modelDeployment - } - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -1860,21 +1653,9 @@ func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, } func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1885,13 +1666,12 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -1899,7 +1679,7 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") @@ -1919,58 +1699,58 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used for fine-tuned models to pass the API key in the query string authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { // Imagen models are published models, use publishers/google/models path if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -1997,11 +1777,11 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2030,14 +1810,10 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2045,16 +1821,13 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2067,12 +1840,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2095,12 +1862,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2124,20 +1885,9 @@ func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContex // ImageEdit edits images for the given input text(s) using Vertex AI. // Returns a BifrostResponse containing the images and any error that occurred. func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2148,13 +1898,12 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -2162,7 +1911,7 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") @@ -2182,19 +1931,19 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -2203,27 +1952,27 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2249,11 +1998,11 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2281,14 +2030,10 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2296,16 +2041,13 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2318,12 +2060,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2346,12 +2082,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2382,18 +2112,9 @@ func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, bifrostReq.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Only Gemini models support video generation in Vertex - if !schemas.IsVeoModel(deployment) && !schemas.IsAllDigitsASCII(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", deployment), providerName) + if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model)) } // Convert Bifrost request to Gemini format (reusing Gemini converters) @@ -2403,7 +2124,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2411,12 +2131,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used to pass the API key in the query string @@ -2427,12 +2147,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini video generation using predictLongRunning - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predictLongRunning") + completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning") // Create HTTP request req := fasthttp.AcquireRequest() @@ -2451,11 +2171,11 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2475,17 +2195,13 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: bifrostReq.Model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var operation gemini.GenerateVideosOperation @@ -2501,12 +2217,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = bifrostReq.Model - if bifrostReq.Model != deployment { - bifrostResp.ExtraFields.ModelDeployment = deployment - } - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2522,18 +2232,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // VideoRetrieve retrieves the status of a video generation operation. // Uses the fetchPredictOperation endpoint for Vertex AI. func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Construct base URL based on region @@ -2549,12 +2253,12 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID // We need to extract the model path from it to construct the fetchPredictOperation endpoint // Extract: projects/.../models/MODEL_ID from the operation name - taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, providerName) + taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey()) var modelPath string if idx := strings.Index(taskID, "/operations/"); idx != -1 { modelPath = taskID[:idx] } else { - return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil) } // Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation @@ -2569,7 +2273,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Create request body with operation name (using sjson to avoid map marshaling) jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err) } // Create HTTP request @@ -2589,11 +2293,11 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2613,10 +2317,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -2630,10 +2331,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if sendBackRawResponse { @@ -2647,9 +2346,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // First retrieves the video status to get the URL, then downloads the content. // Handles both regular URLs and data URLs (base64-encoded videos). func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2663,12 +2361,10 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte var latency time.Duration @@ -2680,7 +2376,7 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded contentType = videoResp.Videos[0].ContentType @@ -2710,11 +2406,11 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2729,19 +2425,17 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ @@ -2751,8 +2445,6 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2790,19 +2482,6 @@ func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequ } } -func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string) string { - if key.VertexKeyConfig == nil { - return model - } - - if key.VertexKeyConfig.Deployments != nil { - if deployment, ok := key.VertexKeyConfig.Deployments[model]; ok { - return deployment - } - } - return model -} - // BatchCreate is not supported by Vertex AI provider. func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) @@ -2862,25 +2541,13 @@ func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schem // CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint. // Supports Gemini models with both text and image content. func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - var ( jsonBody []byte bifrostErr *schemas.BifrostError ) - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } @@ -2891,7 +2558,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiResponsesRequest(request), nil }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2909,38 +2575,38 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" var completeURL string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { if region == "global" { completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID) } else { completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - completeURL = getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":countTokens") + completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens") } if completeURL == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model/deployment: %s", deployment), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model)) } req := fasthttp.AcquireRequest() @@ -2962,11 +2628,11 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2994,14 +2660,10 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -3009,16 +2671,13 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch respOwned = false return &schemas.BifrostCountTokensResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := &anthropic.AnthropicCountTokensResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -3027,12 +2686,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3055,12 +2708,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := vertexResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3125,14 +2772,9 @@ func (provider *VertexProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3198,12 +2840,12 @@ func (provider *VertexProvider) Passthrough( tokenSource, err := getAuthTokenSource(key) if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3253,7 +2895,7 @@ func (provider *VertexProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -3267,9 +2909,6 @@ func (provider *VertexProvider) Passthrough( } bifrostResponse.ExtraFields.ProviderResponseHeaders = headers - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest - bifrostResponse.ExtraFields.ModelRequested = req.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3285,13 +2924,9 @@ func (provider *VertexProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3357,13 +2992,13 @@ func (provider *VertexProvider) PassthroughStream( if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3403,9 +3038,9 @@ func (provider *VertexProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { @@ -3420,9 +3055,7 @@ func (provider *VertexProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Set stream idle timeout from provider config. @@ -3435,11 +3068,7 @@ func (provider *VertexProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3450,9 +3079,9 @@ func (provider *VertexProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3501,7 +3130,7 @@ func (provider *VertexProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/vllm/utils.go b/core/providers/vllm/utils.go index d2cefce786..ab6d694938 100644 --- a/core/providers/vllm/utils.go +++ b/core/providers/vllm/utils.go @@ -13,9 +13,6 @@ func HandleVLLMResponse[T any](responseBody []byte, response *T, requestBody []b return rawRequest, rawResponse, bifrostErr } if err := sonic.Unmarshal(responseBody, &errorResp); err == nil && errorResp.Error != nil && errorResp.Error.Message != "" { - errorResp.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: schemas.VLLM, - } return rawRequest, rawResponse, &errorResp } return rawRequest, rawResponse, nil diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index eab5828af9..d36aa7354b 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -76,9 +76,7 @@ func (provider *VLLMProvider) baseURLOrError(key schemas.Key) (string, *schemas. if u == "" { return "", providerUtils.NewBifrostOperationError( "no base URL configured: set vllm_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) + nil) } return u, nil } @@ -246,9 +244,6 @@ func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas return nil, err } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -314,12 +309,14 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( statusCode := resp.StatusCode() if statusCode != fasthttp.StatusOK { - return nil, nil, nil, nil, statusCode, latency, openai.ParseOpenAIError(resp, schemas.RerankRequest, provider.GetProviderKey(), request.Model) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, nil, nil, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -336,16 +333,12 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( // Rerank performs a rerank request to vLLM's API. func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVLLMRerankRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -358,6 +351,9 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke resolvedPath = "/" + resolvedPath } + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + responsePayload, rawRequest, rawResponse, responseBody, statusCode, latency, bifrostErr := provider.callVLLMRerankEndpoint(ctx, key, request, resolvedPath, jsonData) if bifrostErr != nil && !hasPathOverride && isRerankFallbackStatus(statusCode) { var fallbackLatency time.Duration @@ -365,7 +361,7 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke latency += fallbackLatency } if bifrostErr != nil { - return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, sendBackRawRequest, sendBackRawResponse) } returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments @@ -373,19 +369,16 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke if err != nil { return nil, providerUtils.EnrichError( ctx, - providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), + providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonData, responseBody, - provider.sendBackRawRequest, - provider.sendBackRawResponse, + sendBackRawRequest, + sendBackRawResponse, ) } // Keep requested model as the canonical model in Bifrost response. bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -435,7 +428,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Use centralized converter reqBody := openai.ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -491,9 +484,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -502,7 +495,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -521,9 +514,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -563,7 +556,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -580,11 +573,6 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p _, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger) return @@ -603,11 +591,8 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() diff --git a/core/providers/xai/errors.go b/core/providers/xai/errors.go index 78b22463e0..38a46888a8 100644 --- a/core/providers/xai/errors.go +++ b/core/providers/xai/errors.go @@ -15,7 +15,7 @@ type XAIErrorResponse struct { // ParseXAIError parses xAI-specific error responses. // xAI returns errors in format: {"code": "...", "error": "..."} // Unlike OpenAI which uses: {"error": {"message": "...", "type": "...", "code": "..."}} -func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseXAIError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse xAI error format var xaiErr XAIErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &xaiErr) @@ -35,10 +35,5 @@ func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, pro } } - // Set ExtraFields individually to preserve RawResponse from HandleProviderAPIError - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index ecf285c379..6ec5ebda6b 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -65,7 +65,7 @@ func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider { // ListModels performs a list models request to xAI's API. func (provider *XAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } return openai.HandleOpenAIListModelsRequest( ctx, diff --git a/core/schemas/account.go b/core/schemas/account.go index 7bb67067bc..3dfb16a0b4 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -126,11 +126,12 @@ type Key struct { Models WhiteList `json:"models"` // List of models this key can access BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration - ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration + ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) @@ -140,6 +141,48 @@ type Key struct { Description string `json:"description,omitempty"` // Description of key } +type KeyAliases map[string]string + +func (ka KeyAliases) Validate() error { + seen := make(map[string]struct{}, len(ka)) + for from, to := range ka { + if strings.TrimSpace(from) == "" { + return fmt.Errorf("alias source cannot be empty") + } + if strings.TrimSpace(to) == "" { + return fmt.Errorf("alias target for %q cannot be empty", from) + } + if strings.TrimSpace(from) != from { + return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from) + } + if strings.TrimSpace(to) != to { + return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from) + } + normalized := strings.ToLower(from) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate alias source %q (case-insensitive)", from) + } + seen[normalized] = struct{}{} + } + return nil +} + +func (ka KeyAliases) Resolve(model string) string { + if ka == nil { + return model + } + if alias, ok := ka[model]; ok { + return alias + } + // Fall back to case-insensitive lookup for consistency with WhiteList.Contains + for k, v := range ka { + if strings.EqualFold(k, model) { + return v + } + } + return model +} + type AzureAuthType string const ( @@ -150,9 +193,8 @@ const ( // AzureKeyConfig represents the Azure-specific configuration. // It contains Azure-specific settings required for service access and deployment management. type AzureKeyConfig struct { - Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" + Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL + APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication @@ -163,11 +205,10 @@ type AzureKeyConfig struct { // VertexKeyConfig represents the Vertex-specific configuration. // It contains Vertex-specific settings required for authentication and service access. type VertexKeyConfig struct { - ProjectID EnvVar `json:"project_id"` - ProjectNumber EnvVar `json:"project_number"` - Region EnvVar `json:"region"` - AuthCredentials EnvVar `json:"auth_credentials"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles + ProjectID EnvVar `json:"project_id"` + ProjectNumber EnvVar `json:"project_number"` + Region EnvVar `json:"region"` + AuthCredentials EnvVar `json:"auth_credentials"` } // NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string. @@ -198,17 +239,12 @@ type BedrockKeyConfig struct { ExternalID *EnvVar `json:"external_id,omitempty"` RoleSessionName *EnvVar `json:"session_name,omitempty"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles - BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations + BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations } // NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings. // To use Bedrock API Key authentication, set Value in Key struct instead. -type ReplicateKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - // VLLMKeyConfig represents the vLLM-specific key configuration. // It allows each key to target a different vLLM server URL and model name, // enabling per-key routing and round-robin load balancing across multiple vLLM instances. @@ -217,6 +253,12 @@ type VLLMKeyConfig struct { ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection) } +// ReplicateKeyConfig represents the Replicate-specific key configuration. +// It contains Replicate-specific settings required for authentication and service access. +type ReplicateKeyConfig struct { + UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint +} + // OllamaKeyConfig represents the Ollama-specific key configuration. // It allows each key to target a different Ollama server URL, // enabling per-key routing and round-robin load balancing across multiple Ollama instances. @@ -249,4 +291,4 @@ type Account interface { // This includes network settings, authentication details, and other provider-specific // configurations. GetConfigForProvider(providerKey ModelProvider) (*ProviderConfig, error) -} \ No newline at end of file +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index f8ed758d51..69fcaf9894 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -209,7 +209,7 @@ const ( BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) + BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) @@ -222,7 +222,7 @@ const ( BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) - BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) + BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) @@ -798,6 +798,213 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &BifrostResponseExtraFields{} } +func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if r == nil { + return + } + resolvedModel := resolvedModelUsed + if resolvedModel == "" { + resolvedModel = originalModelRequested + } + switch { + case r.ListModelsResponse != nil: + r.ListModelsResponse.ExtraFields.RequestType = requestType + r.ListModelsResponse.ExtraFields.Provider = provider + r.ListModelsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ListModelsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TextCompletionResponse != nil: + r.TextCompletionResponse.ExtraFields.RequestType = requestType + r.TextCompletionResponse.ExtraFields.Provider = provider + r.TextCompletionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TextCompletionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ChatResponse != nil: + r.ChatResponse.ExtraFields.RequestType = requestType + r.ChatResponse.ExtraFields.Provider = provider + r.ChatResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ChatResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesResponse != nil: + r.ResponsesResponse.ExtraFields.RequestType = requestType + r.ResponsesResponse.ExtraFields.Provider = provider + r.ResponsesResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesStreamResponse != nil: + r.ResponsesStreamResponse.ExtraFields.RequestType = requestType + r.ResponsesStreamResponse.ExtraFields.Provider = provider + r.ResponsesStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.CountTokensResponse != nil: + r.CountTokensResponse.ExtraFields.RequestType = requestType + r.CountTokensResponse.ExtraFields.Provider = provider + r.CountTokensResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.CountTokensResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.EmbeddingResponse != nil: + r.EmbeddingResponse.ExtraFields.RequestType = requestType + r.EmbeddingResponse.ExtraFields.Provider = provider + r.EmbeddingResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.EmbeddingResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.RerankResponse != nil: + r.RerankResponse.ExtraFields.RequestType = requestType + r.RerankResponse.ExtraFields.Provider = provider + r.RerankResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.RerankResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechResponse != nil: + r.SpeechResponse.ExtraFields.RequestType = requestType + r.SpeechResponse.ExtraFields.Provider = provider + r.SpeechResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechStreamResponse != nil: + r.SpeechStreamResponse.ExtraFields.RequestType = requestType + r.SpeechStreamResponse.ExtraFields.Provider = provider + r.SpeechStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionResponse != nil: + r.TranscriptionResponse.ExtraFields.RequestType = requestType + r.TranscriptionResponse.ExtraFields.Provider = provider + r.TranscriptionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionStreamResponse != nil: + r.TranscriptionStreamResponse.ExtraFields.RequestType = requestType + r.TranscriptionStreamResponse.ExtraFields.Provider = provider + r.TranscriptionStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationResponse != nil: + r.ImageGenerationResponse.ExtraFields.RequestType = requestType + r.ImageGenerationResponse.ExtraFields.Provider = provider + r.ImageGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationStreamResponse != nil: + r.ImageGenerationStreamResponse.ExtraFields.RequestType = requestType + r.ImageGenerationStreamResponse.ExtraFields.Provider = provider + r.ImageGenerationStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoGenerationResponse != nil: + r.VideoGenerationResponse.ExtraFields.RequestType = requestType + r.VideoGenerationResponse.ExtraFields.Provider = provider + r.VideoGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDownloadResponse != nil: + r.VideoDownloadResponse.ExtraFields.RequestType = requestType + r.VideoDownloadResponse.ExtraFields.Provider = provider + r.VideoDownloadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDownloadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoListResponse != nil: + r.VideoListResponse.ExtraFields.RequestType = requestType + r.VideoListResponse.ExtraFields.Provider = provider + r.VideoListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDeleteResponse != nil: + r.VideoDeleteResponse.ExtraFields.RequestType = requestType + r.VideoDeleteResponse.ExtraFields.Provider = provider + r.VideoDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileUploadResponse != nil: + r.FileUploadResponse.ExtraFields.RequestType = requestType + r.FileUploadResponse.ExtraFields.Provider = provider + r.FileUploadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileUploadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileListResponse != nil: + r.FileListResponse.ExtraFields.RequestType = requestType + r.FileListResponse.ExtraFields.Provider = provider + r.FileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileRetrieveResponse != nil: + r.FileRetrieveResponse.ExtraFields.RequestType = requestType + r.FileRetrieveResponse.ExtraFields.Provider = provider + r.FileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileDeleteResponse != nil: + r.FileDeleteResponse.ExtraFields.RequestType = requestType + r.FileDeleteResponse.ExtraFields.Provider = provider + r.FileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileContentResponse != nil: + r.FileContentResponse.ExtraFields.RequestType = requestType + r.FileContentResponse.ExtraFields.Provider = provider + r.FileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCreateResponse != nil: + r.BatchCreateResponse.ExtraFields.RequestType = requestType + r.BatchCreateResponse.ExtraFields.Provider = provider + r.BatchCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchListResponse != nil: + r.BatchListResponse.ExtraFields.RequestType = requestType + r.BatchListResponse.ExtraFields.Provider = provider + r.BatchListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchRetrieveResponse != nil: + r.BatchRetrieveResponse.ExtraFields.RequestType = requestType + r.BatchRetrieveResponse.ExtraFields.Provider = provider + r.BatchRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCancelResponse != nil: + r.BatchCancelResponse.ExtraFields.RequestType = requestType + r.BatchCancelResponse.ExtraFields.Provider = provider + r.BatchCancelResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCancelResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchDeleteResponse != nil: + r.BatchDeleteResponse.ExtraFields.RequestType = requestType + r.BatchDeleteResponse.ExtraFields.Provider = provider + r.BatchDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchResultsResponse != nil: + r.BatchResultsResponse.ExtraFields.RequestType = requestType + r.BatchResultsResponse.ExtraFields.Provider = provider + r.BatchResultsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchResultsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerCreateResponse != nil: + r.ContainerCreateResponse.ExtraFields.RequestType = requestType + r.ContainerCreateResponse.ExtraFields.Provider = provider + r.ContainerCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerListResponse != nil: + r.ContainerListResponse.ExtraFields.RequestType = requestType + r.ContainerListResponse.ExtraFields.Provider = provider + r.ContainerListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerRetrieveResponse != nil: + r.ContainerRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerRetrieveResponse.ExtraFields.Provider = provider + r.ContainerRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerDeleteResponse != nil: + r.ContainerDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerDeleteResponse.ExtraFields.Provider = provider + r.ContainerDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileCreateResponse != nil: + r.ContainerFileCreateResponse.ExtraFields.RequestType = requestType + r.ContainerFileCreateResponse.ExtraFields.Provider = provider + r.ContainerFileCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileListResponse != nil: + r.ContainerFileListResponse.ExtraFields.RequestType = requestType + r.ContainerFileListResponse.ExtraFields.Provider = provider + r.ContainerFileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileRetrieveResponse != nil: + r.ContainerFileRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerFileRetrieveResponse.ExtraFields.Provider = provider + r.ContainerFileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileContentResponse != nil: + r.ContainerFileContentResponse.ExtraFields.RequestType = requestType + r.ContainerFileContentResponse.ExtraFields.Provider = provider + r.ContainerFileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileDeleteResponse != nil: + r.ContainerFileDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerFileDeleteResponse.ExtraFields.Provider = provider + r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.PassthroughResponse != nil: + r.PassthroughResponse.ExtraFields.RequestType = requestType + r.PassthroughResponse.ExtraFields.Provider = provider + r.PassthroughResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.PassthroughResponse.ExtraFields.ResolvedModelUsed = resolvedModel + } +} + // BifrostMCPResponse is the response struct for all MCP responses. // only ONE of the following fields should be set: // - ChatMessage @@ -812,10 +1019,10 @@ type BifrostMCPResponse struct { type BifrostResponseExtraFields struct { RequestType RequestType `json:"request_type"` Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - ModelDeployment string `json:"model_deployment,omitempty"` // only present for providers which use model deployments (e.g. Azure, Bedrock) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses RawRequest interface{} `json:"raw_request,omitempty"` RawResponse interface{} `json:"raw_response,omitempty"` CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` @@ -909,6 +1116,20 @@ type BifrostError struct { ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } +func (e *BifrostError) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if e == nil { + return + } + e.ExtraFields.RequestType = requestType + e.ExtraFields.Provider = provider + e.ExtraFields.OriginalModelRequested = originalModelRequested + if resolvedModelUsed != "" { + e.ExtraFields.ResolvedModelUsed = resolvedModelUsed + } else { + e.ExtraFields.ResolvedModelUsed = originalModelRequested + } +} + // StreamControl represents stream control options. type StreamControl struct { LogError *bool `json:"log_error,omitempty"` // Optional: Controls logging of error @@ -982,11 +1203,12 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + LiteLLMCompat bool `json:"litellm_compat,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` } diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 7fb56abfbe..ee9dcdbed6 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -63,7 +63,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -96,7 +97,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -132,7 +134,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -149,13 +152,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } diff --git a/core/schemas/models.go b/core/schemas/models.go index 5a0e8588c1..32b82bb104 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -138,7 +138,7 @@ type Model struct { ID string `json:"id"` CanonicalSlug *string `json:"canonical_slug,omitempty"` Name *string `json:"name,omitempty"` - Deployment *string `json:"deployment,omitempty"` // Name of the actual deployment + Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN) Created *int64 `json:"created,omitempty"` ContextLength *int `json:"context_length,omitempty"` MaxInputTokens *int `json:"max_input_tokens,omitempty"` diff --git a/core/schemas/models_test.go b/core/schemas/models_test.go index 3e60fdda76..b9748952bd 100644 --- a/core/schemas/models_test.go +++ b/core/schemas/models_test.go @@ -94,7 +94,7 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { Error: &ErrorField{Message: "unauthorized"}, ExtraFields: BifrostErrorExtraFields{ Provider: "openai", - ModelRequested: "gpt-4", + OriginalModelRequested: "gpt-4", }, } keyStatus := KeyStatus{ @@ -112,6 +112,6 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { // Error fields other than key_statuses should be preserved dataStr := string(data) assert.Contains(t, dataStr, `"unauthorized"`) - assert.Contains(t, dataStr, `"model_requested":"gpt-4"`) + assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`) assert.Contains(t, dataStr, `"status_code":401`) } diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index d78c4ea97b..3e5bcbf46f 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -14,7 +14,8 @@ type SpanHandle interface{} // This is the return type for tracer's streaming accumulation methods. type StreamAccumulatorResult struct { RequestID string // Request ID - Model string // Model used + RequestedModel string // Original model requested by the caller + ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists) Provider ModelProvider // Provider used Status string // Status of the stream Latency int64 // Latency in milliseconds diff --git a/core/utils.go b/core/utils.go index 00dba31ddc..12d86e2508 100644 --- a/core/utils.go +++ b/core/utils.go @@ -132,6 +132,51 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError { return nil } +// validateKey validates the given key. +func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) bool { + // Valid the key for the provider + switch providerKey { + case schemas.Azure: + if key.AzureKeyConfig == nil { + return false + } + if key.AzureKeyConfig.Endpoint.GetValue() == "" { + return false + } + case schemas.Bedrock: + // Key is valid if either: + // 1. BedrockKeyConfig is provided + // 2. Value is provided and is not empty + if key.BedrockKeyConfig == nil { + if key.Value.GetValue() == "" { + return false + } + key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} + } + case schemas.Vertex: + if key.VertexKeyConfig == nil { + return false + } + case schemas.Replicate: + if key.ReplicateKeyConfig == nil { + return false + } + case schemas.VLLM: + if key.VLLMKeyConfig == nil || key.VLLMKeyConfig.URL.GetValue() == "" { + return false + } + case schemas.Ollama: + if key.OllamaKeyConfig == nil || key.OllamaKeyConfig.URL.GetValue() == "" { + return false + } + case schemas.SGL: + if key.SGLKeyConfig == nil || key.SGLKeyConfig.URL.GetValue() == "" { + return false + } + } + return true +} + // IsRateLimitErrorMessage checks if an error message indicates a rate limit issue func IsRateLimitErrorMessage(errorMessage string) bool { if errorMessage == "" { @@ -176,7 +221,7 @@ func newBifrostErrorFromMsg(message string) *schemas.BifrostError { // newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context. // It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled). -func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, requestType schemas.RequestType, stage string) *schemas.BifrostError { +func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError { var statusCode int var errorType string var message string @@ -200,11 +245,6 @@ func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelP Message: message, Error: ctx.Err(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: provider, - ModelRequested: model, - }, } } @@ -333,14 +373,14 @@ func IsFinalChunk(ctx *schemas.BifrostContext) bool { return false } -// GetResponseFields extracts the request type, provider, and model from the result or error -func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, model string) { +// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error. +func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) { if result != nil { extraFields := result.GetExtraFields() - return extraFields.RequestType, extraFields.Provider, extraFields.ModelRequested + return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed } if err != nil { - return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.ModelRequested + return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed } return } diff --git a/docs/features/litellm-compat.mdx b/docs/features/litellm-compat.mdx index b26f94cd7e..51cd26dcd9 100644 --- a/docs/features/litellm-compat.mdx +++ b/docs/features/litellm-compat.mdx @@ -125,7 +125,8 @@ When either transformation is applied: - `extra_fields.litellm_compat`: Set to `true` - `extra_fields.provider`: The provider that handled the request - `extra_fields.request_type`: Reflects the original request type -- `extra_fields.model_requested`: The originally requested model +- `extra_fields.original_model_requested`: The originally requested model +- `extra_fields.resolved_model_used`: The actual provider API identifier used (equals original_model_requested when no alias mapping exists) ### Error Handling diff --git a/docs/openapi/schemas/management/providers.yaml b/docs/openapi/schemas/management/providers.yaml index b5e3a07b41..7b0a9972f6 100644 --- a/docs/openapi/schemas/management/providers.yaml +++ b/docs/openapi/schemas/management/providers.yaml @@ -71,10 +71,6 @@ AzureKeyConfig: properties: endpoint: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string api_version: $ref: '../../schemas/management/common.yaml#/EnvVar' client_id: @@ -101,10 +97,6 @@ VertexKeyConfig: $ref: '../../schemas/management/common.yaml#/EnvVar' auth_credentials: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string BedrockKeyConfig: type: object @@ -120,10 +112,6 @@ BedrockKeyConfig: $ref: '../../schemas/management/common.yaml#/EnvVar' arn: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string batch_s3_config: type: object properties: @@ -159,6 +147,14 @@ OllamaKeyConfig: required: - url +ReplicateKeyConfig: + type: object + description: Replicate-specific key configuration + properties: + use_deployments_endpoint: + type: boolean + description: Whether to use the deployments endpoint instead of the models endpoint + SglKeyConfig: type: object description: SGLang-specific key configuration @@ -168,15 +164,6 @@ SglKeyConfig: required: - url -ReplicateKeyConfig: - type: object - description: Replicate-specific key configuration - properties: - deployments: - type: object - additionalProperties: - type: string - VLLMKeyConfig: type: object description: vLLM-specific key configuration for per-key routing to different vLLM instances @@ -214,20 +201,28 @@ Key: weight: type: number description: Weight for load balancing + aliases: + type: object + propertyNames: + minLength: 1 + additionalProperties: + type: string + minLength: 1 + description: Model alias mappings — maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.) azure_key_config: $ref: '#/AzureKeyConfig' vertex_key_config: $ref: '#/VertexKeyConfig' bedrock_key_config: $ref: '#/BedrockKeyConfig' - replicate_key_config: - $ref: '#/ReplicateKeyConfig' vllm_key_config: $ref: '#/VllmKeyConfig' ollama_key_config: $ref: '#/OllamaKeyConfig' sgl_key_config: $ref: '#/SglKeyConfig' + replicate_key_config: + $ref: '#/ReplicateKeyConfig' enabled: type: boolean description: Whether the key is active (defaults to true) diff --git a/docs/providers/provider-routing.mdx b/docs/providers/provider-routing.mdx index dc214adb76..7a341a8e40 100644 --- a/docs/providers/provider-routing.mdx +++ b/docs/providers/provider-routing.mdx @@ -427,19 +427,19 @@ This is particularly useful for proxy providers (OpenRouter, Vertex) where you w - + -**Key Concept**: Deployments are **key-specific** mappings that allow user-friendly model names to map to provider-specific deployment identifiers. +**Key Concept**: Aliases are **key-level** mappings that allow user-friendly model names to map to provider-specific identifiers. -**How Deployments Work**: +**How Aliases Work**: - Defined at the **Key level**, not Virtual Key level -- Structure: `deployments: {"alias": "deployment-id"}` -- **Alias** (left side): User-facing model name used in requests -- **Deployment ID** (right side): Provider-specific identifier sent to the API +- Structure: `aliases: {"user-facing-name": "provider-specific-id"}` +- **Alias key** (left side): User-facing model name used in requests +- **Provider ID** (right side): Provider-specific identifier sent to the API **Azure OpenAI Example**: -Provider configuration with deployment mapping: +Provider configuration with alias mapping: ```json { "providers": { @@ -448,13 +448,12 @@ Provider configuration with deployment mapping: { "name": "azure-prod-key", "value": "your-api-key", - "models": [], // Not used when deployments exist + "aliases": { + "gpt-4o": "my-prod-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, "azure_key_config": { - "endpoint": "https://your-resource.openai.azure.com", - "deployments": { - "gpt-4o": "my-prod-gpt4o-deployment", - "gpt-4o-mini": "my-mini-deployment" - } + "endpoint": "https://your-resource.openai.azure.com" } } ] @@ -467,9 +466,9 @@ Provider configuration with deployment mapping: 1. **Allowed models derived from aliases**: `["gpt-4o", "gpt-4o-mini"]` 2. **User requests with alias**: `{"model": "gpt-4o"}` 3. **Bifrost validates**: `gpt-4o` is in derived allowed models ✅ -4. **Bifrost maps to deployment**: `gpt-4o` → `my-prod-gpt4o-deployment` +4. **Bifrost resolves alias**: `gpt-4o` → `my-prod-gpt4o-deployment` 5. **Sent to Azure**: Uses `my-prod-gpt4o-deployment` as the deployment name -6. **Pricing lookup**: If pricing for deployment not found, falls back to alias `gpt-4o` +6. **Pricing lookup**: If pricing for resolved ID not found, falls back to alias `gpt-4o` **Bedrock Example with Inference Profiles**: @@ -480,15 +479,14 @@ Provider configuration with deployment mapping: "keys": [ { "name": "bedrock-key", - "models": [], + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-opus": "us.anthropic.claude-3-opus-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "your-access-key", "secret_key": "your-secret-key", - "region": "us-east-1", - "deployments": { - "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - "claude-opus": "us.anthropic.claude-3-opus-20240229-v1:0" - } + "region": "us-east-1" } } ] @@ -498,10 +496,10 @@ Provider configuration with deployment mapping: ``` **What Happens**: -1. **Allowed models**: `["claude-sonnet", "claude-opus"]` (from deployment aliases) +1. **Allowed models**: `["claude-sonnet", "claude-opus"]` (from alias keys) 2. **User requests**: `{"model": "claude-sonnet"}` 3. **Bifrost validates**: `claude-sonnet` in allowed models ✅ -4. **Maps to inference profile**: `claude-sonnet` → `us.anthropic.claude-3-5-sonnet-20241022-v2:0` +4. **Resolves alias**: `claude-sonnet` → `us.anthropic.claude-3-5-sonnet-20241022-v2:0` 5. **Sent to Bedrock**: Full ARN used in API call **Priority of Model Restrictions**: @@ -509,7 +507,7 @@ Provider configuration with deployment mapping: When determining allowed models for a key: ``` 1. If key.models is NOT empty → Use key.models -2. Else if deployments exist → Use deployment aliases (map keys) +2. Else if aliases exist → Use alias keys 3. Else → All models allowed (use Model Catalog) ``` @@ -519,11 +517,12 @@ When determining allowed models for a key: "keys": [ { "models": ["gpt-4o", "gpt-3.5-turbo"], // Explicit restriction + "aliases": { + "gpt-4o": "my-deployment", + "gpt-4-turbo": "another-deployment" // NOT accessible! + }, "azure_key_config": { - "deployments": { - "gpt-4o": "my-deployment", - "gpt-4-turbo": "another-deployment" // NOT accessible! - } + "endpoint": "https://your-resource.openai.azure.com" } } ] @@ -536,39 +535,39 @@ Result: Only `["gpt-4o", "gpt-3.5-turbo"]` allowed (models field takes priority) { "keys": [ { + "aliases": { + "claude-3-5-sonnet": "anthropic/claude-3-5-sonnet@20241022", + "gemini-pro": "google/gemini-1.5-pro" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "claude-3-5-sonnet": "anthropic/claude-3-5-sonnet@20241022", - "gemini-pro": "google/gemini-1.5-pro" - } + "region": "us-central1" } } ] } ``` -**Use Cases for Deployments**: +**Use Cases for Aliases**: - **Azure**: Map generic model names to specific deployment names in your Azure resource - **Bedrock**: Use short aliases for long inference profile ARNs - **Vertex**: Map to specific model versions or regional endpoints -- **Multi-environment**: Different deployments per key (dev/staging/prod) +- **Multi-environment**: Different aliases per key (dev/staging/prod) **Key Insight**: ``` User Request: {"model": "gpt-4o"} ↓ -Validation: Check if "gpt-4o" in allowed models (derived from deployments) +Validation: Check if "gpt-4o" in allowed models (derived from aliases) ↓ -Mapping: deployments["gpt-4o"] → "my-prod-gpt4o-deployment" +Mapping: aliases["gpt-4o"] → "my-prod-gpt4o-deployment" ↓ API Call: Uses "my-prod-gpt4o-deployment" as deployment ID ↓ -Pricing: Falls back to "gpt-4o" if deployment not in pricing data +Pricing: Falls back to "gpt-4o" if resolved ID not in pricing data ``` -This allows user-friendly model names in requests while supporting provider-specific deployment patterns at the key level. +This allows user-friendly model names in requests while supporting provider-specific identifier patterns at the key level. diff --git a/docs/providers/supported-providers/azure.mdx b/docs/providers/supported-providers/azure.mdx index 1020b0b855..46f41a8840 100644 --- a/docs/providers/supported-providers/azure.mdx +++ b/docs/providers/supported-providers/azure.mdx @@ -116,12 +116,12 @@ detects the auth environment. ```json { + "aliases": { + "gpt-4": "my-gpt4-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -132,18 +132,18 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent ```json { + "aliases": { + "gpt-4": "my-gpt4-deployment", + "gpt-4-turbo": "my-gpt4-turbo-deployment", + "claude-3": "my-claude-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", "client_id": "your-client-id", "client_secret": "your-client-secret", "tenant_id": "your-tenant-id", "scopes": ["https://cognitiveservices.azure.com/.default"], - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment", - "gpt-4-turbo": "my-gpt4-turbo-deployment", - "claude-3": "my-claude-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -156,14 +156,15 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent ```json { + "value": "your-azure-api-key", + "aliases": { + "gpt-4": "my-gpt4-deployment", + "gpt-4-turbo": "my-gpt4-turbo-deployment", + "claude-3": "my-claude-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment", - "gpt-4-turbo": "my-gpt4-turbo-deployment", - "claude-3": "my-claude-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -175,7 +176,7 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent - `tenant_id` - Azure Entra ID tenant ID (optional, for Service Principal auth) - `scopes` - OAuth scopes for token requests (default: `["https://cognitiveservices.azure.com/.default"]`) - `api_version` - API version to use (default: `2024-10-21`) -- `deployments` - Map of model names to deployment IDs (optional, can be provided per-request) +- `aliases` - Map of model names to Azure deployment IDs (optional, set at key level) - `allowed_models` - List of allowed models to use from this key (optional) ### Deployment Selection @@ -189,7 +190,7 @@ Deployments can be specified at three levels (in order of precedence): 2. **Key configuration** ```json - {"deployments": {"gpt-4": "my-gpt4-deployment"}} + {"aliases": {"gpt-4": "my-gpt4-deployment"}} ``` 3. **Model name** (lowest priority, if no deployment specified) diff --git a/docs/providers/supported-providers/bedrock.mdx b/docs/providers/supported-providers/bedrock.mdx index 1ca8b8efb2..52ff79ca3f 100644 --- a/docs/providers/supported-providers/bedrock.mdx +++ b/docs/providers/supported-providers/bedrock.mdx @@ -1206,7 +1206,7 @@ S3-backed file operations. Files are stored in S3 buckets integrated with Bedroc - Deployment mapping from configuration - Model allowlist support (`allowed_models` config) -**Multi-key support**: Results aggregated from all keys, filtered by `allowedModels` if configured +**Multi-key support**: Results aggregated from all keys, filtered by the key-level `models` allowlist if configured --- @@ -1280,43 +1280,43 @@ When using AWS Bedrock inference profiles or application inference profiles, you | Field | Purpose | |-------|---------| | **`arn`** | The ARN prefix (everything before the final `/resource-id`). Required for URL formation when using inference profiles. | -| **`deployments`** | Map logical model names to the **model ID or inference profile resource ID only** — not the full ARN. | +| **`aliases`** | Map logical model names to the **model ID or inference profile resource ID only** — not the full ARN. Set at the key level, not inside `bedrock_key_config`. | -**Do not** put the full ARN in the deployments mapping. The resource ID (e.g., `abc12xyz`) goes in `deployments`; the ARN prefix goes in the dedicated `arn` field. Putting the full ARN in `deployments` causes malformed URLs and `UnknownOperationException`. +**Do not** put the full ARN in the aliases mapping. The resource ID (e.g., `abc12xyz`) goes in `aliases`; the ARN prefix goes in the dedicated `arn` field inside `bedrock_key_config`. Putting the full ARN in `aliases` causes malformed URLs and `UnknownOperationException`. -**Application inference profiles** — use the resource ID (short alphanumeric suffix) in deployments: +**Application inference profiles** — use the resource ID (short alphanumeric suffix) in aliases: ```json { + "aliases": { + "claude-opus-4-6": "ghi56rst", + "claude-sonnet-4-5": "jkl78mno" + }, "bedrock_key_config": { "access_key": "your-aws-access-key", "secret_key": "your-aws-secret-key", "session_token": "optional-session-token", "region": "eu-west-1", - "arn": "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile", - "deployments": { - "claude-opus-4-6": "ghi56rst", - "claude-sonnet-4-5": "jkl78mno" - } + "arn": "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile" } } ``` -**Cross-region inference profiles** — use the model identifier (e.g., `us.anthropic.claude-3-5-sonnet-v1:0`) in deployments: +**Cross-region inference profiles** — use the model identifier (e.g., `us.anthropic.claude-3-5-sonnet-v1:0`) in aliases: ```json { + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-v1:0" + }, "bedrock_key_config": { "access_key": "your-aws-access-key", "secret_key": "your-aws-secret-key", "session_token": "optional-session-token", "region": "us-east-1", - "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile", - "deployments": { - "claude-sonnet": "us.anthropic.claude-3-5-sonnet-v1:0" - } + "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } ``` diff --git a/docs/providers/supported-providers/replicate.mdx b/docs/providers/supported-providers/replicate.mdx index bec376b731..4ae19e511e 100644 --- a/docs/providers/supported-providers/replicate.mdx +++ b/docs/providers/supported-providers/replicate.mdx @@ -77,10 +77,8 @@ Configure deployed models in the Replicate key configuration. Deployments map cu { "provider": "replicate", "value": "your-api-key", - "replicate_key_config": { - "deployments": { - "my-model": "owner/my-deployment-name" - } + "aliases": { + "my-model": "owner/my-deployment-name" } } ``` diff --git a/docs/providers/supported-providers/vertex.mdx b/docs/providers/supported-providers/vertex.mdx index 538664821a..4ea108d733 100644 --- a/docs/providers/supported-providers/vertex.mdx +++ b/docs/providers/supported-providers/vertex.mdx @@ -522,27 +522,29 @@ To provide a complete model listing experience, Bifrost performs **multi-pass mo - Custom models are identified by having deployment values that contain only digits - Example: `"deployment": "1234567890"` -2. **Second Pass - Non-Custom Models from Deployments** - - Adds standard foundation models from your `deployments` configuration +2. **Second Pass - Non-Custom Models from Aliases** + - Adds standard foundation models from your `aliases` configuration - Non-custom models have alphanumeric deployment values (e.g., `gemini-pro`, `claude-3-5-sonnet`) - - Filters by `allowedModels` if specified + - Filters by the key-level `models` allowlist, if specified - Example: `"deployment": "gemini-2.0-flash"` -3. **Third Pass - Allowed Models Not in Deployments** - - Adds models specified in `allowedModels` that weren't in the `deployments` map +3. **Third Pass - Allowed Models Not in Aliases** + - Adds models specified in `models` that weren't in the `aliases` map - Ensures all explicitly allowed models appear in the list - Uses the model name itself as the deployment value - Skips digit-only model IDs (reserved for custom models) ### Model Filtering Logic -- **If `allowedModels` is empty**: All models from all three passes are included -- **If `allowedModels` is non-empty**: Only models/deployments with keys in `allowedModels` are included +- **If `models` is empty and no aliases are configured**: No models are returned +- **If `models` is empty but aliases are configured**: Only aliased models are returned +- **If `models` is `["*"]`**: All models from all three passes are included (unrestricted) +- **If `models` is non-empty**: Only models/aliases whose request names appear in `models` are included - **Duplicate Prevention**: Each model ID is tracked to prevent duplicates across passes ### Model Name Formatting -Non-custom models from deployments and allowed models are automatically formatted for display: +Non-custom models from aliases and allowed models are automatically formatted for display: - `gemini-pro` → "Gemini Pro" - `claude-3-5-sonnet` → "Claude 3 5 Sonnet" @@ -557,13 +559,13 @@ Formatting uses title case and converts hyphens/underscores to spaces. ```json { + "aliases": { + "my-gemini-ft": "1234567890", + "my-claude-ft": "9876543210" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "my-gemini-ft": "1234567890", - "my-claude-ft": "9876543210" - } + "region": "us-central1" } } ``` @@ -575,33 +577,33 @@ This returns only your custom fine-tuned models from the API. ```json { + "aliases": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "gemini-2.0-flash": "gemini-2.0-flash", - "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022" - } + "region": "us-central1" } } ``` -This returns both custom models AND foundation models from deployments. +This returns both custom models AND foundation models from aliases. ```json { + "models": ["gemini-2.0-flash", "claude-3-5-sonnet"], + "aliases": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022", + "gemini-1.5-pro": "gemini-1.5-pro" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "gemini-2.0-flash": "gemini-2.0-flash", - "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022", - "gemini-1.5-pro": "gemini-1.5-pro" - }, - "allowedModels": ["gemini-2.0-flash", "claude-3-5-sonnet"] + "region": "us-central1" } } ``` @@ -664,7 +666,7 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p **Severity**: High **Behavior**: Vertex AI's List Models API only returns custom fine-tuned models, NOT foundation models -**Impact**: Bifrost performs three-pass discovery to include foundation models from deployments and allowedModels configuration +**Impact**: Bifrost performs three-pass discovery to include foundation models from aliases and the key-level `models` allowlist **Why**: This is a Vertex AI API limitation - foundation models must be explicitly configured **Code**: `models.go:76-217` diff --git a/docs/quickstart/gateway/provider-configuration.mdx b/docs/quickstart/gateway/provider-configuration.mdx index 2e4e2b71e3..5986ed87f6 100644 --- a/docs/quickstart/gateway/provider-configuration.mdx +++ b/docs/quickstart/gateway/provider-configuration.mdx @@ -1054,7 +1054,7 @@ Azure supports three authentication methods: **Managed Identity** (DefaultAzureC #### Managed Identity / DefaultAzureCredential -Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredential`, which auto-detects managed identity on Azure VMs, App Service, AKS, and similar environments. Provide only `endpoint`, `deployments`, and optionally `api_version`. +Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredential`, which auto-detects managed identity on Azure VMs, App Service, AKS, and similar environments. Provide only `endpoint` and optionally `api_version`. #### Azure Entra ID (Service Principal) @@ -1070,7 +1070,7 @@ Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredenti 4. Set **Client Secret**: Your Azure Entra ID client secret 5. Set **Tenant ID**: Your Azure Entra ID tenant ID 6. Set **Endpoint**: Your Azure endpoint URL -7. Configure **Deployments**: Map model names to deployment names +7. Configure **Aliases**: Map model names to deployment names 8. Set **API Version**: e.g., `2024-08-01-preview` 9. Save configuration @@ -1089,16 +1089,16 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", "client_id": "env.AZURE_CLIENT_ID", "client_secret": "env.AZURE_CLIENT_SECRET", "tenant_id": "env.AZURE_TENANT_ID", "scopes": ["https://cognitiveservices.azure.com/.default"], - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1120,16 +1120,16 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", "client_id": "env.AZURE_CLIENT_ID", "client_secret": "env.AZURE_CLIENT_SECRET", "tenant_id": "env.AZURE_TENANT_ID", "scopes": ["https://cognitiveservices.azure.com/.default"], - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1156,7 +1156,7 @@ For simpler use cases, provide the authentication credential directly in the `va 1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** 2. Set **API Key**: Your Azure API key 3. Set **Endpoint**: Your Azure endpoint URL -4. Configure **Deployments**: Map model names to deployment names +4. Configure **Aliases**: Map model names to deployment names 5. Set **API Version**: e.g., `2024-08-01-preview` 6. Save configuration @@ -1175,12 +1175,12 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "env.AZURE_API_KEY", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1202,12 +1202,12 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "env.AZURE_API_KEY", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1240,8 +1240,8 @@ AWS Bedrock supports both explicit credentials and IAM role authentication: 3. Set **Access Key**: AWS Access Key ID (or leave empty to use IAM in environment) 4. Set **Secret Key**: AWS Secret Access Key (or leave empty to use IAM in environment) 5. Set **Region**: e.g., `us-east-1` -6. Configure **Deployments**: Map model names to inference profiles -7. Set **ARN**: Required for deployments mapping +6. Configure **Aliases**: Map model names to inference profiles +7. Set **ARN**: Required only when Bifrost must construct a full inference-profile ARN for an alias 8. Save configuration @@ -1256,16 +1256,16 @@ curl --location 'http://localhost:8080/api/providers' \ "keys": [ { "name": "bedrock-key-1", - "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"], "weight": 1.0, + "aliases": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", "session_token": "env.AWS_SESSION_TOKEN", "region": "us-east-1", - "deployments": { - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" - }, "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } @@ -1284,16 +1284,16 @@ curl --location 'http://localhost:8080/api/providers' \ "keys": [ { "name": "bedrock-key-1", - "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"], "weight": 1.0, + "aliases": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", "session_token": "env.AWS_SESSION_TOKEN", "region": "us-east-1", - "deployments": { - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" - }, "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } @@ -1310,9 +1310,9 @@ curl --location 'http://localhost:8080/api/providers' \ **Notes:** - If using API Key authentication, set `value` field to the API key, else leave it empty for IAM role authentication. - In IAM role authentication, if both `access_key` and `secret_key` are empty, Bifrost uses IAM role authentication from the environment. -- `arn` is required for URL formation - `deployments` mapping is ignored without it. -- When using `arn` + `deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. -- **ARN vs deployments**: Put the ARN prefix in `arn` and the model/inference profile resource ID only in `deployments` — never the full ARN in deployments. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. +- `arn` is required when you want Bifrost to build a full inference-profile ARN from an alias target. +- Aliases are still resolved before provider dispatch; without `arn`, the resolved alias value is sent as the Bedrock model/profile identifier directly. +- **ARN vs aliases**: Put the ARN prefix in `arn` and the model/inference profile resource ID only in the key-level `aliases` map — never the full ARN in alias values. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. ### Google Vertex @@ -1343,15 +1343,16 @@ curl --location 'http://localhost:8080/api/providers' \ { "name": "vertex-key-1", "value": "env.VERTEX_API_KEY", - "models": ["gemini-pro", "gemini-pro-vision"], + "models": ["gemini-pro", "gemini-pro-vision", "123456789", "fine-tuned-gemini-2.5-pro"], "weight": 1.0, + "aliases": { + "fine-tuned-gemini-2.5-pro": "123456789" + }, "vertex_key_config": { "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS", - "deployments": { - "fine-tuned-gemini-2.5-pro": "123456789" - } + "auth_credentials": "env.VERTEX_CREDENTIALS" } } ] @@ -1370,15 +1371,16 @@ curl --location 'http://localhost:8080/api/providers' \ { "name": "vertex-key-1", "value": "env.VERTEX_API_KEY", - "models": ["gemini-pro", "gemini-pro-vision"], + "models": ["gemini-pro", "gemini-pro-vision", "123456789", "fine-tuned-gemini-2.5-pro"], "weight": 1.0, + "aliases": { + "fine-tuned-gemini-2.5-pro": "123456789" + }, "vertex_key_config": { "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS", - "deployments": { - "fine-tuned-gemini-2.5-pro": "123456789" - } + "auth_credentials": "env.VERTEX_CREDENTIALS" } } ] @@ -1395,7 +1397,7 @@ curl --location 'http://localhost:8080/api/providers' \ - You can leave both API Key and Auth Credentials empty to use service account authentication from the environment. - You must set Project Number in Key config if using fine-tuned models. - API Key Authentication is only supported for Gemini and fine-tuned models. -- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the deployments in the key config. +- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the aliases on the key. Vertex AI support for fine-tuned models is currently in beta. Requests to non-Gemini fine-tuned models may fail, so please test and report any issues. diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx index 448d5e8b39..87c2901ece 100644 --- a/docs/quickstart/go-sdk/provider-configuration.mdx +++ b/docs/quickstart/go-sdk/provider-configuration.mdx @@ -417,16 +417,17 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo Value: "", // Leave empty for Service Principal auth Models: []string{"gpt-4o", "gpt-4o-mini"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: os.Getenv("AZURE_ENDPOINT"), ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")), ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")), TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment", - }, - APIVersion: bifrost.Ptr("2024-08-01-preview"), + Scopes: []string{"https://cognitiveservices.azure.com/.default"}, + APIVersion: bifrost.Ptr("2024-08-01-preview"), }, }, }, nil @@ -448,12 +449,12 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo Value: os.Getenv("AZURE_OPENAI_KEY"), Models: []string{"gpt-4o", "gpt-4o-mini"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment", - }, + Endpoint: os.Getenv("AZURE_ENDPOINT"), APIVersion: bifrost.Ptr("2024-08-01-preview"), }, }, @@ -479,19 +480,19 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo case schemas.Bedrock: return []schemas.Key{ { - Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"}, + Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"}, Weight: 1.0, Value: os.Getenv("AWS_API_KEY"), // Leave empty for IAM role authentication + // Model profiles (inference profiles): map short names to profile resource IDs + Aliases: schemas.KeyAliases{ + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), // Leave empty for API Key authentication or system's IAM pickup SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), // Leave empty for API Key authentication or system's IAM pickup SessionToken: bifrost.Ptr(os.Getenv("AWS_SESSION_TOKEN")), // Optional Region: bifrost.Ptr("us-east-1"), - // For model profiles (inference profiles) - Deployments: map[string]string{ - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0", - }, - // For direct model access without profiles + // ARN prefix for profile URLs; put resource IDs only in Aliases, not full ARNs ARN: bifrost.Ptr("arn:aws:bedrock:us-east-1:123456789012:inference-profile"), }, }, @@ -504,9 +505,9 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo **Notes:** - If using API Key authentication, set `Value` field to the API key, else leave it empty for IAM role authentication. - In IAM role authentication, if both `AccessKey` and `SecretKey` are empty, Bifrost uses IAM from the environment. -- `ARN` is required for URL formation - `Deployments` mapping is ignored without it. -- When using `ARN` + `Deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. -- **ARN vs Deployments**: Put the ARN prefix in `ARN` and the model/inference profile resource ID only in `Deployments` — never the full ARN in Deployments. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. +- `ARN` is required when you want Bifrost to build a full inference-profile ARN from an alias target. +- Aliases are still resolved before provider dispatch; without `ARN`, the resolved alias value is sent as the Bedrock model/profile identifier directly. +- **ARN vs Aliases**: Put the ARN prefix in `ARN` and the model/inference profile resource ID only in `Aliases` — never the full ARN in alias values. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. @@ -521,16 +522,16 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo return []schemas.Key{ { Value: os.Getenv("VERTEX_API_KEY"), // only when using gemini or fine-tuned models - Models: []string{"gemini-pro", "gemini-pro-vision"}, + Models: []string{"gemini-pro", "gemini-pro-vision", "fine-tuned-gemini-2.5-pro"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "fine-tuned-gemini-2.5-pro": "123456789", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ - ProjectID: os.Getenv("VERTEX_PROJECT_ID"), // GCP project ID - ProjectNumber: os.Getenv("VERTEX_PROJECT_NUMBER"), // GCP project number (only when using fine-tuned models) - Region: "us-central1", // GCP region - AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), // Service account JSON - Deployments: map[string]string{ - "fine-tuned-gemini-2.5-pro": "123456789" - }, + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), // GCP project ID + ProjectNumber: os.Getenv("VERTEX_PROJECT_NUMBER"), // GCP project number (only when using fine-tuned models) + Region: "us-central1", // GCP region + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), // Service account JSON }, }, }, nil @@ -543,7 +544,7 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo - You can leave both API Key and Auth Credentials empty to use service account authentication from the environment. - You must set Project Number if using fine-tuned models. - API Key Authentication is only supported for Gemini and fine-tuned models. -- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the deployments in the key config. +- You can use custom fine-tuned models by passing `vertex/` if you have set the aliases on the key. Vertex AI support for fine-tuned models is currently in beta. Requests to non-Gemini fine-tuned models may fail, so please test and report any issues. diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 3c5d7d5b78..7e2b3fe6d1 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "maps" "sort" "strconv" @@ -338,6 +339,9 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { enabled := *key.Enabled redactedConfig.Keys[i].Enabled = &enabled } + if key.Aliases != nil { + redactedConfig.Keys[i].Aliases = maps.Clone(key.Aliases) + } redactedConfig.Keys[i].Value = *key.Value.Redacted() // Add back use for batch api if key.UseForBatchAPI != nil { @@ -352,9 +356,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Azure key config if present if key.AzureKeyConfig != nil { - azureConfig := &schemas.AzureKeyConfig{ - Deployments: key.AzureKeyConfig.Deployments, - } + azureConfig := &schemas.AzureKeyConfig{} azureConfig.Endpoint = *key.AzureKeyConfig.Endpoint.Redacted() azureConfig.APIVersion = key.AzureKeyConfig.APIVersion if key.AzureKeyConfig.ClientID != nil { @@ -374,9 +376,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Vertex key config if present if key.VertexKeyConfig != nil { - vertexConfig := &schemas.VertexKeyConfig{ - Deployments: key.VertexKeyConfig.Deployments, - } + vertexConfig := &schemas.VertexKeyConfig{} vertexConfig.ProjectID = *key.VertexKeyConfig.ProjectID.Redacted() vertexConfig.ProjectNumber = *key.VertexKeyConfig.ProjectNumber.Redacted() vertexConfig.Region = *key.VertexKeyConfig.Region.Redacted() @@ -386,9 +386,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Bedrock key config if present if key.BedrockKeyConfig != nil { - bedrockConfig := &schemas.BedrockKeyConfig{ - Deployments: key.BedrockKeyConfig.Deployments, - } + bedrockConfig := &schemas.BedrockKeyConfig{} bedrockConfig.AccessKey = *key.BedrockKeyConfig.AccessKey.Redacted() bedrockConfig.SecretKey = *key.BedrockKeyConfig.SecretKey.Redacted() if key.BedrockKeyConfig.SessionToken != nil { @@ -416,13 +414,6 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { redactedConfig.Keys[i].BedrockKeyConfig = bedrockConfig } - if key.ReplicateKeyConfig != nil { - replicateConfig := &schemas.ReplicateKeyConfig{ - Deployments: key.ReplicateKeyConfig.Deployments, - } - redactedConfig.Keys[i].ReplicateKeyConfig = replicateConfig - } - if key.VLLMKeyConfig != nil { vllmConfig := &schemas.VLLMKeyConfig{ ModelName: key.VLLMKeyConfig.ModelName, @@ -431,6 +422,13 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { redactedConfig.Keys[i].VLLMKeyConfig = vllmConfig } + if key.ReplicateKeyConfig != nil { + replicateConfig := &schemas.ReplicateKeyConfig{ + UseDeploymentsEndpoint: key.ReplicateKeyConfig.UseDeploymentsEndpoint, + } + redactedConfig.Keys[i].ReplicateKeyConfig = replicateConfig + } + if key.OllamaKeyConfig != nil { ollamaConfig := &schemas.OllamaKeyConfig{} ollamaConfig.URL = *key.OllamaKeyConfig.URL.Redacted() @@ -584,9 +582,9 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } - // Hash ReplicateKeyConfig - if key.ReplicateKeyConfig != nil { - data, err := sonic.Marshal(key.ReplicateKeyConfig) + // Hash Aliases + if key.Aliases != nil { + data, err := sonic.Marshal(key.Aliases) if err != nil { return "", err } @@ -600,6 +598,14 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } + // Hash ReplicateKeyConfig + if key.ReplicateKeyConfig != nil { + data, err := sonic.Marshal(key.ReplicateKeyConfig) + if err != nil { + return "", err + } + hash.Write(data) + } // Hash OllamaKeyConfig if key.OllamaKeyConfig != nil { data, err := sonic.Marshal(key.OllamaKeyConfig) diff --git a/framework/configstore/encryption_test.go b/framework/configstore/encryption_test.go index 0d1f3625cd..9ac36baede 100644 --- a/framework/configstore/encryption_test.go +++ b/framework/configstore/encryption_test.go @@ -725,7 +725,7 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test now := time.Now().UTC().Format("2006-01-02 15:04:05") insertPlaintextRow(t, db, - `INSERT INTO config_keys (name, provider_id, provider, key_id, value, bedrock_access_key, bedrock_secret_key, bedrock_session_token, bedrock_region, bedrock_arn, bedrock_deployments_json, bedrock_batch_s3_config_json, encryption_status, created_at, updated_at) + `INSERT INTO config_keys (name, provider_id, provider, key_id, value, bedrock_access_key, bedrock_secret_key, bedrock_session_token, bedrock_region, bedrock_arn, aliases_json, bedrock_batch_s3_config_json, encryption_status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'plain_text', ?, ?)`, "bedrock-key", 1, "bedrock", "br-1", "sk-bedrock-key-value", "AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "FwoGZXIvYXdzEBYaDH7sampleSessionToken", @@ -747,9 +747,15 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test assert.NotEqual(t, "FwoGZXIvYXdzEBYaDH7sampleSessionToken", raw["bedrock_session_token"]) assert.NotEqual(t, "us-west-2", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:iam::123456789:role/bedrock", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-claude") + rawAliasesVal := raw["aliases_json"] + var rawAliasesStr string + switch v := rawAliasesVal.(type) { + case string: + rawAliasesStr = v + case []byte: + rawAliasesStr = string(v) } + assert.NotContains(t, rawAliasesStr, "profile-claude") if rawBatch, ok := raw["bedrock_batch_s3_config_json"].(string); ok { assert.NotContains(t, rawBatch, "my-bucket") } @@ -767,7 +773,7 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test assert.Equal(t, "us-west-2", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:iam::123456789:role/bedrock", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-claude", found.BedrockKeyConfig.Deployments["claude-3"]) + assert.Equal(t, "profile-claude", found.Aliases["claude-3"]) require.NotNil(t, found.BedrockKeyConfig.BatchS3Config) require.Len(t, found.BedrockKeyConfig.BatchS3Config.Buckets, 1) assert.Equal(t, "my-bucket", found.BedrockKeyConfig.BatchS3Config.Buckets[0].BucketName) diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index d1ce44139a..d9592fdd3e 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -346,6 +346,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddChainRuleColumnToRoutingRules(ctx, db); err != nil { return err } + if err := migrationDropDeploymentColumnsAndAddAliases(ctx, db); err != nil { + return err + } + if err := migrationAddReplicateKeyConfigColumn(ctx, db); err != nil { + return err + } if err := migrationAddBudgetCalendarAlignedColumn(ctx, db); err != nil { return err } @@ -1320,15 +1326,15 @@ func migrationAddVertexProjectNumberColumn(ctx context.Context, db *gorm.DB) err return nil } -// migrationAddVertexDeploymentsJSONColumn adds the vertex_deployments_json column to the key table +// migrationAddVertexDeploymentsJSONColumn adds the vertex_deployments_json column to the key table. +// This column is later dropped by migrationDropDeploymentColumnsAndAddAliases after data is migrated. func migrationAddVertexDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ ID: "add_vertex_deployments_json_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableKey{}, "vertex_deployments_json") { - if err := migrator.AddColumn(&tables.TableKey{}, "vertex_deployments_json"); err != nil { + if !tx.Migrator().HasColumn(&tables.TableKey{}, "vertex_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys ADD COLUMN vertex_deployments_json TEXT").Error; err != nil { return err } } @@ -1336,15 +1342,15 @@ func migrationAddVertexDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) e }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if err := migrator.DropColumn(&tables.TableKey{}, "vertex_deployments_json"); err != nil { - return err + if tx.Migrator().HasColumn(&tables.TableKey{}, "vertex_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN vertex_deployments_json").Error; err != nil { + return err + } } return nil }, }}) - err := m.Migrate() - if err != nil { + if err := m.Migrate(); err != nil { return fmt.Errorf("error while running vertex deployments JSON migration: %s", err.Error()) } return nil @@ -2044,14 +2050,14 @@ func migrationAddConfigHashColumn(ctx context.Context, db *gorm.DB) error { if key.ConfigHash == "" { // Convert to schemas.Key and generate hash schemaKey := schemas.Key{ - Name: key.Name, - Value: key.Value, - Models: key.Models, - Weight: getWeight(key.Weight), - AzureKeyConfig: key.AzureKeyConfig, - VertexKeyConfig: key.VertexKeyConfig, - BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Name: key.Name, + Value: key.Value, + Models: key.Models, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, } hash, err := GenerateKeyHash(schemaKey) if err != nil { @@ -3520,15 +3526,15 @@ func migrationAddAzureScopesColumn(ctx context.Context, db *gorm.DB) error { return nil } -// migrationAddReplicateDeploymentsJSONColumn adds the replicate_deployments_json column to the key table +// migrationAddReplicateDeploymentsJSONColumn adds the replicate_deployments_json column to the key table. +// This column is later dropped by migrationDropDeploymentColumnsAndAddAliases after data is migrated. func migrationAddReplicateDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ ID: "add_replicate_deployments_json_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableKey{}, "replicate_deployments_json") { - if err := migrator.AddColumn(&tables.TableKey{}, "replicate_deployments_json"); err != nil { + if !tx.Migrator().HasColumn(&tables.TableKey{}, "replicate_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys ADD COLUMN replicate_deployments_json TEXT").Error; err != nil { return err } } @@ -3536,20 +3542,123 @@ func migrationAddReplicateDeploymentsJSONColumn(ctx context.Context, db *gorm.DB }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if err := migrator.DropColumn(&tables.TableKey{}, "replicate_deployments_json"); err != nil { - return err + if tx.Migrator().HasColumn(&tables.TableKey{}, "replicate_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN replicate_deployments_json").Error; err != nil { + return err + } } return nil }, }}) - err := m.Migrate() - if err != nil { + if err := m.Migrate(); err != nil { return fmt.Errorf("error while running replicate deployments JSON migration: %s", err.Error()) } return nil } +// migrationDropDeploymentColumnsAndAddAliases adds the unified aliases_json column, migrates +// existing per-provider deployment data into it, then drops the legacy columns. +// Only one deployment column will be populated per row (they were mutually exclusive). +func migrationDropDeploymentColumnsAndAddAliases(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "drop_deployment_columns_and_add_aliases", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + m := tx.Migrator() + + // Add aliases_json column first + if !m.HasColumn(&tables.TableKey{}, "aliases_json") { + if err := m.AddColumn(&tables.TableKey{}, "aliases_json"); err != nil { + return err + } + } + + // Copy data from whichever legacy deployment column is populated into aliases_json. + // Only rows where aliases_json is not already set are touched. + // Exactly one deployment column will be non-null per row (they were mutually exclusive). + for _, col := range []string{ + "azure_deployments_json", + "vertex_deployments_json", + "bedrock_deployments_json", + "replicate_deployments_json", + } { + if !m.HasColumn(&tables.TableKey{}, col) { + continue + } + if err := tx.Exec( + "UPDATE config_keys SET aliases_json = " + col + + " WHERE aliases_json IS NULL AND " + col + " IS NOT NULL AND " + col + " != ''", + ).Error; err != nil { + return err + } + } + + // Drop legacy deployment columns + for _, col := range []string{ + "azure_deployments_json", + "vertex_deployments_json", + "bedrock_deployments_json", + "replicate_deployments_json", + } { + if m.HasColumn(&tables.TableKey{}, col) { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN " + col).Error; err != nil { + return err + } + } + } + + // Recompute config_hash for keys that had aliases_json populated above, + // since aliases_json is part of the hash input and these rows now have stale hashes. + var affectedKeys []tables.TableKey + if err := tx.Where( + "aliases_json IS NOT NULL AND aliases_json != ? AND aliases_json != ?", "", "{}", + ).Find(&affectedKeys).Error; err != nil { + return fmt.Errorf("failed to fetch keys for hash recomputation: %w", err) + } + for _, key := range affectedKeys { + schemaKey := schemas.Key{ + Name: key.Name, + Value: key.Value, + Models: key.Models, + BlacklistedModels: key.BlacklistedModels, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + } + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for key %s: %w", key.Name, err) + } + if err := tx.Model(&key).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for key %s: %w", key.Name, err) + } + log.Printf("[Migration] Recomputed config_hash for key '%s' after aliases migration", key.Name) + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + m := tx.Migrator() + if m.HasColumn(&tables.TableKey{}, "aliases_json") { + if err := m.DropColumn(&tables.TableKey{}, "aliases_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running drop deployment columns and add aliases migration: %s", err.Error()) + } + return nil +} + // migrationAddKeyStatusColumns adds status and description columns to config_keys table // These columns track the status and description of each individual key func migrationAddKeyStatusColumns(ctx context.Context, db *gorm.DB) error { @@ -5095,8 +5204,9 @@ func migrationBackfillAllowedModelsWildcard(ctx context.Context, db *gorm.DB) er AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, Enabled: key.Enabled, @@ -5375,6 +5485,82 @@ func migrationAddChainRuleColumnToRoutingRules(ctx context.Context, db *gorm.DB) return nil } +// migrationAddReplicateKeyConfigColumn adds the replicate_use_deployments_endpoint column to the key table +func migrationAddReplicateKeyConfigColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_replicate_key_config_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint") { + if err := mg.AddColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint"); err != nil { + return err + } + // Backfill: Replicate keys that had deployments configured (now in aliases_json after + // migrationDropDeploymentColumnsAndAddAliases) were using the deployments endpoint. + trueVal := true + if err := tx.Model(&tables.TableKey{}). + Where("provider = ? AND aliases_json IS NOT NULL AND aliases_json != ? AND aliases_json != ?", + string(schemas.Replicate), "", "{}", + ). + Update("ReplicateUseDeploymentsEndpoint", &trueVal).Error; err != nil { + return err + } + + // Recompute config_hash for Replicate keys that were updated above, + // since replicate_use_deployments_endpoint is part of the hash input. + var affectedKeys []tables.TableKey + if err := tx.Where( + "provider = ? AND replicate_use_deployments_endpoint IS NOT NULL", + string(schemas.Replicate), + ).Find(&affectedKeys).Error; err != nil { + return fmt.Errorf("failed to fetch replicate keys for hash recomputation: %w", err) + } + for _, key := range affectedKeys { + schemaKey := schemas.Key{ + Name: key.Name, + Value: key.Value, + Models: key.Models, + BlacklistedModels: key.BlacklistedModels, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + } + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for key %s: %w", key.Name, err) + } + if err := tx.Model(&key).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for key %s: %w", key.Name, err) + } + log.Printf("[Migration] Recomputed config_hash for replicate key '%s' after replicate config backfill", key.Name) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint") { + if err := mg.DropColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_replicate_key_config_column migration: %s", err.Error()) + } + return nil +} + // migrationAddBudgetCalendarAlignedColumn adds the calendar_aligned column to the governance_budgets table. func migrationAddBudgetCalendarAlignedColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 3e942a14fc..ade838add8 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -48,8 +48,9 @@ func schemaKeyFromTableKey(dbKey tables.TableKey) schemas.Key { AzureKeyConfig: dbKey.AzureKeyConfig, VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, - ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, + ReplicateKeyConfig: dbKey.ReplicateKeyConfig, OllamaKeyConfig: dbKey.OllamaKeyConfig, SGLKeyConfig: dbKey.SGLKeyConfig, ConfigHash: dbKey.ConfigHash, @@ -73,8 +74,9 @@ func tableKeyFromSchemaKey(provider tables.TableProvider, key schemas.Key) (tabl AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: key.ConfigHash, @@ -379,8 +381,9 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, @@ -552,8 +555,9 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, @@ -677,8 +681,9 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, OllamaKeyConfig: key.OllamaKeyConfig, SGLKeyConfig: key.SGLKeyConfig, ConfigHash: key.ConfigHash, diff --git a/framework/configstore/tables/encryption_test.go b/framework/configstore/tables/encryption_test.go index 8807570e14..58297384fd 100644 --- a/framework/configstore/tables/encryption_test.go +++ b/framework/configstore/tables/encryption_test.go @@ -175,12 +175,12 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { Provider: "bedrock", KeyID: "bedrock-uuid-1", Value: *schemas.NewEnvVar("bedrock-val"), + Aliases: schemas.KeyAliases{"model-a": "profile-a"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ - AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), - SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), - Region: schemas.NewEnvVar("us-west-2"), - ARN: schemas.NewEnvVar("arn:aws:iam::123456789:role/test"), - Deployments: map[string]string{"model-a": "profile-a"}, + AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), + SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), + Region: schemas.NewEnvVar("us-west-2"), + ARN: schemas.NewEnvVar("arn:aws:iam::123456789:role/test"), BatchS3Config: &schemas.BatchS3Config{ Buckets: []schemas.S3BucketConfig{ {BucketName: "my-batch-bucket", Prefix: "jobs/", IsDefault: true}, @@ -197,9 +197,17 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { assert.NotEqual(t, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", raw["bedrock_secret_key"]) assert.NotEqual(t, "us-west-2", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:iam::123456789:role/test", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-a") - } + rawAliasesVal := raw["aliases_json"] + require.NotNil(t, rawAliasesVal, "aliases_json should be present in raw row") + var rawAliasesStr string + switch v := rawAliasesVal.(type) { + case string: + rawAliasesStr = v + case []byte: + rawAliasesStr = string(v) + } + require.NotEmpty(t, rawAliasesStr, "aliases_json should not be empty") + assert.NotContains(t, rawAliasesStr, "profile-a") if rawBatch, ok := raw["bedrock_batch_s3_config_json"].(string); ok { assert.NotContains(t, rawBatch, "my-batch-bucket") } @@ -213,7 +221,7 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { assert.Equal(t, "us-west-2", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:iam::123456789:role/test", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-a", found.BedrockKeyConfig.Deployments["model-a"]) + assert.Equal(t, "profile-a", found.Aliases["model-a"]) require.NotNil(t, found.BedrockKeyConfig.BatchS3Config) require.Len(t, found.BedrockKeyConfig.BatchS3Config.Buckets, 1) assert.Equal(t, "my-batch-bucket", found.BedrockKeyConfig.BatchS3Config.Buckets[0].BucketName) @@ -1144,6 +1152,7 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { Provider: "custom", KeyID: "multi-uuid", Value: *schemas.NewEnvVar("multi-api-key"), + Aliases: schemas.KeyAliases{"claude-3": "profile-claude"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://azure.endpoint.com"), ClientID: schemas.NewEnvVar("multi-azure-cid"), @@ -1163,7 +1172,6 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { SessionToken: sessionToken, Region: schemas.NewEnvVar("eu-west-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:eu-west-1:123:role"), - Deployments: map[string]string{"claude-3": "profile-claude"}, }, } @@ -1180,9 +1188,17 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { assert.NotEqual(t, "us-central1", raw["vertex_region"]) assert.NotEqual(t, "eu-west-1", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:bedrock:eu-west-1:123:role", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-claude") - } + rawAliasesVal2 := raw["aliases_json"] + require.NotNil(t, rawAliasesVal2, "aliases_json should be present in raw row") + var rawAliasesStr2 string + switch v := rawAliasesVal2.(type) { + case string: + rawAliasesStr2 = v + case []byte: + rawAliasesStr2 = string(v) + } + require.NotEmpty(t, rawAliasesStr2, "aliases_json should not be empty") + assert.NotContains(t, rawAliasesStr2, "profile-claude") var found TableKey require.NoError(t, db.First(&found, key.ID).Error) @@ -1214,7 +1230,7 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { assert.Equal(t, "eu-west-1", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:bedrock:eu-west-1:123:role", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-claude", found.BedrockKeyConfig.Deployments["claude-3"]) + assert.Equal(t, "profile-claude", found.Aliases["claude-3"]) } // ============================================================================ @@ -1268,9 +1284,9 @@ func TestTableMCPClient_EncryptionDisabled_StoresPlaintext(t *testing.T) { db := setupTestDB(t) client := &TableMCPClient{ - ClientID: "mcp-dis-1", - Name: "disabled-mcp", - ConnectionType: "sse", + ClientID: "mcp-dis-1", + Name: "disabled-mcp", + ConnectionType: "sse", ConnectionString: schemas.NewEnvVar("https://mcp.example.com"), Headers: map[string]schemas.EnvVar{ "Authorization": *schemas.NewEnvVar("Bearer secret-token"), diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index 8c68243925..73ba61f2e1 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -29,21 +29,22 @@ type TableKey struct { // Config hash is used to detect changes synced from config.json file ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` + // Unified aliases + AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases + // Azure config fields (embedded instead of separate table for simplicity) - AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"` - AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"` - AzureDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"` - AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"` - AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"` - AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string + AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"` + AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"` + AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"` + AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"` + AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"` + AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string // Vertex config fields (embedded) VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"` VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"` VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"` VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"` - VertexDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Bedrock config fields (embedded) BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"` @@ -54,16 +55,15 @@ type TableKey struct { BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"` BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"` BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"` - BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config - // Replicate config fields (embedded) - ReplicateDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - // VLLM config fields (embedded) VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"` VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"` + // Replicate config fields (embedded) + ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"` + // Ollama config fields (embedded) OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"` @@ -81,11 +81,12 @@ type TableKey struct { // Virtual fields for runtime use (not stored in DB) Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default) BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"` + Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"` AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` - ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"` VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"` + ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"` OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"` SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"` } @@ -169,20 +170,9 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.AzureScopesJSON = nil } - if k.AzureKeyConfig.Deployments != nil { - data, err := json.Marshal(k.AzureKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.AzureDeploymentsJSON = &s - } else { - k.AzureDeploymentsJSON = nil - } } else { k.AzureEndpoint = nil k.AzureAPIVersion = nil - k.AzureDeploymentsJSON = nil k.AzureClientID = nil k.AzureClientSecret = nil k.AzureTenantID = nil @@ -213,22 +203,11 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.VertexAuthCredentials = nil } - if k.VertexKeyConfig.Deployments != nil { - data, err := json.Marshal(k.VertexKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.VertexDeploymentsJSON = &s - } else { - k.VertexDeploymentsJSON = nil - } } else { k.VertexProjectID = nil k.VertexProjectNumber = nil k.VertexRegion = nil k.VertexAuthCredentials = nil - k.VertexDeploymentsJSON = nil } if k.BedrockKeyConfig != nil { if k.BedrockKeyConfig.AccessKey.GetValue() != "" { @@ -282,16 +261,6 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.BedrockRoleSessionName = nil } - if k.BedrockKeyConfig.Deployments != nil { - data, err := sonic.Marshal(k.BedrockKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.BedrockDeploymentsJSON = &s - } else { - k.BedrockDeploymentsJSON = nil - } if k.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config) if err != nil { @@ -311,19 +280,21 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.BedrockRoleARN = nil k.BedrockExternalID = nil k.BedrockRoleSessionName = nil - k.BedrockDeploymentsJSON = nil k.BedrockBatchS3ConfigJSON = nil } - if k.ReplicateKeyConfig != nil && k.ReplicateKeyConfig.Deployments != nil { - data, err := sonic.Marshal(k.ReplicateKeyConfig.Deployments) + if k.Aliases != nil { + if err := k.Aliases.Validate(); err != nil { + return err + } + data, err := sonic.Marshal(k.Aliases) if err != nil { return err } s := string(data) - k.ReplicateDeploymentsJSON = &s + k.AliasesJSON = &s } else { - k.ReplicateDeploymentsJSON = nil + k.AliasesJSON = nil } if k.VLLMKeyConfig != nil { @@ -344,6 +315,13 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.VLLMModelName = nil } + if k.ReplicateKeyConfig != nil { + v := k.ReplicateKeyConfig.UseDeploymentsEndpoint + k.ReplicateUseDeploymentsEndpoint = &v + } else { + k.ReplicateUseDeploymentsEndpoint = nil + } + if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.GetValue() != "" { u := k.OllamaKeyConfig.URL k.OllamaUrl = &u @@ -417,12 +395,13 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil { return fmt.Errorf("failed to encrypt bedrock role session name: %w", err) } - if err := encryptString(k.BedrockDeploymentsJSON); err != nil { - return fmt.Errorf("failed to encrypt bedrock deployments: %w", err) - } if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil { return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err) } + // Aliases + if err := encryptString(k.AliasesJSON); err != nil { + return fmt.Errorf("failed to encrypt aliases: %w", err) + } // VLLM if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to encrypt vllm url: %w", err) @@ -503,12 +482,13 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil { return fmt.Errorf("failed to decrypt bedrock role session name: %w", err) } - if err := decryptString(k.BedrockDeploymentsJSON); err != nil { - return fmt.Errorf("failed to decrypt bedrock deployments: %w", err) - } if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil { return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err) } + // Aliases + if err := decryptString(k.AliasesJSON); err != nil { + return fmt.Errorf("failed to decrypt aliases: %w", err) + } // VLLM if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to decrypt vllm url: %w", err) @@ -562,20 +542,10 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { azureConfig.Endpoint = *k.AzureEndpoint } - if k.AzureDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.AzureDeploymentsJSON), &deployments); err != nil { - return err - } - azureConfig.Deployments = deployments - } else { - azureConfig.Deployments = nil - } - k.AzureKeyConfig = azureConfig } // Reconstruct Vertex config if fields are present - if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil || (k.VertexDeploymentsJSON != nil && *k.VertexDeploymentsJSON != "") { + if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil { config := &schemas.VertexKeyConfig{} if k.VertexProjectID != nil { @@ -592,20 +562,10 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if k.VertexAuthCredentials != nil { config.AuthCredentials = *k.VertexAuthCredentials } - if k.VertexDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.VertexDeploymentsJSON), &deployments); err != nil { - return err - } - config.Deployments = deployments - } else { - config.Deployments = nil - } - k.VertexKeyConfig = config } // Reconstruct Bedrock config if fields are present - if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockDeploymentsJSON != nil && *k.BedrockDeploymentsJSON != "") || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") { + if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") { bedrockConfig := &schemas.BedrockKeyConfig{} if k.BedrockAccessKey != nil { @@ -623,16 +583,6 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { bedrockConfig.SecretKey = *k.BedrockSecretKey } - if k.BedrockDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.BedrockDeploymentsJSON), &deployments); err != nil { - return err - } - bedrockConfig.Deployments = deployments - } else { - bedrockConfig.Deployments = nil - } - if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" { var batchS3Config schemas.BatchS3Config if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil { @@ -643,15 +593,15 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { k.BedrockKeyConfig = bedrockConfig } - // Reconstruct Replicate config if fields are present - if k.ReplicateDeploymentsJSON != nil && *k.ReplicateDeploymentsJSON != "" { - replicateConfig := &schemas.ReplicateKeyConfig{} - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.ReplicateDeploymentsJSON), &deployments); err != nil { + // Reconstruct Aliases + if k.AliasesJSON != nil && *k.AliasesJSON != "" { + var aliases schemas.KeyAliases + if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil { return err } - replicateConfig.Deployments = deployments - k.ReplicateKeyConfig = replicateConfig + k.Aliases = aliases + } else { + k.Aliases = nil } // Reconstruct VLLM config if fields are present if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") { @@ -666,6 +616,14 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { } else { k.VLLMKeyConfig = nil } + // Reconstruct Replicate config if fields are present + if k.ReplicateUseDeploymentsEndpoint != nil { + k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{ + UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint, + } + } else { + k.ReplicateKeyConfig = nil + } // Reconstruct Ollama config if fields are present if k.OllamaUrl != nil { k.OllamaKeyConfig = &schemas.OllamaKeyConfig{ diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index af6b9f2b7e..d7b1c92a6a 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -122,7 +122,6 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { key.AzureClientSecret = nil key.AzureTenantID = nil key.AzureScopesJSON = nil - key.AzureDeploymentsJSON = nil key.AzureKeyConfig = nil // Clear all Vertex-related sensitive fields @@ -141,13 +140,8 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { key.BedrockRoleARN = nil key.BedrockExternalID = nil key.BedrockRoleSessionName = nil - key.BedrockDeploymentsJSON = nil key.BedrockKeyConfig = nil - // Clear all Replicate-related sensitive fields - key.ReplicateDeploymentsJSON = nil - key.ReplicateKeyConfig = nil - pc.Keys[i] = *key } } diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index 399f172f48..65ea54e9ac 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -212,6 +212,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddPluginLogsColumn(ctx, db); err != nil { return err } + if err := migrationAddAliasColumn(ctx, db); err != nil { + return err + } return nil } @@ -2030,6 +2033,11 @@ var performanceIndexes = []performanceIndexDef{ name: "idx_logs_ts_provider_status", sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_ts_provider_status ON logs(timestamp, provider, status)", }, + { + table: "logs", + name: "idx_logs_alias", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_alias ON logs(alias)", + }, } // ensurePerformanceIndexes checks whether each performance GIN index exists and is @@ -2176,3 +2184,41 @@ func migrationAddPluginLogsColumn(ctx context.Context, db *gorm.DB) error { } return nil } + +// migrationAddAliasColumn adds the alias column to the logs table. +// The alias field stores the original model name the caller used when routing resolved it to a different model via alias mapping. +// Index creation is deferred to ensurePerformanceIndexes (called post-startup in a background goroutine) +// because CREATE INDEX CONCURRENTLY cannot run inside a transaction and a regular CREATE INDEX +// takes a SHARE lock that blocks writes on large tables during rolling deploys. +func migrationAddAliasColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_alias_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if !mig.HasColumn(&Log{}, "alias") { + if err := mig.AddColumn(&Log{}, "alias"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if mig.HasColumn(&Log{}, "alias") { + if err := mig.DropColumn(&Log{}, "alias"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding alias column: %s", err.Error()) + } + return nil +} diff --git a/framework/logstore/rdb.go b/framework/logstore/rdb.go index 3e53adfdd9..39bea8720c 100644 --- a/framework/logstore/rdb.go +++ b/framework/logstore/rdb.go @@ -79,6 +79,9 @@ func (s *RDBLogStore) applyFilters(baseQuery *gorm.DB, filters SearchFilters) *g if len(filters.Models) > 0 { baseQuery = baseQuery.Where("model IN ?", filters.Models) } + if len(filters.Aliases) > 0 { + baseQuery = baseQuery.Where("alias IN ?", filters.Aliases) + } if len(filters.Status) > 0 { baseQuery = baseQuery.Where("status IN ?", filters.Status) } @@ -446,7 +449,7 @@ func (s *RDBLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pag // last element from input_history and responses_input_history arrays. func (s *RDBLogStore) listSelectColumns() string { baseCols := strings.Join([]string{ - "id", "parent_request_id", "timestamp", "object_type", "provider", "model", + "id", "parent_request_id", "timestamp", "object_type", "provider", "model", "alias", "number_of_retries", "fallback_index", "selected_key_id", "selected_key_name", "virtual_key_id", "virtual_key_name", @@ -2018,6 +2021,20 @@ func (s *RDBLogStore) GetDistinctModels(ctx context.Context) ([]string, error) { return models, nil } +// GetDistinctAliases returns all unique non-empty alias values using SELECT DISTINCT. +// Scoped to recent data to avoid full table scans. +func (s *RDBLogStore) GetDistinctAliases(ctx context.Context) ([]string, error) { + cutoff := time.Now().UTC().AddDate(0, 0, -defaultFilterDataCutoffDays) + var aliases []string + err := s.db.WithContext(ctx).Model(&Log{}). + Where("alias IS NOT NULL AND alias != '' AND timestamp >= ?", cutoff). + Distinct("alias").Limit(defaultFilterDataLimit).Pluck("alias", &aliases).Error + if err != nil { + return nil, fmt.Errorf("failed to get distinct aliases: %w", err) + } + return aliases, nil +} + // allowedKeyPairColumns is a whitelist of column names that can be used in GetDistinctKeyPairs // to prevent SQL injection from interpolated column names. var allowedKeyPairColumns = map[string]struct{}{ diff --git a/framework/logstore/store.go b/framework/logstore/store.go index 5942d93db2..27a2d12f7d 100644 --- a/framework/logstore/store.go +++ b/framework/logstore/store.go @@ -50,6 +50,7 @@ type LogStore interface { // Distinct value methods for filter data GetDistinctModels(ctx context.Context) ([]string, error) + GetDistinctAliases(ctx context.Context) ([]string, error) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error) GetDistinctRoutingEngines(ctx context.Context) ([]string, error) GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error) diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index 53857ffb6d..edf0f879d0 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -31,6 +31,7 @@ const ( type SearchFilters struct { Providers []string `json:"providers,omitempty"` Models []string `json:"models,omitempty"` + Aliases []string `json:"aliases,omitempty"` Status []string `json:"status,omitempty"` Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` @@ -84,6 +85,7 @@ type Log struct { Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` Model string `gorm:"type:varchar(255);index;not null" json:"model"` + Alias *string `gorm:"type:varchar(255);index" json:"alias,omitempty"` // Set when model was resolved via alias mapping; the original name the caller used NumberOfRetries int `gorm:"default:0" json:"number_of_retries"` FallbackIndex int `gorm:"default:0" json:"fallback_index"` SelectedKeyID string `gorm:"type:varchar(255);index:idx_logs_selected_key_id" json:"selected_key_id"` diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index b9d7525f2f..535aed226d 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -87,8 +87,8 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope } provider := string(extraFields.Provider) - model := extraFields.ModelRequested - deployment := extraFields.ModelDeployment + originalModelRequested := extraFields.OriginalModelRequested + resolvedModelUsed := extraFields.ResolvedModelUsed requestType := extraFields.RequestType // Extract usage data from the response @@ -108,7 +108,7 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope requestType = normalizeStreamRequestType(requestType) // Resolve pricing entry with deployment fallback - pricing := mc.resolvePricing(provider, model, deployment, requestType, scopes) + pricing := mc.resolvePricing(provider, originalModelRequested, resolvedModelUsed, requestType, scopes) if pricing == nil { return 0 } @@ -759,37 +759,38 @@ func populateOutputImageCount(imageUsage *schemas.ImageUsage, dataLen int) { // --------------------------------------------------------------------------- // resolvePricing resolves the pricing entry for a model, trying deployment as fallback. -func (mc *ModelCatalog) resolvePricing(provider, model, deployment string, requestType schemas.RequestType, scopes PricingLookupScopes) *configstoreTables.TableModelPricing { - mc.logger.Debug("looking up pricing for model %s and provider %s of request type %s", model, provider, normalizeRequestType(requestType)) +func (mc *ModelCatalog) resolvePricing(provider, originalModelRequested, resolvedModelUsed string, requestType schemas.RequestType, scopes PricingLookupScopes) *configstoreTables.TableModelPricing { + if resolvedModelUsed == "" { + resolvedModelUsed = originalModelRequested + } + mc.logger.Debug("looking up pricing for resolved model %s and provider %s of request type %s", resolvedModelUsed, provider, normalizeRequestType(requestType)) if scopes.Provider == "" { scopes.Provider = provider } - base, exists := mc.getBasePricing(model, provider, requestType) + base, exists := mc.getBasePricing(resolvedModelUsed, provider, requestType) if exists && base != nil { - result, _ := mc.applyPricingOverrides(model, requestType, *base, scopes) + result, _ := mc.applyPricingOverrides(resolvedModelUsed, requestType, *base, scopes) return &result } - if deployment != "" { - mc.logger.Debug("pricing not found for model %s, trying deployment %s", model, deployment) - base, exists = mc.getBasePricing(deployment, provider, requestType) - if exists && base != nil { - // Apply overrides using the requested model name, not the deployment name - result, _ := mc.applyPricingOverrides(model, requestType, *base, scopes) - return &result - } + mc.logger.Debug("pricing not found for resolved model %s, trying alias %s", resolvedModelUsed, originalModelRequested) + base, exists = mc.getBasePricing(originalModelRequested, provider, requestType) + if exists && base != nil { + // Apply overrides using the resolved model name, not the alias + result, _ := mc.applyPricingOverrides(resolvedModelUsed, requestType, *base, scopes) + return &result } // No base catalog entry found; still try overrides in case the user defined // override-only pricing for a model not in the built-in catalog. - mc.logger.Debug("pricing not found for model %s and provider %s, trying override-only pricing", model, provider) - result, applied := mc.applyPricingOverrides(model, requestType, configstoreTables.TableModelPricing{}, scopes) + mc.logger.Debug("pricing not found for resolved model %s and provider %s, trying override-only pricing", resolvedModelUsed, provider) + result, applied := mc.applyPricingOverrides(resolvedModelUsed, requestType, configstoreTables.TableModelPricing{}, scopes) if applied { return &result } - mc.logger.Debug("no pricing found for model %s and provider %s, skipping cost calculation", model, provider) + mc.logger.Debug("no pricing found for resolved model %s and provider %s, skipping cost calculation", resolvedModelUsed, provider) return nil } diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index d69301b4a7..d273f32cfb 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -41,9 +41,9 @@ func makeChatResponse(provider schemas.ModelProvider, model string, usage *schem ChatResponse: &schemas.BifrostChatResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.ChatCompletionRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -55,9 +55,9 @@ func makeEmbeddingResponse(provider schemas.ModelProvider, model string, usage * EmbeddingResponse: &schemas.BifrostEmbeddingResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.EmbeddingRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -69,9 +69,9 @@ func makeRerankResponse(provider schemas.ModelProvider, model string, usage *sch RerankResponse: &schemas.BifrostRerankResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.RerankRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.RerankRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -83,9 +83,9 @@ func makeImageResponse(provider schemas.ModelProvider, model string, usage *sche ImageGenerationResponse: &schemas.BifrostImageGenerationResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.ImageGenerationRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -847,9 +847,9 @@ func TestCalculateCost_SemanticCacheDirectHit(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: true, HitType: &hitType, @@ -883,9 +883,9 @@ func TestCalculateCost_SemanticCacheSemanticHit(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: true, HitType: &hitType, @@ -922,9 +922,9 @@ func TestCalculateCost_SemanticCacheMiss(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 1000, CompletionTokens: 500, TotalTokens: 1500}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: false, ProviderUsed: &embProvider, @@ -1117,9 +1117,9 @@ func TestCalculateCost_StreamRequestTypeNormalized(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 1000, CompletionTokens: 500, TotalTokens: 1500}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } @@ -1474,9 +1474,9 @@ func TestCalculateCost_ImageGeneration_OutputCountFromData(t *testing.T) { {URL: "https://example.com/img3.png", Index: 2}, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: "openai", - ModelRequested: "dall-e-3", + RequestType: schemas.ImageGenerationRequest, + Provider: "openai", + OriginalModelRequested: "dall-e-3", }, }, } diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 29c88542a6..69bae551d4 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -414,5 +414,3 @@ func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[str mc.logger.Debug("model-parameters-sync: successfully downloaded and parsed %d model parameters records", len(paramsData)) return paramsData, nil } - - diff --git a/framework/streaming/accumulator_test.go b/framework/streaming/accumulator_test.go index f7df86b565..18eb43f71b 100644 --- a/framework/streaming/accumulator_test.go +++ b/framework/streaming/accumulator_test.go @@ -64,10 +64,10 @@ func TestChatStreamingFinalChunkNoDeadlock(t *testing.T) { TotalTokens: 150, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: schemas.Anthropic, - ModelRequested: "claude-opus-4", - ChunkIndex: 9, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.Anthropic, + OriginalModelRequested: "claude-opus-4", + ChunkIndex: 9, }, }, } @@ -140,10 +140,10 @@ func TestResponsesStreamingFinalChunkNoDeadlock(t *testing.T) { OutputTokens: 50, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: schemas.Anthropic, - ModelRequested: "claude-opus-4", - ChunkIndex: 4, + RequestType: schemas.ResponsesStreamRequest, + Provider: schemas.Anthropic, + OriginalModelRequested: "claude-opus-4", + ChunkIndex: 4, }, }, } @@ -488,10 +488,10 @@ func TestAudioStreamingFinalChunkNoDeadlock(t *testing.T) { TotalTokens: 150, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "tts-1", - ChunkIndex: 7, + RequestType: schemas.SpeechStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "tts-1", + ChunkIndex: 7, }, }, } @@ -559,10 +559,10 @@ func TestTranscriptionStreamingFinalChunkNoDeadlock(t *testing.T) { TranscriptionResponse: &schemas.BifrostTranscriptionResponse{ Text: "Complete transcription", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "whisper-1", - ChunkIndex: 5, + RequestType: schemas.TranscriptionStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "whisper-1", + ChunkIndex: 5, }, }, } diff --git a/framework/streaming/audio.go b/framework/streaming/audio.go index 9cc2aa6924..0390ea5aaf 100644 --- a/framework/streaming/audio.go +++ b/framework/streaming/audio.go @@ -121,7 +121,7 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) // For audio, all the data comes in the final chunk chunk := a.getAudioStreamChunk() @@ -177,21 +177,23 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeAudio, - Model: model, - Provider: provider, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeAudio, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Provider: provider, + Data: data, + RawRequest: &rawRequest, }, nil } // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeAudio, - Model: model, - Provider: provider, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeAudio, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Provider: provider, + Data: nil, }, nil } diff --git a/framework/streaming/chat.go b/framework/streaming/chat.go index 1d87106913..6602e9a21a 100644 --- a/framework/streaming/chat.go +++ b/framework/streaming/chat.go @@ -464,7 +464,7 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - requestType, provider, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, provider, model, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) streamType := StreamTypeChat if requestType == schemas.TextCompletionStreamRequest { @@ -496,6 +496,9 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, chunk.TokenUsage = result.TextCompletionResponse.Usage } chunk.ChunkIndex = result.TextCompletionResponse.ExtraFields.ChunkIndex + if result.TextCompletionResponse.ExtraFields.RawResponse != nil { + chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TextCompletionResponse.ExtraFields.RawResponse)) + } if isFinalChunk { if a.pricingManager != nil { cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) @@ -561,7 +564,8 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, RequestID: requestID, StreamType: streamType, Provider: provider, - Model: model, + RequestedModel: model, + ResolvedModel: resolvedModel, Data: data, RawRequest: &rawRequest, }, nil @@ -569,10 +573,11 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: streamType, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: streamType, + Provider: provider, + RequestedModel: model, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/images.go b/framework/streaming/images.go index 446d1ca3b3..367b52c037 100644 --- a/framework/streaming/images.go +++ b/framework/streaming/images.go @@ -20,7 +20,7 @@ func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStr finalResponse := &schemas.BifrostImageGenerationResponse{ ID: chunks[i].Delta.ID, Created: chunks[i].Delta.CreatedAt, - Model: chunks[i].Delta.ExtraFields.ModelRequested, + Model: chunks[i].Delta.ExtraFields.OriginalModelRequested, Data: []schemas.ImageData{ { B64JSON: chunks[i].Delta.B64JSON, @@ -53,8 +53,8 @@ func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStr } // Extract metadata - if model == "" && chunk.Delta.ExtraFields.ModelRequested != "" { - model = chunk.Delta.ExtraFields.ModelRequested + if model == "" && chunk.Delta.ExtraFields.OriginalModelRequested != "" { + model = chunk.Delta.ExtraFields.OriginalModelRequested } // Store revised prompt if present (usually in first chunk) @@ -216,7 +216,7 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) chunk := a.getImageStreamChunk() @@ -310,12 +310,13 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, rawRequest = result.ImageGenerationStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeImage, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } @@ -325,10 +326,11 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeImage, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/responses.go b/framework/streaming/responses.go index d33d95e35e..56c461cbcd 100644 --- a/framework/streaming/responses.go +++ b/framework/streaming/responses.go @@ -890,7 +890,7 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) chunk := a.getResponsesStreamChunk() @@ -949,20 +949,22 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/transcription.go b/framework/streaming/transcription.go index 56fb3e477c..3367e25ad6 100644 --- a/framework/streaming/transcription.go +++ b/framework/streaming/transcription.go @@ -131,7 +131,7 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) // For audio, all the data comes in the final chunk chunk := a.getTranscriptionStreamChunk() @@ -194,21 +194,23 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeTranscription, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeTranscription, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeTranscription, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeTranscription, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/types.go b/framework/streaming/types.go index 9d7cf0183f..eb9d10e3ff 100644 --- a/framework/streaming/types.go +++ b/framework/streaming/types.go @@ -228,12 +228,13 @@ func (sa *StreamAccumulator) getLastAudioChunkLocked() *AudioStreamChunk { // ProcessedStreamResponse represents a processed streaming response type ProcessedStreamResponse struct { - RequestID string - StreamType StreamType - Provider schemas.ModelProvider - Model string - Data *AccumulatedData - RawRequest *interface{} + RequestID string + StreamType StreamType + Provider schemas.ModelProvider + RequestedModel string // original model requested by the caller + ResolvedModel string // actual model used by the provider (equals RequestedModel when no alias mapping exists) + Data *AccumulatedData + RawRequest *interface{} } // ToBifrostResponse converts a ProcessedStreamResponse to a BifrostResponse @@ -253,7 +254,7 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { textResp := &schemas.BifrostTextCompletionResponse{ ID: p.RequestID, Object: "text_completion", - Model: p.Model, + Model: p.RequestedModel, Choices: []schemas.BifrostResponseChoice{ { Index: 0, @@ -269,10 +270,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { resp.TextCompletionResponse = textResp resp.TextCompletionResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.TextCompletionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.TextCompletionResponse.ExtraFields.RawRequest = p.RawRequest @@ -297,7 +299,7 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { chatResp := &schemas.BifrostChatResponse{ ID: p.RequestID, Object: "chat.completion", - Model: p.Model, + Model: p.RequestedModel, Created: int(p.Data.StartTimestamp.Unix()), Choices: []schemas.BifrostResponseChoice{ { @@ -314,10 +316,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { resp.ChatResponse = chatResp resp.ChatResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ChatCompletionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.ChatResponse.ExtraFields.RawRequest = p.RawRequest @@ -338,10 +341,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { responsesResp.Usage = p.Data.TokenUsage.ToResponsesResponseUsage() } responsesResp.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ResponsesRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { responsesResp.ExtraFields.RawRequest = p.RawRequest @@ -360,10 +364,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.SpeechResponse = speechResp resp.SpeechResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.SpeechRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.SpeechResponse.ExtraFields.RawRequest = p.RawRequest @@ -381,14 +386,21 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.TranscriptionResponse = transcriptionResp resp.TranscriptionResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.TranscriptionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.TranscriptionResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.TranscriptionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } case StreamTypeImage: imageResp := p.Data.ImageGenerationOutput if imageResp == nil { @@ -398,8 +410,8 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RequestID != "" { imageResp.ID = p.RequestID } - if p.Model != "" { - imageResp.Model = p.Model + if p.RequestedModel != "" { + imageResp.Model = p.RequestedModel } } // Ensure Data is never nil to serialize as [] instead of null @@ -408,10 +420,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.ImageGenerationResponse = imageResp resp.ImageGenerationResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ImageGenerationRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.ImageGenerationResponse.ExtraFields.RawRequest = p.RawRequest diff --git a/framework/tracing/tracer.go b/framework/tracing/tracer.go index b69ecc981e..c44f513428 100644 --- a/framework/tracing/tracer.go +++ b/framework/tracing/tracer.go @@ -306,9 +306,10 @@ func (t *Tracer) ProcessStreamingChunk(traceID string, isFinalChunk bool, result // Convert ProcessedStreamResponse to StreamAccumulatorResult accResult := &schemas.StreamAccumulatorResult{ - RequestID: processedResp.RequestID, - Model: processedResp.Model, - Provider: processedResp.Provider, + RequestID: processedResp.RequestID, + RequestedModel: processedResp.RequestedModel, + ResolvedModel: processedResp.ResolvedModel, + Provider: processedResp.Provider, } if processedResp.Data != nil { diff --git a/plugins/governance/main.go b/plugins/governance/main.go index f8ba8d8b95..dbaf9abc75 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -28,9 +28,7 @@ import ( const PluginName = "governance" const ( - governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected" - governanceIsCacheReadContextKey schemas.BifrostContextKey = "bf-governance-is-cache-read" - governanceIsBatchContextKey schemas.BifrostContextKey = "bf-governance-is-batch" + governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected" VirtualKeyPrefix = "sk-bf-" ) @@ -1237,7 +1235,7 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche } // Extract request type, provider, and model - requestType, provider, model := bifrost.GetResponseFields(result, err) + requestType, provider, requestedModel, _ := bifrost.GetResponseFields(result, err) // Extract governance information virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) @@ -1245,20 +1243,6 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche // Extract user ID for enterprise user-level governance userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceUserID) - // Extract cache and batch flags from context - isCacheRead := false - isBatch := false - if val := ctx.Value(governanceIsCacheReadContextKey); val != nil { - if b, ok := val.(bool); ok { - isCacheRead = b - } - } - if val := ctx.Value(governanceIsBatchContextKey); val != nil { - if b, ok := val.(bool); ok { - isBatch = b - } - } - if requestType == schemas.ListModelsRequest && result != nil && result.ListModelsResponse != nil && virtualKey != "" { // filter models which are not supported on this virtual key result.ListModelsResponse.Data = p.filterModelsForVirtualKey(result.ListModelsResponse.Data, virtualKey) @@ -1277,11 +1261,12 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche } // If effectiveVK is empty, it will be passed as empty string to postHookWorker // The tracker will handle empty virtual keys gracefully by only updating provider-level and model-level usage - if model != "" { + if requestedModel != "" { p.wg.Add(1) go func() { defer p.wg.Done() - p.postHookWorker(result, provider, model, requestType, effectiveVK, requestID, userID, isCacheRead, isBatch, isFinalChunk, pricingScopes) + // Use the requested model for usage tracking + p.postHookWorker(result, provider, requestedModel, requestType, effectiveVK, requestID, userID, isFinalChunk, pricingScopes) }() } @@ -1463,7 +1448,7 @@ func (p *GovernancePlugin) Cleanup() error { // - isBatch: Whether the request is a batch request // - isFinalChunk: Whether the request is the final chunk // - pricingScopes: Prebuilt pricing lookup scopes using governance VK ID (nil if not applicable) -func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, _, _, isFinalChunk bool, pricingScopes *modelcatalog.PricingLookupScopes) { +func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, isFinalChunk bool, pricingScopes *modelcatalog.PricingLookupScopes) { // Determine if request was successful success := (result != nil) diff --git a/plugins/governance/model_provider_governance_test.go b/plugins/governance/model_provider_governance_test.go index 98565bbb4b..e363c15435 100644 --- a/plugins/governance/model_provider_governance_test.go +++ b/plugins/governance/model_provider_governance_test.go @@ -1704,9 +1704,9 @@ func TestPostHook_UpdatesProviderBudgetUsage_NoVirtualKey(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1773,9 +1773,9 @@ func TestPostHook_UpdatesProviderRateLimitUsage_NoVirtualKey(t *testing.T) { TotalTokens: 10000, // 10000 tokens used (exactly at limit) }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1840,9 +1840,9 @@ func TestPostHook_UpdatesModelBudgetUsage_NoVirtualKey(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1909,9 +1909,9 @@ func TestPostHook_UpdatesModelRateLimitUsage_NoVirtualKey(t *testing.T) { TotalTokens: 10000, // 10000 tokens used (exactly at limit) }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } diff --git a/plugins/governance/store.go b/plugins/governance/store.go index f9cc7dae0c..6158cc1302 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -2250,7 +2250,7 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c if rules, ok := value.([]*configstoreTables.TableRoutingRule); ok { for _, rule := range rules { if _, err := gs.GetRoutingProgram(rule); err != nil { - gs.logger.Warn("Failed to pre-compile routing program for rule %s: %v", rule.ID, err) + gs.logger.Warn("Failed to pre-compile routing program for rule %s: %v", rule.Name, err) } } } @@ -3731,7 +3731,7 @@ func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTable // Recompile the program immediately to update cache with fresh compilation if _, err := gs.GetRoutingProgram(rule); err != nil { - gs.logger.Warn("Failed to recompile routing program for rule %s: %v", rule.ID, err) + gs.logger.Warn("Failed to recompile routing program for rule %s: %v", rule.Name, err) } return nil diff --git a/plugins/litellmcompat/texttochat.go b/plugins/litellmcompat/texttochat.go index 2eb7446348..b0c1b0a309 100644 --- a/plugins/litellmcompat/texttochat.go +++ b/plugins/litellmcompat/texttochat.go @@ -83,7 +83,7 @@ func transformTextToChatResponse(_ *schemas.BifrostContext, resp *schemas.Bifros // Restore original request type metadata textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.ModelRequested = tc.OriginalModel + textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel textCompletionResponse.ExtraFields.LiteLLMCompat = true if logger != nil { @@ -110,7 +110,7 @@ func transformTextToChatError(_ *schemas.BifrostContext, err *schemas.BifrostErr // Restore original request type in error metadata err.ExtraFields.RequestType = tc.OriginalRequestType - err.ExtraFields.ModelRequested = tc.OriginalModel + err.ExtraFields.OriginalModelRequested = tc.OriginalModel err.ExtraFields.LiteLLMCompat = true return err @@ -141,7 +141,7 @@ func TransformTextToChatStreamResponse(ctx *schemas.BifrostContext, stream *sche // Restore original request type metadata textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.ModelRequested = tc.OriginalModel + textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel textCompletionResponse.ExtraFields.LiteLLMCompat = true // Return a new stream with the text completion response diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 7d0b06b7ad..886cf3083e 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -666,7 +666,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. routingRuleName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceRoutingRuleName) numberOfRetries := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries) - requestType, _, _ := bifrost.GetResponseFields(result, bifrostErr) + requestType, _, originalModelRequested, resolvedModelUsed := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) @@ -706,12 +706,12 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry := &logstore.Log{ ID: requestID, Provider: string(bifrostErr.ExtraFields.Provider), - Model: bifrostErr.ExtraFields.ModelRequested, Status: "error", Stream: bifrost.IsStreamRequestType(requestType), Timestamp: time.Now().UTC(), CreatedAt: time.Now().UTC(), } + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) if data, err := sonic.Marshal(bifrostErr); err == nil { entry.ErrorDetails = string(data) } @@ -742,6 +742,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Path A: Error with nil result if result == nil && bifrostErr != nil { entry.Status = "error" + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) if bifrost.IsStreamRequestType(requestType) { entry.Stream = true } @@ -786,6 +787,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. if bifrostErr != nil { entry.Status = "error" entry.Stream = true + applyModelAlias(entry, originalModelRequested, resolvedModelUsed) if data, err := sonic.Marshal(bifrostErr); err == nil { entry.ErrorDetails = string(data) } @@ -794,6 +796,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // tracer or traceID not available, or accumulator returned nil - still write what we have entry.Status = "success" entry.Stream = true + applyModelAlias(entry, originalModelRequested, resolvedModelUsed) } else if isFinalChunk { // Apply streaming output fields to the entry entry.Stream = true @@ -823,6 +826,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Path C: Non-streaming response if bifrostErr != nil { entry.Status = "error" + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) // Serialize error details immediately since bifrostErr may be released // back to the pool before the async batch writer processes this entry. // Also set ErrorDetailsParsed for UI callback (JSON serialization uses this field). @@ -832,6 +836,8 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry.ErrorDetailsParsed = bifrostErr } else if result != nil { entry.Status = "success" + extraFields := result.GetExtraFields() + applyModelAlias(entry, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed) p.applyNonStreamingOutputToEntry(entry, result) // Flip status for passthrough error responses (4xx/5xx from provider) if isPassthroughErrorResponse(result) { diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index f3f727bd56..1ec78b951a 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -404,10 +404,13 @@ func (p *LoggerPlugin) updateStreamingLogEntry( tempEntry := &logstore.Log{} updates["latency"] = float64(streamResponse.Data.Latency) - // Update model if provided - if streamResponse.Data.Model != "" { - updates["model"] = streamResponse.Data.Model + // Update model and alias from resolved/requested model pair. + tempEntry2 := &logstore.Log{} + applyModelAlias(tempEntry2, streamResponse.RequestedModel, streamResponse.ResolvedModel) + if tempEntry2.Model != "" { + updates["model"] = tempEntry2.Model } + updates["alias"] = tempEntry2.Alias needsSerialization := false @@ -544,10 +547,8 @@ func (p *LoggerPlugin) applyStreamingOutputToEntry(entry *logstore.Log, streamRe latF := float64(streamResponse.Data.Latency) entry.Latency = &latF - // Update model if provided - if streamResponse.Data.Model != "" { - entry.Model = streamResponse.Data.Model - } + // Update model and alias from resolved/requested model pair. + applyModelAlias(entry, streamResponse.RequestedModel, streamResponse.ResolvedModel) // Token usage if streamResponse.Data.TokenUsage != nil { @@ -802,6 +803,16 @@ func (p *LoggerPlugin) GetAvailableModels(ctx context.Context) []string { return models } +// GetAvailableAliases returns all unique alias values from logs. +func (p *LoggerPlugin) GetAvailableAliases(ctx context.Context) []string { + aliases, err := p.store.GetDistinctAliases(ctx) + if err != nil { + p.logger.Error("failed to get available aliases: %v", err) + return []string{} + } + return aliases +} + func (p *LoggerPlugin) GetAvailableSelectedKeys(ctx context.Context) []KeyPair { results, err := p.store.GetDistinctKeyPairs(ctx, "selected_key_id", "selected_key_name") if err != nil { @@ -978,11 +989,17 @@ func (p *LoggerPlugin) calculateCostForLog(logEntry *logstore.Log) (float64, err // Build a minimal BifrostResponse matching the request type so that // extractCostInput routes usage into the correct field for each compute function. + originalModelRequested := logEntry.Model + if logEntry.Alias != nil && *logEntry.Alias != "" { + originalModelRequested = *logEntry.Alias + } + extraFields := schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: schemas.ModelProvider(logEntry.Provider), - ModelRequested: logEntry.Model, - CacheDebug: cacheDebug, + RequestType: requestType, + Provider: schemas.ModelProvider(logEntry.Provider), + OriginalModelRequested: originalModelRequested, + ResolvedModelUsed: logEntry.Model, + CacheDebug: cacheDebug, } resp := buildResponseForRequestType(requestType, usage, extraFields) diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 80bc953e99..7b64d944bb 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -63,6 +63,9 @@ type LogManager interface { // GetAvailableModels returns all unique models from logs GetAvailableModels(ctx context.Context) []string + // GetAvailableAliases returns all unique alias values from logs + GetAvailableAliases(ctx context.Context) []string + // GetAvailableSelectedKeys returns all unique selected key ID-Name pairs from logs GetAvailableSelectedKeys(ctx context.Context) []KeyPair @@ -211,6 +214,11 @@ func (p *PluginLogManager) GetAvailableModels(ctx context.Context) []string { return p.plugin.GetAvailableModels(ctx) } +// GetAvailableAliases returns all unique alias values from logs +func (p *PluginLogManager) GetAvailableAliases(ctx context.Context) []string { + return p.plugin.GetAvailableAliases(ctx) +} + // GetAvailableSelectedKeys returns all unique selected key ID-Name pairs from logs func (p *PluginLogManager) GetAvailableSelectedKeys(ctx context.Context) []KeyPair { return p.plugin.GetAvailableSelectedKeys(ctx) @@ -480,7 +488,7 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r // Build accumulated data data := &streaming.AccumulatedData{ RequestID: result.RequestID, - Model: result.Model, + Model: result.RequestedModel, Status: result.Status, Stream: true, Latency: result.Latency, @@ -503,11 +511,12 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r } resp := &streaming.ProcessedStreamResponse{ - RequestID: result.RequestID, - StreamType: streamType, - Provider: result.Provider, - Model: result.Model, - Data: data, + RequestID: result.RequestID, + StreamType: streamType, + Provider: result.Provider, + RequestedModel: result.RequestedModel, + ResolvedModel: result.ResolvedModel, + Data: data, } if result.RawRequest != nil { diff --git a/plugins/logging/writer.go b/plugins/logging/writer.go index 224e3ce8d8..24f4125601 100644 --- a/plugins/logging/writer.go +++ b/plugins/logging/writer.go @@ -328,6 +328,23 @@ func buildCompleteLogEntryFromPending(pending *PendingLogData) *logstore.Log { return entry } +// applyModelAlias sets entry.Model to resolvedModel (falling back to requestedModel if empty) +// and entry.Alias to requestedModel when the two differ (i.e. an alias mapping was applied). +func applyModelAlias(entry *logstore.Log, requestedModel, resolvedModel string) { + if resolvedModel != "" && resolvedModel != requestedModel { + entry.Model = resolvedModel + entry.Alias = &requestedModel + } else { + // No alias mapping; keep whichever value is non-empty as the model. + if resolvedModel != "" { + entry.Model = resolvedModel + } else if requestedModel != "" { + entry.Model = requestedModel + } + entry.Alias = nil + } +} + // applyOutputFieldsToEntry sets common output fields on a log entry. func applyOutputFieldsToEntry( entry *logstore.Log, diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index b1fa286615..0e85d79bc4 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -117,10 +117,11 @@ func convertAccResultToProcessedStreamResponse(accResult *schemas.StreamAccumula streamType = streaming.StreamTypeResponses } return &streaming.ProcessedStreamResponse{ - RequestID: accResult.RequestID, - StreamType: streamType, - Model: accResult.Model, - Provider: accResult.Provider, + RequestID: accResult.RequestID, + StreamType: streamType, + RequestedModel: accResult.RequestedModel, + ResolvedModel: accResult.ResolvedModel, + Provider: accResult.Provider, Data: &streaming.AccumulatedData{ Status: accResult.Status, Latency: accResult.Latency, @@ -510,7 +511,7 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B isFinalChunk := bifrost.IsFinalChunk(ctx) go func() { - requestType, _, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, _, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) var streamResponse *streaming.ProcessedStreamResponse if bifrost.IsStreamRequestType(requestType) { @@ -610,8 +611,12 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B } } } - logger.AddTagToGeneration(generationID, "model", string(model)) - logger.AddTagToTrace(traceID, "model", string(model)) + logger.AddTagToGeneration(generationID, "model", string(resolvedModel)) + logger.AddTagToTrace(traceID, "model", string(resolvedModel)) + if originalModel != "" && originalModel != resolvedModel { + logger.AddTagToGeneration(generationID, "alias", originalModel) + logger.AddTagToTrace(traceID, "alias", originalModel) + } // Flush only the effective logger that was used for this request logger.Flush() }() diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index 29189c35cb..d9ccb765f4 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -853,7 +853,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: req.RequestType, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -877,7 +877,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -905,7 +905,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -959,7 +959,7 @@ func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, re ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, }, } @@ -1083,7 +1083,7 @@ func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*sche ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ChatCompletionRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, }, }, }, diff --git a/plugins/semanticcache/plugin_cache_type_test.go b/plugins/semanticcache/plugin_cache_type_test.go index b97a09715f..13979df38f 100644 --- a/plugins/semanticcache/plugin_cache_type_test.go +++ b/plugins/semanticcache/plugin_cache_type_test.go @@ -390,9 +390,10 @@ func TestDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t *testing }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-5.2", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-5.2", + ResolvedModelUsed: "gpt-5.2", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -417,8 +418,11 @@ func TestDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t *testing if extraFields.Provider != schemas.OpenAI { t.Fatalf("expected cached provider %q, got %q", schemas.OpenAI, extraFields.Provider) } - if extraFields.ModelRequested != "gpt-5.2" { - t.Fatalf("expected cached model_requested %q, got %q", "gpt-5.2", extraFields.ModelRequested) + if extraFields.OriginalModelRequested != "gpt-5.2" { + t.Fatalf("expected OriginalModelRequested %q, got %q", "gpt-5.2", extraFields.OriginalModelRequested) + } + if extraFields.ResolvedModelUsed != "gpt-5.2" { + t.Fatalf("expected ResolvedModelUsed %q, got %q", "gpt-5.2", extraFields.ResolvedModelUsed) } if extraFields.CacheDebug == nil { t.Fatal("expected cache_debug on cache hit") @@ -491,10 +495,11 @@ func TestStreamingDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-5.2", - RequestType: schemas.ChatCompletionStreamRequest, - ChunkIndex: chunk.chunkIndex, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-5.2", + ResolvedModelUsed: "gpt-5.2", + RequestType: schemas.ChatCompletionStreamRequest, + ChunkIndex: chunk.chunkIndex, }, }, } @@ -526,8 +531,11 @@ func TestStreamingDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t if extraFields.Provider != schemas.OpenAI { t.Fatalf("expected cached provider %q on chunk %d, got %q", schemas.OpenAI, chunkCount, extraFields.Provider) } - if extraFields.ModelRequested != "gpt-5.2" { - t.Fatalf("expected cached model_requested %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.ModelRequested) + if extraFields.OriginalModelRequested != "gpt-5.2" { + t.Fatalf("expected OriginalModelRequested %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.OriginalModelRequested) + } + if extraFields.ResolvedModelUsed != "gpt-5.2" { + t.Fatalf("expected ResolvedModelUsed %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.ResolvedModelUsed) } if chunkCount == len(chunks)-1 { if extraFields.CacheDebug == nil || !extraFields.CacheDebug.CacheHit { diff --git a/plugins/semanticcache/plugin_integration_test.go b/plugins/semanticcache/plugin_integration_test.go index 92e3cd16e2..58ab9d04c3 100644 --- a/plugins/semanticcache/plugin_integration_test.go +++ b/plugins/semanticcache/plugin_integration_test.go @@ -18,7 +18,7 @@ func TestSemanticCacheBasicFlow(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") - + // Test request request := &schemas.BifrostRequest{ RequestType: schemas.ChatCompletionRequest, @@ -75,9 +75,9 @@ func TestSemanticCacheBasicFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -213,9 +213,9 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -309,7 +309,7 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") request := &schemas.BifrostRequest{ @@ -375,10 +375,10 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionStreamRequest, - ChunkIndex: i, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionStreamRequest, + ChunkIndex: i, }, }, } @@ -524,9 +524,9 @@ func TestSemanticCache_CustomTTLHandling(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -547,7 +547,7 @@ func TestSemanticCache_CustomThresholdHandling(t *testing.T) { defer setup.Cleanup() // Configure plugin with custom threshold key - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") ctx.SetValue(CacheThresholdKey, 0.95) // Very high threshold @@ -635,9 +635,9 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } diff --git a/plugins/semanticcache/plugin_vectorstore_test.go b/plugins/semanticcache/plugin_vectorstore_test.go index d0b9ef78d0..5e390bbe80 100644 --- a/plugins/semanticcache/plugin_vectorstore_test.go +++ b/plugins/semanticcache/plugin_vectorstore_test.go @@ -132,9 +132,9 @@ func TestSemanticCache_AllVectorStores_BasicFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -331,9 +331,9 @@ func TestSemanticCache_AllVectorStores_ParameterFiltering(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index 755a8e2f44..7025b902e7 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -136,6 +136,7 @@ func Init(config *Config, pricingManager *modelcatalog.ModelCatalog, logger sche defaultBifrostLabels := []string{ "provider", "model", + "alias", "method", "virtual_key_id", "virtual_key_name", @@ -359,7 +360,17 @@ func (p *PrometheusPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas. // - Request latency // - Total request count func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - requestType, provider, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, provider, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) + + // Determine effective model label and alias label (mirrors applyModelAlias logic in logging) + model := originalModel + alias := "" + if resolvedModel != "" { + model = resolvedModel + if resolvedModel != originalModel { + alias = originalModel + } + } startTime, ok := ctx.Value(startTimeKey).(time.Time) if !ok { @@ -393,6 +404,7 @@ func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche labelValues := map[string]string{ "provider": string(provider), "model": model, + "alias": alias, "method": string(requestType), "virtual_key_id": virtualKeyID, "virtual_key_name": virtualKeyName, diff --git a/tests/governance/e2e_test.go b/tests/governance/e2e_test.go index 381a83ade1..bab26fff30 100644 --- a/tests/governance/e2e_test.go +++ b/tests/governance/e2e_test.go @@ -1426,7 +1426,7 @@ func TestWeightedProviderLoadBalancing(t *testing.T) { // Try to detect which provider was used // Check if model in response contains provider name if provider, ok := resp.Body["extra_fields"].(map[string]interface{})["provider"].(string); ok { - model, ok := resp.Body["extra_fields"].(map[string]interface{})["model_requested"].(string) + model, ok := resp.Body["extra_fields"].(map[string]interface{})["original_model_requested"].(string) if !ok { t.Logf("Request %d failed to get model requested", i+1) continue diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 733403bcf9..f5acfe8c73 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -741,9 +741,9 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { for i, modelEntry := range resp.Data { provider, modelName := schemas.ParseModelString(modelEntry.ID, "") pricingEntry := h.config.ModelCatalog.GetPricingEntryForModel(modelName, provider) - if pricingEntry == nil && modelEntry.Deployment != nil { - // Retry with deployment - pricingEntry = h.config.ModelCatalog.GetPricingEntryForModel(*modelEntry.Deployment, provider) + if pricingEntry == nil && modelEntry.Alias != nil { + // Retry with alias + pricingEntry = h.config.ModelCatalog.GetPricingEntryForModel(*modelEntry.Alias, provider) } if pricingEntry != nil && modelEntry.Pricing == nil { pricing := &schemas.Pricing{} diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go index 7d48f0945a..852a0695b0 100644 --- a/transports/bifrost-http/handlers/logging.go +++ b/transports/bifrost-http/handlers/logging.go @@ -94,6 +94,9 @@ func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } @@ -305,6 +308,9 @@ func (h *LoggingHandler) getLogsStats(ctx *fasthttp.RequestCtx) { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } @@ -434,6 +440,9 @@ func parseHistogramFilters(ctx *fasthttp.RequestCtx) *logstore.SearchFilters { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } @@ -636,6 +645,7 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { var ( models []string + aliases []string selectedKeys []logging.KeyPair virtualKeys []logging.KeyPair routingRules []logging.KeyPair @@ -653,6 +663,13 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { mu.Unlock() return nil }) + g.Go(func() error { + result := h.logManager.GetAvailableAliases(gCtx) + mu.Lock() + aliases = result + mu.Unlock() + return nil + }) g.Go(func() error { result := h.logManager.GetAvailableSelectedKeys(gCtx) mu.Lock() @@ -780,7 +797,7 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { if metadataKeys == nil { metadataKeys = make(map[string][]string) } - SendJSON(ctx, map[string]interface{}{"models": models, "selected_keys": selectedKeysArray, "virtual_keys": virtualKeysArray, "routing_rules": routingRulesArray, "routing_engines": routingEngines, "metadata_keys": metadataKeys}) + SendJSON(ctx, map[string]interface{}{"models": models, "aliases": aliases, "selected_keys": selectedKeysArray, "virtual_keys": virtualKeysArray, "routing_rules": routingRulesArray, "routing_engines": routingEngines, "metadata_keys": metadataKeys}) } // deleteLogs handles DELETE /api/logs - Delete logs by their IDs diff --git a/transports/bifrost-http/handlers/provider_keys.go b/transports/bifrost-http/handlers/provider_keys.go index c287c68f52..efd23d0bfb 100644 --- a/transports/bifrost-http/handlers/provider_keys.go +++ b/transports/bifrost-http/handlers/provider_keys.go @@ -113,6 +113,11 @@ func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) { return } + if err := key.Aliases.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err)) + return + } + if key.ID == "" { key.ID = uuid.NewString() } @@ -219,6 +224,11 @@ func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) { return } + if err := mergedKey.Aliases.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err)) + return + } + if err := validateProviderKeyURL(provider, mergedKey); err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return @@ -425,6 +435,11 @@ func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey s } } + // ReplicateKeyConfig has no sensitive fields — pass through as-is + if updateKey.ReplicateKeyConfig == nil && oldRawKey.ReplicateKeyConfig != nil { + mergedKey.ReplicateKeyConfig = oldRawKey.ReplicateKeyConfig + } + if updateKey.OllamaKeyConfig != nil && oldRedactedKey.OllamaKeyConfig != nil && oldRawKey.OllamaKeyConfig != nil { if updateKey.OllamaKeyConfig.URL.IsRedacted() && updateKey.OllamaKeyConfig.URL.Equals(&oldRedactedKey.OllamaKeyConfig.URL) { diff --git a/transports/bifrost-http/handlers/wsresponses.go b/transports/bifrost-http/handlers/wsresponses.go index a1608e9722..18a0377f41 100644 --- a/transports/bifrost-http/handlers/wsresponses.go +++ b/transports/bifrost-http/handlers/wsresponses.go @@ -378,7 +378,7 @@ func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model str } streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest streamResp.ExtraFields.Provider = provider - streamResp.ExtraFields.ModelRequested = model + streamResp.ExtraFields.OriginalModelRequested = model return &streamResp } diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index 25033b4928..7d3fdd9b22 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -43,7 +43,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { return nil, errors.New("invalid request type") }, TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { - if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -85,7 +85,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger return nil, errors.New("invalid request type") }, ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { - if isClaudeModel(resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment, string(resp.ExtraFields.Provider)) { + if isClaudeModel(resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed, string(resp.ExtraFields.Provider)) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -113,7 +113,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger }, StreamConfig: &StreamConfig{ ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { - if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed) { if resp.ExtraFields.RawResponse != nil { raw, ok := resp.ExtraFields.RawResponse.(string) if !ok { @@ -396,15 +396,15 @@ func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.Bif } // shouldUsePassthrough checks if the request should be sent to the passthrough endpoint. -func shouldUsePassthrough(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, deployment string) bool { - return anthropic.IsClaudeCodeRequest(ctx) && isClaudeModel(model, deployment, string(provider)) +func shouldUsePassthrough(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, alias string) bool { + return anthropic.IsClaudeCodeRequest(ctx) && isClaudeModel(model, alias, string(provider)) } -func isClaudeModel(model, deployment, provider string) bool { +func isClaudeModel(model, alias, provider string) bool { return (provider == string(schemas.Anthropic) || - (provider == "" && schemas.IsAnthropicModel(model))) || - (provider == string(schemas.Vertex) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) || - (provider == string(schemas.Azure) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) + (provider == "" && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias)))) || + (provider == string(schemas.Vertex) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias))) || + (provider == string(schemas.Azure) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias))) } // extractAnthropicListModelsParams extracts query parameters for list models request diff --git a/transports/bifrost-http/integrations/cursor.go b/transports/bifrost-http/integrations/cursor.go index a4ad12bc33..29513c1c05 100644 --- a/transports/bifrost-http/integrations/cursor.go +++ b/transports/bifrost-http/integrations/cursor.go @@ -104,10 +104,10 @@ func cursorChunkID(extras *schemas.BifrostResponseExtraFields) string { // cursorModel returns the best model name available from extra fields. func cursorModel(extras *schemas.BifrostResponseExtraFields) string { - if extras.ModelDeployment != "" { - return extras.ModelDeployment + if extras.ResolvedModelUsed != "" { + return extras.ResolvedModelUsed } - return extras.ModelRequested + return extras.OriginalModelRequested } // convertResponsesStreamToChatChunk maps a Responses API stream event to a diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index c4eaa27033..abe9adf0a2 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -286,9 +286,6 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ if deploymentEndpointStr != "" && deploymentIDStr != "" && azureKeyStr != "" { key.Value = *schemas.NewEnvVar(strings.TrimPrefix(azureKeyStr, "Bearer ")) key.AzureKeyConfig.Endpoint = *schemas.NewEnvVar(deploymentEndpointStr) - key.AzureKeyConfig.Deployments = map[string]string{ - deploymentIDStr: deploymentIDStr, - } } if apiVersionStr != "" { diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index a417a7f1d5..19c8f3b7db 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -715,6 +715,9 @@ func processProvider( if providerKeyInFile.ID == "" { providerCfgInFile.Keys[i].ID = uuid.NewString() } + if err := providerKeyInFile.Aliases.Validate(); err != nil { + return fmt.Errorf("invalid aliases for key %q in provider %s: %w", providerKeyInFile.Name, provider, err) + } } // Generate hash from config.json provider config fileProviderConfigHash, err := providerCfgInFile.GenerateConfigHash(string(provider)) @@ -798,6 +801,7 @@ func mergeProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []schema VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, OllamaKeyConfig: dbKey.OllamaKeyConfig, SGLKeyConfig: dbKey.SGLKeyConfig, @@ -878,6 +882,7 @@ func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []sc VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, OllamaKeyConfig: dbKey.OllamaKeyConfig, SGLKeyConfig: dbKey.SGLKeyConfig, diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 3b3bccc82b..fe1e3bebee 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -2259,9 +2259,8 @@ func TestGenerateKeyHash(t *testing.T) { Models: []string{"gpt-4", "gpt-3.5-turbo"}, Weight: 1.5, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://my-azure.openai.azure.com"), - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, - APIVersion: schemas.NewEnvVar(apiVersion), + Endpoint: *schemas.NewEnvVar("https://my-azure.openai.azure.com"), + APIVersion: schemas.NewEnvVar(apiVersion), }, } @@ -2282,12 +2281,30 @@ func TestGenerateKeyHash(t *testing.T) { Models: []string{"gpt-4", "gpt-3.5-turbo"}, Weight: 1.5, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Different endpoint - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, - APIVersion: schemas.NewEnvVar(apiVersion), + Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Different endpoint + APIVersion: schemas.NewEnvVar(apiVersion), }, } + // Aliases alone should produce different hash + keyWithAliases := schemas.Key{ + ID: "key-1", + Name: "test-key", + Value: *schemas.NewEnvVar("sk-123"), + Models: []string{"gpt-4", "gpt-3.5-turbo"}, + Weight: 1.5, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, + } + + hashWithAliases, err := configstore.GenerateKeyHash(keyWithAliases) + if err != nil { + t.Fatalf("Failed to generate hash: %v", err) + } + + if hash1 == hashWithAliases { + t.Error("Expected different hash for keys with Aliases") + } + hash6b, err := configstore.GenerateKeyHash(key6b) if err != nil { t.Fatalf("Failed to generate hash: %v", err) @@ -4744,12 +4761,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4758,12 +4773,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4783,12 +4796,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4797,12 +4808,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4822,12 +4831,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4836,12 +4843,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4861,11 +4866,9 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4874,12 +4877,9 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-3.5-turbo": "gpt-35-turbo-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-3.5-turbo": "gpt-35-turbo-deployment", // Added! - }, }, } @@ -4912,9 +4912,6 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4938,9 +4935,6 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4969,12 +4963,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), // APIVersion is nil (will use default) - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4983,12 +4975,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), // Explicitly set - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -5011,13 +5001,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5026,13 +5014,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5052,13 +5038,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5067,13 +5051,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAI44QH8DHBEXAMPLE"), // Changed! SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5093,13 +5075,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5108,13 +5088,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("differentSecretKey/NEWKEY/bPxRfiCYEXAMPLEKEY"), // Changed! Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5134,13 +5112,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5149,13 +5125,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5175,14 +5149,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-east-1:123456789012:inference-profile/old-profile"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5191,14 +5163,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-east-1:123456789012:inference-profile/new-profile"), // Changed! - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5218,13 +5188,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5233,14 +5201,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile", "claude-3.5": "claude-35-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - "claude-3.5": "claude-35-inference-profile", // Added! - }, }, } @@ -5274,9 +5239,6 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5301,9 +5263,6 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5332,14 +5291,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), // SessionToken is nil - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5348,14 +5305,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), SessionToken: schemas.NewEnvVar("AQoDYXdzEJr..."), // Explicitly set - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5376,13 +5331,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar(""), // Empty for IAM role auth SecretKey: *schemas.NewEnvVar(""), // Empty for IAM role auth Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5392,13 +5345,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5420,12 +5371,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -5457,12 +5406,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-dashboard-edited"), // Changed via dashboard! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -5492,12 +5439,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), // Original value from file Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -5531,13 +5476,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://new-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added! - }, }, }, }, @@ -5629,8 +5571,8 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { if finalConfig.Keys[0].AzureKeyConfig.APIVersion.GetValue() != "2024-10-21" { t.Errorf("Expected updated APIVersion, got %s", finalConfig.Keys[0].AzureKeyConfig.APIVersion.GetValue()) } - if len(finalConfig.Keys[0].AzureKeyConfig.Deployments) != 2 { - t.Errorf("Expected 2 deployments, got %d", len(finalConfig.Keys[0].AzureKeyConfig.Deployments)) + if len(finalConfig.Keys[0].Aliases) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(finalConfig.Keys[0].Aliases)) } t.Log("Step 5 - Final state verified, Azure provider lifecycle complete ✓") @@ -5644,13 +5586,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), // Empty for Bedrock with IAM or AccessKey auth Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, } @@ -5681,13 +5621,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key-eu", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAI44QH8DHBEXAMPLE"), SecretKey: *schemas.NewEnvVar("je7MtGbClwBF/2Zp9Utk/h3yCo8nvbEXAMPLEKEY"), Region: schemas.NewEnvVar("eu-west-1"), // Different region - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, } @@ -5710,13 +5648,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -5751,15 +5687,12 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), // Added! - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! - }, }, }, }, @@ -5863,8 +5796,8 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { if fileKey.BedrockKeyConfig.ARN == nil || fileKey.BedrockKeyConfig.ARN.GetValue() != "arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile" { t.Error("Expected ARN to be set") } - if len(fileKey.BedrockKeyConfig.Deployments) != 2 { - t.Errorf("Expected 2 deployments, got %d", len(fileKey.BedrockKeyConfig.Deployments)) + if len(fileKey.Aliases) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(fileKey.Aliases)) } // Verify dashboard-added key is preserved @@ -5882,15 +5815,12 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", - }, }, }, }, @@ -5928,12 +5858,10 @@ func TestProviderHashComparison_AzureNewProviderFromConfig(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -5997,13 +5925,11 @@ func TestProviderHashComparison_BedrockNewProviderFromConfig(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -6068,12 +5994,10 @@ func TestProviderHashComparison_AzureDBValuePreservedWhenHashMatches(t *testing. Name: "azure-openai-key", Value: *schemas.NewEnvVar("DASHBOARD-EDITED-SECRET-KEY"), // Dashboard edited this! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -6099,12 +6023,10 @@ func TestProviderHashComparison_AzureDBValuePreservedWhenHashMatches(t *testing. Name: "azure-openai-key", Value: *schemas.NewEnvVar("original-key-from-file"), // Different value than DB! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), // Same APIVersion: schemas.NewEnvVar("2024-02-01"), // Same - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", // Same - }, }, }, }, @@ -6158,13 +6080,11 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("DASHBOARD-EDITED-ACCESS-KEY"), // Dashboard edited! SecretKey: *schemas.NewEnvVar("DASHBOARD-EDITED-SECRET-KEY"), // Dashboard edited! Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -6190,13 +6110,11 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), // Different! SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), // Different! Region: schemas.NewEnvVar("us-east-1"), // Same - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", // Same - }, }, }, }, @@ -6281,12 +6199,10 @@ func TestProviderHashComparison_AzureConfigChangedInFile(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://NEW-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", // Added! - }, }, }, }, @@ -6371,14 +6287,12 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/new-profile"), // Added! - Deployments: map[string]string{ - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! - }, }, }, }, @@ -13349,9 +13263,8 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { t.Run("AzureKeyConfig_GORMRoundTrip", func(t *testing.T) { apiVersion := "2024-02-01" azureConfig := &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://myresource.openai.azure.com"), - APIVersion: schemas.NewEnvVar(apiVersion), - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, + Endpoint: *schemas.NewEnvVar("https://myresource.openai.azure.com"), + APIVersion: schemas.NewEnvVar(apiVersion), } keyToSave := tables.TableKey{ @@ -13362,6 +13275,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: *schemas.NewEnvVar("azure-key-value"), Weight: ptrFloat64(1.0), AzureKeyConfig: azureConfig, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, } schemaKey := schemas.Key{ @@ -13369,6 +13283,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: keyToSave.Value, Weight: getWeight(keyToSave.Weight), AzureKeyConfig: keyToSave.AzureKeyConfig, + Aliases: keyToSave.Aliases, } hashBeforeSave, _ := configstore.GenerateKeyHash(schemaKey) @@ -13382,6 +13297,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: keyFromDB.Value, Weight: getWeight(keyFromDB.Weight), AzureKeyConfig: keyFromDB.AzureKeyConfig, + Aliases: keyFromDB.Aliases, } hashAfterLoad, _ := configstore.GenerateKeyHash(schemaKeyFromDB) @@ -14466,14 +14382,12 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), ProjectNumber: *schemas.NewEnvVar("123456789"), Region: *schemas.NewEnvVar("us-central1"), AuthCredentials: *schemas.NewEnvVar(`{"type":"service_account"}`), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14482,14 +14396,12 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), ProjectNumber: *schemas.NewEnvVar("123456789"), Region: *schemas.NewEnvVar("us-central1"), AuthCredentials: *schemas.NewEnvVar(`{"type":"service_account"}`), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14614,12 +14526,10 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14628,13 +14538,10 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", // Added! - }, }, } @@ -15018,11 +14925,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -15031,12 +14936,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added - }, }, } @@ -15054,12 +14956,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", - }, }, } @@ -15068,11 +14967,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", // gpt-4o removed - }, }, } @@ -15090,11 +14987,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment-v1"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment-v1", - }, }, } @@ -15103,11 +14998,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment-v2"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment-v2", // Value changed - }, }, } @@ -15126,8 +15019,7 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: nil, // No deployments + Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), }, } @@ -15136,11 +15028,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -15161,13 +15051,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - }, }, } @@ -15176,14 +15064,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5", // Added - }, }, } @@ -15201,14 +15086,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5", - }, }, } @@ -15217,13 +15099,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", // claude-3.5 removed - }, }, } @@ -15241,13 +15121,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-old"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-old", - }, }, } @@ -15256,13 +15134,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-new"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-new", // Value changed - }, }, } @@ -15283,12 +15159,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -15297,13 +15171,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", // Added - }, }, } @@ -15321,13 +15192,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", - }, }, } @@ -15336,12 +15204,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", // gemini-1.5-pro removed - }, }, } @@ -15359,12 +15225,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint-v1"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint-v1", - }, }, } @@ -15373,12 +15237,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint-v2"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint-v2", // Value changed - }, }, } @@ -15397,9 +15259,8 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, VertexKeyConfig: &schemas.VertexKeyConfig{ - ProjectID: *schemas.NewEnvVar("my-project"), - Region: *schemas.NewEnvVar("us-central1"), - Deployments: nil, // No deployments + ProjectID: *schemas.NewEnvVar("my-project"), + Region: *schemas.NewEnvVar("us-central1"), }, } @@ -15408,12 +15269,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } diff --git a/transports/config.schema.json b/transports/config.schema.json index 03c383ed70..df417ddd47 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -238,7 +238,7 @@ "$ref": "#/$defs/provider" }, "replicate": { - "$ref": "#/$defs/provider" + "$ref": "#/$defs/provider_with_replicate_config" }, "elevenlabs": { "$ref": "#/$defs/provider" @@ -1835,6 +1835,17 @@ "type": "boolean", "description": "Whether this key can be used for batch API operations (default: false)", "default": false + }, + "aliases": { + "type": "object", + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "propertyNames": { + "minLength": 1 + }, + "description": "Model alias mappings: maps a model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" } }, "required": [ @@ -1865,13 +1876,6 @@ "type": "string", "description": "AWS session token (can use env. prefix)" }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "Model to deployment mappings" - }, "arn": { "type": "string", "description": "AWS ARN" @@ -1928,6 +1932,28 @@ } ] }, + "replicate_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "replicate_key_config": { + "type": "object", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint (default: false)" + } + }, + "additionalProperties": false + } + } + } + ] + }, "ollama_key": { "allOf": [ { @@ -2001,13 +2027,6 @@ "type": "string", "description": "Azure endpoint (can use env. prefix)" }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "Model to deployment mappings" - }, "api_version": { "type": "string", "description": "Azure API version" @@ -2052,13 +2071,6 @@ "auth_credentials": { "type": "string", "description": "Authentication credentials (can use env. prefix)" - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "Model to deployment mappings" } }, "required": [ @@ -2197,6 +2209,88 @@ ], "additionalProperties": false }, + "provider_with_replicate_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/replicate_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config_without_base_url" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_replicate_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/replicate_key" + }, + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, "provider_with_azure_config": { "type": "object", "properties": { diff --git a/transports/schema_test/config_schema_test.go b/transports/schema_test/config_schema_test.go index cb551591cc..a523467ee8 100644 --- a/transports/schema_test/config_schema_test.go +++ b/transports/schema_test/config_schema_test.go @@ -165,21 +165,13 @@ func validateConfig(t *testing.T, schema *jsonschema.Schema, configJSON string) return schema.Validate(v) } -func TestSchemaVertexKeyDeployments(t *testing.T) { - schemaPath := getSchemaPath(t) - data, err := os.ReadFile(schemaPath) - if err != nil { - t.Fatalf("failed to read schema: %v", err) - } - var schema map[string]interface{} - if err := json.Unmarshal(data, &schema); err != nil { - t.Fatalf("failed to parse schema: %v", err) - } +func TestSchemaKeyAliases(t *testing.T) { + schema := loadSchema(t) - t.Run("vertex_key $def includes deployments field", func(t *testing.T) { - _, found := navigateJSON(schema, "$defs", "vertex_key", "allOf", 1, "properties", "vertex_key_config", "properties", "deployments") + t.Run("base_key $def includes aliases field", func(t *testing.T) { + _, found := navigateJSON(schema, "$defs", "base_key", "properties", "aliases") if !found { - t.Error("$defs/vertex_key is missing 'deployments' property — vertex provider uses getModelDeployment() on every request") + t.Error("$defs/base_key is missing 'aliases' property — aliases replaced per-provider deployments maps") } }) @@ -190,30 +182,60 @@ func TestSchemaVertexKeyDeployments(t *testing.T) { } }) - t.Run("vertex config with deployments validates successfully", func(t *testing.T) { + t.Run("vertex_key_config does not include deployments field", func(t *testing.T) { + _, found := navigateJSON(schema, "$defs", "vertex_key", "allOf", 1, "properties", "vertex_key_config", "properties", "deployments") + if found { + t.Error("$defs/vertex_key still has 'deployments' in vertex_key_config — deployments were moved to top-level key aliases") + } + }) + + t.Run("key with aliases validates successfully", func(t *testing.T) { compiled := compileSchema(t) config := `{ "providers": { "vertex": { "keys": [{ - "key_id": "test", "name": "test", "value": "", "weight": 1, "models": ["gemini-2.0-flash"], + "aliases": {"gemini-2.0-flash": "gemini-2.0-flash-001"}, "vertex_key_config": { "project_id": "my-project", "region": "us-central1", "auth_credentials": "", - "project_number": "123456", - "deployments": {"gemini-2.0-flash": "gemini-2.0-flash-001"} + "project_number": "123456" + } + }] + } + } + }` + if err := validateConfig(t, compiled, config); err != nil { + t.Errorf("key with aliases should be valid, got: %v", err) + } + }) + + t.Run("azure key with aliases validates successfully", func(t *testing.T) { + compiled := compileSchema(t) + config := `{ + "providers": { + "azure": { + "keys": [{ + "name": "test", + "value": "my-api-key", + "weight": 1, + "models": ["gpt-4o"], + "aliases": {"gpt-4o": "gpt-4o-deployment"}, + "azure_key_config": { + "endpoint": "https://my-resource.openai.azure.com", + "api_version": "2024-02-01" } }] } } }` if err := validateConfig(t, compiled, config); err != nil { - t.Errorf("vertex config with deployments should be valid, got: %v", err) + t.Errorf("azure key with aliases should be valid, got: %v", err) } }) } diff --git a/ui/app/globals.css b/ui/app/globals.css index 1402e40487..418f86667f 100644 --- a/ui/app/globals.css +++ b/ui/app/globals.css @@ -226,19 +226,62 @@ body { } div.content-container:has(.no-padding-parent) { - @apply !p-0; + @apply p-0!; } div.content-container main.content-container-inner:has(.no-padding-parent) { - @apply !p-0; + @apply p-0!; } div.content-container:has(.no-border-parent) { - @apply !border-0; + @apply border-0!; +} + +/* ReactFlow Controls — follow Bifrost colour schema */ + +.react-flow__controls { + background-color: var(--card); + border: 1px solid var(--border); + border-radius: var(--radius); + box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1); +} + +.react-flow__controls-button { + background-color: var(--card); + border-bottom: 1px solid var(--border); + fill: var(--foreground); +} + +.react-flow__controls-button:hover { + background-color: var(--muted); +} + +.react-flow__controls-button svg { + fill: var(--foreground); +} + +/* Dark mode — needs !important to beat ReactFlow's bundled specificity */ +.dark .react-flow__controls { + background-color: var(--card) !important; + border-color: var(--border) !important; +} + +.dark .react-flow__controls-button { + background-color: var(--card) !important; + border-bottom-color: var(--border) !important; + fill: var(--foreground) !important; +} + +.dark .react-flow__controls-button:hover { + background-color: var(--muted) !important; +} + +.dark .react-flow__controls-button svg { + fill: var(--foreground) !important; } /* // Custom styling for streamdown */ [data-streamdown="code-block"], [data-streamdown="code-block-body"]{ - @apply !rounded-sm; + @apply rounded-sm!; } \ No newline at end of file diff --git a/ui/app/workspace/logs/page.tsx b/ui/app/workspace/logs/page.tsx index 2f1bf161b2..bc2540e906 100644 --- a/ui/app/workspace/logs/page.tsx +++ b/ui/app/workspace/logs/page.tsx @@ -80,6 +80,7 @@ export default function LogsPage() { { providers: parseAsArrayOf(parseAsString).withDefault([]), models: parseAsArrayOf(parseAsString).withDefault([]), + aliases: parseAsArrayOf(parseAsString).withDefault([]), status: parseAsArrayOf(parseAsString).withDefault([]), objects: parseAsArrayOf(parseAsString).withDefault([]), selected_key_ids: parseAsArrayOf(parseAsString).withDefault([]), @@ -188,6 +189,7 @@ export default function LogsPage() { () => ({ providers: urlState.providers, models: urlState.models, + aliases: urlState.aliases, status: urlState.status, objects: urlState.objects, selected_key_ids: urlState.selected_key_ids, @@ -208,7 +210,7 @@ export default function LogsPage() { }), // Only re-derive filters when filter-related URL params change (not pagination) [ - urlState.providers, urlState.models, urlState.status, urlState.objects, + urlState.providers, urlState.models, urlState.aliases, urlState.status, urlState.objects, urlState.selected_key_ids, urlState.virtual_key_ids, urlState.routing_rule_ids, urlState.routing_engine_used, urlState.content_search, urlState.start_time, urlState.end_time, @@ -239,6 +241,7 @@ export default function LogsPage() { setUrlState({ providers: newFilters.providers || [], models: newFilters.models || [], + aliases: newFilters.aliases || [], status: newFilters.status || [], objects: newFilters.objects || [], selected_key_ids: newFilters.selected_key_ids || [], @@ -663,6 +666,9 @@ export default function LogsPage() { if (filters.providers?.length && !filters.providers.includes(log.provider)) { return false; } + if (filters.aliases?.length && !filters.aliases.includes(log.alias ?? "")) { + return false; + } if (filters.models?.length && !filters.models.includes(log.model)) { return false; } diff --git a/ui/app/workspace/logs/sheets/logDetailsSheet.tsx b/ui/app/workspace/logs/sheets/logDetailsSheet.tsx index 0b351c839c..9bba51bbdd 100644 --- a/ui/app/workspace/logs/sheets/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/sheets/logDetailsSheet.tsx @@ -293,6 +293,9 @@ export function LogDetailSheet({ } /> {!isContainer && } + {!isContainer && displayLog.alias && ( + + )} void, hasDeleteAccess = true, metadataKeys: string[] = []): ColumnDef[] => { const baseColumns: ColumnDef[] = [ - { - accessorKey: "status", - header: "", - size: 8, - maxSize: 8, - cell: ({ row }) => { - const status = row.original.status as Status; - return
; + { + accessorKey: "status", + header: "", + size: 8, + maxSize: 8, + cell: ({ row }) => { + const status = row.original.status as Status; + return
; + }, }, - }, - { - accessorKey: "timestamp", - header: ({ column }) => ( - - ), - cell: ({ row }) => { - const timestamp = row.original.timestamp; - return
{moment(timestamp).format("YYYY-MM-DD hh:mm:ss A (Z)")}
; + { + accessorKey: "timestamp", + header: ({ column }) => ( + + ), + cell: ({ row }) => { + const timestamp = row.original.timestamp; + return
{moment(timestamp).format("YYYY-MM-DD hh:mm:ss A (Z)")}
; + }, }, - }, - { - id: "request_type", - header: "Type", - cell: ({ row }) => { - return ( - - {RequestTypeLabels[row.original.object as keyof typeof RequestTypeLabels]} - - ); + { + id: "request_type", + header: "Type", + cell: ({ row }) => { + return ( + + {RequestTypeLabels[row.original.object as keyof typeof RequestTypeLabels]} + + ); + }, }, - }, - { - accessorKey: "input", - header: "Message", - cell: ({ row }) => { - const input = getMessage(row.original); - const isLargePayload = row.original.is_large_payload_request || row.original.is_large_payload_response; - return ( -
- {isLargePayload && ( - - LP - - )} -
- {input || (isLargePayload - ? `Large payload ${row.original.is_large_payload_request && row.original.is_large_payload_response ? "request & response" : row.original.is_large_payload_request ? "request" : "response"}` - : "-")} + { + accessorKey: "input", + header: "Message", + cell: ({ row }) => { + const input = getMessage(row.original); + const isLargePayload = row.original.is_large_payload_request || row.original.is_large_payload_response; + return ( +
+ {isLargePayload && ( + + LP + + )} +
+ {input || (isLargePayload + ? `Large payload ${row.original.is_large_payload_request && row.original.is_large_payload_response ? "request & response" : row.original.is_large_payload_request ? "request" : "response"}` + : "-")} +
-
- ); + ); + }, }, - }, - { - accessorKey: "provider", - header: "Provider", - cell: ({ row }) => { - const provider = row.original.provider as ProviderName; - return ( - - - {provider} - - ); + { + accessorKey: "provider", + header: "Provider", + cell: ({ row }) => { + const provider = row.original.provider as ProviderName; + return ( + + + {provider} + + ); + }, }, - }, - { - accessorKey: "model", - header: "Model", - cell: ({ row }) =>
{row.original.model || "N/A"}
, - }, - { - accessorKey: "latency", - header: ({ column }) => ( - - ), - cell: ({ row }) => { - const latency = row.original.latency; - return ( -
{latency === undefined || latency === null ? "N/A" : `${latency.toLocaleString()}ms`}
- ); + { + accessorKey: "model", + header: "Model", + cell: ({ row }) =>
{row.original.model || "N/A"}
, + }, - }, - { - accessorKey: "tokens", - header: ({ column }) => ( - - ), - cell: ({ row }) => { - const tokenUsage = row.original.token_usage; - if (!tokenUsage) { - return
N/A
; - } + { + accessorKey: "latency", + header: ({ column }) => ( + + ), + cell: ({ row }) => { + const latency = row.original.latency; + return ( +
{latency === undefined || latency === null ? "N/A" : `${latency.toLocaleString()}ms`}
+ ); + }, + }, + { + accessorKey: "tokens", + header: ({ column }) => ( + + ), + cell: ({ row }) => { + const tokenUsage = row.original.token_usage; + if (!tokenUsage) { + return
N/A
; + } - return ( -
-
- {tokenUsage.total_tokens.toLocaleString()}{" "} - {tokenUsage.completion_tokens != null && tokenUsage.prompt_tokens != null - ? `(${tokenUsage.prompt_tokens.toLocaleString()}+${tokenUsage.completion_tokens.toLocaleString()})` - : ""} + return ( +
+
+ {tokenUsage.total_tokens.toLocaleString()}{" "} + {tokenUsage.completion_tokens != null && tokenUsage.prompt_tokens != null + ? `(${tokenUsage.prompt_tokens.toLocaleString()}+${tokenUsage.completion_tokens.toLocaleString()})` + : ""} +
-
- ); + ); + }, }, - }, - { - accessorKey: "cost", - header: ({ column }) => ( - - ), - cell: ({ row }) => { - if (!row.original.cost) { - return
N/A
; - } + { + accessorKey: "cost", + header: ({ column }) => ( + + ), + cell: ({ row }) => { + if (!row.original.cost) { + return
N/A
; + } - return ( -
-
{row.original.cost?.toFixed(4)}
-
- ); + return ( +
+
{row.original.cost?.toFixed(4)}
+
+ ); + }, }, - }, ]; // Generate dynamic metadata columns diff --git a/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx b/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx index a724a3f8d2..1c958c915f 100644 --- a/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx +++ b/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx @@ -33,7 +33,6 @@ export default function AddNewKeySheet({ show, onCancel, provider, keyId, provid className="custom-scrollbar p-8" data-testid="key-form" onInteractOutside={(e) => e.preventDefault()} - onEscapeKeyDown={(e) => e.preventDefault()} > diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index f978cee951..33413edcd1 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -4,13 +4,13 @@ import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { EnvVarInput } from "@/components/ui/envVarInput"; import { FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; +import { HeadersTable, type CellRenderParams } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { ModelMultiselect } from "@/components/ui/modelMultiselect"; import { Separator } from "@/components/ui/separator"; import { Switch } from "@/components/ui/switch"; import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { TagInput } from "@/components/ui/tagInput"; -import { Textarea } from "@/components/ui/textarea"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { isRedacted } from "@/lib/utils/validation"; import { Info, Plus, Trash2 } from "lucide-react"; @@ -20,6 +20,34 @@ import { Control, UseFormReturn } from "react-hook-form"; // Providers that support batch APIs const BATCH_SUPPORTED_PROVIDERS = ["openai", "bedrock", "anthropic", "gemini", "azure"]; +/** Normalize form value (object or legacy JSON string) for the alias map editor. */ +function normalizeAliasesValue( + v: Record | string | undefined | null, +): Record { + if (v == null) { + return {}; + } + if (typeof v === "string") { + const t = v.trim(); + if (!t) { + return {}; + } + try { + const p = JSON.parse(t) as unknown; + if (typeof p === "object" && p !== null && !Array.isArray(p)) { + return Object.fromEntries(Object.entries(p as Record).map(([k, val]) => [k, String(val ?? "")])); + } + } catch { + return {}; + } + return {}; + } + if (typeof v === "object" && !Array.isArray(v)) { + return Object.fromEntries(Object.entries(v).map(([k, val]) => [k, typeof val === "string" ? val : String(val ?? "")])); + } + return {}; +} + interface Props { control: Control; providerName: string; @@ -174,7 +202,7 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { />
{/* Hide API Key field for Azure when using Entra ID/Default Credential, and for Bedrock when not using API Key auth */} - {!isAzure && !isBedrock && !isKeylessProvider && ( + {!(isAzure && (azureAuthType === "entra_id" || azureAuthType === "default_credential")) && !(isBedrock) && ( )} /> + ( + + Aliases (Optional) + + Map each request model name to the provider's identifier (deployment name, inference profile ID, fine-tuned endpoint ID, + etc.) or just a custom name, e.g. "claude-sonnet-4-5" -> "custom-claude-4.5-sonnet". + + +
+ { + form.clearErrors("key.aliases"); + field.onChange(Object.keys(next).length > 0 ? next : {}); + }} + keyPlaceholder="Request model name" + valuePlaceholder="Deployment / profile / resource ID" + renderValueInput={({ value: cellValue, onChange, placeholder, disabled }: CellRenderParams) => ( + + )} + /> +
+
+ +
+ )} + /> )} {supportsBatchAPI && !isBedrock && !isAzure && } @@ -465,46 +532,6 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { /> )} - - ( - - Deployments (Required) - JSON object mapping model names to deployment names - -