Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -938,9 +938,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem
},
}
}
hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil &&
(req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil)
if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) {
if (req.Input == nil || len(req.Input.Contents) == 0) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Expand All @@ -954,6 +952,22 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem
},
}
}
if req.Input != nil {
if err := req.Input.Validate(); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.EmbeddingRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
}

bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.EmbeddingRequest
Expand All @@ -967,6 +981,59 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem
return response.EmbeddingResponse, nil
}

// BatchEmbeddingRequest sends a batch embedding request with optional per-item parameter overrides.
func (bifrost *Bifrost) BatchEmbeddingRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch embedding request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchEmbeddingRequest,
},
}
}
if len(req.Items) == 0 && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch embedding request has no items",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchEmbeddingRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
if err := req.Validate(); err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchEmbeddingRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}

bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchEmbeddingRequest
bifrostReq.BatchEmbeddingRequest = req

response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.EmbeddingResponse, nil
}

// RerankRequest sends a rerank request to the specified provider.
func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
if req == nil {
Expand Down Expand Up @@ -4522,6 +4589,12 @@ func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fall
tmp.Model = fallback.Model
fallbackReq.EmbeddingRequest = &tmp
}
if req.BatchEmbeddingRequest != nil {
tmp := *req.BatchEmbeddingRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.BatchEmbeddingRequest = &tmp
}
if req.RerankRequest != nil {
tmp := *req.RerankRequest
tmp.Provider = fallback.Provider
Expand Down Expand Up @@ -6049,6 +6122,15 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config
}
embeddingResponse.BackfillParams(req.BifrostRequest.EmbeddingRequest)
response.EmbeddingResponse = embeddingResponse
case schemas.BatchEmbeddingRequest:
embeddingResponse, bifrostError := provider.BatchEmbedding(req.Context, key, req.BifrostRequest.BatchEmbeddingRequest)
if bifrostError != nil {
return nil, bifrostError
}
if strings.TrimSpace(embeddingResponse.Model) == "" {
embeddingResponse.Model = req.BifrostRequest.BatchEmbeddingRequest.Model
}
response.EmbeddingResponse = embeddingResponse
case schemas.RerankRequest:
rerankResponse, bifrostError := provider.Rerank(req.Context, key, req.BifrostRequest.RerankRequest)
if bifrostError != nil {
Expand Down Expand Up @@ -7072,6 +7154,7 @@ func resetBifrostRequest(req *schemas.BifrostRequest) {
req.ResponsesRequest = nil
req.CountTokensRequest = nil
req.EmbeddingRequest = nil
req.BatchEmbeddingRequest = nil
req.RerankRequest = nil
req.OCRRequest = nil
req.SpeechRequest = nil
Expand Down
14 changes: 13 additions & 1 deletion core/internal/llmtests/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type TestScenarios struct {
Transcription bool // Speech-to-text functionality
TranscriptionStream bool // Streaming speech-to-text functionality
Embedding bool // Embedding functionality
MultimodalEmbedding bool // Multimodal embedding functionality (text + image)
Reasoning bool // Reasoning/thinking functionality via Responses API
PromptCaching bool // Prompt caching functionality
ListModels bool // List available models functionality
Expand Down Expand Up @@ -109,6 +110,7 @@ type ComprehensiveTestConfig struct {
VisionModel string
ReasoningModel string
EmbeddingModel string
MultimodalEmbeddingModel string // Model for multimodal embedding tests (text + image)
RerankModel string
TranscriptionModel string
SpeechSynthesisModel string
Expand Down Expand Up @@ -340,7 +342,17 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context,
return []schemas.Key{
{
Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"),
Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"},
Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "google/gemini-2.0-flash-001", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004", "multimodalembedding@001"},
Weight: 1.0,
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"),
Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION", "us-central1")),
AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"),
},
UseForBatchAPI: bifrost.Ptr(true),
},
{
Models: []string{"gemini-embedding-2-preview"},
Weight: 1.0,
VertexKeyConfig: &schemas.VertexKeyConfig{
ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"),
Expand Down
16 changes: 7 additions & 9 deletions core/internal/llmtests/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@ func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context
"Goodnight, moon!",
}

contents := make([]schemas.EmbeddingContent, len(testTexts))
for i, text := range testTexts {
t := text
contents[i] = schemas.EmbeddingContent{{Type: schemas.EmbeddingContentPartTypeText, Text: &t}}
}
request := &schemas.BifrostEmbeddingRequest{
Provider: testConfig.Provider,
Model: testConfig.EmbeddingModel,
Input: &schemas.EmbeddingInput{
Texts: testTexts,
},
Input: &schemas.EmbeddingInput{Contents: contents},
Params: &schemas.EmbeddingParameters{
EncodingFormat: bifrost.Ptr("float"),
},
Expand Down Expand Up @@ -123,12 +126,7 @@ func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbedding
embeddings := make([][]float64, len(testTexts))
responseDataLength := len(response.Data)
if responseDataLength != len(testTexts) {
if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil {
responseDataLength = len(response.Data[0].Embedding.Embedding2DArray)
}
if responseDataLength != len(testTexts) {
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
}
t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength)
}

for i := range responseDataLength {
Expand Down
Loading
Loading