diff --git a/core/bifrost.go b/core/bifrost.go index 40cc6219b2..c8bc5e0e3b 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -22,8 +22,19 @@ type RequestType string const ( TextCompletionRequest RequestType = "text_completion" ChatCompletionRequest RequestType = "chat_completion" + EmbeddingRequest RequestType = "embedding" ) +// executor is a function type that handles specific request types. +type executor func(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) + +// messageExecutors is a factory map for handling different request types. +var messageExecutors = map[RequestType]executor{ + TextCompletionRequest: handleTextCompletion, + ChatCompletionRequest: handleChatCompletion, + EmbeddingRequest: handleEmbedding, +} + // ChannelMessage represents a message passed through the request channel. // It contains the request, response and error channels, and the request type. type ChannelMessage struct { @@ -380,6 +391,45 @@ func (bifrost *Bifrost) calculateBackoff(attempt int, config *schemas.ProviderCo return time.Duration(jitter) } +// handleTextCompletion executes a text completion request +func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.TextCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text not provided for text completion request", + }, + } + } + return provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params) +} + +// handleChatCompletion executes a chat completion request +func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.ChatCompletionInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "chats not provided for chat completion request", + }, + } + } + return provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params) +} + +// handleEmbedding executes an embedding request +func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.EmbeddingInput == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "input not provided for embedding request", + }, + } + } + return provider.Embedding(req.Context, req.Model, key, req.Input.EmbeddingInput, req.Params) +} + // requestWorker handles incoming requests from the queue for a specific provider. // It manages retries, error handling, and response processing. func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan ChannelMessage) { @@ -439,31 +489,21 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan bifrost.logger.Debug(fmt.Sprintf("Attempting request for provider %s", provider.GetProviderKey())) - // Attempt the request - if req.Type == TextCompletionRequest { - if req.Input.TextCompletionInput == nil { - bifrostError = &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text not provided for text completion request", - }, - } - break // Don't retry client errors - } else { - result, bifrostError = provider.TextCompletion(req.Context, req.Model, key, *req.Input.TextCompletionInput, req.Params) - } - } else if req.Type == ChatCompletionRequest { - if req.Input.ChatCompletionInput == nil { - bifrostError = &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "chats not provided for chat completion request", - }, - } - break // Don't retry client errors - } else { - result, bifrostError = provider.ChatCompletion(req.Context, req.Model, key, *req.Input.ChatCompletionInput, req.Params) + // Attempt the request using factory pattern + executor := messageExecutors[req.Type] + if executor == nil { + bifrostError = &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("unsupported request type: %s", req.Type), + }, } + break + } + + result, bifrostError = executor(provider, &req, key) + if bifrostError != nil && !bifrostError.IsBifrostError { + break // Don't retry client errors } bifrost.logger.Debug(fmt.Sprintf("Request for provider %s completed", provider.GetProviderKey())) @@ -614,93 +654,81 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. // tryTextCompletion attempts a text completion request with a single provider. // This is a helper function used by TextCompletionRequest to handle individual provider attempts. func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.getProviderQueue(req.Provider) - if err != nil { - return nil, newBifrostError(err) - } + return bifrost.tryRequest(req, ctx, TextCompletionRequest, true) +} - // Add MCP tools to request if MCP is configured - if bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) +// ChatCompletionRequest sends a chat completion request to the specified provider. +// It handles plugin hooks, request validation, response processing, and fallback providers. +// If the primary provider fails, it will try each fallback provider in order until one succeeds. +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req == nil { + return nil, newBifrostErrorFromMsg("bifrost request cannot be nil") } - pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger) - preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) - if shortCircuit != nil { - // Handle short-circuit with response (success case) - if shortCircuit.Response != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) - if bifrostErr != nil { - return nil, bifrostErr - } - return resp, nil - } - // Handle short-circuit with error - if shortCircuit.Error != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) - if bifrostErr != nil { - return nil, bifrostErr - } - return resp, nil - } + if req.Provider == "" { + return nil, newBifrostErrorFromMsg("provider is required") } - if preReq == nil { - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + + if req.Model == "" { + return nil, newBifrostErrorFromMsg("model is required") } - msg := bifrost.getChannelMessage(*preReq, TextCompletionRequest) - msg.Context = ctx + // Try the primary provider first + primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx) + if primaryErr == nil { + return primaryResult, nil + } - select { - case queue <- *msg: - // Message was sent successfully - case <-ctx.Done(): - bifrost.releaseChannelMessage(msg) - return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") - default: - if bifrost.dropExcessRequests { - bifrost.releaseChannelMessage(msg) - bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, newBifrostErrorFromMsg("request dropped: queue is full") - } - if ctx == nil { - ctx = bifrost.backgroundCtx - } - select { - case queue <- *msg: - // Message was sent successfully - case <-ctx.Done(): - bifrost.releaseChannelMessage(msg) - return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") - } + // Check if this is a short-circuit error that doesn't allow fallbacks + // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) + if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + return nil, primaryErr } - var result *schemas.BifrostResponse - var resp *schemas.BifrostResponse - select { - case result = <-msg.Response: - resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) - if bifrostErr != nil { - bifrost.releaseChannelMessage(msg) - return nil, bifrostErr - } - bifrost.releaseChannelMessage(msg) - return resp, nil - case bifrostErrVal := <-msg.Err: - bifrostErrPtr := &bifrostErrVal - resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) - bifrost.releaseChannelMessage(msg) - if bifrostErrPtr != nil { - return nil, bifrostErrPtr + // If primary provider failed and we have fallbacks, try them in order + // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil + if len(req.Fallbacks) > 0 { + for _, fallback := range req.Fallbacks { + // Check if we have config for this fallback provider + _, err := bifrost.account.GetConfigForProvider(fallback.Provider) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) + continue + } + + // Create a new request with the fallback provider and model + fallbackReq := *req + fallbackReq.Provider = fallback.Provider + fallbackReq.Model = fallback.Model + + // Try the fallback provider + result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx) + if fallbackErr == nil { + bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + return result, nil + } + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + return nil, fallbackErr + } + + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) } - return resp, nil } + + // All providers failed, return the original error + return nil, primaryErr } -// ChatCompletionRequest sends a chat completion request to the specified provider. +// tryChatCompletion attempts a chat completion request with a single provider. +// This is a helper function used by ChatCompletionRequest to handle individual provider attempts. +func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + return bifrost.tryRequest(req, ctx, ChatCompletionRequest, true) +} + +// EmbeddingRequest sends an embedding request to the specified provider. // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. -func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if req == nil { return nil, newBifrostErrorFromMsg("bifrost request cannot be nil") } @@ -713,8 +741,12 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. return nil, newBifrostErrorFromMsg("model is required") } + if req.Input.EmbeddingInput == nil { + return nil, newBifrostErrorFromMsg("embedding_input is required") + } + // Try the primary provider first - primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx) + primaryResult, primaryErr := bifrost.tryEmbedding(req, ctx) if primaryErr == nil { return primaryResult, nil } @@ -726,7 +758,6 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. } // If primary provider failed and we have fallbacks, try them in order - // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil if len(req.Fallbacks) > 0 { for _, fallback := range req.Fallbacks { // Check if we have config for this fallback provider @@ -742,7 +773,7 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. fallbackReq.Model = fallback.Model // Try the fallback provider - result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx) + result, fallbackErr := bifrost.tryEmbedding(&fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil @@ -759,16 +790,22 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. return nil, primaryErr } -// tryChatCompletion attempts a chat completion request with a single provider. -// This is a helper function used by ChatCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { +// tryEmbedding attempts an embedding request with a single provider. +// This is a helper function used by EmbeddingRequest to handle individual provider attempts. +func (bifrost *Bifrost) tryEmbedding(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + return bifrost.tryRequest(req, ctx, EmbeddingRequest, false) +} + +// tryRequest is a generic function that handles common request processing logic +// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling +func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context, requestType RequestType, includeMCP bool) (*schemas.BifrostResponse, *schemas.BifrostError) { queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { return nil, newBifrostError(err) } - // Add MCP tools to request if MCP is configured - if bifrost.mcpManager != nil { + // Add MCP tools to request if MCP is configured and requested + if includeMCP && bifrost.mcpManager != nil { req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) } @@ -796,7 +833,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - msg := bifrost.getChannelMessage(*preReq, ChatCompletionRequest) + msg := bifrost.getChannelMessage(*preReq, requestType) msg.Context = ctx select { diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 466a73a572..794a9fbbea 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -694,3 +694,8 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc return bifrostResponse, nil } + +// Embedding is not supported by the Anthropic provider. +func (provider *AnthropicProvider) Embedding(ctx context.Context, model, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "anthropic") +} diff --git a/core/providers/azure.go b/core/providers/azure.go index 5867778334..be4ec44b17 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -43,6 +43,20 @@ type AzureChatResponse struct { Usage schemas.LLMUsage `json:"usage"` // Token usage statistics } +// AzureEmbeddingResponse represents the response structure from Azure's embedding API. +type AzureEmbeddingResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding interface{} `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` + Usage schemas.LLMUsage `json:"usage"` + ID string `json:"id"` + SystemFingerprint *string `json:"system_fingerprint"` +} + // AzureError represents the error response structure from Azure's API. // It includes error code and message information. type AzureError struct { @@ -356,3 +370,111 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key st return bifrostResponse, nil } + +// Embedding generates embeddings for the given input text(s) using Azure OpenAI. +// The input can be either a single string or a slice of strings for batch embedding. +// Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *AzureProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{Message: "no input text provided for embedding"}, + } + } + + // Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body + requestBody := map[string]interface{}{ + "input": input.Texts, + } + + // Merge any additional parameters + if params != nil { + if params.EncodingFormat != nil { + requestBody["encoding_format"] = *params.EncodingFormat + } + if params.Dimensions != nil { + requestBody["dimensions"] = *params.Dimensions + } + if params.User != nil { + requestBody["user"] = *params.User + } + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + responseBody, err := provider.completeRequest(ctx, requestBody, "embeddings", key, model) + if err != nil { + return nil, err + } + + // Parse response + var response AzureEmbeddingResponse + if err := json.Unmarshal(responseBody, &response); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Model: response.Model, + Usage: response.Usage, + SystemFingerprint: response.SystemFingerprint, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Azure, + RawResponse: responseBody, + }, + } + + // Extract embeddings from response data + if len(response.Data) > 0 { + embeddings := make([][]float32, len(response.Data)) + for i, data := range response.Data { + switch v := data.Embedding.(type) { + case []float32: + embeddings[i] = v + case []float64: + // Direct conversion from []float64 to []float32 + floatArray := make([]float32, len(v)) + for j := range v { + floatArray[j] = float32(v[j]) + } + embeddings[i] = floatArray + case []interface{}: + // Fallback: element-by-element conversion for []interface{} + floatArray := make([]float32, len(v)) + for j := range v { + if num, ok := v[j].(float64); ok { + floatArray[j] = float32(num) + } else { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), + }, + } + } + } + embeddings[i] = floatArray + default: + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding), + }, + } + } + } + bifrostResponse.Embedding = embeddings + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 169e2a53e1..bba35da713 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "maps" "net/http" "net/url" "strings" @@ -1091,3 +1092,160 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken return nil } + +// Embedding generates embeddings for the given input text(s) using Amazon Bedrock. +// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *BedrockProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + switch { + case strings.HasPrefix(model, "amazon.titan-embed-text"): + return provider.handleTitanEmbedding(ctx, model, key, input, params) + case strings.HasPrefix(model, "cohere.embed"): + return provider.handleCohereEmbedding(ctx, model, key, input, params) + default: + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{Message: "embedding is not supported for this Bedrock model"}, + } + } +} + +// handleTitanEmbedding handles embedding requests for Amazon Titan models. +func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Titan Text Embeddings V1/V2 - only supports single text input + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{Message: "no input text provided for embedding"}, + } + } + if len(input.Texts) > 1 { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{Message: "Amazon Titan embedding models support only single text input, received multiple texts"}, + } + } + + requestBody := map[string]interface{}{ + "inputText": input.Texts[0], + } + + if params != nil { + // Titan models do not support the dimensions parameter - they have fixed dimensions + if params.Dimensions != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{Message: "Amazon Titan embedding models do not support custom dimensions parameter"}, + } + } + if params.ExtraParams != nil { + for k, v := range params.ExtraParams { + requestBody[k] = v + } + } + } + + // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly + path := url.PathEscape(model) + "/invoke" + rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) + if err != nil { + return nil, err + } + + // Parse Titan response from raw message + var titanResp struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` + } + if err := json.Unmarshal(rawResponse, &titanResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing Titan embedding response", + Error: err, + }, + } + } + + bifrostResponse := &schemas.BifrostResponse{ + Embedding: [][]float32{titanResp.Embedding}, + Model: model, + Usage: schemas.LLMUsage{ + PromptTokens: titanResp.InputTextTokenCount, + TotalTokens: titanResp.InputTextTokenCount, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Bedrock, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock. +func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{Message: "no input text provided for embedding"}, + } + } + + requestBody := map[string]interface{}{ + "texts": input.Texts, + "input_type": "search_document", + } + if params != nil && params.ExtraParams != nil { + maps.Copy(requestBody, params.ExtraParams) + } + + // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly + path := url.PathEscape(model) + "/invoke" + rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) + if err != nil { + return nil, err + } + + // Parse Cohere response + var cohereResp struct { + Embeddings [][]float32 `json:"embeddings"` + ID string `json:"id"` + Texts []string `json:"texts"` + } + if err := json.Unmarshal(rawResponse, &cohereResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing Cohere embedding response", + Error: err, + }, + } + } + + // Calculate token usage based on input texts (approximation since Cohere doesn't provide this) + totalInputTokens := approximateTokenCount(input.Texts) + + bifrostResponse := &schemas.BifrostResponse{ + Embedding: cohereResp.Embeddings, + ID: cohereResp.ID, + Model: model, + Usage: schemas.LLMUsage{ + PromptTokens: totalInputTokens, + TotalTokens: totalInputTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Bedrock, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/cohere.go b/core/providers/cohere.go index c8b17b3f74..89e83e056e 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -93,6 +93,14 @@ type CohereError struct { Message string `json:"message"` // Error message } +// CohereEmbeddingResponse represents the response from Cohere's embedding API. +type CohereEmbeddingResponse struct { + ID string `json:"id"` // Unique identifier for the embedding request + Embeddings struct { + Float [][]float32 `json:"float"` // Array of float embeddings, one for each input text + } `json:"embeddings"` // Embeddings in the response +} + // CohereProvider implements the Provider interface for Cohere. type CohereProvider struct { logger schemas.Logger // Logger for provider operations @@ -139,12 +147,7 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. func (provider *CohereProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by cohere provider", - }, - } + return nil, newUnsupportedOperationError("text completion", "cohere") } // ChatCompletion performs a chat completion request to the Cohere API. @@ -537,3 +540,139 @@ func convertChatHistory(history []struct { } return &converted } + +// Embedding generates embeddings for the given input text(s) using the Cohere API. +// Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). +func (provider *CohereProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{Message: "no input text provided for embedding"}, + } + } + + // Prepare request body with default values + requestBody := map[string]interface{}{ + "texts": input.Texts, + "model": model, + "input_type": "search_document", // Default input type - can be overridden via ExtraParams + "embedding_types": []string{"float"}, // Default to float embeddings + } + + // Apply additional parameters if provided + if params != nil { + // Validate encoding format - Cohere API supports float, int8, uint8, binary, ubinary, but our provider only implements float + if params.EncodingFormat != nil { + if *params.EncodingFormat != "float" { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Cohere provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), + }, + } + } + // Override default with the specified format + requestBody["embedding_types"] = []string{*params.EncodingFormat} + } + + // Merge extra parameters - this allows overriding input_type and other parameters + if params.ExtraParams != nil { + for k, v := range params.ExtraParams { + requestBody[k] = v + } + } + } + + // Marshal request body + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from cohere embedding provider: %s", string(resp.Body()))) + + var errorResp CohereError + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = errorResp.Message + + return nil, bifrostErr + } + + // Parse response + var cohereResp CohereEmbeddingResponse + if err := json.Unmarshal(resp.Body(), &cohereResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing Cohere embedding response", + Error: err, + }, + } + } + + // Parse raw response for consistent format + var rawResponse interface{} + if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing raw response for Cohere embedding", + Error: err, + }, + } + } + + // Calculate token usage approximation (since Cohere doesn't provide this for embeddings) + totalInputTokens := approximateTokenCount(input.Texts) + + // Create BifrostResponse + bifrostResponse := &schemas.BifrostResponse{ + ID: cohereResp.ID, + Embedding: cohereResp.Embeddings.Float, + Model: model, + Usage: schemas.LLMUsage{ + PromptTokens: totalInputTokens, + TotalTokens: totalInputTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Cohere, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/mistral.go b/core/providers/mistral.go index 6bc34da683..26b645ece5 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -25,6 +25,20 @@ type MistralResponse struct { Usage schemas.LLMUsage `json:"usage"` } +// MistralEmbeddingResponse represents the response structure from Mistral's embedding API. +type MistralEmbeddingResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` + Usage schemas.LLMUsage `json:"usage"` + ID string `json:"id"` + SystemFingerprint *string `json:"system_fingerprint"` +} + // mistralResponsePool provides a pool for Mistral response objects. var mistralResponsePool = sync.Pool{ New: func() interface{} { @@ -93,12 +107,7 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Mistral provider. func (provider *MistralProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by mistral provider", - }, - } + return nil, newUnsupportedOperationError("text completion", "mistral") } // ChatCompletion performs a chat completion request to the Mistral API. @@ -185,3 +194,145 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model, key return bifrostResponse, nil } + +// Embedding generates embeddings for the given input text(s) using the Mistral API. +// Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). +func (provider *MistralProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{Message: "no input text provided for embedding"}, + } + } + + // Prepare request body with base parameters + requestBody := map[string]interface{}{ + "model": model, + "input": input.Texts, + } + + // Merge any additional parameters + if params != nil { + // Validate encoding format - Mistral API supports multiple formats, but our provider only implements float + if params.EncodingFormat != nil { + if *params.EncodingFormat != "float" { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Mistral provider currently only supports 'float' encoding format, received: %s", *params.EncodingFormat), + }, + } + } + // Map to Mistral's parameter name + requestBody["output_dtype"] = *params.EncodingFormat + } + + // Map dimensions to Mistral's parameter name + if params.Dimensions != nil { + requestBody["output_dimension"] = *params.Dimensions + } + + // Merge any extra parameters + if params.ExtraParams != nil { + for k, v := range params.ExtraParams { + requestBody[k] = v + } + } + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from mistral embedding provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Mistral embedding error: %v", errorResp) + return nil, bifrostErr + } + + // Parse response using json.RawMessage to avoid double parsing + var rawMessage json.RawMessage = resp.Body() + + // Parse into structured response + var mistralResp MistralEmbeddingResponse + if err := json.Unmarshal(rawMessage, &mistralResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing Mistral embedding response", + Error: err, + }, + } + } + + // Parse raw response for consistent format + var rawResponse interface{} + if err := json.Unmarshal(rawMessage, &rawResponse); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error parsing raw response for Mistral embedding", + Error: err, + }, + } + } + + // Convert data to embeddings array + var embeddings [][]float32 + for _, data := range mistralResp.Data { + embeddings = append(embeddings, data.Embedding) + } + + // Create BifrostResponse + bifrostResponse := &schemas.BifrostResponse{ + ID: mistralResp.ID, + Object: mistralResp.Object, + Embedding: embeddings, + Model: mistralResp.Model, + Usage: mistralResp.Usage, + SystemFingerprint: mistralResp.SystemFingerprint, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Mistral, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/ollama.go b/core/providers/ollama.go index bed89846dd..f00f764e6f 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -94,12 +94,7 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Ollama provider. func (provider *OllamaProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by ollama provider", - }, - } + return nil, newUnsupportedOperationError("text completion", "ollama") } // ChatCompletion performs a chat completion request to the Ollama API. @@ -188,3 +183,8 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model, key s return bifrostResponse, nil } + +// Embedding is not supported by the Ollama provider. +func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "ollama") +} diff --git a/core/providers/openai.go b/core/providers/openai.go index bd345e3c0b..9577d8c72f 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -4,7 +4,10 @@ package providers import ( "context" + "encoding/base64" + "encoding/binary" "fmt" + "math" "strings" "sync" "time" @@ -18,14 +21,19 @@ import ( // OpenAIResponse represents the response structure from the OpenAI API. // It includes completion choices, model information, and usage statistics. type OpenAIResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Object string `json:"object"` // Type of completion (text.completion or chat.completion) - Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices - Model string `json:"model"` // Model used for the completion - Created int `json:"created"` // Unix timestamp of completion creation - ServiceTier *string `json:"service_tier"` // Service tier used for the request - SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request - Usage schemas.LLMUsage `json:"usage"` // Token usage statistics + ID string `json:"id"` // Unique identifier for the completion + Object string `json:"object"` // Type of completion (text.completion, chat.completion, or embedding) + Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices + Data []struct { // Embedding data + Object string `json:"object"` + Embedding any `json:"embedding"` + Index int `json:"index"` + } `json:"data,omitempty"` + Model string `json:"model"` // Model used for the completion + Created int `json:"created"` // Unix timestamp of completion creation + ServiceTier *string `json:"service_tier"` // Service tier used for the request + SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request + Usage schemas.LLMUsage `json:"usage"` // Token usage statistics } // OpenAIError represents the error response structure from the OpenAI API. @@ -111,12 +119,7 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by openai provider", - }, - } + return nil, newUnsupportedOperationError("text completion", "openai") } // ChatCompletion performs a chat completion request to the OpenAI API. @@ -264,3 +267,192 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas return formattedMessages, preparedParams } + +// Embedding generates embeddings for the given input text(s). +// The input can be either a single string or a slice of strings for batch embedding. +// Returns a BifrostResponse containing the embedding(s) and any error that occurred. +func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Validate input texts are not empty + if len(input.Texts) == 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "input texts cannot be empty", + }, + } + } + + // Prepare request body with base parameters + requestBody := map[string]interface{}{ + "model": model, + "input": input.Texts, + } + + // Merge any additional parameters + if params != nil { + // Map standard parameters + if params.EncodingFormat != nil { + requestBody["encoding_format"] = *params.EncodingFormat + } + if params.Dimensions != nil { + requestBody["dimensions"] = *params.Dimensions + } + if params.User != nil { + requestBody["user"] = *params.User + } + + // Merge any extra parameters + requestBody = mergeConfig(requestBody, params.ExtraParams) + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from openai provider: %s", string(resp.Body()))) + + var errorResp OpenAIError + + bifrostErr := handleProviderAPIError(resp, &errorResp) + + if errorResp.EventID != "" { + bifrostErr.EventID = &errorResp.EventID + } + bifrostErr.Error.Type = &errorResp.Error.Type + bifrostErr.Error.Code = &errorResp.Error.Code + bifrostErr.Error.Message = errorResp.Error.Message + bifrostErr.Error.Param = errorResp.Error.Param + if errorResp.Error.EventID != "" { + bifrostErr.Error.EventID = &errorResp.Error.EventID + } + + return nil, bifrostErr + } + + // Parse response + var response OpenAIResponse + if err := json.Unmarshal(resp.Body(), &response); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Model: response.Model, + Created: response.Created, + Usage: response.Usage, + ServiceTier: response.ServiceTier, + SystemFingerprint: response.SystemFingerprint, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + }, + } + + // Extract embeddings from response data + if len(response.Data) > 0 { + embeddings := make([][]float32, len(response.Data)) + for i, data := range response.Data { + switch v := data.Embedding.(type) { + case []float32: + embeddings[i] = v + case []interface{}: + // Convert []interface{} to []float32 + floatArray := make([]float32, len(v)) + for j := range v { + if num, ok := v[j].(float64); ok { + floatArray[j] = float32(num) + } else { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), + }, + } + } + } + embeddings[i] = floatArray + case string: + // Decode base64 string into float32 array + decodedData, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "failed to decode base64 embedding", + Error: err, + }, + } + } + + // Validate that decoded data length is divisible by 4 (size of float32) + const sizeOfFloat32 = 4 + if len(decodedData)%sizeOfFloat32 != 0 { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "malformed base64 embedding data: length not divisible by 4", + }, + } + } + + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + embeddings[i] = floats + default: + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("unsupported embedding type: %T", data.Embedding), + }, + } + } + } + bifrostResponse.Embedding = embeddings + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 94712f91dc..9899755f60 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -541,3 +541,37 @@ func isLikelyBase64(s string) bool { // Check if it contains only base64 characters using pre-compiled regex return base64Regex.MatchString(cleanData) } + +// newUnsupportedOperationError creates a standardized error for unsupported operations. +// This helper reduces code duplication across providers that don't support certain operations. +func newUnsupportedOperationError(operation string, providerName string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: fmt.Sprintf("%s is not supported by %s provider", operation, providerName), + }, + } +} + +// approximateTokenCount provides a rough approximation of token count for text. +// WARNING: This is a best-effort approximation using 1 token per 4 characters. +// This heuristic is particularly inaccurate for: +// - Non-ASCII text (multi-byte characters) +// - Short texts +// - Different languages and tokenization methods +// - Various model-specific tokenizers +// +// The actual token count may vary significantly based on tokenization method, +// language, and text structure. Consider omitting token metrics when precise +// counts are unavailable to avoid misleading usage information. +// +// For precise token usage tracking, implement a proper tokenizer that matches +// the model's tokenization method. +func approximateTokenCount(texts []string) int { + totalInputTokens := 0 + for _, text := range texts { + // Rough approximation: 1 token per 4 characters + totalInputTokens += len(text) / 4 + } + return totalInputTokens +} diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 6dc6ac9967..c739f00c38 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -80,12 +80,7 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. func (provider *VertexProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "text completion is not supported by vertex provider", - }, - } + return nil, newUnsupportedOperationError("text completion", "vertex") } // ChatCompletion performs a chat completion request to the Vertex API. @@ -307,3 +302,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s return bifrostResponse, nil } } + +// Embedding is not supported by the Vertex provider. +func (provider *VertexProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "vertex") +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 64c8e61e31..a6274deca6 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -50,14 +50,20 @@ const ( //* Request Structs // RequestInput represents the input for a model request, which can be either -// a text completion or a chat completion, but either one must be provided. +// a text completion, a chat completion, or an embedding request. type RequestInput struct { TextCompletionInput *string `json:"text_completion_input,omitempty"` ChatCompletionInput *[]BifrostMessage `json:"chat_completion_input,omitempty"` + EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"` +} + +// EmbeddingInput represents the input for an embedding request. +type EmbeddingInput struct { + Texts []string `json:"texts"` } // BifrostRequest represents a request to be processed by Bifrost. -// It must be provided when calling the Bifrost for text completion or chat completion. +// It must be provided when calling the Bifrost for text completion, chat completion, or embedding. // It contains the model identifier, input data, and parameters for the request. type BifrostRequest struct { Provider ModelProvider `json:"provider"` @@ -91,6 +97,9 @@ type ModelParameters struct { PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls + EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") + Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output + User *string `json:"user,omitempty"` // User identifier for tracking // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` @@ -289,7 +298,7 @@ type ImageURLStruct struct { // BifrostResponse represents the complete result from any bifrost request. type BifrostResponse struct { ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` // text.completion or chat.completion + Object string `json:"object,omitempty"` // text.completion, chat.completion, or embedding Choices []BifrostResponseChoice `json:"choices,omitempty"` Model string `json:"model,omitempty"` Created int `json:"created,omitempty"` // The Unix timestamp (in seconds). @@ -297,6 +306,7 @@ type BifrostResponse struct { SystemFingerprint *string `json:"system_fingerprint,omitempty"` Usage LLMUsage `json:"usage,omitempty"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + Embedding [][]float32 `json:"data,omitempty"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) } // LLMUsage represents token usage information diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 800178063f..53fcc3bba5 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -158,4 +158,6 @@ type Provider interface { TextCompletion(ctx context.Context, model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request ChatCompletion(ctx context.Context, model, key string, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) + // Embedding performs an embedding request + Embedding(ctx context.Context, model string, key string, input *EmbeddingInput, params *ModelParameters) (*BifrostResponse, *BifrostError) }