Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 24 additions & 99 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,6 @@ type AzureTextResponse struct {
Usage schemas.LLMUsage `json:"usage"` // Token usage statistics
}

// AzureChatResponse represents the response structure from Azure's chat completion API.
// It includes completion choices, model information, and usage statistics.
type AzureChatResponse struct {
ID string `json:"id"` // Unique identifier for the completion
Object string `json:"object"` // Type of completion (always "chat.completion")
Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices
Model string `json:"model"` // Model used for the completion
Created int `json:"created"` // Unix timestamp of completion creation
SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request
Usage schemas.LLMUsage `json:"usage"` // Token usage statistics
}

// AzureEmbeddingResponse represents the response structure from Azure's embedding API.
type AzureEmbeddingResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding interface{} `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage schemas.LLMUsage `json:"usage"`
ID string `json:"id"`
SystemFingerprint *string `json:"system_fingerprint"`
}

// AzureError represents the error response structure from Azure's API.
// It includes error code and message information.
type AzureError struct {
Expand All @@ -79,19 +53,19 @@ var azureTextCompletionResponsePool = sync.Pool{
// azureChatResponsePool provides a pool for Azure chat response objects.
var azureChatResponsePool = sync.Pool{
New: func() interface{} {
return &AzureChatResponse{}
return &schemas.BifrostResponse{}
},
}

// acquireAzureChatResponse gets an Azure chat response from the pool and resets it.
func acquireAzureChatResponse() *AzureChatResponse {
resp := azureChatResponsePool.Get().(*AzureChatResponse)
*resp = AzureChatResponse{} // Reset the struct
func acquireAzureChatResponse() *schemas.BifrostResponse {
resp := azureChatResponsePool.Get().(*schemas.BifrostResponse)
*resp = schemas.BifrostResponse{} // Reset the struct
return resp
}

// releaseAzureChatResponse returns an Azure chat response to the pool.
func releaseAzureChatResponse(resp *AzureChatResponse) {
func releaseAzureChatResponse(resp *schemas.BifrostResponse) {
if resp != nil {
azureChatResponsePool.Put(resp)
}
Expand Down Expand Up @@ -139,7 +113,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A

// Pre-warm response pools
for range config.ConcurrencyAndBufferSize.Concurrency {
azureChatResponsePool.Put(&AzureChatResponse{})
azureChatResponsePool.Put(&schemas.BifrostResponse{})
azureTextCompletionResponsePool.Put(&AzureTextResponse{})

}
Expand Down Expand Up @@ -342,39 +316,24 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string,
return nil, bifrostErr
}

// Create final response
bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Choices: response.Choices,
Model: response.Model,
Created: response.Created,
SystemFingerprint: response.SystemFingerprint,
Usage: &response.Usage,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
},
}
response.ExtraFields.Provider = schemas.Azure

// Set raw response if enabled
if provider.sendBackRawResponse {
bifrostResponse.ExtraFields.RawResponse = rawResponse
response.ExtraFields.RawResponse = rawResponse
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
response.ExtraFields.Params = *params
}

return bifrostResponse, nil
return response, nil
}

// Embedding generates embeddings for the given input text(s) using Azure OpenAI.
// The input can be either a single string or a slice of strings for batch embedding.
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newBifrostOperationError("no input text provided for embedding", nil, schemas.Azure)
}

// Prepare request body - Azure uses deployment-scoped URLs, so model is not needed in body
requestBody := map[string]interface{}{
"input": input.Texts,
Expand All @@ -399,61 +358,27 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key
return nil, err
}

// Parse response
var response AzureEmbeddingResponse
if err := sonic.Unmarshal(responseBody, &response); err != nil {
return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Azure)
}
// Pre-allocate response structs from pools
response := acquireAzureChatResponse()
defer releaseAzureChatResponse(response)

bifrostResponse := &schemas.BifrostResponse{
ID: response.ID,
Object: response.Object,
Model: response.Model,
Usage: &response.Usage,
SystemFingerprint: response.SystemFingerprint,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Azure,
RawResponse: responseBody,
},
// Use enhanced response handler with pre-allocated response
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}

// Extract embeddings from response data
if len(response.Data) > 0 {
embeddings := make([][]float32, len(response.Data))
for i, data := range response.Data {
switch v := data.Embedding.(type) {
case []float32:
embeddings[i] = v
case []float64:
// Direct conversion from []float64 to []float32
floatArray := make([]float32, len(v))
for j := range v {
floatArray[j] = float32(v[j])
}
embeddings[i] = floatArray
case []interface{}:
// Fallback: element-by-element conversion for []interface{}
floatArray := make([]float32, len(v))
for j := range v {
if num, ok := v[j].(float64); ok {
floatArray[j] = float32(num)
} else {
return nil, newBifrostOperationError(fmt.Sprintf("unsupported number type in embedding array: %T", v[j]), nil, schemas.Azure)
}
}
embeddings[i] = floatArray
default:
return nil, newBifrostOperationError(fmt.Sprintf("unsupported embedding type: %T", data.Embedding), nil, schemas.Azure)
}
}
bifrostResponse.Embedding = embeddings
}
response.ExtraFields.Provider = schemas.Azure

if params != nil {
bifrostResponse.ExtraFields.Params = *params
response.ExtraFields.Params = *params
}

return bifrostResponse, nil
if provider.sendBackRawResponse {
response.ExtraFields.RawResponse = rawResponse
}

return response, nil
}

// ChatCompletionStream performs a streaming chat completion request to Azure's OpenAI API.
Expand Down
40 changes: 23 additions & 17 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,14 +1130,6 @@ func (provider *BedrockProvider) Embedding(ctx context.Context, model string, ke

// handleTitanEmbedding handles embedding requests for Amazon Titan models.
func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Titan Text Embeddings V1/V2 - only supports single text input
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock)
}
if len(input.Texts) > 1 {
return nil, newConfigurationError("Amazon Titan embedding models support only single text input, received multiple texts", schemas.Bedrock)
}

requestBody := map[string]interface{}{
"inputText": input.Texts[0],
}
Expand Down Expand Up @@ -1171,8 +1163,17 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model
}

bifrostResponse := &schemas.BifrostResponse{
Embedding: [][]float32{titanResp.Embedding},
Model: model,
Object: "list",
Data: []schemas.BifrostEmbedding{
{
Index: 0,
Object: "embedding",
Embedding: schemas.BifrostEmbeddingResponse{
Embedding2DArray: &[][]float32{titanResp.Embedding},
},
},
},
Model: model,
Usage: &schemas.LLMUsage{
PromptTokens: titanResp.InputTextTokenCount,
TotalTokens: titanResp.InputTextTokenCount,
Expand All @@ -1192,10 +1193,6 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model

// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock.
func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock)
}

requestBody := map[string]interface{}{
"texts": input.Texts,
"input_type": "search_document",
Expand Down Expand Up @@ -1225,9 +1222,18 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode
totalInputTokens := approximateTokenCount(input.Texts)

bifrostResponse := &schemas.BifrostResponse{
Embedding: cohereResp.Embeddings,
ID: cohereResp.ID,
Model: model,
Object: "list",
Data: []schemas.BifrostEmbedding{
{
Index: 0,
Object: "embedding",
Embedding: schemas.BifrostEmbeddingResponse{
Embedding2DArray: &cohereResp.Embeddings,
},
},
},
ID: cohereResp.ID,
Model: model,
Usage: &schemas.LLMUsage{
PromptTokens: totalInputTokens,
TotalTokens: totalInputTokens,
Expand Down
19 changes: 12 additions & 7 deletions core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,6 @@ func convertChatHistory(history []struct {
// Embedding generates embeddings for the given input text(s) using the Cohere API.
// Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s).
func (provider *CohereProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Cohere)
}

// Prepare request body with default values
requestBody := map[string]interface{}{
"texts": input.Texts,
Expand Down Expand Up @@ -683,9 +679,18 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key

// Create BifrostResponse
bifrostResponse := &schemas.BifrostResponse{
ID: cohereResp.ID,
Embedding: cohereResp.Embeddings.Float,
Model: model,
ID: cohereResp.ID,
Object: "list",
Data: []schemas.BifrostEmbedding{
{
Index: 0,
Object: "embedding",
Embedding: schemas.BifrostEmbeddingResponse{
Embedding2DArray: &cohereResp.Embeddings.Float,
},
},
},
Model: model,
Usage: &schemas.LLMUsage{
PromptTokens: totalInputTokens,
TotalTokens: totalInputTokens,
Expand Down
63 changes: 14 additions & 49 deletions core/providers/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,6 @@ import (
"github.com/valyala/fasthttp"
)

// MistralEmbeddingResponse represents the response structure from Mistral's embedding API.
type MistralEmbeddingResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage schemas.LLMUsage `json:"usage"`
ID string `json:"id"`
SystemFingerprint *string `json:"system_fingerprint"`
}

// mistralResponsePool provides a pool for Mistral response objects.
var mistralResponsePool = sync.Pool{
New: func() interface{} {
Expand Down Expand Up @@ -183,10 +169,6 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model strin
// Embedding generates embeddings for the given input text(s) using the Mistral API.
// Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s).
func (provider *MistralProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Mistral)
}

// Prepare request body with base parameters
requestBody := map[string]interface{}{
"model": model,
Expand Down Expand Up @@ -254,46 +236,29 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke
return nil, bifrostErr
}

// Parse response using sonic.RawMessage to avoid double parsing
rawMessage := resp.Body()
responseBody := resp.Body()

// Parse into structured response
var mistralResp MistralEmbeddingResponse
if err := sonic.Unmarshal(rawMessage, &mistralResp); err != nil {
return nil, newBifrostOperationError("error parsing Mistral embedding response", err, schemas.Mistral)
}
// Pre-allocate response structs from pools
response := acquireMistralResponse()
defer releaseMistralResponse(response)

// Parse raw response for consistent format
var rawResponse interface{}
if err := sonic.Unmarshal(rawMessage, &rawResponse); err != nil {
return nil, newBifrostOperationError("error parsing raw response for Mistral embedding", err, schemas.Mistral)
// Use enhanced response handler with pre-allocated response
rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}

// Convert data to embeddings array
var embeddings [][]float32
for _, data := range mistralResp.Data {
embeddings = append(embeddings, data.Embedding)
}
response.ExtraFields.Provider = schemas.Mistral

// Create BifrostResponse
bifrostResponse := &schemas.BifrostResponse{
ID: mistralResp.ID,
Object: mistralResp.Object,
Embedding: embeddings,
Model: mistralResp.Model,
Usage: &mistralResp.Usage,
SystemFingerprint: mistralResp.SystemFingerprint,
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Mistral,
RawResponse: rawResponse,
},
if params != nil {
response.ExtraFields.Params = *params
}

if params != nil {
bifrostResponse.ExtraFields.Params = *params
if provider.sendBackRawResponse {
response.ExtraFields.RawResponse = rawResponse
}

return bifrostResponse, nil
return response, nil
}

// ChatCompletionStream performs a streaming chat completion request to the Mistral API.
Expand Down
Loading