Skip to content
Merged
138 changes: 138 additions & 0 deletions core/bifrost.go
Comment thread
connyay marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type RequestType string
const (
TextCompletionRequest RequestType = "text_completion"
ChatCompletionRequest RequestType = "chat_completion"
EmbeddingRequest RequestType = "embedding"
)

// ChannelMessage represents a message passed through the request channel.
Expand Down Expand Up @@ -452,6 +453,18 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
} else {
result, bifrostError = provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params)
}
} else if req.Type == EmbeddingRequest {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pratham-Mishra04 as a code hygiene - I think nuking this if-else ladder is critical.

We can probably use some factory pattern where

type executor func (provider, message)

type MessageExecutors map[string]executor

const executors = MessageExecutors{<add pre-defined handlers for each message>}

if req.Input.EmbeddingInput == nil {
bifrostError = &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "input not provided for embedding request",
},
}
break // Don't retry client errors
} else {
result, bifrostError = provider.Embedding(req.Context, req.Model, key, *req.Input.EmbeddingInput, req.Params)
}
Comment thread
akshaydeo marked this conversation as resolved.
Outdated
}

bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey()))
Expand Down Expand Up @@ -788,6 +801,131 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
}
}

// EmbeddingRequest sends an embedding request to the specified provider.
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req == nil {
return nil, newBifrostErrorFromMsg("bifrost request cannot be nil")
}

if req.Provider == "" {
return nil, newBifrostErrorFromMsg("provider is required")
}

if req.Model == "" {
return nil, newBifrostErrorFromMsg("model is required")
}

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryEmbedding(req, ctx)
if primaryErr == nil {
return primaryResult, nil
}

// If primary provider failed and we have fallbacks, try them in order
if len(req.Fallbacks) > 0 {
for _, fallback := range req.Fallbacks {
// Check if we have config for this fallback provider
_, err := bifrost.account.GetConfigForProvider(fallback.Provider)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err))
continue
}

// Create a new request with the fallback provider and model
fallbackReq := *req
fallbackReq.Provider = fallback.Provider
fallbackReq.Model = fallback.Model

// Try the fallback provider
result, fallbackErr := bifrost.tryEmbedding(&fallbackReq, ctx)
if fallbackErr == nil {
bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
return result, nil
}
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
return nil, fallbackErr
}

bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
}
}

// All providers failed, return the original error
return nil, primaryErr
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// tryEmbedding attempts an embedding request with a single provider.
// This is a helper function used by EmbeddingRequest to handle individual provider attempts.
func (bifrost *Bifrost) tryEmbedding(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) {
queue, err := bifrost.getProviderQueue(req.Provider)
if err != nil {
return nil, newBifrostError(err)
}

pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger)
preReq, preResp, preCount := pipeline.RunPreHooks(&ctx, req)
if preResp != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, preResp, nil, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
if preReq == nil {
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

msg := bifrost.getChannelMessage(*preReq, EmbeddingRequest)
msg.Context = ctx

select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}
if ctx == nil {
ctx = bifrost.backgroundCtx
}
select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
}
}

var result *schemas.BifrostResponse
var resp *schemas.BifrostResponse
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
}
bifrost.releaseChannelMessage(msg)
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
}
return resp, nil
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// Cleanup gracefully stops all workers when triggered.
// It closes all request channels and waits for workers to exit.
func (bifrost *Bifrost) Cleanup() {
Expand Down
10 changes: 10 additions & 0 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,13 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc

return bifrostResponse, nil
}

// Embedding is not supported by the Anthropic provider.
func (provider *AnthropicProvider) Embedding(ctx context.Context, model, key string, input schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "embedding is not supported by anthropic provider",
},
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
114 changes: 114 additions & 0 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ type AzureChatResponse struct {
Usage schemas.LLMUsage `json:"usage"` // Token usage statistics
}

// AzureEmbeddingResponse represents the response structure from Azure's embedding API.
type AzureEmbeddingResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding interface{} `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage schemas.LLMUsage `json:"usage"`
ID string `json:"id"`
SystemFingerprint *string `json:"system_fingerprint"`
}

// AzureError represents the error response structure from Azure's API.
// It includes error code and message information.
type AzureError struct {
Expand Down Expand Up @@ -356,3 +370,103 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key st

return bifrostResponse, nil
}

// Embedding generates embeddings for the given input text(s) using Azure OpenAI.
// The input can be either a single string or a slice of strings for batch embedding.
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *AzureProvider) Embedding(ctx context.Context, model string, key string, input schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{Message: "no input text provided for embedding"},
}
}

// Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body
requestBody := map[string]interface{}{
"input": input.Texts,
}

// Merge any additional parameters
if params != nil {
if params.EncodingFormat != nil {
requestBody["encoding_format"] = *params.EncodingFormat
}
if params.Dimensions != nil {
requestBody["dimensions"] = *params.Dimensions
}
if params.User != nil {
requestBody["user"] = *params.User
}
requestBody = mergeConfig(requestBody, params.ExtraParams)
}

responseBody, err := provider.completeRequest(ctx, requestBody, "embeddings", key, model)
if err != nil {
return nil, err
}

// Parse response
var response AzureEmbeddingResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: schemas.ErrProviderResponseUnmarshal,
Error: err,
},
}
}

bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Object: response.Object,
Model: response.Model,
Usage: response.Usage,
SystemFingerprint: response.SystemFingerprint,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: responseBody,
},
}

// Extract embeddings from response data
if len(response.Data) > 0 {
embeddings := make([][]float32, len(response.Data))
for i, data := range response.Data {
switch v := data.Embedding.(type) {
case []float32:
embeddings[i] = v
case []interface{}:
floatArray := make([]float32, len(v))
for j := range v {
if num, ok := v[j].(float64); ok {
floatArray[j] = float32(num)
} else {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]),
},
}
}
}
embeddings[i] = floatArray
default:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: schemas.ErrorField{
Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding),
},
}
}
}
bifrostResponse.Embedding = embeddings
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
}

return bifrostResponse, nil
}
Loading