diff --git a/core/bifrost.go b/core/bifrost.go index c383161fb0..06bbd64db8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -938,9 +938,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem }, } } - hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil && - (req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil) - if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) { + if (req.Input == nil || len(req.Input.Contents) == 0) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ @@ -954,6 +952,22 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem }, } } + if req.Input != nil { + if err := req.Input.Validate(); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.EmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, + }, + } + } + } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.EmbeddingRequest @@ -967,6 +981,59 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem return response.EmbeddingResponse, nil } +// BatchEmbeddingRequest sends a batch embedding request with optional per-item parameter overrides. +func (bifrost *Bifrost) BatchEmbeddingRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + if req == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "batch embedding request is nil", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.BatchEmbeddingRequest, + }, + } + } + if len(req.Items) == 0 && !isLargePayloadPassthrough(ctx) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "batch embedding request has no items", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.BatchEmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, + }, + } + } + if err := req.Validate(); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.BatchEmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, + }, + } + } + + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.RequestType = schemas.BatchEmbeddingRequest + bifrostReq.BatchEmbeddingRequest = req + + response, err := bifrost.handleRequest(ctx, bifrostReq) + if err != nil { + return nil, err + } + return response.EmbeddingResponse, nil +} + // RerankRequest sends a rerank request to the specified provider. func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { if req == nil { @@ -4522,6 +4589,12 @@ func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fall tmp.Model = fallback.Model fallbackReq.EmbeddingRequest = &tmp } + if req.BatchEmbeddingRequest != nil { + tmp := *req.BatchEmbeddingRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.BatchEmbeddingRequest = &tmp + } if req.RerankRequest != nil { tmp := *req.RerankRequest tmp.Provider = fallback.Provider @@ -6049,6 +6122,15 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config } embeddingResponse.BackfillParams(req.BifrostRequest.EmbeddingRequest) response.EmbeddingResponse = embeddingResponse + case schemas.BatchEmbeddingRequest: + embeddingResponse, bifrostError := provider.BatchEmbedding(req.Context, key, req.BifrostRequest.BatchEmbeddingRequest) + if bifrostError != nil { + return nil, bifrostError + } + if strings.TrimSpace(embeddingResponse.Model) == "" { + embeddingResponse.Model = req.BifrostRequest.BatchEmbeddingRequest.Model + } + response.EmbeddingResponse = embeddingResponse case schemas.RerankRequest: rerankResponse, bifrostError := provider.Rerank(req.Context, key, req.BifrostRequest.RerankRequest) if bifrostError != nil { @@ -7072,6 +7154,7 @@ func resetBifrostRequest(req *schemas.BifrostRequest) { req.ResponsesRequest = nil req.CountTokensRequest = nil req.EmbeddingRequest = nil + req.BatchEmbeddingRequest = nil req.RerankRequest = nil req.OCRRequest = nil req.SpeechRequest = nil diff --git a/core/internal/llmtests/account.go b/core/internal/llmtests/account.go index 87726014b1..5f238ca2a3 100644 --- a/core/internal/llmtests/account.go +++ b/core/internal/llmtests/account.go @@ -49,6 +49,7 @@ type TestScenarios struct { Transcription bool // Speech-to-text functionality TranscriptionStream bool // Streaming speech-to-text functionality Embedding bool // Embedding functionality + MultimodalEmbedding bool // Multimodal embedding functionality (text + image) Reasoning bool // Reasoning/thinking functionality via Responses API PromptCaching bool // Prompt caching functionality ListModels bool // List available models functionality @@ -109,6 +110,7 @@ type ComprehensiveTestConfig struct { VisionModel string ReasoningModel string EmbeddingModel string + MultimodalEmbeddingModel string // Model for multimodal embedding tests (text + image) RerankModel string TranscriptionModel string SpeechSynthesisModel string @@ -340,7 +342,17 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), - Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"}, + Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "google/gemini-2.0-flash-001", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004", "multimodalembedding@001"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION", "us-central1")), + AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), + }, + UseForBatchAPI: bifrost.Ptr(true), + }, + { + Models: []string{"gemini-embedding-2-preview"}, Weight: 1.0, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), diff --git a/core/internal/llmtests/embedding.go b/core/internal/llmtests/embedding.go index ee722bbc3e..c3105bef2f 100644 --- a/core/internal/llmtests/embedding.go +++ b/core/internal/llmtests/embedding.go @@ -58,12 +58,15 @@ func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context "Goodnight, moon!", } + contents := make([]schemas.EmbeddingContent, len(testTexts)) + for i, text := range testTexts { + t := text + contents[i] = schemas.EmbeddingContent{{Type: schemas.EmbeddingContentPartTypeText, Text: &t}} + } request := &schemas.BifrostEmbeddingRequest{ Provider: testConfig.Provider, Model: testConfig.EmbeddingModel, - Input: &schemas.EmbeddingInput{ - Texts: testTexts, - }, + Input: &schemas.EmbeddingInput{Contents: contents}, Params: &schemas.EmbeddingParameters{ EncodingFormat: bifrost.Ptr("float"), }, @@ -123,12 +126,7 @@ func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbedding embeddings := make([][]float64, len(testTexts)) responseDataLength := len(response.Data) if responseDataLength != len(testTexts) { - if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil { - responseDataLength = len(response.Data[0].Embedding.Embedding2DArray) - } - if responseDataLength != len(testTexts) { - t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength) - } + t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength) } for i := range responseDataLength { diff --git a/core/internal/llmtests/embedding_multimodal.go b/core/internal/llmtests/embedding_multimodal.go new file mode 100644 index 0000000000..9a2e199847 --- /dev/null +++ b/core/internal/llmtests/embedding_multimodal.go @@ -0,0 +1,239 @@ +package llmtests + +import ( + "context" + "os" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// testImageDataURI is a 1×1 red pixel PNG encoded as a data URI. +// Used as a lightweight inline image for multimodal embedding tests — +// no external network dependency. +const testImageDataURI = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAAOklEQVR4nO3RQREAQAjDwHKS8C8AWSchfPhlBZSZUNOdS+90PR5Y8AfIRMhEyETIRMhEyETIRMhEIR/EvwFs/VkrpgAAAABJRU5ErkJggg==" + +// makeTextContent returns a single-text EmbeddingContent for the given string. +func makeTextContent(text string) schemas.EmbeddingContent { + t := text + return schemas.EmbeddingContent{{ + Type: schemas.EmbeddingContentPartTypeText, + Text: &t, + }} +} + +// makeImageDataContent returns a single-image (inline data URI) EmbeddingContent. +func makeImageDataContent(dataURI string) schemas.EmbeddingContent { + d := dataURI + return schemas.EmbeddingContent{{ + Type: schemas.EmbeddingContentPartTypeImage, + Image: &schemas.EmbeddingMediaPart{Data: &d}, + }} +} + +// makeMultimodalContent returns a EmbeddingContent with both text and image parts, +// producing a single aggregated multimodal embedding. +func makeMultimodalContent(text, imageDataURI string) schemas.EmbeddingContent { + t := text + d := imageDataURI + return schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &t}, + {Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{Data: &d}}, + } +} + +// validateEmbeddingCount asserts that the response contains exactly wantCount embeddings. +func validateEmbeddingCount(t *testing.T, resp *schemas.BifrostEmbeddingResponse, wantCount int) { + t.Helper() + if resp == nil { + t.Fatal("embedding response is nil") + } + got := len(resp.Data) + if got != wantCount { + t.Fatalf("expected %d embeddings, got %d", wantCount, got) + } +} + +// validateNonEmptyVectors asserts every embedding vector is non-empty and all +// share the same dimension. +func validateNonEmptyVectors(t *testing.T, resp *schemas.BifrostEmbeddingResponse) { + t.Helper() + if resp == nil || len(resp.Data) == 0 { + t.Fatal("embedding response has no data") + } + + var dim int + for i, item := range resp.Data { + vec, err := getEmbeddingVector(item) + if err != nil { + t.Fatalf("embedding[%d]: failed to extract vector: %v", i, err) + } + if len(vec) == 0 { + t.Fatalf("embedding[%d]: vector is empty", i) + } + if dim == 0 { + dim = len(vec) + } else if len(vec) != dim { + t.Fatalf("embedding[%d]: dimension mismatch: got %d, expected %d", i, len(vec), dim) + } + } + t.Logf("✅ %d embedding vector(s), %d dimensions each", len(resp.Data), dim) +} + +// RunMultimodalEmbeddingTest runs all multimodal embedding sub-scenarios for +// providers that declare MultimodalEmbedding support. +// +// Scenarios covered: +// 1. Single text input → 1 embedding +// 2. Batch text inputs → N embeddings +// 3. Single image input (inline data URI) → 1 embedding +// 4. Single multimodal content (text + image) → 1 aggregated embedding +// 5. Batch images → N embeddings (skipped for Vertex: no batch on multimodal path) +// 6. Batch multimodal (text+image per entry) → N embeddings (same skip) +func RunMultimodalEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultimodalEmbedding { + t.Logf("MultimodalEmbedding not enabled for provider %s", testConfig.Provider) + return + } + + model := testConfig.MultimodalEmbeddingModel + if strings.TrimSpace(model) == "" { + t.Skipf("MultimodalEmbedding enabled but MultimodalEmbeddingModel not set for %s; skipping", testConfig.Provider) + } + + t.Run("MultimodalEmbedding", func(t *testing.T) { + // Vertex Gemini path does not support batch for multimodal inputs. + vertexNoBatch := testConfig.Provider == schemas.Vertex + + run := func(name string, req *schemas.BifrostEmbeddingRequest, wantCount int) { + t.Run(name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + retryConfig := GetTestRetryConfigForScenario("MultimodalEmbedding", testConfig) + retryContext := TestRetryContext{ + ScenarioName: name, + ExpectedBehavior: map[string]interface{}{ + "should_return_embeddings": true, + "should_have_valid_vectors": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": model, + "want_count": wantCount, + }, + } + + // Build a dummy string slice of the right length for EmbeddingExpectations. + dummyTexts := make([]string, wantCount) + expectations := EmbeddingExpectations(dummyTexts) + expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) + + embeddingRetryConfig := EmbeddingRetryConfig{ + MaxAttempts: retryConfig.MaxAttempts, + BaseDelay: retryConfig.BaseDelay, + MaxDelay: retryConfig.MaxDelay, + Conditions: []EmbeddingRetryCondition{}, + OnRetry: retryConfig.OnRetry, + OnFinalFail: retryConfig.OnFinalFail, + } + + resp, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, name, func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.EmbeddingRequest(bfCtx, req) + }) + + if bifrostErr != nil { + t.Fatalf("❌ %s multimodal embedding request failed after retries: %v", name, GetErrorMessage(bifrostErr)) + } + + validateEmbeddingCount(t, resp, wantCount) + validateNonEmptyVectors(t, resp) + }) + } + + // ── 1. Single text ──────────────────────────────────────────────────────── + run("SingleText", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeTextContent("The quick brown fox jumps over the lazy dog."), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 1) + + // ── 2. Batch text ───────────────────────────────────────────────────────── + run("BatchText", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeTextContent("Cats are great pets."), + makeTextContent("Dogs are loyal companions."), + makeTextContent("The sky is blue."), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 3) + + // ── 3. Single image (inline data URI) ──────────────────────────────────── + run("SingleImage", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeImageDataContent(testImageDataURI), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 1) + + // ── 4. Single multimodal content (text + image → 1 aggregated embedding) ─ + run("SingleMultimodal", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeMultimodalContent("A red pixel.", testImageDataURI), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 1) + + if vertexNoBatch { + t.Logf("⏭ Skipping batch multimodal scenarios for Vertex (single-content only on Gemini embedding path)") + return + } + + // ── 5. Batch images ─────────────────────────────────────────────────────── + run("BatchImages", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeImageDataContent(testImageDataURI), + makeImageDataContent(testImageDataURI), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 2) + + // ── 6. Batch multimodal (text+image per entry) ──────────────────────────── + run("BatchMultimodal", &schemas.BifrostEmbeddingRequest{ + Provider: testConfig.Provider, + Model: model, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + makeMultimodalContent("First image description.", testImageDataURI), + makeMultimodalContent("Second image description.", testImageDataURI), + }, + }, + Params: &schemas.EmbeddingParameters{EncodingFormat: bifrost.Ptr("float")}, + }, 2) + }) // end t.Run("MultimodalEmbedding") +} diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index 1c3c760d1c..8f37184c87 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -1363,12 +1363,8 @@ func validateEmbeddingFields(t *testing.T, response *schemas.BifrostEmbeddingRes if expectations.ProviderSpecific != nil { if raw, exists := expectations.ProviderSpecific["expected_embedding_count"]; exists { if expectedCount, ok := intFromProviderSpecific(raw); ok { - actualCount := len(response.Data) - // Also check for 2D arrays (some providers return single embedding with 2D array) - if actualCount == 1 && response.Data[0].Embedding.Embedding2DArray != nil { - actualCount = len(response.Data[0].Embedding.Embedding2DArray) - } - if actualCount != expectedCount { + actualCount := len(response.Data) + if actualCount != expectedCount { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Expected %d embeddings, got %d", expectedCount, actualCount)) @@ -1379,13 +1375,8 @@ func validateEmbeddingFields(t *testing.T, response *schemas.BifrostEmbeddingRes // Validate each embedding has non-empty vector data for i, embedding := range response.Data { - hasData := false - if embedding.Embedding.EmbeddingArray != nil && len(embedding.Embedding.EmbeddingArray) > 0 { - hasData = true - } - if embedding.Embedding.Embedding2DArray != nil && len(embedding.Embedding.Embedding2DArray) > 0 { - hasData = true - } + e := embedding.Embedding + hasData := len(e.Float) > 0 || e.Base64 != nil || len(e.Int8) > 0 || len(e.Uint8) > 0 || len(e.Binary) > 0 || len(e.Ubinary) > 0 if !hasData { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Embedding %d has no vector data", i)) @@ -1396,12 +1387,8 @@ func validateEmbeddingFields(t *testing.T, response *schemas.BifrostEmbeddingRes if expectedDimensions, ok := expectations.ProviderSpecific["expected_dimensions"].(int); ok { for i, embedding := range response.Data { var actualDimensions int - if embedding.Embedding.EmbeddingArray != nil { - actualDimensions = len(embedding.Embedding.EmbeddingArray) - } else if embedding.Embedding.Embedding2DArray != nil { - if len(embedding.Embedding.Embedding2DArray) > 0 { - actualDimensions = len(embedding.Embedding.Embedding2DArray[0]) - } + if vec, err := getEmbeddingVector(embedding); err == nil { + actualDimensions = len(vec) } if actualDimensions != expectedDimensions { result.Passed = false @@ -1566,10 +1553,8 @@ func collectEmbeddingResponseMetrics(response *schemas.BifrostEmbeddingResponse, result.MetricsCollected["has_usage"] = response.Usage != nil if len(response.Data) > 0 { var dimensions int - if response.Data[0].Embedding.EmbeddingArray != nil { - dimensions = len(response.Data[0].Embedding.EmbeddingArray) - } else if len(response.Data[0].Embedding.Embedding2DArray) > 0 { - dimensions = len(response.Data[0].Embedding.Embedding2DArray[0]) + if vec, err := getEmbeddingVector(response.Data[0]); err == nil { + dimensions = len(vec) } result.MetricsCollected["embedding_dimensions"] = dimensions } diff --git a/core/internal/llmtests/tests.go b/core/internal/llmtests/tests.go index 6b79796f84..90b0bf8d35 100644 --- a/core/internal/llmtests/tests.go +++ b/core/internal/llmtests/tests.go @@ -57,6 +57,7 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunTranscriptionStreamTest, RunTranscriptionStreamAdvancedTest, RunEmbeddingTest, + RunMultimodalEmbeddingTest, RunRerankTest, RunChatCompletionReasoningTest, RunMultiTurnReasoningTest, @@ -174,6 +175,7 @@ func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) { {"Transcription", testConfig.Scenarios.Transcription}, {"TranscriptionStream", testConfig.Scenarios.TranscriptionStream}, {"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""}, + {"MultimodalEmbedding", testConfig.Scenarios.MultimodalEmbedding && testConfig.MultimodalEmbeddingModel != ""}, {"Rerank", testConfig.Scenarios.Rerank && testConfig.RerankModel != ""}, {"ChatCompletionReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, {"MultiTurnReasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, diff --git a/core/internal/llmtests/utils.go b/core/internal/llmtests/utils.go index f0badae1b2..9c2559acb4 100644 --- a/core/internal/llmtests/utils.go +++ b/core/internal/llmtests/utils.go @@ -3,7 +3,9 @@ package llmtests import ( "context" "encoding/base64" + "encoding/binary" "fmt" + "math" "os" "path/filepath" "runtime" @@ -623,20 +625,45 @@ func ExtractToolCalls(response *schemas.BifrostResponse) []ToolCallInfo { // getEmbeddingVector extracts the float64 vector from a BifrostEmbeddingResponse. func getEmbeddingVector(embedding schemas.EmbeddingData) ([]float64, error) { - if embedding.Embedding.EmbeddingArray != nil { - return embedding.Embedding.EmbeddingArray, nil + e := embedding.Embedding + + if len(e.Float) > 0 { + return e.Float, nil + } + + if e.Base64 != nil { + b, err := base64.StdEncoding.DecodeString(*e.Base64) + if err != nil { + b, err = base64.URLEncoding.DecodeString(*e.Base64) + if err != nil { + return nil, fmt.Errorf("base64 embedding decode failed: %w", err) + } + } + if len(b)%4 != 0 { + return nil, fmt.Errorf("base64 embedding byte length %d is not a multiple of 4", len(b)) + } + out := make([]float64, len(b)/4) + for i := range out { + bits := binary.LittleEndian.Uint32(b[i*4 : i*4+4]) + out[i] = float64(math.Float32frombits(bits)) + } + return out, nil } - if embedding.Embedding.Embedding2DArray != nil { - // For 2D arrays, return the first vector - if len(embedding.Embedding.Embedding2DArray) > 0 { - return embedding.Embedding.Embedding2DArray[0], nil + if len(e.Int8) > 0 { + out := make([]float64, len(e.Int8)) + for i, v := range e.Int8 { + out[i] = float64(v) } - return nil, fmt.Errorf("2D embedding array is empty") + return out, nil } - if embedding.Embedding.EmbeddingStr != nil { - return nil, fmt.Errorf("string embeddings not supported for vector extraction") + if len(e.Uint8) > 0 { + out := make([]float64, len(e.Uint8)) + for i, v := range e.Uint8 { + out[i] = float64(v) + } + return out, nil } return nil, fmt.Errorf("no valid embedding data found") diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index e6f8ca62dd..5a6e35eeff 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -1880,6 +1880,10 @@ func (provider *AnthropicProvider) Embedding(ctx *schemas.BifrostContext, key sc return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *AnthropicProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Anthropic provider. func (provider *AnthropicProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index c1cc8a7929..ca9f6a1ede 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -839,7 +839,7 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return openai.ToOpenAIEmbeddingRequest(request), nil + return openai.ToOpenAIEmbeddingRequest(request) }) if bifrostErr != nil { return nil, bifrostErr @@ -870,14 +870,16 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema }, nil } - response := &schemas.BifrostEmbeddingResponse{} + openaiResp := &openai.OpenaiEmbeddingResponse{} - // Use enhanced response handler with pre-allocated response - rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, openaiResp, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } + response := openaiResp.ToBifrostEmbeddingResponse() + + response.ExtraFields.Provider = provider.GetProviderKey() response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -894,6 +896,10 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return response, nil } +func (provider *AzureProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Azure provider. func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { apiVersion := key.AzureKeyConfig.APIVersion diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 2cae210e23..d42c841a08 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -1873,6 +1873,10 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche return bifrostResponse, nil } +func (provider *BedrockProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Bedrock provider. func (provider *BedrockProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, schemas.Bedrock) diff --git a/core/providers/bedrock/embedding.go b/core/providers/bedrock/embedding.go index cb4ef19e88..f423a78004 100644 --- a/core/providers/bedrock/embedding.go +++ b/core/providers/bedrock/embedding.go @@ -5,31 +5,41 @@ import ( "fmt" "strings" + cohere "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/schemas" ) -// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format +// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format. func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost embedding request is nil") } - // Validate that only single text input is provided for Titan models - if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 { - return nil, fmt.Errorf("no input text provided for embedding") + if bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { + return nil, fmt.Errorf("no input provided for Titan embedding") } - titanReq := &BedrockTitanEmbeddingRequest{} + if len(bifrostReq.Input.Contents) != 1 { + return nil, fmt.Errorf("amazon Titan embedding models support exactly one content item per request; got %d", len(bifrostReq.Input.Contents)) + } + + if bifrostReq.Params != nil && bifrostReq.Params.Dimensions != nil { + return nil, fmt.Errorf("amazon Titan embedding models do not support custom dimensions parameter") + } - // Set input text - if bifrostReq.Input.Text != nil { - titanReq.InputText = *bifrostReq.Input.Text - } else if len(bifrostReq.Input.Texts) > 0 { - var embeddingText string - for _, text := range bifrostReq.Input.Texts { - embeddingText += text + " \n" + var sb strings.Builder + for _, part := range bifrostReq.Input.Contents[0] { + if part.Type != schemas.EmbeddingContentPartTypeText || part.Text == nil { + return nil, fmt.Errorf("amazon Titan embedding models only support text input") + } + if sb.Len() > 0 { + sb.WriteString(" \n") } - titanReq.InputText = embeddingText + sb.WriteString(*part.Text) + } + + titanReq := &BedrockTitanEmbeddingRequest{ + InputText: sb.String(), } if bifrostReq.Params != nil { @@ -39,7 +49,6 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) titanReq.Normalize = &b } } - // Forward remaining extra params (excluding normalize which is now a first-class field) if len(bifrostReq.Params.ExtraParams) > 0 { extra := make(map[string]interface{}) for k, v := range bifrostReq.Params.ExtraParams { @@ -68,8 +77,8 @@ func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *sch { Index: 0, Object: "embedding", - Embedding: schemas.EmbeddingStruct{ - EmbeddingArray: response.Embedding, + Embedding: schemas.EmbeddingsByType{ + Float: response.Embedding, }, }, }, @@ -83,80 +92,18 @@ func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *sch } // ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format. -// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body. -func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) { +// Reuses the Cohere converter since the format is identical. +func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*cohere.CohereEmbeddingRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost embedding request is nil") } - if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) { - return nil, fmt.Errorf("no input provided for embedding") - } - req := &BedrockCohereEmbeddingRequest{} - - // Map texts - if bifrostReq.Input.Text != nil { - req.Texts = []string{*bifrostReq.Input.Text} - } else if len(bifrostReq.Input.Texts) > 0 { - req.Texts = bifrostReq.Input.Texts - } - - if bifrostReq.Params != nil { - extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams)) - for k, v := range bifrostReq.Params.ExtraParams { - extra[k] = v - } - - if v, ok := extra["input_type"]; ok { - if s, ok := v.(string); ok { - req.InputType = s - delete(extra, "input_type") - } - } - if v, ok := extra["truncate"]; ok { - if s, ok := v.(string); ok { - req.Truncate = &s - delete(extra, "truncate") - } - } - if v, ok := extra["embedding_types"]; ok { - if ss, ok := v.([]string); ok { - req.EmbeddingTypes = ss - delete(extra, "embedding_types") - } - } - if v, ok := extra["images"]; ok { - if ss, ok := v.([]string); ok { - req.Images = ss - delete(extra, "images") - } - } - if v, ok := extra["inputs"]; ok { - if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok { - req.Inputs = inputs - delete(extra, "inputs") - } - } - if v, ok := extra["max_tokens"]; ok { - switch n := v.(type) { - case int: - req.MaxTokens = &n - delete(extra, "max_tokens") - case float64: - i := int(n) - req.MaxTokens = &i - delete(extra, "max_tokens") - } - } - if bifrostReq.Params.Dimensions != nil { - req.OutputDimension = bifrostReq.Params.Dimensions - } - if len(extra) > 0 { - req.ExtraParams = extra - } + cohereReq := cohere.ToCohereEmbeddingRequest(bifrostReq) + if cohereReq == nil { + return nil, fmt.Errorf("failed to convert to Cohere embedding request") } - return req, nil + return cohereReq, nil } // DetermineEmbeddingModelType determines the embedding model type from the model name @@ -189,63 +136,47 @@ func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas. Float [][]float32 `json:"float"` Base64 []string `json:"base64"` Int8 [][]int8 `json:"int8"` - Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue + Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue Binary [][]int8 `json:"binary"` Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue } if err := json.Unmarshal(r.Embeddings, &typed); err != nil { return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err) } - if typed.Float != nil { - for i, emb := range typed.Float { - float64Emb := make([]float64, len(emb)) - for j, v := range emb { - float64Emb[j] = float64(v) + + // Determine document count from whichever type was returned. + count := max(len(typed.Float), len(typed.Base64), len(typed.Int8), len(typed.Uint8), len(typed.Binary), len(typed.Ubinary)) + for i := range count { + entry := schemas.EmbeddingData{Object: "embedding", Index: i} + if i < len(typed.Float) { + entry.Embedding.Float = make([]float64, len(typed.Float[i])) + for j, v := range typed.Float[i] { + entry.Embedding.Float[j] = float64(v) } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, - }) } - } - if typed.Base64 != nil { - for i, emb := range typed.Base64 { - e := emb - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e}, - }) + if i < len(typed.Base64) { + s := typed.Base64[i] + entry.Embedding.Base64 = &s } - } - for i, emb := range typed.Int8 { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, - }) - } - for i, emb := range typed.Binary { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, - }) - } - for i, emb := range typed.Uint8 { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, - }) - } - for i, emb := range typed.Ubinary { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, - }) + if i < len(typed.Int8) { + entry.Embedding.Int8 = typed.Int8[i] + } + if i < len(typed.Binary) { + entry.Embedding.Binary = typed.Binary[i] + } + if i < len(typed.Uint8) { + entry.Embedding.Uint8 = make([]uint8, len(typed.Uint8[i])) + for j, v := range typed.Uint8[i] { + entry.Embedding.Uint8[j] = uint8(v) + } + } + if i < len(typed.Ubinary) { + entry.Embedding.Ubinary = make([]uint8, len(typed.Ubinary[i])) + for j, v := range typed.Ubinary[i] { + entry.Embedding.Ubinary[j] = uint8(v) + } + } + bifrostResponse.Data = append(bifrostResponse.Data, entry) } default: @@ -262,7 +193,7 @@ func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas. bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ Object: "embedding", Index: i, - Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + Embedding: schemas.EmbeddingsByType{Float: float64Emb}, }) } } diff --git a/core/providers/bedrock/embedding_test.go b/core/providers/bedrock/embedding_test.go index 244cb85322..4fdc0b7eee 100644 --- a/core/providers/bedrock/embedding_test.go +++ b/core/providers/bedrock/embedding_test.go @@ -22,7 +22,6 @@ func TestToBedrockCohereEmbeddingRequest(t *testing.T) { req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{}) require.Error(t, err) assert.Nil(t, req) - assert.Contains(t, err.Error(), "no input") }) t.Run("returns error for non-nil but empty input", func(t *testing.T) { @@ -31,23 +30,27 @@ func TestToBedrockCohereEmbeddingRequest(t *testing.T) { }) require.Error(t, err) assert.Nil(t, req) - assert.Contains(t, err.Error(), "no input") }) - t.Run("single text strips model and extracts typed params", func(t *testing.T) { + t.Run("single text content extracts typed params", func(t *testing.T) { text := "hello" truncate := "RIGHT" dimensions := 512 + maxTokens := 128 bifrostReq := &schemas.BifrostEmbeddingRequest{ Model: "cohere.embed-english-v3", - Input: &schemas.EmbeddingInput{Text: &text}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}, + }, + }, Params: &schemas.EmbeddingParameters{ Dimensions: &dimensions, + MaxTokens: &maxTokens, + Truncate: &truncate, ExtraParams: map[string]interface{}{ "input_type": "search_query", "embedding_types": []string{"float"}, - "truncate": truncate, - "max_tokens": float64(128), "trace_id": "req-123", }, }, @@ -60,16 +63,23 @@ func TestToBedrockCohereEmbeddingRequest(t *testing.T) { assert.Equal(t, []string{"hello"}, req.Texts) assert.Equal(t, []string{"float"}, req.EmbeddingTypes) assert.Equal(t, &dimensions, req.OutputDimension) - assert.Equal(t, 128, *req.MaxTokens) + assert.Equal(t, &maxTokens, req.MaxTokens) require.NotNil(t, req.Truncate) assert.Equal(t, truncate, *req.Truncate) assert.Equal(t, map[string]interface{}{"trace_id": "req-123"}, req.ExtraParams) }) - t.Run("multiple texts preserve bedrock body shape", func(t *testing.T) { + t.Run("multiple text contents batch into texts array", func(t *testing.T) { + hello := "hello" + world := "world" bifrostReq := &schemas.BifrostEmbeddingRequest{ Model: "cohere.embed-multilingual-v3", - Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &hello}}, + {{Type: schemas.EmbeddingContentPartTypeText, Text: &world}}, + }, + }, Params: &schemas.EmbeddingParameters{ ExtraParams: map[string]interface{}{ "input_type": "search_document", @@ -84,11 +94,15 @@ func TestToBedrockCohereEmbeddingRequest(t *testing.T) { }) } -func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) { +func TestToBedrockCohereEmbeddingRequestWireBody(t *testing.T) { text := "hello" bifrostReq := &schemas.BifrostEmbeddingRequest{ Model: "cohere.embed-english-v3", - Input: &schemas.EmbeddingInput{Text: &text}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}, + }, + }, Params: &schemas.EmbeddingParameters{ ExtraParams: map[string]interface{}{ "input_type": "search_document", @@ -105,8 +119,8 @@ func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) { }, ) require.Nil(t, bifrostErr) - assert.NotContains(t, string(wireBody), `"model"`) assert.JSONEq(t, `{ + "model": "cohere.embed-english-v3", "input_type": "search_document", "texts": ["hello"], "embedding_types": ["float"] diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index 8227e8639a..8625545976 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -393,12 +393,49 @@ func (r *BedrockInvokeRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostCon Model: model, } + var contents []schemas.EmbeddingContent if r.InputText != "" { - req.Input = &schemas.EmbeddingInput{Text: &r.InputText} + inputText := r.InputText + contents = append(contents, schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &inputText}, + }) } else if len(r.Texts) > 0 { - req.Input = &schemas.EmbeddingInput{Texts: r.Texts} + for _, t := range r.Texts { + text := t + contents = append(contents, schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &text}, + }) + } + } + for _, img := range r.Images { + imgCopy := img + contents = append(contents, schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{Data: &imgCopy}}, + }) + } + for _, input := range r.Inputs { + content := make(schemas.EmbeddingContent, 0, len(input.Content)) + for _, block := range input.Content { + switch block.Type { + case "text": + if block.Text != nil { + t := *block.Text + content = append(content, schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeText, Text: &t}) + } + case "image_url": + if block.ImageURL != nil { + u := block.ImageURL.URL + content = append(content, schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{URL: &u}}) + } + } + } + if len(content) > 0 { + contents = append(contents, content) + } + } + if len(contents) > 0 { + req.Input = &schemas.EmbeddingInput{Contents: contents} } - // image-only (r.Images) or mixed (r.Inputs): req.Input stays nil; data flows via ExtraParams extraParams := make(map[string]interface{}) // Forward known embedding-only params into ExtraParams so the provider can pick them up @@ -989,8 +1026,8 @@ func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (i if strings.Contains(strings.ToLower(model), "cohere") { floats := make([][]float32, 0, len(resp.Data)) for _, d := range resp.Data { - float32Emb := make([]float32, len(d.Embedding.EmbeddingArray)) - for i, v := range d.Embedding.EmbeddingArray { + float32Emb := make([]float32, len(d.Embedding.Float)) + for i, v := range d.Embedding.Float { float32Emb[i] = float32(v) } floats = append(floats, float32Emb) @@ -1002,11 +1039,11 @@ func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (i } // Titan format - if resp.Data[0].Embedding.EmbeddingArray == nil { + if len(resp.Data[0].Embedding.Float) == 0 { return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil } - float32Emb := make([]float32, len(resp.Data[0].Embedding.EmbeddingArray)) - for i, v := range resp.Data[0].Embedding.EmbeddingArray { + float32Emb := make([]float32, len(resp.Data[0].Embedding.Float)) + for i, v := range resp.Data[0].Embedding.Float { float32Emb[i] = float32(v) } return &BedrockInvokeEmbeddingResp{ diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index 45292d6d24..f928312070 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -205,6 +205,10 @@ func (provider *CerebrasProvider) Embedding(ctx *schemas.BifrostContext, key sch return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *CerebrasProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Cerebras provider. func (provider *CerebrasProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 1bffa10e87..7604c3522c 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -938,6 +938,10 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem return bifrostResponse, nil } +func (provider *CohereProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Rerank performs a rerank request using the Cohere /v2/rerank API. func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { // Check if rerank is allowed diff --git a/core/providers/cohere/cohere_test.go b/core/providers/cohere/cohere_test.go index 23c73911bc..bfb2b86b35 100644 --- a/core/providers/cohere/cohere_test.go +++ b/core/providers/cohere/cohere_test.go @@ -28,7 +28,8 @@ func TestCohere(t *testing.T) { ChatModel: "command-a-03-2025", VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model TextModel: "", // Cohere focuses on chat - EmbeddingModel: "embed-v4.0", + EmbeddingModel: "embed-v4.0", + MultimodalEmbeddingModel: "embed-v4.0", RerankModel: "rerank-v3.5", ReasoningModel: "command-a-reasoning-08-2025", Scenarios: llmtests.TestScenarios{ @@ -49,6 +50,7 @@ func TestCohere(t *testing.T) { FileURL: false, // Not supported CompleteEnd2End: false, Embedding: true, + MultimodalEmbedding: true, Rerank: true, Reasoning: true, ListModels: true, diff --git a/core/providers/cohere/embedding.go b/core/providers/cohere/embedding.go index a99ef14294..c855f139cf 100644 --- a/core/providers/cohere/embedding.go +++ b/core/providers/cohere/embedding.go @@ -1,81 +1,179 @@ package cohere import ( + "fmt" + + "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format -func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { - return nil +func cohereContentBlockFromEmbeddingPart(part schemas.EmbeddingContentPart) (*CohereContentBlock, error) { + if err := part.Validate(); err != nil { + return nil, err } - - embeddingInput := bifrostReq.Input - cohereReq := &CohereEmbeddingRequest{ - Model: bifrostReq.Model, + switch part.Type { + case schemas.EmbeddingContentPartTypeText: + text := *part.Text + return &CohereContentBlock{Type: CohereContentBlockTypeText, Text: &text}, nil + case schemas.EmbeddingContentPartTypeImage: + if part.Image.URL != nil { + return &CohereContentBlock{ + Type: CohereContentBlockTypeImage, + ImageURL: &CohereImageURL{URL: *part.Image.URL}, + }, nil + } + if part.Image.Data != nil { + return &CohereContentBlock{ + Type: CohereContentBlockTypeImage, + ImageURL: &CohereImageURL{URL: *part.Image.Data}, + }, nil + } + return nil, fmt.Errorf("cohere image part missing data") + default: + return nil, fmt.Errorf("cohere embeddings support only text and image parts") } +} - texts := []string{} - if embeddingInput.Text != nil { - texts = append(texts, *embeddingInput.Text) - } else { - texts = embeddingInput.Texts +func embeddingContentFromCohereBlocks(blocks []CohereContentBlock) (schemas.EmbeddingContent, error) { + result := make(schemas.EmbeddingContent, 0, len(blocks)) + for _, block := range blocks { + switch block.Type { + case CohereContentBlockTypeText: + if block.Text == nil { + return nil, fmt.Errorf("cohere text block missing text") + } + text := *block.Text + result = append(result, schemas.EmbeddingContentPart{ + Type: schemas.EmbeddingContentPartTypeText, + Text: &text, + }) + case CohereContentBlockTypeImage: + if block.ImageURL == nil { + return nil, fmt.Errorf("cohere image block missing image_url") + } + url := block.ImageURL.URL + result = append(result, schemas.EmbeddingContentPart{ + Type: schemas.EmbeddingContentPartTypeImage, + Image: &schemas.EmbeddingMediaPart{URL: &url}, + }) + default: + return nil, fmt.Errorf("unsupported cohere embedding block type %q", block.Type) + } } + return result, nil +} - // Convert texts from Bifrost format - if len(texts) > 0 { - cohereReq.Texts = texts +func isSingleImageContent(content schemas.EmbeddingContent) (string, bool) { + if len(content) != 1 || content[0].Type != schemas.EmbeddingContentPartTypeImage || content[0].Image == nil { + return "", false + } + if content[0].Image.URL != nil { + return *content[0].Image.URL, true } + if content[0].Image.Data != nil { + return *content[0].Image.Data, true + } + return "", false +} - // Set default input type if not specified in extra params - cohereReq.InputType = "search_document" // Default value +// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format. +func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest { + if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { + return nil + } + cohereReq := &CohereEmbeddingRequest{ + Model: bifrostReq.Model, + InputType: "search_document", + } if bifrostReq.Params != nil { cohereReq.OutputDimension = bifrostReq.Params.Dimensions + cohereReq.MaxTokens = bifrostReq.Params.MaxTokens + cohereReq.Truncate = bifrostReq.Params.Truncate cohereReq.ExtraParams = bifrostReq.Params.ExtraParams - if bifrostReq.Params.ExtraParams != nil { - if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok { - delete(cohereReq.ExtraParams, "max_tokens") - cohereReq.MaxTokens = maxTokens + } + + contents := bifrostReq.Input.Contents + + // All single-text contents → texts[] + texts := make([]string, 0, len(contents)) + allSingleText := true + for _, content := range contents { + if len(content) == 1 && content[0].Type == schemas.EmbeddingContentPartTypeText && content[0].Text != nil { + texts = append(texts, *content[0].Text) + } else { + allSingleText = false + break + } + } + if allSingleText { + cohereReq.Texts = texts + } else if len(contents) == 1 { + // Single content with single image → images[] + if imageURL, ok := isSingleImageContent(contents[0]); ok { + cohereReq.Images = []string{imageURL} + } else { + // Single multimodal content → inputs[] with one entry + blocks := make([]CohereContentBlock, 0, len(contents[0])) + for _, part := range contents[0] { + block, err := cohereContentBlockFromEmbeddingPart(part) + if err != nil { + return nil + } + blocks = append(blocks, *block) + } + cohereReq.Inputs = []CohereEmbeddingInput{{Content: blocks}} + } + } else { + // Batch multimodal → inputs[], one entry per content + inputs := make([]CohereEmbeddingInput, 0, len(contents)) + for _, content := range contents { + blocks := make([]CohereContentBlock, 0, len(content)) + for _, part := range content { + block, err := cohereContentBlockFromEmbeddingPart(part) + if err != nil { + return nil + } + blocks = append(blocks, *block) } + inputs = append(inputs, CohereEmbeddingInput{Content: blocks}) } + cohereReq.Inputs = inputs } - // Handle extra params if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { - // Input type - if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok { - delete(cohereReq.ExtraParams, "input_type") - cohereReq.InputType = inputType - } - - // Embedding types if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok { - if len(embeddingTypes) > 0 { + cohereReq.EmbeddingTypes = embeddingTypes + if cohereReq.ExtraParams != nil { delete(cohereReq.ExtraParams, "embedding_types") - cohereReq.EmbeddingTypes = embeddingTypes } } - - // Truncate - if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok { - delete(cohereReq.ExtraParams, "truncate") - cohereReq.Truncate = truncate + if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok { + cohereReq.InputType = inputType + if cohereReq.ExtraParams != nil { + delete(cohereReq.ExtraParams, "input_type") + } + } + if priority, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["priority"]); ok { + cohereReq.Priority = priority + if cohereReq.ExtraParams != nil { + delete(cohereReq.ExtraParams, "priority") + } } } return cohereReq } -// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format +// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format. +// Each Cohere input entry maps to one element in Contents (one output embedding). func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { if req == nil { return nil } provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere)) - bifrostReq := &schemas.BifrostEmbeddingRequest{ Provider: provider, Model: model, @@ -83,33 +181,56 @@ func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.Bifros Params: &schemas.EmbeddingParameters{}, } - // Convert texts - if len(req.Texts) > 0 { - if len(req.Texts) == 1 { - bifrostReq.Input.Text = &req.Texts[0] - } else { - bifrostReq.Input.Texts = req.Texts + switch { + case len(req.Texts) > 0: + contents := make([]schemas.EmbeddingContent, len(req.Texts)) + for i, text := range req.Texts { + t := text + contents[i] = schemas.EmbeddingContent{{ + Type: schemas.EmbeddingContentPartTypeText, + Text: &t, + }} } + bifrostReq.Input.Contents = contents + case len(req.Images) > 0: + url := req.Images[0] + bifrostReq.Input.Contents = []schemas.EmbeddingContent{{{ + Type: schemas.EmbeddingContentPartTypeImage, + Image: &schemas.EmbeddingMediaPart{URL: &url}, + }}} + case len(req.Inputs) > 0: + contents := make([]schemas.EmbeddingContent, 0, len(req.Inputs)) + for _, input := range req.Inputs { + content, err := embeddingContentFromCohereBlocks(input.Content) + if err != nil { + return nil + } + contents = append(contents, content) + } + bifrostReq.Input.Contents = contents } - // Convert parameters - if req.OutputDimension != nil { - bifrostReq.Params.Dimensions = req.OutputDimension - } - - // Convert extra params - extraParams := make(map[string]interface{}) + bifrostReq.Params.Dimensions = req.OutputDimension + bifrostReq.Params.MaxTokens = req.MaxTokens + bifrostReq.Params.Truncate = req.Truncate + extraParams := req.ExtraParams if req.InputType != "" { + if extraParams == nil { + extraParams = map[string]interface{}{} + } extraParams["input_type"] = req.InputType } - if req.EmbeddingTypes != nil { + if len(req.EmbeddingTypes) > 0 { + if extraParams == nil { + extraParams = map[string]interface{}{} + } extraParams["embedding_types"] = req.EmbeddingTypes } - if req.Truncate != nil { - extraParams["truncate"] = *req.Truncate - } - if req.MaxTokens != nil { - extraParams["max_tokens"] = *req.MaxTokens + if req.Priority != nil { + if extraParams == nil { + extraParams = map[string]interface{}{} + } + extraParams["priority"] = req.Priority } if len(extraParams) > 0 { bifrostReq.Params.ExtraParams = extraParams @@ -128,42 +249,57 @@ func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B Object: "list", } - // Convert embeddings data if response.Embeddings != nil { - var bifrostEmbeddings []schemas.EmbeddingData - - // Handle different embedding types - prioritize float embeddings - if response.Embeddings.Float != nil { - for i, embedding := range response.Embeddings.Float { - bifrostEmbedding := schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{ - EmbeddingArray: embedding, - }, - } - bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding) - } - } else if response.Embeddings.Base64 != nil { - // Handle base64 embeddings as strings - for i, embedding := range response.Embeddings.Base64 { - bifrostEmbedding := schemas.EmbeddingData{ - Object: "embedding", - Index: i, - Embedding: schemas.EmbeddingStruct{ - EmbeddingStr: &embedding, - }, - } - bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding) - } + emb := response.Embeddings + + // Determine the number of entries from whichever type is populated + count := 0 + switch { + case len(emb.Float) > 0: + count = len(emb.Float) + case len(emb.Int8) > 0: + count = len(emb.Int8) + case len(emb.Uint8) > 0: + count = len(emb.Uint8) + case len(emb.Binary) > 0: + count = len(emb.Binary) + case len(emb.Ubinary) > 0: + count = len(emb.Ubinary) + case len(emb.Base64) > 0: + count = len(emb.Base64) } - // Note: Int8, Uint8, Binary, Ubinary types would need special handling - // depending on how Bifrost wants to represent them - bifrostResponse.Data = bifrostEmbeddings + bifrostResponse.Data = make([]schemas.EmbeddingData, count) + for i := 0; i < count; i++ { + entry := schemas.EmbeddingData{ + Index: i, + Object: "embedding", + Embedding: schemas.EmbeddingsByType{}, + } + + if len(emb.Float) > i { + entry.Embedding.Float = emb.Float[i] + } + if len(emb.Int8) > i { + entry.Embedding.Int8 = emb.Int8[i] + } + if len(emb.Uint8) > i { + entry.Embedding.Uint8 = emb.Uint8[i] + } + if len(emb.Binary) > i { + entry.Embedding.Binary = emb.Binary[i] + } + if len(emb.Ubinary) > i { + entry.Embedding.Ubinary = emb.Ubinary[i] + } + if len(emb.Base64) > i { + entry.Embedding.Base64 = &emb.Base64[i] + } + + bifrostResponse.Data[i] = entry + } } - // Convert usage information if response.Meta != nil { if response.Meta.Tokens != nil { bifrostResponse.Usage = &schemas.BifrostLLMUsage{} @@ -188,3 +324,57 @@ func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B return bifrostResponse } + +// ToCohereEmbeddingResponse converts a BifrostEmbeddingResponse to Cohere's native embedding response format. +func ToCohereEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *CohereEmbeddingResponse { + if bifrostResp == nil || len(bifrostResp.Data) == 0 { + return nil + } + + cohereResp := &CohereEmbeddingResponse{ + ID: uuid.New().String(), + Embeddings: &CohereEmbeddingData{}, + } + + for _, item := range bifrostResp.Data { + emb := item.Embedding + + if emb.Float != nil { + cohereResp.Embeddings.Float = append(cohereResp.Embeddings.Float, emb.Float) + } + if emb.Int8 != nil { + cohereResp.Embeddings.Int8 = append(cohereResp.Embeddings.Int8, emb.Int8) + } + if emb.Uint8 != nil { + cohereResp.Embeddings.Uint8 = append(cohereResp.Embeddings.Uint8, emb.Uint8) + } + if emb.Binary != nil { + cohereResp.Embeddings.Binary = append(cohereResp.Embeddings.Binary, emb.Binary) + } + if emb.Ubinary != nil { + cohereResp.Embeddings.Ubinary = append(cohereResp.Embeddings.Ubinary, emb.Ubinary) + } + if emb.Base64 != nil { + cohereResp.Embeddings.Base64 = append(cohereResp.Embeddings.Base64, *emb.Base64) + } + } + + cohereResp.ResponseType = schemas.Ptr("embeddings_by_type") + + if bifrostResp.Usage != nil { + inputTokens := bifrostResp.Usage.PromptTokens + outputTokens := bifrostResp.Usage.CompletionTokens + cohereResp.Meta = &CohereEmbeddingMeta{ + BilledUnits: &CohereBilledUnits{ + InputTokens: &inputTokens, + OutputTokens: &outputTokens, + }, + Tokens: &CohereTokenUsage{ + InputTokens: &inputTokens, + OutputTokens: &outputTokens, + }, + } + } + + return cohereResp +} diff --git a/core/providers/cohere/embedding_multimodal_test.go b/core/providers/cohere/embedding_multimodal_test.go new file mode 100644 index 0000000000..77292ac6dc --- /dev/null +++ b/core/providers/cohere/embedding_multimodal_test.go @@ -0,0 +1,54 @@ +package cohere + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestToCohereEmbeddingRequestTextOnlyUsesTexts(t *testing.T) { + text := "hello" + req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Model: "embed-v4.0", + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{{{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}}, + }, + }) + + require.NotNil(t, req) + require.Equal(t, []string{"hello"}, req.Texts) + require.Empty(t, req.Inputs) +} + +func TestToCohereEmbeddingRequestMultimodalUsesInputs(t *testing.T) { + caption := "caption" + req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Model: "embed-v4.0", + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &caption}, + {Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{URL: schemas.Ptr("https://example.com/cat.png")}}, + }}, + }, + }) + + require.NotNil(t, req) + require.Len(t, req.Inputs, 1) + require.Len(t, req.Inputs[0].Content, 2) + require.Equal(t, CohereContentBlockTypeText, req.Inputs[0].Content[0].Type) + require.Equal(t, CohereContentBlockTypeImage, req.Inputs[0].Content[1].Type) +} + +func TestToCohereEmbeddingRequestRejectsUnsupportedModalities(t *testing.T) { + req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Model: "embed-v4.0", + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{{ + {Type: schemas.EmbeddingContentPartTypeAudio, Audio: &schemas.EmbeddingMediaPart{URL: schemas.Ptr("https://example.com/audio.mp3")}}, + }}, + }, + }) + + require.Nil(t, req) +} diff --git a/core/providers/cohere/embedding_test.go b/core/providers/cohere/embedding_test.go index a161c95837..167fc97a5a 100644 --- a/core/providers/cohere/embedding_test.go +++ b/core/providers/cohere/embedding_test.go @@ -19,21 +19,25 @@ func TestToCohereEmbeddingRequest(t *testing.T) { })) }) - t.Run("single text keeps model in direct cohere body", func(t *testing.T) { + t.Run("single text content extracts typed params", func(t *testing.T) { text := "hello" truncate := "END" dimensions := 1024 maxTokens := 256 bifrostReq := &schemas.BifrostEmbeddingRequest{ Model: "embed-v4.0", - Input: &schemas.EmbeddingInput{Text: &text}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}, + }, + }, Params: &schemas.EmbeddingParameters{ Dimensions: &dimensions, + MaxTokens: &maxTokens, + Truncate: &truncate, ExtraParams: map[string]interface{}{ "input_type": "classification", "embedding_types": []string{"float", "int8"}, - "truncate": truncate, - "max_tokens": maxTokens, "priority": "high", }, }, @@ -52,10 +56,17 @@ func TestToCohereEmbeddingRequest(t *testing.T) { assert.Equal(t, map[string]interface{}{"priority": "high"}, req.ExtraParams) }) - t.Run("multiple texts use default input type", func(t *testing.T) { + t.Run("multiple text contents batch into texts array with default input type", func(t *testing.T) { + hello := "hello" + world := "world" req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ Model: "embed-english-v3.0", - Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &hello}}, + {{Type: schemas.EmbeddingContentPartTypeText, Text: &world}}, + }, + }, }) require.NotNil(t, req) @@ -64,13 +75,39 @@ func TestToCohereEmbeddingRequest(t *testing.T) { assert.Equal(t, []string{"hello", "world"}, req.Texts) assert.Nil(t, req.ExtraParams) }) + + t.Run("multimodal content uses inputs array", func(t *testing.T) { + text := "describe this" + imageURL := "data:image/jpeg;base64,abc123" + req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Model: "embed-v4.0", + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + { + {Type: schemas.EmbeddingContentPartTypeText, Text: &text}, + {Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{URL: &imageURL}}, + }, + }, + }, + }) + + require.NotNil(t, req) + require.Len(t, req.Inputs, 1) + require.Len(t, req.Inputs[0].Content, 2) + assert.Equal(t, CohereContentBlockTypeText, req.Inputs[0].Content[0].Type) + assert.Equal(t, CohereContentBlockTypeImage, req.Inputs[0].Content[1].Type) + }) } func TestToCohereEmbeddingRequestBodyIncludesModelForDirectCohere(t *testing.T) { text := "hello" bifrostReq := &schemas.BifrostEmbeddingRequest{ Model: "embed-v4.0", - Input: &schemas.EmbeddingInput{Text: &text}, + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}, + }, + }, } wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index 8e5aa31402..510659bf55 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -272,6 +272,7 @@ type CohereEmbeddingRequest struct { OutputDimension *int `json:"output_dimension,omitempty"` // Optional: Embedding dimensions (256, 512, 1024, 1536) EmbeddingTypes []string `json:"embedding_types,omitempty"` // Optional: Types of embeddings to return Truncate *string `json:"truncate,omitempty"` // Optional: How to handle long inputs + Priority *int `json:"priority,omitempty"` // Optional: Priority of the request ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters } diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index f092ce6d16..5ef27cb004 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -179,6 +179,10 @@ func (provider *ElevenlabsProvider) Embedding(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *ElevenlabsProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech performs a text to speech request func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechRequest); err != nil { diff --git a/core/providers/fireworks/fireworks.go b/core/providers/fireworks/fireworks.go index 827d1777df..39693eff8c 100644 --- a/core/providers/fireworks/fireworks.go +++ b/core/providers/fireworks/fireworks.go @@ -234,6 +234,10 @@ func (provider *FireworksProvider) Embedding(ctx *schemas.BifrostContext, key sc ) } +func (provider *FireworksProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Fireworks AI provider. func (provider *FireworksProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/fireworks/fireworks_test.go b/core/providers/fireworks/fireworks_test.go index 445c6c83d0..e74de7d720 100644 --- a/core/providers/fireworks/fireworks_test.go +++ b/core/providers/fireworks/fireworks_test.go @@ -188,7 +188,9 @@ func fireworksModelSupportsEmbeddings(t *testing.T, client *bifrost.Bifrost, ctx Provider: schemas.Fireworks, Model: model, Input: &schemas.EmbeddingInput{ - Text: &text, + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}, + }, }, }) if bifrostErr != nil { @@ -331,14 +333,16 @@ func TestFireworksProviderUsesNativeEndpoints(t *testing.T) { resp, err := provider.Embedding(ctx, key, &schemas.BifrostEmbeddingRequest{ Provider: schemas.Fireworks, Model: "accounts/fireworks/models/nomic-embed-text-v1.5", - Input: &schemas.EmbeddingInput{ - Text: schemas.Ptr("embedding test"), + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: schemas.Ptr("embedding test")}}, }, + }, }) if err != nil { t.Fatalf("Embedding returned error: %v", llmtests.GetErrorMessage(err)) } - if resp == nil || len(resp.Data) != 1 || len(resp.Data[0].Embedding.EmbeddingArray) != 3 { + if resp == nil || len(resp.Data) != 1 || len(resp.Data[0].Embedding.Float) != 3 { t.Fatalf("unexpected embedding response: %#v", resp) } }, diff --git a/core/providers/gemini/embedding.go b/core/providers/gemini/embedding.go index 79785a4fd2..87ca9d50df 100644 --- a/core/providers/gemini/embedding.go +++ b/core/providers/gemini/embedding.go @@ -1,76 +1,209 @@ package gemini import ( - "github.com/maximhq/bifrost/core/providers/utils" + "fmt" + "strings" + + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format -// GeminiGenerationRequest contains requests array for batch embed content endpoint -func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { - return nil +func mediaPartToGeminiPart(partType schemas.EmbeddingContentPartType, media *schemas.EmbeddingMediaPart) (*Part, error) { + if err := media.Validate(); err != nil { + return nil, err } - embeddingInput := bifrostReq.Input + defaultMime := map[schemas.EmbeddingContentPartType]string{ + schemas.EmbeddingContentPartTypeImage: "image/jpeg", + schemas.EmbeddingContentPartTypeAudio: "audio/mpeg", + schemas.EmbeddingContentPartTypeFile: "application/pdf", + schemas.EmbeddingContentPartTypeVideo: "video/mp4", + }[partType] + + if media.Data != nil { + dataBytes, extractedMime := convertFileDataToBytes(*media.Data) + if len(dataBytes) == 0 { + return nil, fmt.Errorf("empty media data for %s part", partType) + } + mimeType := defaultMime + if media.MIMEType != nil && strings.TrimSpace(*media.MIMEType) != "" { + mimeType = *media.MIMEType + } else if extractedMime != "" { + mimeType = extractedMime + } + return &Part{ + InlineData: &Blob{ + MIMEType: mimeType, + Data: encodeBytesToBase64String(dataBytes), + }, + }, nil + } - // Collect all texts to embed - var texts []string - if embeddingInput.Text != nil { - texts = append(texts, *embeddingInput.Text) + mimeType := defaultMime + if media.MIMEType != nil && strings.TrimSpace(*media.MIMEType) != "" { + mimeType = *media.MIMEType } - if len(embeddingInput.Texts) > 0 { - texts = append(texts, embeddingInput.Texts...) + url := *media.URL + if partType == schemas.EmbeddingContentPartTypeImage { + sanitizedURL, err := schemas.SanitizeImageURL(url) + if err != nil { + return nil, err + } + urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL) + if urlInfo.Type == schemas.ImageContentTypeBase64 { + data := "" + if urlInfo.DataURLWithoutPrefix != nil { + data = *urlInfo.DataURLWithoutPrefix + } + decoded, err := decodeBase64StringToBytes(data) + if err != nil { + return nil, err + } + if urlInfo.MediaType != nil && (media.MIMEType == nil || *media.MIMEType == "") { + mimeType = *urlInfo.MediaType + } + return &Part{ + InlineData: &Blob{ + MIMEType: mimeType, + Data: encodeBytesToBase64String(decoded), + }, + }, nil + } + url = sanitizedURL } - if len(texts) == 0 { - return nil + return &Part{ + FileData: &FileData{ + FileURI: url, + MIMEType: mimeType, + DisplayName: func() string { + if media.Filename != nil { + return *media.Filename + } + return "" + }(), + }, + }, nil +} + +func embeddingContentPartToGeminiPart(part schemas.EmbeddingContentPart) (*Part, error) { + if err := part.Validate(); err != nil { + return nil, err } - // Create batch embedding request with one request per text - batchRequest := &GeminiBatchEmbeddingRequest{ - Requests: make([]GeminiEmbeddingRequest, len(texts)), + switch part.Type { + case schemas.EmbeddingContentPartTypeText: + return &Part{Text: *part.Text}, nil + case schemas.EmbeddingContentPartTypeImage: + return mediaPartToGeminiPart(part.Type, part.Image) + case schemas.EmbeddingContentPartTypeAudio: + return mediaPartToGeminiPart(part.Type, part.Audio) + case schemas.EmbeddingContentPartTypeFile: + return mediaPartToGeminiPart(part.Type, part.File) + case schemas.EmbeddingContentPartTypeVideo: + return mediaPartToGeminiPart(part.Type, part.Video) + default: + return nil, fmt.Errorf("unsupported embedding content part type %q", part.Type) } - if bifrostReq.Params != nil { - batchRequest.ExtraParams = bifrostReq.Params.ExtraParams +} + +// EmbeddingContentToGeminiContent converts a Bifrost EmbeddingContent (a slice +// of typed parts) into the Gemini Content struct used by both the embedContent +// and batchEmbedContents endpoints. +func EmbeddingContentToGeminiContent(content schemas.EmbeddingContent) (*Content, error) { + if err := content.Validate(); err != nil { + return nil, err + } + parts := make([]*Part, 0, len(content)) + for _, contentPart := range content { + part, err := embeddingContentPartToGeminiPart(contentPart) + if err != nil { + return nil, err + } + parts = append(parts, part) } + return &Content{Parts: parts}, nil +} - // Create individual embedding requests for each text - for i, text := range texts { - embeddingReq := GeminiEmbeddingRequest{ - Model: "models/" + bifrostReq.Model, - Content: &Content{ - Parts: []*Part{ - { - Text: text, - }, - }, - }, +func applyGeminiEmbeddingParams(req *GeminiEmbeddingRequest, params *schemas.EmbeddingParameters) { + if params == nil { + return + } + req.OutputDimensionality = params.Dimensions + req.TaskType = params.TaskType + req.Title = params.Title + + if params.ExtraParams != nil { + req.ExtraParams = params.ExtraParams + if documentOCR, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["documentOcr"]); ok { + delete(req.ExtraParams, "documentOcr") + req.DocumentOCR = documentOCR + } + if audioTrackExtraction, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["audioTrackExtraction"]); ok { + delete(req.ExtraParams, "audioTrackExtraction") + req.AudioTrackExtraction = audioTrackExtraction } + } +} - // Add parameters if available - if bifrostReq.Params != nil { - if bifrostReq.Params.Dimensions != nil { - embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions - } +// ToGeminiEmbeddingRequest converts a Bifrost embedding request to Gemini request format. +// Each element in Contents maps to one GeminiEmbeddingRequest (one output embedding). +// Parts within a single content are aggregated into one embedding by Gemini. +func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*GeminiBatchEmbeddingRequest, error) { + if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { + return nil, fmt.Errorf("bifrost request is nil or input is nil") + } - // Handle extra parameters - if bifrostReq.Params.ExtraParams != nil { - if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok { - delete(batchRequest.ExtraParams, "taskType") - embeddingReq.TaskType = taskType - } - if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok { - delete(batchRequest.ExtraParams, "title") - embeddingReq.Title = title - } - } + contents := bifrostReq.Input.Contents + + batchRequest := &GeminiBatchEmbeddingRequest{ + Requests: make([]GeminiEmbeddingRequest, 0, len(contents)), + } + if bifrostReq.Params != nil { + batchRequest.ExtraParams = bifrostReq.Params.ExtraParams + } + for _, contentItem := range contents { + content, err := EmbeddingContentToGeminiContent(contentItem) + if err != nil { + return nil, fmt.Errorf("error converting embedding content to gemini content: %w", err) + } + req := GeminiEmbeddingRequest{ + Model: "models/" + bifrostReq.Model, + Content: content, } + applyGeminiEmbeddingParams(&req, bifrostReq.Params) + batchRequest.Requests = append(batchRequest.Requests, req) + } + return batchRequest, nil +} - batchRequest.Requests[i] = embeddingReq +// ToGeminiBatchEmbeddingRequest converts a BifrostBatchEmbeddingRequest to Gemini's batchEmbedContents format. +// Each item maps to one GeminiEmbeddingRequest. Item-level Params overrides the batch-level default. +func ToGeminiBatchEmbeddingRequest(bifrostReq *schemas.BifrostBatchEmbeddingRequest) (*GeminiBatchEmbeddingRequest, error) { + if bifrostReq == nil || len(bifrostReq.Items) == 0 { + return nil, fmt.Errorf("batch embedding request has no items") + } + + batchRequest := &GeminiBatchEmbeddingRequest{ + Requests: make([]GeminiEmbeddingRequest, 0, len(bifrostReq.Items)), + } + if bifrostReq.Params != nil { + batchRequest.ExtraParams = bifrostReq.Params.ExtraParams } - return batchRequest + for _, item := range bifrostReq.Items { + content, err := EmbeddingContentToGeminiContent(item.Content) + if err != nil { + return nil, fmt.Errorf("error converting embedding content to gemini content: %w", err) + } + req := GeminiEmbeddingRequest{ + Model: "models/" + bifrostReq.Model, + Content: content, + } + applyGeminiEmbeddingParams(&req, item.EffectiveParams(bifrostReq.Params)) + batchRequest.Requests = append(batchRequest.Requests, req) + } + return batchRequest, nil } // ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format @@ -83,164 +216,268 @@ func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *G Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)), } - // Convert each embedding from Bifrost format to Gemini format for i, embedding := range bifrostResp.Data { - var values []float64 - - // Extract embedding values from BifrostEmbeddingResponse - if embedding.Embedding.EmbeddingArray != nil { - values = append([]float64(nil), embedding.Embedding.EmbeddingArray...) - } else if len(embedding.Embedding.Embedding2DArray) > 0 { - // If it's a 2D array, take the first array - values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...) - } - - geminiEmbedding := GeminiEmbedding{ - Values: values, - } - - // Add statistics if available (token count from usage metadata) - if bifrostResp.Usage != nil { + geminiEmbedding := GeminiEmbedding{Values: append([]float64(nil), embedding.Embedding.Float...)} + if bifrostResp.Usage != nil && len(bifrostResp.Data) == 1 { geminiEmbedding.Statistics = &ContentEmbeddingStatistics{ TokenCount: int32(bifrostResp.Usage.PromptTokens), } } - geminiResp.Embeddings[i] = geminiEmbedding } - // Set metadata if available (for Vertex API compatibility) + if len(geminiResp.Embeddings) == 1 { + geminiResp.Embedding = &geminiResp.Embeddings[0] + } if bifrostResp.Usage != nil { geminiResp.Metadata = &EmbedContentMetadata{ BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens), } } - return geminiResp } +func geminiResponseEmbeddings(resp *GeminiEmbeddingResponse) []GeminiEmbedding { + if resp == nil { + return nil + } + if len(resp.Embeddings) > 0 { + return resp.Embeddings + } + if resp.Embedding != nil { + return []GeminiEmbedding{*resp.Embedding} + } + return nil +} + // ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse { - if geminiResp == nil || len(geminiResp.Embeddings) == 0 { + embeddings := geminiResponseEmbeddings(geminiResp) + if len(embeddings) == 0 { return nil } bifrostResp := &schemas.BifrostEmbeddingResponse{ - Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)), + Data: make([]schemas.EmbeddingData, len(embeddings)), Model: model, Object: "list", } - // Convert each embedding from Gemini format to Bifrost format - for i, geminiEmbedding := range geminiResp.Embeddings { - embeddingData := schemas.EmbeddingData{ + for i, geminiEmbedding := range embeddings { + bifrostResp.Data[i] = schemas.EmbeddingData{ Index: i, Object: "embedding", - Embedding: schemas.EmbeddingStruct{ - EmbeddingArray: geminiEmbedding.Values, + Embedding: schemas.EmbeddingsByType{ + Float: geminiEmbedding.Values, }, } - - bifrostResp.Data[i] = embeddingData } - // Convert usage metadata if available - if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) { + hasStats := false + for _, emb := range embeddings { + if emb.Statistics != nil { + hasStats = true + break + } + } + if geminiResp.Metadata != nil || hasStats { bifrostResp.Usage = &schemas.BifrostLLMUsage{} - - // Use statistics from the first embedding if available - if geminiResp.Embeddings[0].Statistics != nil { - bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount) + var totalTokens int + for _, emb := range embeddings { + if emb.Statistics != nil { + totalTokens += int(emb.Statistics.TokenCount) + } + } + if totalTokens > 0 { + bifrostResp.Usage.PromptTokens = totalTokens } else if geminiResp.Metadata != nil { - // Fall back to metadata if statistics are not available bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount) } - - // Set total tokens same as prompt tokens for embeddings bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens } return bifrostResp } -// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format -func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { - if request == nil { +func geminiPartToEmbeddingContentPart(part *Part) (*schemas.EmbeddingContentPart, error) { + if part == nil { + return nil, fmt.Errorf("gemini part is nil") + } + switch { + case part.Text != "": + text := part.Text + return &schemas.EmbeddingContentPart{ + Type: schemas.EmbeddingContentPartTypeText, + Text: &text, + }, nil + case part.InlineData != nil: + mimeType := strings.ToLower(strings.TrimSpace(part.InlineData.MIMEType)) + data := fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, part.InlineData.Data) + mime := part.InlineData.MIMEType + media := &schemas.EmbeddingMediaPart{ + Data: &data, + MIMEType: &mime, + } + switch { + case strings.HasPrefix(mimeType, "image/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeImage, Image: media}, nil + case strings.HasPrefix(mimeType, "audio/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeAudio, Audio: media}, nil + case strings.HasPrefix(mimeType, "video/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeVideo, Video: media}, nil + default: + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeFile, File: media}, nil + } + case part.FileData != nil: + uri := part.FileData.FileURI + mime := part.FileData.MIMEType + media := &schemas.EmbeddingMediaPart{ + URL: &uri, + MIMEType: &mime, + } + if part.FileData.DisplayName != "" { + name := part.FileData.DisplayName + media.Filename = &name + } + mimeType := strings.ToLower(strings.TrimSpace(part.FileData.MIMEType)) + switch { + case strings.HasPrefix(mimeType, "image/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeImage, Image: media}, nil + case strings.HasPrefix(mimeType, "audio/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeAudio, Audio: media}, nil + case strings.HasPrefix(mimeType, "video/"): + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeVideo, Video: media}, nil + default: + return &schemas.EmbeddingContentPart{Type: schemas.EmbeddingContentPartTypeFile, File: media}, nil + } + default: + return nil, fmt.Errorf("unsupported gemini embedding part") + } +} + +func geminiContentToEmbeddingContent(content *Content) (schemas.EmbeddingContent, error) { + if content == nil { + return nil, fmt.Errorf("gemini embedding content is nil") + } + result := make(schemas.EmbeddingContent, 0, len(content.Parts)) + for _, part := range content.Parts { + converted, err := geminiPartToEmbeddingContentPart(part) + if err != nil { + return nil, err + } + result = append(result, *converted) + } + return result, nil +} + +func applyBifrostEmbeddingParams(params *schemas.EmbeddingParameters, req GeminiEmbeddingRequest) *schemas.EmbeddingParameters { + if params == nil { + params = &schemas.EmbeddingParameters{} + } + changed := false + if req.OutputDimensionality != nil { + params.Dimensions = req.OutputDimensionality + changed = true + } + if req.TaskType != nil { + params.TaskType = req.TaskType + changed = true + } + if req.Title != nil { + params.Title = req.Title + changed = true + } + if req.DocumentOCR != nil { + if params.ExtraParams == nil { + params.ExtraParams = map[string]interface{}{} + } + params.ExtraParams["documentOcr"] = req.DocumentOCR + changed = true + } + if req.AudioTrackExtraction != nil { + if params.ExtraParams == nil { + params.ExtraParams = map[string]interface{}{} + } + params.ExtraParams["audioTrackExtraction"] = req.AudioTrackExtraction + changed = true + } + if !changed { return nil } + return params +} + +// ToBifrostBatchEmbeddingRequest converts a GeminiBatchEmbeddingRequest (wire format from +// :batchEmbedContents) to BifrostBatchEmbeddingRequest. Per-item taskType/title/dimensions +// are preserved as item-level Params; shared params across all items become the batch default. +func (r *GeminiBatchEmbeddingRequest) ToBifrostBatchEmbeddingRequest(ctx *schemas.BifrostContext) (*schemas.BifrostBatchEmbeddingRequest, error) { + if r == nil || len(r.Requests) == 0 { + return nil, fmt.Errorf("batch embedding request is empty") + } - provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini)) + provider, model := schemas.ParseModelString(r.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Gemini)) - // Create the embedding request - bifrostReq := &schemas.BifrostEmbeddingRequest{ + bifrostReq := &schemas.BifrostBatchEmbeddingRequest{ Provider: provider, Model: model, - Fallbacks: schemas.ParseFallbacks(request.Fallbacks), + Params: applyBifrostEmbeddingParams(nil, r.Requests[0]), // first request params as batch-level default + Fallbacks: schemas.ParseFallbacks(nil), + Items: make([]schemas.BifrostEmbeddingBatchItem, 0, len(r.Requests)), } - // SDK batch embedding request contains multiple embedding requests with same parameters but different text fields. - if len(request.Requests) > 0 { - var texts []string - for _, req := range request.Requests { - if req.Content != nil && len(req.Content.Parts) > 0 { - for _, part := range req.Content.Parts { - if part != nil && part.Text != "" { - texts = append(texts, part.Text) - } - } - } - } - if len(texts) > 0 { - bifrostReq.Input = &schemas.EmbeddingInput{} - if len(texts) == 1 { - bifrostReq.Input.Text = &texts[0] - } else { - bifrostReq.Input.Texts = texts - } + for _, req := range r.Requests { + content, err := geminiContentToEmbeddingContent(req.Content) + if err != nil { + return nil, fmt.Errorf("error converting embedding content: %w", err) } + bifrostReq.Items = append(bifrostReq.Items, schemas.BifrostEmbeddingBatchItem{ + Content: content, + Params: applyBifrostEmbeddingParams(nil, req), + }) + } - embeddingRequest := request.Requests[0] + return bifrostReq, nil +} - // Convert parameters - if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil { - bifrostReq.Params = &schemas.EmbeddingParameters{} +// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format. +// Each request entry maps to one element in Contents (one output embedding). +func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { + if request == nil { + return nil + } - if embeddingRequest.OutputDimensionality != nil { - bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality - } + provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Gemini)) + bifrostReq := &schemas.BifrostEmbeddingRequest{ + Provider: provider, + Model: model, + Input: &schemas.EmbeddingInput{}, + Fallbacks: schemas.ParseFallbacks(request.Fallbacks), + } - // Handle extra parameters - if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil { - bifrostReq.Params.ExtraParams = make(map[string]interface{}) - if embeddingRequest.TaskType != nil { - bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType - } - if embeddingRequest.Title != nil { - bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title - } + if len(request.Requests) > 0 { + contents := make([]schemas.EmbeddingContent, 0, len(request.Requests)) + for _, req := range request.Requests { + content, err := geminiContentToEmbeddingContent(req.Content) + if err != nil { + return nil } + contents = append(contents, content) } + bifrostReq.Input.Contents = contents + bifrostReq.Params = applyBifrostEmbeddingParams(bifrostReq.Params, request.Requests[0]) + return bifrostReq } - // Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[]. - // If no SDK requests[] were provided, derive embedding input from contents. - if bifrostReq.Input == nil { - var texts []string + if len(request.Contents) > 0 { + contents := make([]schemas.EmbeddingContent, 0, len(request.Contents)) for _, content := range request.Contents { - for _, part := range content.Parts { - if part != nil && part.Text != "" { - texts = append(texts, part.Text) - } - } - } - if len(texts) > 0 { - bifrostReq.Input = &schemas.EmbeddingInput{} - if len(texts) == 1 { - bifrostReq.Input.Text = &texts[0] - } else { - bifrostReq.Input.Texts = texts + converted, err := geminiContentToEmbeddingContent(&content) + if err != nil { + return nil } + contents = append(contents, converted) } + bifrostReq.Input.Contents = contents } return bifrostReq diff --git a/core/providers/gemini/embedding_multimodal_test.go b/core/providers/gemini/embedding_multimodal_test.go new file mode 100644 index 0000000000..4f1c228d56 --- /dev/null +++ b/core/providers/gemini/embedding_multimodal_test.go @@ -0,0 +1,51 @@ +package gemini + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestToGeminiEmbeddingRequestBatchContentUsesBatchRequest(t *testing.T) { + one := "one" + two := "two" + req, err := ToGeminiEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Model: "gemini-embedding-001", + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &one}}, + {{Type: schemas.EmbeddingContentPartTypeText, Text: &two}}, + }, + }, + }) + require.NoError(t, err) + + require.Len(t, req.Requests, 2) + require.Equal(t, "one", req.Requests[0].Content.Parts[0].Text) + require.Equal(t, "two", req.Requests[1].Content.Parts[0].Text) +} + +func TestGeminiGenerationRequestToBifrostEmbeddingRequestPreservesMultimodalContent(t *testing.T) { + request := &GeminiGenerationRequest{ + Model: "gemini/gemini-embedding-001", + Requests: []GeminiEmbeddingRequest{ + { + Content: &Content{ + Parts: []*Part{ + {Text: "hello"}, + {FileData: &FileData{FileURI: "https://example.com/img.png", MIMEType: "image/png"}}, + }, + }, + }, + }, + } + + bifrostReq := request.ToBifrostEmbeddingRequest(schemas.NewBifrostContext(nil, schemas.NoDeadline)) + require.NotNil(t, bifrostReq) + require.NotNil(t, bifrostReq.Input) + require.Len(t, bifrostReq.Input.Contents, 1) + require.Len(t, bifrostReq.Input.Contents[0], 2) + require.Equal(t, schemas.EmbeddingContentPartTypeText, bifrostReq.Input.Contents[0][0].Type) + require.Equal(t, schemas.EmbeddingContentPartTypeImage, bifrostReq.Input.Contents[0][1].Type) +} diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index d0ff23711e..98b3582d8a 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -1170,12 +1170,12 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem providerName := provider.GetProviderKey() - // Convert Bifrost request to Gemini batch embedding request format + // Convert Bifrost request to Gemini embedding request format jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToGeminiEmbeddingRequest(request), nil + return ToGeminiEmbeddingRequest(request) }) if err != nil { return nil, err @@ -1190,8 +1190,8 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem // Set any extra headers from network config providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - // Use Gemini's batchEmbedContents endpoint - req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":batchEmbedContents")) + endpoint := "/models/" + request.Model + ":batchEmbedContents" + req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, endpoint)) req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if key.Value.GetValue() != "" { @@ -1254,7 +1254,7 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem wait() fasthttp.ReleaseResponse(resp) - // Parse Gemini's batch embedding response + // Parse Gemini embedding response var geminiResponse GeminiEmbeddingResponse rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &geminiResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -1285,6 +1285,113 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem return bifrostResponse, nil } +// BatchEmbedding performs a batch embedding request to the Gemini API using batchEmbedContents. +// Each item can carry its own parameter overrides (taskType, title, outputDimensionality, etc.). +func (provider *GeminiProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchEmbeddingRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + jsonData, err := providerUtils.CheckContextAndGetRequestBody( + ctx, + request, + func() (providerUtils.RequestBodyWithExtraParams, error) { + return ToGeminiBatchEmbeddingRequest(request) + }) + if err != nil { + return nil, err + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + + providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + + endpoint := "/models/" + request.Model + ":batchEmbedContents" + req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, endpoint)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + if key.Value.GetValue() != "" { + req.Header.Set("x-goog-api-key", key.Value.GetValue()) + } + + usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req) + if !usedLargePayloadBody { + req.SetBody(jsonData) + } + + activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp) + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp) + if bifrostErr != nil { + wait() + fasthttp.ReleaseResponse(resp) + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + if usedLargePayloadBody { + providerUtils.DrainLargePayloadRemainder(ctx) + } + + providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) + + if resp.StatusCode() != fasthttp.StatusOK { + providerUtils.MaterializeStreamErrorBody(ctx, resp) + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + wait() + fasthttp.ReleaseResponse(resp) + return nil, parsedErr + } + + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) + if decodeErr != nil { + wait() + fasthttp.ReleaseResponse(resp) + return nil, providerUtils.EnrichError(ctx, decodeErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + if isLargeResp { + wait() + return &schemas.BifrostEmbeddingResponse{ + Model: request.Model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Latency: latency.Milliseconds(), + ProviderResponseHeaders: providerResponseHeaders, + }, + }, nil + } + wait() + fasthttp.ReleaseResponse(resp) + + var geminiResponse GeminiEmbeddingResponse + rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &geminiResponse, jsonData, + providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + if bifrostErr != nil { + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + + bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model) + if bifrostResponse == nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, + fmt.Errorf("failed to convert Gemini batch embedding response to Bifrost format")) + } + + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() + bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders + + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + bifrostResponse.ExtraFields.RawRequest = rawRequest + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + return bifrostResponse, nil +} + // Speech performs a speech synthesis request to the Gemini API. func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { // Check if speech is allowed for this provider diff --git a/core/providers/gemini/gemini_test.go b/core/providers/gemini/gemini_test.go index e1fb192f66..6683ab4f84 100644 --- a/core/providers/gemini/gemini_test.go +++ b/core/providers/gemini/gemini_test.go @@ -36,7 +36,8 @@ func TestGemini(t *testing.T) { {Provider: schemas.Gemini, Model: "gemini-2.5-flash"}, }, VisionModel: "gemini-2.5-flash", - EmbeddingModel: "gemini-embedding-001", + EmbeddingModel: "gemini-embedding-001", + MultimodalEmbeddingModel: "gemini-embedding-001", TranscriptionModel: "gemini-2.5-flash", SpeechSynthesisModel: "gemini-2.5-flash-preview-tts", ImageGenerationModel: "gemini-2.5-flash-image", @@ -72,6 +73,7 @@ func TestGemini(t *testing.T) { FileURL: false, // supported files via gemini files api CompleteEnd2End: true, Embedding: true, + MultimodalEmbedding: true, Transcription: false, TranscriptionStream: false, SpeechSynthesis: true, @@ -214,7 +216,7 @@ func TestToBifrostEmbeddingResponsePreservesPrecision(t *testing.T) { require.NotNil(t, resp) - got := resp.Data[0].Embedding.EmbeddingArray[0] + got := resp.Data[0].Embedding.Float[0] assert.Equal(t, want, got) assert.NotEqual(t, float64(float32(want)), got) } diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index a30eadcaf5..c1bec70f7b 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -1150,6 +1150,7 @@ func (tc *GenerationConfigThinkingConfig) UnmarshalJSON(data []byte) error { } type GeminiBatchEmbeddingRequest struct { + Model string `json:"-"` // populated from URL path by Bifrost; not part of wire format Requests []GeminiEmbeddingRequest `json:"requests,omitempty"` ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters } @@ -1162,6 +1163,8 @@ func (r *GeminiBatchEmbeddingRequest) GetExtraParams() map[string]interface{} { // GeminiEmbeddingRequest represents a single embedding request in a batch. type GeminiEmbeddingRequest struct { Content *Content `json:"content,omitempty"` + DocumentOCR *bool `json:"documentOcr,omitempty"` + AudioTrackExtraction *bool `json:"audioTrackExtraction,omitempty"` TaskType *string `json:"taskType,omitempty"` Title *string `json:"title,omitempty"` OutputDimensionality *int `json:"outputDimensionality,omitempty"` @@ -1262,7 +1265,9 @@ func (p *Part) UnmarshalJSON(data []byte) error { VideoMetadata *VideoMetadata `json:"videoMetadata,omitempty"` Thought bool `json:"thought,omitempty"` InlineData *Blob `json:"inlineData,omitempty"` + InlineDataSnake *Blob `json:"inline_data,omitempty"` // Python SDK uses snake_case FileData *FileData `json:"fileData,omitempty"` + FileDataSnake *FileData `json:"file_data,omitempty"` // Python SDK uses snake_case ThoughtSignature string `json:"thoughtSignature,omitempty"` CodeExecutionResult *CodeExecutionResult `json:"codeExecutionResult,omitempty"` ExecutableCode *ExecutableCode `json:"executableCode,omitempty"` @@ -1279,7 +1284,13 @@ func (p *Part) UnmarshalJSON(data []byte) error { p.VideoMetadata = aux.VideoMetadata p.Thought = aux.Thought p.InlineData = aux.InlineData + if p.InlineData == nil { + p.InlineData = aux.InlineDataSnake + } p.FileData = aux.FileData + if p.FileData == nil { + p.FileData = aux.FileDataSnake + } p.CodeExecutionResult = aux.CodeExecutionResult p.ExecutableCode = aux.ExecutableCode p.FunctionCall = aux.FunctionCall @@ -1324,9 +1335,11 @@ type Blob struct { // UnmarshalJSON custom unmarshaler for Blob to handle URL-safe base64 func (b *Blob) UnmarshalJSON(data []byte) error { type BlobAlias struct { - DisplayName string `json:"displayName,omitempty"` - Data string `json:"data,omitempty"` - MIMEType string `json:"mimeType,omitempty"` + DisplayName string `json:"displayName,omitempty"` + DisplayNameSnake string `json:"display_name,omitempty"` // Python SDK uses snake_case + Data string `json:"data,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + MIMETypeSnake string `json:"mime_type,omitempty"` // Python SDK uses snake_case } var aux BlobAlias @@ -1335,7 +1348,13 @@ func (b *Blob) UnmarshalJSON(data []byte) error { } b.DisplayName = aux.DisplayName + if b.DisplayName == "" { + b.DisplayName = aux.DisplayNameSnake + } b.MIMEType = aux.MIMEType + if b.MIMEType == "" { + b.MIMEType = aux.MIMETypeSnake + } if aux.Data != "" { // Convert URL-safe base64 to standard base64 @@ -1455,7 +1474,8 @@ type FunctionResponse struct { // GeminiEmbeddingResponse represents a Google GenAI embedding response. type GeminiEmbeddingResponse struct { - Embeddings []GeminiEmbedding `json:"embeddings"` + Embedding *GeminiEmbedding `json:"embedding,omitempty"` + Embeddings []GeminiEmbedding `json:"embeddings,omitempty"` Metadata *EmbedContentMetadata `json:"metadata,omitempty"` } diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index 9667b989ff..c467e65084 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -175,6 +175,10 @@ func (provider *GroqProvider) Embedding(ctx *schemas.BifrostContext, key schemas return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *GroqProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech handles non-streaming speech synthesis requests. // It formats the request body, makes the API call, and returns the response. // Returns the response and any error that occurred. diff --git a/core/providers/huggingface/embedding.go b/core/providers/huggingface/embedding.go index 5a6ce5d1c8..1f9a707954 100644 --- a/core/providers/huggingface/embedding.go +++ b/core/providers/huggingface/embedding.go @@ -2,6 +2,7 @@ package huggingface import ( "fmt" + "strings" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" @@ -28,15 +29,30 @@ func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) hfReq = &HuggingFaceEmbeddingRequest{} } - // Convert input - if bifrostReq.Input != nil { + if bifrostReq.Input != nil && len(bifrostReq.Input.Contents) > 0 { + contents := bifrostReq.Input.Contents var input InputsCustomType - if bifrostReq.Input.Text != nil { - input = InputsCustomType{Text: bifrostReq.Input.Text} - } else if bifrostReq.Input.Texts != nil { - input = InputsCustomType{Texts: bifrostReq.Input.Texts} + if len(contents) == 1 { + // Single content: extract text from the single entry + text, err := extractTextFromContent(contents[0]) + if err != nil { + return nil, err + } + input = InputsCustomType{Text: &text} + } else { + // Batch: extract text from each content entry + texts := make([]string, 0, len(contents)) + for _, content := range contents { + text, err := extractTextFromContent(content) + if err != nil { + return nil, err + } + texts = append(texts, text) + } + input = InputsCustomType{Texts: texts} } + if inferenceProvider == hfInference { hfReq.Inputs = &input } else { @@ -44,11 +60,9 @@ func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) } } - // Map parameters if bifrostReq.Params != nil { params := bifrostReq.Params - // Map standard parameters if params.EncodingFormat != nil { encodingType := EncodingType(*params.EncodingFormat) hfReq.EncodingFormat = &encodingType @@ -57,7 +71,6 @@ func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) hfReq.Dimensions = params.Dimensions } - // Check for HuggingFace-specific parameters in ExtraParams if params.ExtraParams != nil { if normalize, ok := params.ExtraParams["normalize"].(bool); ok { delete(params.ExtraParams, "normalize") @@ -82,6 +95,25 @@ func ToHuggingFaceEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) return hfReq, nil } +// extractTextFromContent extracts a single text string from a content entry. +// All parts must be text-only; multiple text parts are stitched together. +func extractTextFromContent(content schemas.EmbeddingContent) (string, error) { + var sb strings.Builder + for _, part := range content { + if part.Type != schemas.EmbeddingContentPartTypeText || part.Text == nil { + return "", fmt.Errorf("huggingface embedding only supports text input") + } + if sb.Len() > 0 { + sb.WriteString(" \n") + } + sb.WriteString(*part.Text) + } + if sb.Len() == 0 { + return "", fmt.Errorf("huggingface embedding content has no text") + } + return sb.String(), nil +} + // UnmarshalHuggingFaceEmbeddingResponse unmarshals HuggingFace API response directly into BifrostEmbeddingResponse // Handles multiple formats: standard object, 2D array, or 1D array func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas.BifrostEmbeddingResponse, error) { @@ -109,11 +141,7 @@ func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas. if obj.Usage != nil { bifrostResponse.Usage = obj.Usage } else { - bifrostResponse.Usage = &schemas.BifrostLLMUsage{ - PromptTokens: 0, - CompletionTokens: 0, - TotalTokens: 0, - } + bifrostResponse.Usage = &schemas.BifrostLLMUsage{} } return bifrostResponse, nil } @@ -125,7 +153,7 @@ func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas. embeddings := make([]schemas.EmbeddingData, len(arr2D)) for idx, embedding := range arr2D { embeddings[idx] = schemas.EmbeddingData{ - Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), embedding...)}, + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), embedding...)}, Index: idx, Object: "embedding", } @@ -134,11 +162,7 @@ func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas. Data: embeddings, Model: model, Object: "list", - Usage: &schemas.BifrostLLMUsage{ - PromptTokens: 0, - CompletionTokens: 0, - TotalTokens: 0, - }, + Usage: &schemas.BifrostLLMUsage{}, }, nil } @@ -146,18 +170,14 @@ func UnmarshalHuggingFaceEmbeddingResponse(data []byte, model string) (*schemas. var arr1D []float64 if err := sonic.Unmarshal(data, &arr1D); err == nil { return &schemas.BifrostEmbeddingResponse{ - Data: []schemas.EmbeddingData{{ - Embedding: schemas.EmbeddingStruct{EmbeddingArray: append([]float64(nil), arr1D...)}, - Index: 0, - Object: "embedding", - }}, + Data: []schemas.EmbeddingData{{ + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), arr1D...)}, + Index: 0, + Object: "embedding", + }}, Model: model, Object: "list", - Usage: &schemas.BifrostLLMUsage{ - PromptTokens: 0, - CompletionTokens: 0, - TotalTokens: 0, - }, + Usage: &schemas.BifrostLLMUsage{}, }, nil } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index 38ddd84e9f..4fd720206c 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -703,6 +703,10 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key return bifrostResponse, nil } +func (provider *HuggingFaceProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { // Check if Speech is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.SpeechRequest); err != nil { diff --git a/core/providers/huggingface/huggingface_test.go b/core/providers/huggingface/huggingface_test.go index 4c98cec1f0..031ef52e92 100644 --- a/core/providers/huggingface/huggingface_test.go +++ b/core/providers/huggingface/huggingface_test.go @@ -92,11 +92,11 @@ func TestUnmarshalHuggingFaceEmbeddingResponsePreservesPrecision(t *testing.T) { if resp == nil || len(resp.Data) != 1 { t.Fatalf("expected single embedding response, got %#v", resp) } - if len(resp.Data[0].Embedding.EmbeddingArray) != 1 { - t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray) + if len(resp.Data[0].Embedding.Float) != 1 { + t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.Float) } - got := resp.Data[0].Embedding.EmbeddingArray[0] + got := resp.Data[0].Embedding.Float[0] if got != want { t.Fatalf("expected %0.18f, got %0.18f", want, got) } @@ -117,11 +117,11 @@ func TestUnmarshalHuggingFaceEmbeddingResponse1DPreservesPrecision(t *testing.T) if resp == nil || len(resp.Data) != 1 { t.Fatalf("expected single embedding response, got %#v", resp) } - if len(resp.Data[0].Embedding.EmbeddingArray) != 1 { - t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.EmbeddingArray) + if len(resp.Data[0].Embedding.Float) != 1 { + t.Fatalf("expected single embedding value, got %#v", resp.Data[0].Embedding.Float) } - got := resp.Data[0].Embedding.EmbeddingArray[0] + got := resp.Data[0].Embedding.Float[0] if got != want { t.Fatalf("expected %0.18f, got %0.18f", want, got) } diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index cd7278f721..df720abc7c 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -148,7 +148,9 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) Provider: customMistralProviderName, Model: "codestral-embed", Input: &schemas.EmbeddingInput{ - Texts: []string{"hello"}, + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: schemas.Ptr("hello")}}, + }, }, } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index b833f8b7cc..a85e9a9236 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -265,6 +265,10 @@ func (provider *MistralProvider) Embedding(ctx *schemas.BifrostContext, key sche ) } +func (provider *MistralProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Mistral provider. func (provider *MistralProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index 13e2cb4e33..41c8c0e3a9 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -233,6 +233,10 @@ func (provider *NebiusProvider) Embedding(ctx *schemas.BifrostContext, key schem provider.logger) } +func (provider *NebiusProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Nebius provider. func (provider *NebiusProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 1bc620e947..e3db194e2f 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -270,6 +270,10 @@ func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schem ) } +func (provider *OllamaProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Ollama provider. func (provider *OllamaProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/openai/embedding.go b/core/providers/openai/embedding.go index fa243ac5b8..bcab79c979 100644 --- a/core/providers/openai/embedding.go +++ b/core/providers/openai/embedding.go @@ -1,40 +1,188 @@ package openai import ( + "fmt" + "strings" + "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToBifrostEmbeddingRequest converts an OpenAI embedding request to Bifrost format +// ToBifrostEmbeddingResponse converts an OpenAI embedding response to Bifrost format. +func (r *OpenaiEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse { + data := make([]schemas.EmbeddingData, len(r.Data)) + for i, d := range r.Data { + var embeddingsByType schemas.EmbeddingsByType + switch { + case d.Embedding.EmbeddingStr != nil: + embeddingsByType.Base64 = d.Embedding.EmbeddingStr + case d.Embedding.EmbeddingArray != nil: + embeddingsByType.Float = d.Embedding.EmbeddingArray + case d.Embedding.Embedding2DArray != nil: + // Flatten 2D array into a single float slice (OpenAI does not return 2D embeddings in practice) + var flat []float64 + for _, inner := range d.Embedding.Embedding2DArray { + flat = append(flat, inner...) + } + embeddingsByType.Float = flat + } + data[i] = schemas.EmbeddingData{ + Index: d.Index, + Object: d.Object, + Embedding: embeddingsByType, + } + } + return &schemas.BifrostEmbeddingResponse{ + Data: data, + Model: r.Model, + Object: r.Object, + Usage: r.Usage, + } +} + +// ToOpenAIEmbeddingResponse converts a Bifrost embedding response to OpenAI +func ToOpenAIEmbeddingResponse(resp *schemas.BifrostEmbeddingResponse) *OpenaiEmbeddingResponse { + if resp == nil { + return nil + } + data := make([]EmbeddingData, len(resp.Data)) + for i, d := range resp.Data { + var embStruct EmbeddingStruct + switch { + case d.Embedding.Base64 != nil: + embStruct.EmbeddingStr = d.Embedding.Base64 + case d.Embedding.Float != nil: + embStruct.EmbeddingArray = d.Embedding.Float + } + data[i] = EmbeddingData{ + Index: d.Index, + Object: d.Object, + Embedding: embStruct, + } + } + return &OpenaiEmbeddingResponse{ + Data: data, + Model: resp.Model, + Object: resp.Object, + Usage: resp.Usage, + } +} + +// ToBifrostEmbeddingRequest converts an OpenAI embedding request to Bifrost format. func (request *OpenAIEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI)) + var embeddingInput *schemas.EmbeddingInput + if request.Input != nil { + switch { + case request.Input.Text != nil: + t := *request.Input.Text + embeddingInput = &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &t}}, + }, + } + case request.Input.Texts != nil: + contents := make([]schemas.EmbeddingContent, len(request.Input.Texts)) + for i, text := range request.Input.Texts { + t := text + contents[i] = schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &t}, + } + } + embeddingInput = &schemas.EmbeddingInput{Contents: contents} + case request.Input.Embedding != nil: + tokens := request.Input.Embedding + embeddingInput = &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeTokens, Tokens: tokens}}, + }, + } + case request.Input.Embeddings != nil: + contents := make([]schemas.EmbeddingContent, len(request.Input.Embeddings)) + for i, tokens := range request.Input.Embeddings { + t := tokens + contents[i] = schemas.EmbeddingContent{ + {Type: schemas.EmbeddingContentPartTypeTokens, Tokens: t}, + } + } + embeddingInput = &schemas.EmbeddingInput{Contents: contents} + } + } + return &schemas.BifrostEmbeddingRequest{ Provider: provider, Model: model, - Input: request.Input, + Input: embeddingInput, Params: &request.EmbeddingParameters, Fallbacks: schemas.ParseFallbacks(request.Fallbacks), } } -// ToOpenAIEmbeddingRequest converts a Bifrost embedding request to OpenAI format -func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *OpenAIEmbeddingRequest { +// ToOpenAIEmbeddingRequest converts a Bifrost embedding request to OpenAI format. +func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*OpenAIEmbeddingRequest, error) { if bifrostReq == nil { - return nil + return nil, nil } - params := bifrostReq.Params + var input *OpenAIEmbeddingInput + if bifrostReq.Input != nil { + var texts []string + var tokenBatches [][]int + for _, content := range bifrostReq.Input.Contents { + var sb strings.Builder + var tokens []int + for _, part := range content { + switch part.Type { + case schemas.EmbeddingContentPartTypeText: + if part.Text != nil { + if sb.Len() > 0 { + sb.WriteString(" \n") + } + sb.WriteString(*part.Text) + } + case schemas.EmbeddingContentPartTypeTokens: + if part.Tokens != nil { + tokens = append(tokens, part.Tokens...) + } + default: + return nil, fmt.Errorf("openai embedding does not support %q input", part.Type) + } + } + if sb.Len() > 0 && len(tokens) > 0 { + return nil, fmt.Errorf("openai embedding does not support mixing text and token inputs within a single content entry") + } + if sb.Len() > 0 { + texts = append(texts, sb.String()) + } else if len(tokens) > 0 { + tokenBatches = append(tokenBatches, tokens) + } + } + + if len(texts) > 0 && len(tokenBatches) > 0 { + return nil, fmt.Errorf("openai embedding does not support mixing text and token inputs in the same request") + } + switch { + case len(texts) == 1: + input = &OpenAIEmbeddingInput{Text: &texts[0]} + case len(texts) > 1: + input = &OpenAIEmbeddingInput{Texts: texts} + case len(tokenBatches) == 1: + input = &OpenAIEmbeddingInput{Embedding: tokenBatches[0]} + case len(tokenBatches) > 1: + input = &OpenAIEmbeddingInput{Embeddings: tokenBatches} + } + } openaiReq := &OpenAIEmbeddingRequest{ Model: bifrostReq.Model, - Input: bifrostReq.Input, + Input: input, } - // Map parameters - if params != nil { - openaiReq.EmbeddingParameters = *params - openaiReq.ExtraParams = params.ExtraParams + if bifrostReq.Params != nil { + openaiReq.EmbeddingParameters = *bifrostReq.Params + openaiReq.ExtraParams = bifrostReq.Params.ExtraParams } - return openaiReq + + return openaiReq, nil } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index b1dbd4b090..3c82dc4149 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1840,6 +1840,10 @@ func (provider *OpenAIProvider) Embedding(ctx *schemas.BifrostContext, key schem ) } +func (provider *OpenAIProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // HandleOpenAIEmbeddingRequest handles embedding requests for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same embedding request format. func HandleOpenAIEmbeddingRequest( @@ -1852,7 +1856,7 @@ func HandleOpenAIEmbeddingRequest( providerName schemas.ModelProvider, sendBackRawRequest bool, sendBackRawResponse bool, - customResponseHandler responseHandler[schemas.BifrostEmbeddingResponse], + customResponseHandler responseHandler[OpenaiEmbeddingResponse], logger schemas.Logger, ) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Create request @@ -1885,10 +1889,11 @@ func HandleOpenAIEmbeddingRequest( return nil, lpErr } if len(lpResult.ResponseBody) > 0 { - response := &schemas.BifrostEmbeddingResponse{} - if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { + openaiResp := &OpenaiEmbeddingResponse{} + if err := sonic.Unmarshal(lpResult.ResponseBody, openaiResp); err != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } + response := openaiResp.ToBifrostEmbeddingResponse() response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } @@ -1904,7 +1909,7 @@ func HandleOpenAIEmbeddingRequest( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToOpenAIEmbeddingRequest(request), nil + return ToOpenAIEmbeddingRequest(request) }) if bifrostErr != nil { return nil, bifrostErr @@ -1942,20 +1947,24 @@ func HandleOpenAIEmbeddingRequest( }, nil } - response := &schemas.BifrostEmbeddingResponse{} + openaiResp := &OpenaiEmbeddingResponse{} var rawRequest, rawResponse interface{} if customResponseHandler != nil { - rawRequest, rawResponse, bifrostErr = customResponseHandler(body, response, jsonData, sendBackRawRequest, sendBackRawResponse) + rawRequest, rawResponse, bifrostErr = customResponseHandler(body, openaiResp, jsonData, sendBackRawRequest, sendBackRawResponse) } else { - rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse) + rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, openaiResp, jsonData, sendBackRawRequest, sendBackRawResponse) } if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } + response := openaiResp.ToBifrostEmbeddingResponse() + response.ExtraFields.Provider = providerName + response.ExtraFields.ResolvedModelUsed = request.Model + response.ExtraFields.RequestType = schemas.EmbeddingRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index e2eab5245a..df3fa672fb 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -51,10 +51,88 @@ func (req *OpenAITextCompletionRequest) IsStreamingRequested() bool { return req.Stream != nil && *req.Stream } +type OpenAIEmbeddingInput struct { + Text *string + Texts []string + Embedding []int + Embeddings [][]int +} + +func (e *OpenAIEmbeddingInput) MarshalJSON() ([]byte, error) { + // enforce one-of + set := 0 + if e.Text != nil { + set++ + } + if e.Texts != nil { + set++ + } + if e.Embedding != nil { + set++ + } + if e.Embeddings != nil { + set++ + } + if set == 0 { + return nil, fmt.Errorf("embedding input is empty") + } + if set > 1 { + return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings") + } + + if e.Text != nil { + return providerUtils.MarshalSorted(*e.Text) + } + if e.Texts != nil { + return providerUtils.MarshalSorted(e.Texts) + } + if e.Embedding != nil { + return providerUtils.MarshalSorted(e.Embedding) + } + if e.Embeddings != nil { + return providerUtils.MarshalSorted(e.Embeddings) + } + + return nil, fmt.Errorf("invalid embedding input") +} + +func (e *OpenAIEmbeddingInput) UnmarshalJSON(data []byte) error { + e.Text = nil + e.Texts = nil + e.Embedding = nil + e.Embeddings = nil + // Try string + var s string + if err := sonic.Unmarshal(data, &s); err == nil { + e.Text = &s + return nil + } + // Try []string + var ss []string + if err := sonic.Unmarshal(data, &ss); err == nil { + e.Texts = ss + return nil + } + // Try []int + var i []int + if err := sonic.Unmarshal(data, &i); err == nil { + e.Embedding = i + return nil + } + // Try [][]int + var i2 [][]int + if err := sonic.Unmarshal(data, &i2); err == nil { + e.Embeddings = i2 + return nil + } + + return fmt.Errorf("unsupported embedding input shape") +} + // OpenAIEmbeddingRequest represents an OpenAI embedding request type OpenAIEmbeddingRequest struct { - Model string `json:"model"` - Input *schemas.EmbeddingInput `json:"input"` // Can be string or []string + Model string `json:"model"` + Input *OpenAIEmbeddingInput `json:"input"` // Can be string or []string schemas.EmbeddingParameters @@ -1023,3 +1101,61 @@ func (r *OpenAIVideoRemixRequest) GetExtraParams() map[string]interface{} { // ErrVideoNotReady is an error that is returned when a video is not ready yet var ErrVideoNotReady = errors.New("video is not ready yet, use GET /v1/videos/{video_id} to check status") + +type OpenaiEmbeddingResponse struct { + Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) + Model string `json:"model"` + Object string `json:"object"` // "list" + Usage *schemas.BifrostLLMUsage `json:"usage"` +} + +type EmbeddingData struct { + Index int `json:"index"` + Object string `json:"object"` // "embedding" + Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64 or [][]float64 +} + +type EmbeddingStruct struct { + // Embedding responses preserve provider precision in normalized API output. + EmbeddingStr *string + EmbeddingArray []float64 + Embedding2DArray [][]float64 +} + +func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { + if be.EmbeddingStr != nil { + return providerUtils.MarshalSorted(be.EmbeddingStr) + } + if be.EmbeddingArray != nil { + return providerUtils.MarshalSorted(be.EmbeddingArray) + } + if be.Embedding2DArray != nil { + return providerUtils.MarshalSorted(be.Embedding2DArray) + } + return nil, fmt.Errorf("no embedding found") +} + +func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + be.EmbeddingStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of float64 + var arrayContent []float64 + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + be.EmbeddingArray = arrayContent + return nil + } + + // Try to unmarshal as a direct 2D array of float64 + var arrayContent2D [][]float64 + if err := sonic.Unmarshal(data, &arrayContent2D); err == nil { + be.Embedding2DArray = arrayContent2D + return nil + } + + return fmt.Errorf("embedding field is neither a string nor an array of float64 nor a 2D array of float64") +} diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 36e4ff0566..7239275af0 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -418,6 +418,10 @@ func (provider *OpenRouterProvider) Embedding(ctx *schemas.BifrostContext, key s ) } +func (provider *OpenRouterProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the OpenRouter provider. func (provider *OpenRouterProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index ae4cb22ab7..386153630b 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -171,6 +171,10 @@ func (provider *ParasailProvider) Embedding(ctx *schemas.BifrostContext, key sch return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *ParasailProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Parasail provider. func (provider *ParasailProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index addb6a5fb8..9ef645d520 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -245,6 +245,10 @@ func (provider *PerplexityProvider) Embedding(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *PerplexityProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Perplexity provider. func (provider *PerplexityProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 652f99c63b..6aee8b95b4 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -1687,6 +1687,10 @@ func (provider *ReplicateProvider) Embedding(ctx *schemas.BifrostContext, key sc return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *ReplicateProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the replicate provider. func (provider *ReplicateProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index e51e5e6355..fc71eba593 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -105,6 +105,10 @@ func (provider *RunwayProvider) Embedding(ctx *schemas.BifrostContext, key schem return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *RunwayProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Runway provider. func (provider *RunwayProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index f47c7b34e4..fbd9f241db 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -275,6 +275,10 @@ func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas. ) } +func (provider *SGLProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the SGL provider. func (provider *SGLProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/vertex/embedding.go b/core/providers/vertex/embedding.go index 54662f50fe..d3062d4e84 100644 --- a/core/providers/vertex/embedding.go +++ b/core/providers/vertex/embedding.go @@ -1,115 +1,249 @@ package vertex import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/providers/gemini" "github.com/maximhq/bifrost/core/schemas" ) -// ToVertexEmbeddingRequest converts a Bifrost embedding request to Vertex AI format +// isVertexNativeMultimodalEmbeddingModel returns true for the Vertex-native +// multimodal embedding model (multimodalembedding@001). This model uses the +// :predict endpoint but with a different instance format (text/image/video fields +// instead of content) and a different response format (textEmbedding/imageEmbedding). +func isVertexNativeMultimodalEmbeddingModel(model string) bool { + return strings.Contains(strings.ToLower(strings.TrimSpace(model)), "multimodalembedding") +} + +func isVertexGeminiEmbeddingModel(model string) bool { + model = strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(model, "gemini-embedding-2") +} + +// ToVertexEmbeddingRequest converts a Bifrost embedding request to Vertex AI text embedding format. +// All contents must be text-only. Each content entry maps to one instance (one output embedding). func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *VertexEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { + if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { return nil } - // Create the request + vertexReq := &VertexEmbeddingRequest{} if bifrostReq.Params != nil { vertexReq.ExtraParams = bifrostReq.Params.ExtraParams } - var texts []string - if bifrostReq.Input.Text != nil { - texts = []string{*bifrostReq.Input.Text} - } else { - texts = bifrostReq.Input.Texts - } - // Create instances for each text - instances := make([]VertexEmbeddingInstance, 0, len(texts)) - for _, text := range texts { - instance := VertexEmbeddingInstance{ - Content: text, + instances := make([]VertexEmbeddingInstance, 0, len(bifrostReq.Input.Contents)) + for _, content := range bifrostReq.Input.Contents { + // Vertex text embedding expects a single text string per instance; + // stitch multiple text parts together. + var sb strings.Builder + for _, part := range content { + if part.Type != schemas.EmbeddingContentPartTypeText || part.Text == nil { + return nil + } + sb.WriteString(*part.Text) } - - // Add optional task_type and title from params + instance := VertexEmbeddingInstance{Content: sb.String()} if bifrostReq.Params != nil { - if taskTypeStr, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["task_type"]); ok { - delete(vertexReq.ExtraParams, "task_type") - instance.TaskType = taskTypeStr - } - if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok { - delete(vertexReq.ExtraParams, "title") - instance.Title = title - } + instance.TaskType = bifrostReq.Params.TaskType + instance.Title = bifrostReq.Params.Title } - instances = append(instances, instance) } vertexReq.Instances = instances - // Add parameters if present + if bifrostReq.Params != nil { parameters := &VertexEmbeddingParameters{} - - // Set autoTruncate (defaults to true) autoTruncate := true - if bifrostReq.Params.ExtraParams != nil { - if autoTruncateVal, ok := schemas.SafeExtractBool(bifrostReq.Params.ExtraParams["autoTruncate"]); ok { - delete(vertexReq.ExtraParams, "autoTruncate") - autoTruncate = autoTruncateVal - } + if bifrostReq.Params.AutoTruncate != nil { + autoTruncate = *bifrostReq.Params.AutoTruncate } parameters.AutoTruncate = &autoTruncate + parameters.OutputDimensionality = bifrostReq.Params.Dimensions + vertexReq.Parameters = parameters + } + + return vertexReq +} + +// ToVertexGeminiEmbeddingRequest converts a Bifrost embedding request to Vertex Gemini embedding format. +// Only a single content entry is supported (len == 1); batch is not supported by this endpoint. +func ToVertexGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*VertexGeminiEmbeddingRequest, error) { + if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { + return nil, fmt.Errorf("embedding input is not provided") + } + if len(bifrostReq.Input.Contents) > 1 { + return nil, fmt.Errorf("vertex gemini embedding does not support batch inputs (multiple contents); use a single content entry") + } - // Add outputDimensionality if specified - if bifrostReq.Params.Dimensions != nil { - delete(vertexReq.ExtraParams, "dimensions") - parameters.OutputDimensionality = bifrostReq.Params.Dimensions + content := bifrostReq.Input.Contents[0] + params := bifrostReq.Params + gemContent, err := gemini.EmbeddingContentToGeminiContent(content) + if err != nil { + return nil, err + } + req := &VertexGeminiEmbeddingRequest{ + Content: gemContent, + } + if params != nil { + req.TaskType = params.TaskType + req.Title = params.Title + req.OutputDimensionality = params.Dimensions + req.AutoTruncate = params.AutoTruncate + + if params.ExtraParams != nil { + req.ExtraParams = params.ExtraParams + if documentOCR, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["documentOcr"]); ok { + delete(req.ExtraParams, "documentOcr") + req.DocumentOCR = documentOCR + } + if audioTrackExtraction, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["audioTrackExtraction"]); ok { + delete(req.ExtraParams, "audioTrackExtraction") + req.AudioTrackExtraction = audioTrackExtraction + } } + } + return req, nil +} - vertexReq.Parameters = parameters +// extractBase64FromDataURI strips the "data:;base64," prefix from a data URI, +// returning the raw base64 string that Vertex multimodal embedding expects. +func extractBase64FromDataURI(dataURI string) string { + if !strings.HasPrefix(dataURI, "data:") { + return dataURI // already raw base64 or a GCS URI + } + info := schemas.ExtractURLTypeInfo(dataURI) + if info.DataURLWithoutPrefix != nil { + return *info.DataURLWithoutPrefix } + return dataURI +} - return vertexReq +// ToVertexMultimodalEmbeddingRequest converts a Bifrost embedding request to the +// Vertex native multimodal embedding format (multimodalembedding@001). +// Each content entry maps to one instance. Parts within a content are merged into +// the instance fields (text, image, video). Only text, image, and video are supported; +// audio and file parts will return an error. +func ToVertexMultimodalEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*VertexEmbeddingRequest, error) { + if bifrostReq == nil || bifrostReq.Input == nil || len(bifrostReq.Input.Contents) == 0 { + return nil, fmt.Errorf("embedding input is not provided") + } + + instances := make([]VertexEmbeddingInstance, 0, len(bifrostReq.Input.Contents)) + for _, content := range bifrostReq.Input.Contents { + instance := VertexEmbeddingInstance{} + for _, part := range content { + switch part.Type { + case schemas.EmbeddingContentPartTypeText: + instance.Text = part.Text + case schemas.EmbeddingContentPartTypeImage: + if part.Image == nil { + return nil, fmt.Errorf("image part has no payload") + } + img := &VertexMultimodalImageInput{} + if part.Image.Data != nil { + b64 := extractBase64FromDataURI(*part.Image.Data) + img.BytesBase64Encoded = &b64 + } else if part.Image.URL != nil { + if !strings.HasPrefix(*part.Image.URL, "gs://") { + return nil, fmt.Errorf("vertex multimodal embedding requires a GCS URI (gs://) for image URL input") + } + img.GCSUri = part.Image.URL + } else { + return nil, fmt.Errorf("image part must set data or url") + } + instance.Image = img + case schemas.EmbeddingContentPartTypeVideo: + if part.Video == nil { + return nil, fmt.Errorf("video part has no payload") + } + if part.Video.URL == nil || !strings.HasPrefix(*part.Video.URL, "gs://") { + return nil, fmt.Errorf("vertex multimodal embedding requires a GCS URI (gs://) for video input") + } + vid := &VertexMultimodalVideoInput{GCSUri: part.Video.URL} + instance.Video = vid + default: + return nil, fmt.Errorf("vertex multimodalembedding@001 does not support %q parts", part.Type) + } + } + instances = append(instances, instance) + } + + req := &VertexEmbeddingRequest{Instances: instances} + if bifrostReq.Params != nil { + req.Parameters = &VertexEmbeddingParameters{ + Dimension: bifrostReq.Params.Dimensions, + AutoTruncate: bifrostReq.Params.AutoTruncate, + } + req.ExtraParams = bifrostReq.Params.ExtraParams + } + return req, nil } -// ToBifrostEmbeddingResponse converts a Vertex AI embedding response to Bifrost format +// ToBifrostEmbeddingResponse converts a Vertex AI embedding response to Bifrost format. +// Handles both text embedding responses (Embeddings.Values) and native multimodal +// responses (TextEmbedding / ImageEmbedding / VideoEmbeddings). func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse { if response == nil || len(response.Predictions) == 0 { return nil } - // Convert predictions to Bifrost embeddings embeddings := make([]schemas.EmbeddingData, 0, len(response.Predictions)) var usage *schemas.BifrostLLMUsage - - for i, prediction := range response.Predictions { - if prediction.Embeddings == nil || len(prediction.Embeddings.Values) == 0 { + idx := 0 + + for _, prediction := range response.Predictions { + // Text embedding model response + if prediction.Embeddings != nil && len(prediction.Embeddings.Values) > 0 { + embeddings = append(embeddings, schemas.EmbeddingData{ + Object: "embedding", + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), prediction.Embeddings.Values...)}, + Index: idx, + }) + idx++ + if prediction.Embeddings.Statistics != nil { + if usage == nil { + usage = &schemas.BifrostLLMUsage{} + } + usage.TotalTokens += prediction.Embeddings.Statistics.TokenCount + usage.PromptTokens += prediction.Embeddings.Statistics.TokenCount + } continue } - // Create embedding object - embedding := schemas.EmbeddingData{ - Object: "embedding", - Embedding: schemas.EmbeddingStruct{ - EmbeddingArray: append([]float64(nil), prediction.Embeddings.Values...), - }, - Index: i, + // Native multimodal model response — textEmbedding, imageEmbedding, videoEmbeddings + // are all in the same embedding space so each is returned as a separate EmbeddingData. + if len(prediction.TextEmbedding) > 0 { + embeddings = append(embeddings, schemas.EmbeddingData{ + Object: "embedding", + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), prediction.TextEmbedding...)}, + Index: idx, + }) + idx++ } - - // Extract statistics if available - if prediction.Embeddings.Statistics != nil { - if usage == nil { - usage = &schemas.BifrostLLMUsage{} - } - usage.TotalTokens += prediction.Embeddings.Statistics.TokenCount - usage.PromptTokens += prediction.Embeddings.Statistics.TokenCount + if len(prediction.ImageEmbedding) > 0 { + embeddings = append(embeddings, schemas.EmbeddingData{ + Object: "embedding", + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), prediction.ImageEmbedding...)}, + Index: idx, + }) + idx++ + } + for _, ve := range prediction.VideoEmbeddings { + embeddings = append(embeddings, schemas.EmbeddingData{ + Object: "embedding", + Embedding: schemas.EmbeddingsByType{Float: append([]float64(nil), ve.Embedding...)}, + Index: idx, + }) + idx++ } - - embeddings = append(embeddings, embedding) } return &schemas.BifrostEmbeddingResponse{ - Object: "list", - Data: embeddings, - Usage: usage, - ExtraFields: schemas.BifrostResponseExtraFields{ - }, + Object: "list", + Data: embeddings, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{}, } } diff --git a/core/providers/vertex/embedding_multimodal_test.go b/core/providers/vertex/embedding_multimodal_test.go new file mode 100644 index 0000000000..bd8d28e14b --- /dev/null +++ b/core/providers/vertex/embedding_multimodal_test.go @@ -0,0 +1,45 @@ +package vertex + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestToVertexGeminiEmbeddingRequest(t *testing.T) { + text := "hello" + req, err := ToVertexGeminiEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{{ + {Type: schemas.EmbeddingContentPartTypeText, Text: &text}, + {Type: schemas.EmbeddingContentPartTypeImage, Image: &schemas.EmbeddingMediaPart{URL: schemas.Ptr("https://example.com/img.png")}}, + }}, + }, + Params: &schemas.EmbeddingParameters{ + TaskType: schemas.Ptr("RETRIEVAL_DOCUMENT"), + Dimensions: schemas.Ptr(128), + }, + }) + require.NoError(t, err) + require.NotNil(t, req.Content) + require.Len(t, req.Content.Parts, 2) + require.Equal(t, "hello", req.Content.Parts[0].Text) + require.NotNil(t, req.Content.Parts[1].FileData) + require.Equal(t, 128, *req.OutputDimensionality) +} + +func TestToVertexGeminiEmbeddingRequestRejectsBatch(t *testing.T) { + t1 := "first" + t2 := "second" + _, err := ToVertexGeminiEmbeddingRequest(&schemas.BifrostEmbeddingRequest{ + Input: &schemas.EmbeddingInput{ + Contents: []schemas.EmbeddingContent{ + {{Type: schemas.EmbeddingContentPartTypeText, Text: &t1}}, + {{Type: schemas.EmbeddingContentPartTypeText, Text: &t2}}, + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "batch") +} diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index bbdb89d17f..b4b600699a 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -3,6 +3,7 @@ package vertex import ( "time" + "github.com/maximhq/bifrost/core/providers/gemini" providerUtils "github.com/maximhq/bifrost/core/providers/utils" ) @@ -110,17 +111,54 @@ type VertexAdvancedVoiceOptions struct { LowLatencyJourneySynthesis bool `json:"lowLatencyJourneySynthesis,omitempty"` } -// VertexEmbeddingInstance represents a single embedding instance in the request +// VertexMultimodalImageInput represents an image for the multimodalembedding@001 model. +// Exactly one of BytesBase64Encoded or GCSUri must be set. +type VertexMultimodalImageInput struct { + BytesBase64Encoded *string `json:"bytesBase64Encoded,omitempty"` // Raw base64 string (no data URI prefix) + GCSUri *string `json:"gcsUri,omitempty"` // gs://bucket/object +} + +// VertexVideoSegmentConfig controls which portion of a video is embedded. +type VertexVideoSegmentConfig struct { + StartOffsetSec *int `json:"startOffsetSec,omitempty"` + EndOffsetSec *int `json:"endOffsetSec,omitempty"` + IntervalSec *int `json:"intervalSec,omitempty"` // 4=Essential, 8=Standard, 16=Plus +} + +// VertexMultimodalVideoInput represents a video for the multimodalembedding@001 model. +type VertexMultimodalVideoInput struct { + GCSUri *string `json:"gcsUri,omitempty"` + VideoSegmentConfig *VertexVideoSegmentConfig `json:"videoSegmentConfig,omitempty"` +} + +// VertexVideoEmbedding is one segment's embedding returned for a video input. +type VertexVideoEmbedding struct { + StartOffsetSec int `json:"startOffsetSec"` + EndOffsetSec int `json:"endOffsetSec"` + Embedding []float64 `json:"embedding"` +} + +// VertexEmbeddingInstance represents a single embedding instance in the request. +// For text embedding models (text-multilingual-embedding-*): populate Content, TaskType, Title. +// For the native multimodal model (multimodalembedding@001): populate Text, Image, and/or Video. type VertexEmbeddingInstance struct { - Content string `json:"content"` // The text to generate embeddings for - TaskType *string `json:"task_type,omitempty"` // Intended downstream application (optional) - Title *string `json:"title,omitempty"` // Used to help the model produce better embeddings (optional) + // Text embedding fields + Content string `json:"content,omitempty"` // Plain text for text-only embedding models + TaskType *string `json:"task_type,omitempty"` // Downstream task hint (text models only) + Title *string `json:"title,omitempty"` // Optional title (text models only) + + // Native multimodal embedding fields (multimodalembedding@001) + Text *string `json:"text,omitempty"` + Image *VertexMultimodalImageInput `json:"image,omitempty"` + Video *VertexMultimodalVideoInput `json:"video,omitempty"` } -// VertexEmbeddingParameters represents the parameters for the embedding request +// VertexEmbeddingParameters represents the parameters for the embedding request. +// Dimension applies to multimodalembedding@001; OutputDimensionality to text models. type VertexEmbeddingParameters struct { - AutoTruncate *bool `json:"autoTruncate,omitempty"` // When true, input text will be truncated (defaults to true) - OutputDimensionality *int `json:"outputDimensionality,omitempty"` // Output embedding size (optional) + AutoTruncate *bool `json:"autoTruncate,omitempty"` // Truncate long inputs (text models) + OutputDimensionality *int `json:"outputDimensionality,omitempty"` // Output dimensions (text models) + Dimension *int `json:"dimension,omitempty"` // Output dimensions (multimodalembedding@001) } // VertexEmbeddingRequest represents the complete embedding request to Vertex AI @@ -134,6 +172,21 @@ func (r *VertexEmbeddingRequest) GetExtraParams() map[string]interface{} { return r.ExtraParams } +type VertexGeminiEmbeddingRequest struct { + Content *gemini.Content `json:"content,omitempty"` + DocumentOCR *bool `json:"documentOcr,omitempty"` + AudioTrackExtraction *bool `json:"audioTrackExtraction,omitempty"` + TaskType *string `json:"taskType,omitempty"` + Title *string `json:"title,omitempty"` + OutputDimensionality *int `json:"outputDimensionality,omitempty"` + AutoTruncate *bool `json:"autoTruncate,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *VertexGeminiEmbeddingRequest) GetExtraParams() map[string]interface{} { + return r.ExtraParams +} + // VertexEmbeddingStatistics represents statistics computed from the input text type VertexEmbeddingStatistics struct { Truncated bool `json:"truncated"` // Whether the input text was truncated @@ -146,9 +199,18 @@ type VertexEmbeddingValues struct { Statistics *VertexEmbeddingStatistics `json:"statistics"` // Statistics about the input text } -// VertexEmbeddingPrediction represents a single prediction in the response +// VertexEmbeddingPrediction represents a single prediction in the response. +// Text embedding models populate Embeddings. +// The native multimodal model (multimodalembedding@001) populates TextEmbedding, +// ImageEmbedding, and/or VideoEmbeddings — all in the same embedding space. type VertexEmbeddingPrediction struct { - Embeddings *VertexEmbeddingValues `json:"embeddings"` // The embedding result + // Text embedding model response + Embeddings *VertexEmbeddingValues `json:"embeddings,omitempty"` + + // Native multimodal model response (multimodalembedding@001) + TextEmbedding []float64 `json:"textEmbedding,omitempty"` + ImageEmbedding []float64 `json:"imageEmbedding,omitempty"` + VideoEmbeddings []VertexVideoEmbedding `json:"videoEmbeddings,omitempty"` } // VertexEmbeddingResponse represents the complete embedding response from Vertex AI diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 04701ee18d..ed2fece6f5 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -1418,10 +1418,19 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem return nil, providerUtils.NewConfigurationError("region is not set in key config") } + isGeminiEmbedding2Request := isVertexGeminiEmbeddingModel(request.Model) + isNativeMultimodalRequest := isVertexNativeMultimodalEmbeddingModel(request.Model) + jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { + if isGeminiEmbedding2Request { + return ToVertexGeminiEmbeddingRequest(request) + } + if isNativeMultimodalRequest { + return ToVertexMultimodalEmbeddingRequest(request) + } return ToVertexEmbeddingRequest(request), nil }, ) @@ -1429,14 +1438,25 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem return nil, bifrostErr } + authQuery := "" + if key.Value.GetValue() != "" { + authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) + } + // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - // Build the native Vertex embedding API endpoint - url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict") + // Build the native Vertex embedding API endpoint. + // Gemini embedding models use :embedContent; all others (text + native multimodal) use :predict. + var url string + if isGeminiEmbedding2Request { + url = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":embedContent") + } else { + url = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict") + } // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1450,22 +1470,29 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem }() req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(url) req.Header.SetContentType("application/json") // Set any extra headers from network config providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - // Getting oauth2 token - tokenSource, err := getAuthTokenSource(key) - if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) - } - token, err := tokenSource.Token() - if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err) + // If auth query is set, add it to the URL + // Otherwise, get the oauth2 token and set the Authorization header + if authQuery != "" { + url = fmt.Sprintf("%s?%s", url, authQuery) + } else { + // Getting oauth2 token + tokenSource, err := getAuthTokenSource(key) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) + } + token, err := tokenSource.Token() + if err != nil { + return nil, providerUtils.NewBifrostOperationError("error getting token", err) + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) } - req.Header.Set("Authorization", "Bearer "+token.AccessToken) + + req.SetRequestURI(url) usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req) if !usedLargePayloadBody { @@ -1530,18 +1557,38 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem }, nil } - // Parse Vertex's native embedding response using typed response - var vertexResponse VertexEmbeddingResponse - if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + var bifrostResponse *schemas.BifrostEmbeddingResponse + // Use centralized Vertex converter + if isGeminiEmbedding2Request { + var geminiResponse gemini.GeminiEmbeddingResponse + if err := sonic.Unmarshal(responseBody, &geminiResponse); err != nil { + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + bifrostResponse = gemini.ToBifrostEmbeddingResponse(&geminiResponse, request.Model) + } else { + // Parse Vertex's native embedding response using typed response + var vertexResponse VertexEmbeddingResponse + if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + bifrostResponse = vertexResponse.ToBifrostEmbeddingResponse() } - // Use centralized Vertex converter - bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse() + if bifrostResponse == nil { + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, fmt.Errorf("provider returned empty embedding response")), + jsonBody, + responseBody, + provider.sendBackRawRequest, + provider.sendBackRawResponse, + ) + } // Set ExtraFields bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) + bifrostResponse.Model = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -1555,6 +1602,10 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem return bifrostResponse, nil } +func (provider *VertexProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the Vertex provider. func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index d754f33d22..94adc0def5 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -26,49 +26,51 @@ func TestVertex(t *testing.T) { rerankModel := strings.TrimSpace(os.Getenv("VERTEX_RERANK_MODEL")) testConfig := llmtests.ComprehensiveTestConfig{ - Provider: schemas.Vertex, - ChatModel: "gemini-2.5-pro", - PromptCachingModel: "claude-sonnet-4-5", - VisionModel: "claude-sonnet-4-5", - TextModel: "", // Vertex doesn't support text completion in newer models - EmbeddingModel: "text-multilingual-embedding-002", - RerankModel: rerankModel, - ReasoningModel: "claude-4.5-haiku", - ImageGenerationModel: "gemini-2.5-flash-image", - ImageEditModel: "imagen-3.0-capability-001", - VideoGenerationModel: "veo-3.1-generate-preview", + Provider: schemas.Vertex, + ChatModel: "gemini-2.5-pro", + PromptCachingModel: "claude-sonnet-4-5", + VisionModel: "claude-sonnet-4-5", + TextModel: "", // Vertex doesn't support text completion in newer models + EmbeddingModel: "text-multilingual-embedding-002", + MultimodalEmbeddingModel: "gemini-embedding-2-preview", + RerankModel: rerankModel, + ReasoningModel: "claude-4.5-haiku", + ImageGenerationModel: "gemini-2.5-flash-image", + ImageEditModel: "imagen-3.0-capability-001", + VideoGenerationModel: "veo-3.1-generate-preview", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, - MultipleToolCalls: true, - MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, - VideoGeneration: false, // disabled for now because of long running operations - VideoRetrieve: false, - VideoRemix: false, - VideoDownload: false, - VideoList: false, - VideoDelete: false, - MultipleImages: true, - CompleteEnd2End: true, - FileBase64: true, - Embedding: true, - Rerank: rerankModel != "", - Reasoning: true, - PromptCaching: true, - ListModels: false, - CountTokens: true, - StructuredOutputs: true, // Structured outputs with nullable enum support + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, + MultipleToolCalls: true, + MultipleToolCallsStreaming: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, + VideoGeneration: false, // disabled for now because of long running operations + VideoRetrieve: false, + VideoRemix: false, + VideoDownload: false, + VideoList: false, + VideoDelete: false, + MultipleImages: true, + CompleteEnd2End: true, + FileBase64: true, + Embedding: true, + MultimodalEmbedding: true, + Rerank: rerankModel != "", + Reasoning: true, + PromptCaching: true, + ListModels: false, + CountTokens: true, + StructuredOutputs: true, // Structured outputs with nullable enum support InterleavedThinking: true, EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Vertex) ServerToolsViaOpenAIEndpoint: true, // web_search only on Vertex per Table 20 (web_fetch/code_execution skip) diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index 7952e161cd..bd2cfb201b 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -246,6 +246,10 @@ func (provider *VLLMProvider) Embedding(ctx *schemas.BifrostContext, key schemas ) } +func (provider *VLLMProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Responses performs a responses request to vLLM's API (via chat completion). func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index e787f307fd..87bb04bae3 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -223,6 +223,10 @@ func (provider *XAIProvider) Embedding(ctx *schemas.BifrostContext, key schemas. return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } +func (provider *XAIProvider) BatchEmbedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchEmbeddingRequest, provider.GetProviderKey()) +} + // Speech is not supported by the xAI provider. func (provider *XAIProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 25eaa6fd01..379b135192 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -110,6 +110,7 @@ const ( ResponsesRequest RequestType = "responses" ResponsesStreamRequest RequestType = "responses_stream" EmbeddingRequest RequestType = "embedding" + BatchEmbeddingRequest RequestType = "batch_embedding" SpeechRequest RequestType = "speech" SpeechStreamRequest RequestType = "speech_stream" TranscriptionRequest RequestType = "transcription" @@ -399,6 +400,7 @@ type BifrostRequest struct { ResponsesRequest *BifrostResponsesRequest CountTokensRequest *BifrostResponsesRequest EmbeddingRequest *BifrostEmbeddingRequest + BatchEmbeddingRequest *BifrostBatchEmbeddingRequest RerankRequest *BifrostRerankRequest OCRRequest *BifrostOCRRequest SpeechRequest *BifrostSpeechRequest @@ -455,6 +457,8 @@ func (br *BifrostRequest) GetRequestFields() (provider ModelProvider, model stri return br.CountTokensRequest.Provider, br.CountTokensRequest.Model, br.CountTokensRequest.Fallbacks case br.EmbeddingRequest != nil: return br.EmbeddingRequest.Provider, br.EmbeddingRequest.Model, br.EmbeddingRequest.Fallbacks + case br.BatchEmbeddingRequest != nil: + return br.BatchEmbeddingRequest.Provider, br.BatchEmbeddingRequest.Model, br.BatchEmbeddingRequest.Fallbacks case br.RerankRequest != nil: return br.RerankRequest.Provider, br.RerankRequest.Model, br.RerankRequest.Fallbacks case br.OCRRequest != nil: @@ -596,6 +600,8 @@ func (br *BifrostRequest) SetProvider(provider ModelProvider) { br.CountTokensRequest.Provider = provider case br.EmbeddingRequest != nil: br.EmbeddingRequest.Provider = provider + case br.BatchEmbeddingRequest != nil: + br.BatchEmbeddingRequest.Provider = provider case br.RerankRequest != nil: br.RerankRequest.Provider = provider case br.OCRRequest != nil: @@ -647,6 +653,8 @@ func (br *BifrostRequest) SetModel(model string) { br.CountTokensRequest.Model = model case br.EmbeddingRequest != nil: br.EmbeddingRequest.Model = model + case br.BatchEmbeddingRequest != nil: + br.BatchEmbeddingRequest.Model = model case br.RerankRequest != nil: br.RerankRequest.Model = model case br.OCRRequest != nil: @@ -700,6 +708,8 @@ func (br *BifrostRequest) SetFallbacks(fallbacks []Fallback) { br.CountTokensRequest.Fallbacks = fallbacks case br.EmbeddingRequest != nil: br.EmbeddingRequest.Fallbacks = fallbacks + case br.BatchEmbeddingRequest != nil: + br.BatchEmbeddingRequest.Fallbacks = fallbacks case br.RerankRequest != nil: br.RerankRequest.Fallbacks = fallbacks case br.OCRRequest != nil: @@ -731,6 +741,8 @@ func (br *BifrostRequest) SetRawRequestBody(rawRequestBody []byte) { br.CountTokensRequest.RawRequestBody = rawRequestBody case br.EmbeddingRequest != nil: br.EmbeddingRequest.RawRequestBody = rawRequestBody + case br.BatchEmbeddingRequest != nil: + br.BatchEmbeddingRequest.RawRequestBody = rawRequestBody case br.RerankRequest != nil: br.RerankRequest.RawRequestBody = rawRequestBody case br.OCRRequest != nil: diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 8dfe6ea34b..0cc7e8131e 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -18,6 +18,47 @@ func (r *BifrostEmbeddingRequest) GetRawRequestBody() []byte { return r.RawRequestBody } +// BifrostBatchEmbeddingRequest supports batch embeddings where each item can carry its own +type BifrostBatchEmbeddingRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Params *EmbeddingParameters `json:"params,omitempty"` // default for all items + Items []BifrostEmbeddingBatchItem `json:"items"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` + RawRequestBody []byte `json:"-"` +} + +func (r *BifrostBatchEmbeddingRequest) GetRawRequestBody() []byte { + return r.RawRequestBody +} + +func (r *BifrostBatchEmbeddingRequest) Validate() error { + if r == nil || len(r.Items) == 0 { + return fmt.Errorf("batch embedding request has no items") + } + for i, item := range r.Items { + if err := item.Content.Validate(); err != nil { + return fmt.Errorf("item %d: %w", i, err) + } + } + return nil +} + +// BifrostEmbeddingBatchItem is one entry in a BifrostBatchEmbeddingRequest. +// Params nil means inherit from BifrostBatchEmbeddingRequest.Params. +type BifrostEmbeddingBatchItem struct { + Content EmbeddingContent `json:"content"` + Params *EmbeddingParameters `json:"params,omitempty"` +} + +// EffectiveParams returns the item-level Params if set, otherwise the batch-level default. +func (i *BifrostEmbeddingBatchItem) EffectiveParams(defaultParams *EmbeddingParameters) *EmbeddingParameters { + if i.Params != nil { + return i.Params + } + return defaultParams +} + type BifrostEmbeddingResponse struct { Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) Model string `json:"model"` @@ -38,86 +79,165 @@ func (r *BifrostEmbeddingResponse) BackfillParams(request *BifrostEmbeddingReque // EmbeddingInput represents the input for an embedding request. type EmbeddingInput struct { - Text *string - Texts []string - Embedding []int - Embeddings [][]int + Contents []EmbeddingContent `json:"contents,omitempty"` +} + +type EmbeddingContent []EmbeddingContentPart + +type EmbeddingContentPartType string + +const ( + EmbeddingContentPartTypeText EmbeddingContentPartType = "text" + EmbeddingContentPartTypeImage EmbeddingContentPartType = "image" + EmbeddingContentPartTypeAudio EmbeddingContentPartType = "audio" + EmbeddingContentPartTypeFile EmbeddingContentPartType = "file" + EmbeddingContentPartTypeVideo EmbeddingContentPartType = "video" + EmbeddingContentPartTypeTokens EmbeddingContentPartType = "tokens" +) + +type EmbeddingContentPart struct { + Type EmbeddingContentPartType `json:"type"` + + Text *string `json:"text,omitempty"` + Image *EmbeddingMediaPart `json:"image,omitempty"` + Audio *EmbeddingMediaPart `json:"audio,omitempty"` + File *EmbeddingMediaPart `json:"file,omitempty"` + Video *EmbeddingMediaPart `json:"video,omitempty"` + Tokens []int `json:"tokens,omitempty"` } -func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { - // enforce one-of +type EmbeddingMediaPart struct { + Data *string `json:"data,omitempty"` + URL *string `json:"url,omitempty"` + MIMEType *string `json:"mime_type,omitempty"` + Filename *string `json:"filename,omitempty"` + Detail *string `json:"detail,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +func (m *EmbeddingMediaPart) Validate() error { + if m == nil { + return fmt.Errorf("embedding media payload is nil") + } set := 0 - if e.Text != nil { + if m.Data != nil { + if *m.Data == "" { + return fmt.Errorf("embedding media data is empty") + } set++ } - if e.Texts != nil { + if m.URL != nil { + if *m.URL == "" { + return fmt.Errorf("embedding media url is empty") + } set++ } - if e.Embedding != nil { + if set != 1 { + return fmt.Errorf("embedding media payload must set exactly one of data or url") + } + return nil +} + +func (p EmbeddingContentPart) Validate() error { + set := 0 + if p.Text != nil { set++ } - if e.Embeddings != nil { + if p.Image != nil { set++ } - if set == 0 { - return nil, fmt.Errorf("embedding input is empty") + if p.Audio != nil { + set++ } - if set > 1 { - return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings") + if p.File != nil { + set++ } - - if e.Text != nil { - return MarshalSorted(*e.Text) + if p.Video != nil { + set++ } - if e.Texts != nil { - return MarshalSorted(e.Texts) + if len(p.Tokens) > 0 { + set++ } - if e.Embedding != nil { - return MarshalSorted(e.Embedding) + if set != 1 { + return fmt.Errorf("embedding content part must set exactly one modality") } - if e.Embeddings != nil { - return MarshalSorted(e.Embeddings) + + switch p.Type { + case EmbeddingContentPartTypeText: + if p.Text == nil { + return fmt.Errorf("embedding content part type %q requires text payload", p.Type) + } + case EmbeddingContentPartTypeImage: + if p.Image == nil { + return fmt.Errorf("embedding content part type %q requires image payload", p.Type) + } + return p.Image.Validate() + case EmbeddingContentPartTypeAudio: + if p.Audio == nil { + return fmt.Errorf("embedding content part type %q requires audio payload", p.Type) + } + return p.Audio.Validate() + case EmbeddingContentPartTypeFile: + if p.File == nil { + return fmt.Errorf("embedding content part type %q requires file payload", p.Type) + } + return p.File.Validate() + case EmbeddingContentPartTypeVideo: + if p.Video == nil { + return fmt.Errorf("embedding content part type %q requires video payload", p.Type) + } + return p.Video.Validate() + case EmbeddingContentPartTypeTokens: + if len(p.Tokens) == 0 { + return fmt.Errorf("embedding content part type %q requires tokens payload", p.Type) + } + default: + return fmt.Errorf("unsupported embedding content part type %q", p.Type) } - return nil, fmt.Errorf("invalid embedding input") + return nil } -func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { - e.Text = nil - e.Texts = nil - e.Embedding = nil - e.Embeddings = nil - // Try string - var s string - if err := Unmarshal(data, &s); err == nil { - e.Text = &s - return nil +func (c EmbeddingContent) Validate() error { + if len(c) == 0 { + return fmt.Errorf("embedding content is empty") } - // Try []string - var ss []string - if err := Unmarshal(data, &ss); err == nil { - e.Texts = ss - return nil + for _, part := range c { + if err := part.Validate(); err != nil { + return err + } } - // Try []int - var i []int - if err := Unmarshal(data, &i); err == nil { - e.Embedding = i - return nil + return nil +} + +func (e *EmbeddingInput) Validate() error { + if e == nil || len(e.Contents) == 0 { + return fmt.Errorf("embedding input is empty") } - // Try [][]int - var i2 [][]int - if err := Unmarshal(data, &i2); err == nil { - e.Embeddings = i2 - return nil + for _, content := range e.Contents { + if err := content.Validate(); err != nil { + return err + } } + return nil +} - return fmt.Errorf("unsupported embedding input shape") +// GetContents returns the contents slice directly. +func (e *EmbeddingInput) GetContents() []EmbeddingContent { + if e == nil { + return nil + } + return e.Contents } type EmbeddingParameters struct { EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output + TaskType *string `json:"task_type,omitempty"` // Intended embedding task + Title *string `json:"title,omitempty"` // Optional title for the content + AutoTruncate *bool `json:"auto_truncate,omitempty"` // Automatically truncate long inputs + Truncate *string `json:"truncate,omitempty"` // Provider-specific truncation strategy + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum tokens to process // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. @@ -125,74 +245,16 @@ type EmbeddingParameters struct { } type EmbeddingData struct { - Index int `json:"index"` - Object string `json:"object"` // "embedding" - Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32 + Index int `json:"index"` + Object string `json:"object"` // "embedding" + Embedding EmbeddingsByType `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32 } -type EmbeddingStruct struct { - // Embedding responses preserve provider precision in normalized API output. - EmbeddingStr *string - EmbeddingArray []float64 - Embedding2DArray [][]float64 - EmbeddingInt8Array []int8 // for int8 / binary formats - EmbeddingInt32Array []int32 // for uint8 / ubinary formats -} - -func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { - if be.EmbeddingStr != nil { - return MarshalSorted(be.EmbeddingStr) - } - if be.EmbeddingArray != nil { - return MarshalSorted(be.EmbeddingArray) - } - if be.Embedding2DArray != nil { - return MarshalSorted(be.Embedding2DArray) - } - if be.EmbeddingInt8Array != nil { - return Marshal(be.EmbeddingInt8Array) - } - if be.EmbeddingInt32Array != nil { - return Marshal(be.EmbeddingInt32Array) - } - return nil, fmt.Errorf("no embedding found") -} - -func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { - // First, try to unmarshal as a direct string - var stringContent string - if err := Unmarshal(data, &stringContent); err == nil { - be.EmbeddingStr = &stringContent - return nil - } - - // Try to unmarshal as a direct array of float64 - var arrayContent []float64 - if err := Unmarshal(data, &arrayContent); err == nil { - be.EmbeddingArray = arrayContent - return nil - } - - // Try to unmarshal as a direct 2D array of float64 - var arrayContent2D [][]float64 - if err := Unmarshal(data, &arrayContent2D); err == nil { - be.Embedding2DArray = arrayContent2D - return nil - } - - // Try to unmarshal as a direct array of int8 - var int8Content []int8 - if err := Unmarshal(data, &int8Content); err == nil { - be.EmbeddingInt8Array = int8Content - return nil - } - - // Try to unmarshal as a direct array of int32 - var int32Content []int32 - if err := Unmarshal(data, &int32Content); err == nil { - be.EmbeddingInt32Array = int32Content - return nil - } - - return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32") +type EmbeddingsByType struct { + Float []float64 `json:"float,omitempty"` // Float embeddings + Int8 []int8 `json:"int8,omitempty"` // Int8 embeddings + Uint8 []uint8 `json:"uint8,omitempty"` // Uint8 embeddings + Binary []int8 `json:"binary,omitempty"` // Binary embeddings + Ubinary []uint8 `json:"ubinary,omitempty"` // Unsigned binary embeddings + Base64 *string `json:"base64,omitempty"` // Base64 embeddings } diff --git a/core/schemas/embedding_multimodal_test.go b/core/schemas/embedding_multimodal_test.go new file mode 100644 index 0000000000..07b688790d --- /dev/null +++ b/core/schemas/embedding_multimodal_test.go @@ -0,0 +1,27 @@ +package schemas + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEmbeddingInputValidateRejectsEmpty(t *testing.T) { + input := &EmbeddingInput{} + err := input.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "empty") +} + +func TestEmbeddingContentPartValidateRejectsMultipleModalities(t *testing.T) { + text := "bad" + part := EmbeddingContentPart{ + Type: EmbeddingContentPartTypeImage, + Text: &text, + Image: &EmbeddingMediaPart{URL: Ptr("https://example.com/img.png")}, + } + + err := part.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "exactly one modality") +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 80f6bf3d91..0113bf148e 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -308,6 +308,7 @@ type AllowedRequests struct { ResponsesStream bool `json:"responses_stream"` CountTokens bool `json:"count_tokens"` Embedding bool `json:"embedding"` + BatchEmbedding bool `json:"batch_embedding"` Rerank bool `json:"rerank"` OCR bool `json:"ocr"` Speech bool `json:"speech"` @@ -381,6 +382,8 @@ func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool { return ar.CountTokens case EmbeddingRequest: return ar.Embedding + case BatchEmbeddingRequest: + return ar.BatchEmbedding case RerankRequest: return ar.Rerank case OCRRequest: @@ -592,6 +595,8 @@ type Provider interface { CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError) // Embedding performs an embedding request Embedding(ctx *BifrostContext, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError) + // BatchEmbedding performs a batch embedding request with optional per-item parameter overrides + BatchEmbedding(ctx *BifrostContext, key Key, request *BifrostBatchEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError) // Rerank performs a rerank request to reorder documents by relevance to a query Rerank(ctx *BifrostContext, key Key, request *BifrostRerankRequest) (*BifrostRerankResponse, *BifrostError) // OCR performs an optical character recognition request on a document diff --git a/core/schemas/serialization_test.go b/core/schemas/serialization_test.go index f7ba3d80f8..033438304c 100644 --- a/core/schemas/serialization_test.go +++ b/core/schemas/serialization_test.go @@ -158,29 +158,29 @@ func TestSonic_OrderedMap_NestedPreservesOrder(t *testing.T) { assert.Equal(t, input, string(output)) } -func TestSonic_EmbeddingStruct_PreservesFloat64Precision(t *testing.T) { +func TestSonic_EmbeddingsByType_PreservesFloat64Precision(t *testing.T) { const want = 0.12345678901234568 - var embedding EmbeddingStruct - err := embedding.UnmarshalJSON([]byte(`[0.12345678901234568]`)) + var embedding EmbeddingsByType + err := Unmarshal([]byte(`{"float":[0.12345678901234568]}`), &embedding) require.NoError(t, err) - require.Len(t, embedding.EmbeddingArray, 1) + require.Len(t, embedding.Float, 1) - got := embedding.EmbeddingArray[0] + got := embedding.Float[0] assert.Equal(t, want, got) float32Rounded := float64(float32(want)) assert.NotEqual(t, float32Rounded, got) - marshaled, err := embedding.MarshalJSON() + marshaled, err := Marshal(embedding) require.NoError(t, err) - var roundTrip []float64 + var roundTrip EmbeddingsByType err = Unmarshal(marshaled, &roundTrip) require.NoError(t, err) - require.Len(t, roundTrip, 1) - assert.Equal(t, math.Float64bits(got), math.Float64bits(roundTrip[0])) + require.Len(t, roundTrip.Float, 1) + assert.Equal(t, math.Float64bits(got), math.Float64bits(roundTrip.Float[0])) } // --- ToolFunctionParameters through sonic --- diff --git a/core/schemas/trace.go b/core/schemas/trace.go index 6a3b2b9290..219a906e1d 100644 --- a/core/schemas/trace.go +++ b/core/schemas/trace.go @@ -8,16 +8,16 @@ import ( // Trace represents a distributed trace that captures the full lifecycle of a request type Trace struct { - RequestID string // Request ID for the trace - TraceID string // Unique identifier for this trace - ParentID string // Parent trace ID from incoming W3C traceparent header - RootSpan *Span // The root span of this trace - Spans []*Span // All spans in this trace - StartTime time.Time // When the trace started - EndTime time.Time // When the trace completed - Attributes map[string]any // Additional attributes for the trace + RequestID string // Request ID for the trace + TraceID string // Unique identifier for this trace + ParentID string // Parent trace ID from incoming W3C traceparent header + RootSpan *Span // The root span of this trace + Spans []*Span // All spans in this trace + StartTime time.Time // When the trace started + EndTime time.Time // When the trace completed + Attributes map[string]any // Additional attributes for the trace PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing - mu sync.Mutex // Mutex for thread-safe span operations + mu sync.Mutex // Mutex for thread-safe span operations } // AddSpan adds a span to the trace in a thread-safe manner @@ -216,7 +216,7 @@ const ( AttrN = "gen_ai.request.n" AttrSeed = "gen_ai.request.seed" AttrSuffix = "gen_ai.request.suffix" - AttrDimensions = "gen_ai.request.dimensions" + AttrDimensions = "gen_ai.embeddings.dimension.count" AttrEncodingFormat = "gen_ai.request.encoding_format" AttrLanguage = "gen_ai.request.language" AttrPrompt = "gen_ai.request.prompt" diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index 24e7882897..cb4d8920d0 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -245,6 +245,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddStopReasonColumn(ctx, db); err != nil { return err } + if err := migrationAddEmbeddingInputColumn(ctx, db); err != nil { + return err + } return nil } @@ -2664,3 +2667,81 @@ func migrationAddStopReasonColumn(ctx context.Context, db *gorm.DB) error { return nil } +// migrationAddEmbeddingInputColumn adds the embedding_input column to the logs table and +// backfills historical embedding logs by reconstructing EmbeddingContent entries from +// text blocks stored in the old input_history column. +func migrationAddEmbeddingInputColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_embedding_input_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + gormMigrator := tx.Migrator() + if !gormMigrator.HasColumn(&Log{}, "embedding_input") { + if err := gormMigrator.AddColumn(&Log{}, "embedding_input"); err != nil { + return err + } + } + dialect := tx.Dialector.Name() + switch dialect { + case "postgres": + return tx.Exec(` + UPDATE logs + SET embedding_input = ( + SELECT jsonb_agg( + jsonb_build_array( + jsonb_build_object('type', 'text', 'text', block->>'text') + ) + ) + FROM jsonb_array_elements(input_history::jsonb) AS history_item, + jsonb_array_elements(history_item -> 'content' -> 'content_blocks') AS block + WHERE block->>'type' = 'text' + AND (block->>'text') IS NOT NULL + AND (block->>'text') != '' + ) + WHERE object_type = 'embedding' + AND input_history IS NOT NULL + AND input_history NOT IN ('', '[]', 'null') + AND embedding_input IS NULL + `).Error + case "sqlite": + return tx.Exec(` + UPDATE logs + SET embedding_input = ( + SELECT json_group_array( + json_array( + json_object('type', 'text', 'text', json_extract(block.value, '$.text')) + ) + ) + FROM json_each(input_history) AS history_item, + json_each(json_extract(history_item.value, '$.content.content_blocks')) AS block + WHERE json_extract(block.value, '$.type') = 'text' + AND json_extract(block.value, '$.text') IS NOT NULL + AND json_extract(block.value, '$.text') != '' + ) + WHERE object_type = 'embedding' + AND input_history IS NOT NULL + AND input_history NOT IN ('', '[]', 'null') + AND embedding_input IS NULL + `).Error + default: + return nil + } + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + gormMigrator := tx.Migrator() + if gormMigrator.HasColumn(&Log{}, "embedding_input") { + if err := gormMigrator.DropColumn(&Log{}, "embedding_input"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while adding embedding_input column: %s", err.Error()) + } + return nil +} diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index bff72c75b1..7861d2f4fd 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -141,6 +141,7 @@ type Log struct { ResponsesInputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ResponsesMessage OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage ResponsesOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ResponsesMessage + EmbeddingInput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.EmbeddingContent EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized [][]float32 RerankOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.RerankResult OCROutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostOCRResponse @@ -197,6 +198,7 @@ type Log struct { ResponsesInputHistoryParsed []schemas.ResponsesMessage `gorm:"-" json:"responses_input_history,omitempty"` OutputMessageParsed *schemas.ChatMessage `gorm:"-" json:"output_message,omitempty"` ResponsesOutputParsed []schemas.ResponsesMessage `gorm:"-" json:"responses_output,omitempty"` + EmbeddingInputParsed []schemas.EmbeddingContent `gorm:"-" json:"embedding_input,omitempty"` EmbeddingOutputParsed []schemas.EmbeddingData `gorm:"-" json:"embedding_output,omitempty"` RerankOutputParsed []schemas.RerankResult `gorm:"-" json:"rerank_output,omitempty"` OCROutputParsed *schemas.BifrostOCRResponse `gorm:"-" json:"ocr_output,omitempty"` @@ -305,6 +307,14 @@ func (l *Log) SerializeFields() error { } } + if l.EmbeddingInputParsed != nil { + if data, err := sonic.Marshal(l.EmbeddingInputParsed); err != nil { + return err + } else { + l.EmbeddingInput = string(data) + } + } + if l.EmbeddingOutputParsed != nil { if data, err := sonic.Marshal(l.EmbeddingOutputParsed); err != nil { return err @@ -572,6 +582,12 @@ func (l *Log) DeserializeFields() error { } } + if l.EmbeddingInput != "" { + if err := sonic.Unmarshal([]byte(l.EmbeddingInput), &l.EmbeddingInputParsed); err != nil { + l.EmbeddingInputParsed = nil + } + } + if l.EmbeddingOutput != "" { if err := sonic.Unmarshal([]byte(l.EmbeddingOutput), &l.EmbeddingOutputParsed); err != nil { // Log error but don't fail the operation - initialize as nil @@ -1029,6 +1045,15 @@ type MCPToolLogStats struct { func (l *Log) BuildContentSummary() string { var parts []string + // Add embedding input text parts + for _, content := range l.EmbeddingInputParsed { + for _, part := range content { + if part.Type == schemas.EmbeddingContentPartTypeText && part.Text != nil && *part.Text != "" { + parts = append(parts, *part.Text) + } + } + } + // Add input messages for _, msg := range l.InputHistoryParsed { if msg.Content != nil { diff --git a/framework/tracing/llmspan.go b/framework/tracing/llmspan.go index cd981a5947..11540f428a 100644 --- a/framework/tracing/llmspan.go +++ b/framework/tracing/llmspan.go @@ -435,17 +435,16 @@ func PopulateEmbeddingRequestAttributes(req *schemas.BifrostEmbeddingRequest, at // Extract input if req.Input != nil { - if req.Input.Text != nil { - attrs[schemas.AttrInputText] = *req.Input.Text - } else if req.Input.Texts != nil { - attrs[schemas.AttrInputText] = strings.Join(req.Input.Texts, ",") - } else if req.Input.Embedding != nil { - embedding := make([]string, len(req.Input.Embedding)) - for i, v := range req.Input.Embedding { - // Use a float‑safe representation; adjust precision as needed. - embedding[i] = fmt.Sprintf("%v", v) + var texts []string + for _, content := range req.Input.Contents { + for _, part := range content { + if part.Type == schemas.EmbeddingContentPartTypeText && part.Text != nil { + texts = append(texts, *part.Text) + } } - attrs[schemas.AttrInputEmbedding] = strings.Join(embedding, ",") + } + if len(texts) > 0 { + attrs[schemas.AttrInputText] = strings.Join(texts, ",") } } } diff --git a/plugins/logging/main.go b/plugins/logging/main.go index d635991107..855e104de8 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -242,6 +242,7 @@ type InitialLogData struct { Object string InputHistory []schemas.ChatMessage ResponsesInputHistory []schemas.ResponsesMessage + EmbeddingInput []schemas.EmbeddingContent Params any SpeechInput *schemas.SpeechInput TranscriptionInput *schemas.TranscriptionInput @@ -530,6 +531,24 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr } case schemas.EmbeddingRequest: initialData.Params = req.EmbeddingRequest.Params + embContents := extractEmbeddingInput(req) + reqThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadRequestThreshold).(int64) + if reqThreshold > 0 && len(embContents) > 0 { + var totalDataSize int64 + for _, content := range embContents { + for _, part := range content { + for _, media := range []*schemas.EmbeddingMediaPart{part.Image, part.Audio, part.File, part.Video} { + if media != nil && media.Data != nil { + totalDataSize += int64(len(*media.Data)) + } + } + } + } + if totalDataSize > reqThreshold { + embContents = redactEmbeddingMediaData(embContents) + } + } + initialData.EmbeddingInput = embContents case schemas.RerankRequest: initialData.Params = req.RerankRequest.Params case schemas.OCRRequest: diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 8c90b517bc..92ee407ddb 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -512,36 +512,6 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s }, }, []schemas.ResponsesMessage{} } - if request.EmbeddingRequest != nil { - // Large payload passthrough can intentionally leave Input nil to avoid - // materializing giant request bodies. Logging should degrade gracefully. - if request.EmbeddingRequest.Input == nil { - return []schemas.ChatMessage{}, []schemas.ResponsesMessage{} - } - texts := request.EmbeddingRequest.Input.Texts - - if len(texts) == 0 && request.EmbeddingRequest.Input.Text != nil { - texts = []string{*request.EmbeddingRequest.Input.Text} - } - - contentBlocks := make([]schemas.ChatContentBlock, len(texts)) - for i, text := range texts { - // Create a per-iteration copy to avoid reusing the same memory address - t := text - contentBlocks[i] = schemas.ChatContentBlock{ - Type: schemas.ChatContentBlockTypeText, - Text: &t, - } - } - return []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: &schemas.ChatMessageContent{ - ContentBlocks: contentBlocks, - }, - }, - }, []schemas.ResponsesMessage{} - } if request.RerankRequest != nil { query := request.RerankRequest.Query return []schemas.ChatMessage{ @@ -559,6 +529,51 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s return []schemas.ChatMessage{}, []schemas.ResponsesMessage{} } +// extractEmbeddingInput returns the full EmbeddingContent slice from the request, +// preserving all part types (text, image, audio, file, video). Returns nil when +// Input is nil (large-payload passthrough). +func extractEmbeddingInput(request *schemas.BifrostRequest) []schemas.EmbeddingContent { + if request.EmbeddingRequest == nil || request.EmbeddingRequest.Input == nil { + return nil + } + return request.EmbeddingRequest.Input.Contents +} + +// redactEmbeddingMediaData returns a copy of contents with inline Data fields +// stripped from all media parts. Only called when the total data size exceeds +// the large-payload threshold. URL-based parts are preserved. +func redactEmbeddingMediaData(contents []schemas.EmbeddingContent) []schemas.EmbeddingContent { + stripped := make([]schemas.EmbeddingContent, len(contents)) + for i, content := range contents { + parts := make(schemas.EmbeddingContent, len(content)) + for j, part := range content { + if part.Image != nil { + cp := *part.Image + cp.Data = nil + part.Image = &cp + } + if part.Audio != nil { + cp := *part.Audio + cp.Data = nil + part.Audio = &cp + } + if part.File != nil { + cp := *part.File + cp.Data = nil + part.File = &cp + } + if part.Video != nil { + cp := *part.Video + cp.Data = nil + part.Video = &cp + } + parts[j] = part + } + stripped[i] = parts + } + return stripped +} + func extractRealtimeInputHistory(input []schemas.ResponsesMessage) []schemas.ChatMessage { messages := make([]schemas.ChatMessage, 0, len(input)) for _, item := range input { diff --git a/plugins/logging/writer.go b/plugins/logging/writer.go index fe34bc4a8f..54dec7be4f 100644 --- a/plugins/logging/writer.go +++ b/plugins/logging/writer.go @@ -235,6 +235,7 @@ func estimateLogEntrySize(log *logstore.Log) int { // baseline below. n := len(log.InputHistory) + len(log.ResponsesInputHistory) + + len(log.EmbeddingInput) + len(log.OutputMessage) + len(log.ResponsesOutput) + len(log.EmbeddingOutput) + @@ -284,6 +285,7 @@ func buildInitialLogEntry(pending *PendingLogData) *logstore.Log { CreatedAt: pending.Timestamp, InputHistoryParsed: pending.InitialData.InputHistory, ResponsesInputHistoryParsed: pending.InitialData.ResponsesInputHistory, + EmbeddingInputParsed: pending.InitialData.EmbeddingInput, ParamsParsed: pending.InitialData.Params, ToolsParsed: pending.InitialData.Tools, PassthroughRequestBody: pending.InitialData.PassthroughRequestBody, @@ -312,6 +314,7 @@ func buildCompleteLogEntryFromPending(pending *PendingLogData) *logstore.Log { // Set parsed fields for serialization via GORM hooks InputHistoryParsed: pending.InitialData.InputHistory, ResponsesInputHistoryParsed: pending.InitialData.ResponsesInputHistory, + EmbeddingInputParsed: pending.InitialData.EmbeddingInput, ParamsParsed: pending.InitialData.Params, ToolsParsed: pending.InitialData.Tools, SpeechInputParsed: pending.InitialData.SpeechInput, diff --git a/plugins/semanticcache/go.mod b/plugins/semanticcache/go.mod index b9de1b9be6..f27725a5b1 100644 --- a/plugins/semanticcache/go.mod +++ b/plugins/semanticcache/go.mod @@ -3,6 +3,7 @@ module github.com/maximhq/bifrost/plugins/semanticcache go 1.26.2 require ( + github.com/bytedance/sonic v1.15.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/google/uuid v1.6.0 github.com/maximhq/bifrost/core v1.5.8 @@ -41,7 +42,6 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.2 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 5bed26528d..1008afc5d8 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -222,28 +222,6 @@ func TestToFloat32Embedding(t *testing.T) { } } -func TestFlattenToFloat32Embedding(t *testing.T) { - input := [][]float64{ - {0.25, 0.5}, - {-0.75}, - {}, - {1.25, 2.5}, - } - - got := flattenToFloat32Embedding(input) - want := []float32{0.25, 0.5, -0.75, 1.25, 2.5} - - if len(got) != len(want) { - t.Fatalf("expected %d elements, got %d", len(want), len(got)) - } - - for i := range want { - if got[i] != want[i] { - t.Fatalf("expected element %d to be %v, got %v", i, want[i], got[i]) - } - } -} - // TestDirectVsSemanticSearch tests the difference between direct hash matching and semantic search func TestDirectVsSemanticSearch(t *testing.T) { setup := NewTestSetup(t) diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index e9b847c6dc..a5a5c034e7 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -548,9 +548,14 @@ func CreateEmbeddingRequest(texts []string) *schemas.BifrostEmbeddingRequest { return &schemas.BifrostEmbeddingRequest{ Provider: schemas.OpenAI, Model: "text-embedding-3-small", - Input: &schemas.EmbeddingInput{ - Texts: texts, - }, + Input: func() *schemas.EmbeddingInput { + contents := make([]schemas.EmbeddingContent, len(texts)) + for i, text := range texts { + t := text + contents[i] = schemas.EmbeddingContent{{Type: schemas.EmbeddingContentPartTypeText, Text: &t}} + } + return &schemas.EmbeddingInput{Contents: contents} + }(), } } diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 125ae4670e..fc94bb6860 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -2,9 +2,11 @@ package semanticcache import ( "context" - "encoding/json" + "encoding/base64" + "encoding/binary" "fmt" "maps" + "math" "strings" "time" @@ -39,21 +41,24 @@ func toFloat32Embedding(values []float64) []float32 { return embedding } -func flattenToFloat32Embedding(values [][]float64) []float32 { - total := 0 - for _, arr := range values { - total += len(arr) +// decodeBase64Embedding decodes a base64-encoded embedding of raw IEEE 754 float32 bytes (little-endian). +func decodeBase64Embedding(s string) ([]float32, error) { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + b, err = base64.URLEncoding.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("base64 decode failed: %w", err) + } } - if total == 0 { - return nil + if len(b)%4 != 0 { + return nil, fmt.Errorf("base64 embedding byte length %d is not a multiple of 4", len(b)) } - - embedding := make([]float32, 0, total) - for _, arr := range values { - embedding = append(embedding, toFloat32Embedding(arr)...) + vals := make([]float32, len(b)/4) + for i := range vals { + bits := binary.LittleEndian.Uint32(b[i*4 : i*4+4]) + vals[i] = math.Float32frombits(bits) } - - return embedding + return vals, nil } // generateEmbedding generates an embedding for the given text using the configured provider. @@ -63,7 +68,7 @@ func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string Provider: plugin.config.Provider, Model: plugin.config.EmbeddingModel, Input: &schemas.EmbeddingInput{ - Text: &text, + Contents: []schemas.EmbeddingContent{{{Type: schemas.EmbeddingContentPartTypeText, Text: &text}}}, }, } @@ -93,17 +98,14 @@ func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string inputTokens = response.Usage.TotalTokens } - if embedding.EmbeddingStr != nil { - // decode embedding.EmbeddingStr to []float32 - var vals []float32 - if err := json.Unmarshal([]byte(*embedding.EmbeddingStr), &vals); err != nil { - return nil, 0, fmt.Errorf("failed to parse string embedding: %w", err) + if len(embedding.Float) > 0 { + return toFloat32Embedding(embedding.Float), inputTokens, nil + } else if embedding.Base64 != nil { + vals, err := decodeBase64Embedding(*embedding.Base64) + if err != nil { + return nil, 0, fmt.Errorf("failed to decode base64 embedding: %w", err) } return vals, inputTokens, nil - } else if embedding.EmbeddingArray != nil { - return toFloat32Embedding(embedding.EmbeddingArray), inputTokens, nil - } else if len(embedding.Embedding2DArray) > 0 { - return flattenToFloat32Embedding(embedding.Embedding2DArray), inputTokens, nil } return nil, 0, fmt.Errorf("embedding data is not in expected format") @@ -438,18 +440,16 @@ func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest) (stri return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } - texts := req.EmbeddingRequest.Input.Texts - - if len(texts) == 0 && req.EmbeddingRequest.Input.Text != nil { - texts = []string{*req.EmbeddingRequest.Input.Text} - } - - var text string - for _, t := range texts { - text += t + " " + var textParts []string + for _, content := range req.EmbeddingRequest.Input.Contents { + for _, part := range content { + if part.Type == schemas.EmbeddingContentPartTypeText && part.Text != nil { + textParts = append(textParts, normalizeText(*part.Text)) + } + } } - return strings.TrimSpace(text), metadataHash, nil + return strings.Join(textParts, " "), metadataHash, nil case req.TranscriptionRequest != nil: // Skip semantic caching for transcription requests @@ -757,35 +757,22 @@ func (plugin *Plugin) getNormalizedInputForCaching(req *schemas.BifrostRequest) case schemas.SpeechRequest, schemas.SpeechStreamRequest: return normalizeText(req.SpeechRequest.Input.Input) case schemas.EmbeddingRequest: - // Create a deep copy of the input to avoid mutating the original request - copiedInput := schemas.EmbeddingInput{} - if req.EmbeddingRequest.Input.Text != nil { - copiedText := *req.EmbeddingRequest.Input.Text - copiedInput.Text = &copiedText - } else if len(req.EmbeddingRequest.Input.Texts) > 0 { - copiedTexts := make([]string, len(req.EmbeddingRequest.Input.Texts)) - copy(copiedTexts, req.EmbeddingRequest.Input.Texts) - copiedInput.Texts = copiedTexts - } else if req.EmbeddingRequest.Input.Embedding != nil { - copiedEmbedding := make([]int, len(req.EmbeddingRequest.Input.Embedding)) - copy(copiedEmbedding, req.EmbeddingRequest.Input.Embedding) - copiedInput.Embedding = copiedEmbedding - } else if req.EmbeddingRequest.Input.Embeddings != nil { - copiedEmbeddings := make([][]int, len(req.EmbeddingRequest.Input.Embeddings)) - copy(copiedEmbeddings, req.EmbeddingRequest.Input.Embeddings) - copiedInput.Embeddings = copiedEmbeddings - } - if copiedInput.Text != nil { - normalizedText := normalizeText(*copiedInput.Text) - copiedInput.Text = &normalizedText - } else if len(copiedInput.Texts) > 0 { - normalizedTexts := make([]string, len(copiedInput.Texts)) - for i, text := range copiedInput.Texts { - normalizedTexts[i] = normalizeText(text) + // Deep copy Contents and normalize all text parts. + src := req.EmbeddingRequest.Input.Contents + copiedContents := make([]schemas.EmbeddingContent, len(src)) + for i, content := range src { + copiedContent := make(schemas.EmbeddingContent, len(content)) + for j, part := range content { + copied := part + if part.Type == schemas.EmbeddingContentPartTypeText && part.Text != nil { + normalized := normalizeText(*part.Text) + copied.Text = &normalized + } + copiedContent[j] = copied } - copiedInput.Texts = normalizedTexts + copiedContents[i] = copiedContent } - return copiedInput + return schemas.EmbeddingInput{Contents: copiedContents} case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: return req.TranscriptionRequest.Input case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: diff --git a/tests/integrations/python/config.yml b/tests/integrations/python/config.yml index f1b01580b0..178cae779f 100644 --- a/tests/integrations/python/config.yml +++ b/tests/integrations/python/config.yml @@ -127,6 +127,7 @@ providers: speech: "gemini-2.5-flash-preview-tts" transcription: "gemini-2.5-flash" embeddings: "gemini-embedding-001" + multimodal_embeddings: "gemini-embedding-2-preview" image_generation: "gemini-2.5-flash-image" image_edit: "gemini-3-pro-image-preview" imagen: "imagen-4.0-generate-001" @@ -159,6 +160,7 @@ providers: file: "claude-sonnet-4-5" thinking: "gemini-2.5-pro" embeddings: "gemini-embedding-001" + multimodal_embeddings: "gemini-embedding-2-preview" image_generation: "imagen-4.0-generate-001" image_edit: "imagen-3.0-capability-001" imagen: "imagen-4.0-generate-001" @@ -196,6 +198,7 @@ providers: vision: "command-a-vision-07-2025" tools: "command-a-03-2025" embeddings: "embed-v4.0" + multimodal_embeddings: "embed-v4.0" streaming: "command-a-03-2025" count_tokens: "command-a-03-2025" alternatives: @@ -404,6 +407,7 @@ provider_scenarios: transcription: true transcription_streaming: true embeddings: true + multimodal_embeddings: true image_generation: true # Gemini image generation via responseModalities image_edit: true # Gemini image editing imagen: true # Imagen via :predict endpoint @@ -451,6 +455,7 @@ provider_scenarios: transcription: false transcription_streaming: false embeddings: true + multimodal_embeddings: true image_generation: true image_edit: true imagen: true # Imagen via :predict endpoint @@ -540,6 +545,7 @@ provider_scenarios: transcription: false transcription_streaming: false embeddings: true + multimodal_embeddings: true thinking: false prompt_caching: false citations: false @@ -601,6 +607,7 @@ scenario_capabilities: transcription: "transcription" transcription_streaming: "transcription" embeddings: "embeddings" + multimodal_embeddings: "multimodal_embeddings" image_generation: "image_generation" # Uses image_generation model image_edit: "image_edit" # Uses image_edit model imagen: "imagen" # Uses imagen model (Gemini/Vertex) diff --git a/tests/integrations/python/tests/test_cohere.py b/tests/integrations/python/tests/test_cohere.py new file mode 100644 index 0000000000..c1769019f2 --- /dev/null +++ b/tests/integrations/python/tests/test_cohere.py @@ -0,0 +1,261 @@ +""" +Integration tests for Cohere SDK with Bifrost. + +Covers embedding scenarios only: + - Text embeddings (single, batch, input_type variations) + - Custom dimensions and embedding types + - Truncation + - Image embeddings + - Multimodal mixed inputs (text + image) +""" + +import httpx +import pytest +import cohere + +from .utils.common import ( + BASE64_IMAGE, + EMBEDDINGS_SINGLE_TEXT, + EMBEDDINGS_MULTIPLE_TEXTS, + Config, + get_api_key, +) +from .utils.config_loader import get_config, get_integration_url +from .utils.parametrize import format_provider_model, get_cross_provider_params_for_scenario + + +def get_provider_cohere_client(provider: str = "cohere") -> cohere.ClientV2: + """Create Cohere ClientV2 pointed at Bifrost with x-model-provider header.""" + api_key = get_api_key(provider) + base_url = get_integration_url("cohere") + config = get_config() + api_config = config.get_api_config() + timeout = api_config.get("timeout", 30) + + return cohere.ClientV2( + api_key=api_key, + base_url=base_url, + httpx_client=httpx.Client( + headers={"x-model-provider": provider}, + timeout=float(timeout), + ), + ) + + +@pytest.fixture +def test_config(): + return Config() + + +def assert_valid_cohere_embedding_response(response, expected_count: int, expected_dimensions: int | None = None): + """Assert a Cohere embed response contains valid float embeddings.""" + assert response is not None, "Response should not be None" + assert response.embeddings is not None, "Response should have embeddings" + assert response.embeddings.float is not None, "Response embeddings should have float vectors" + vectors = response.embeddings.float + assert len(vectors) == expected_count, ( + f"Expected {expected_count} embeddings, got {len(vectors)}" + ) + for i, vec in enumerate(vectors): + assert isinstance(vec, list), f"Embedding {i} should be a list" + assert len(vec) > 0, f"Embedding {i} should not be empty" + assert all(isinstance(v, float) for v in vec), f"Embedding {i} values should be floats" + if expected_dimensions is not None: + assert len(vec) == expected_dimensions, ( + f"Embedding {i}: expected {expected_dimensions} dims, got {len(vec)}" + ) + + +class TestCohereIntegration: + """Cohere SDK embedding tests via Bifrost.""" + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_01_single_text_embedding(self, test_config, provider, model): + """Single string with input_type=search_document.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=[EMBEDDINGS_SINGLE_TEXT], + input_type="search_document", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Single text embedding: provider={provider} dims={len(response.embeddings.float[0])}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_02_batch_text_embeddings(self, test_config, provider, model): + """Batch of 3 strings with input_type=search_document.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + texts = EMBEDDINGS_MULTIPLE_TEXTS[:3] + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=texts, + input_type="search_document", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=3) + print(f"✓ Batch text embeddings: provider={provider} count=3 dims={len(response.embeddings.float[0])}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_03_search_query_embedding(self, test_config, provider, model): + """Single string with input_type=search_query.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=["What is machine learning?"], + input_type="search_query", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Search query embedding: provider={provider}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_04_classification_embedding(self, test_config, provider, model): + """Single string with input_type=classification.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=["This is a positive review."], + input_type="classification", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Classification embedding: provider={provider}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_05_clustering_embedding(self, test_config, provider, model): + """Single string with input_type=clustering.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=["Renewable energy sources include solar and wind."], + input_type="clustering", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Clustering embedding: provider={provider}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_06_custom_dimensions_embedding(self, test_config, provider, model): + """Single string with output_dimension=512 (embed-v4.0 only).""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=[EMBEDDINGS_SINGLE_TEXT], + input_type="search_document", + embedding_types=["float"], + output_dimension=512, + ) + + assert_valid_cohere_embedding_response(response, expected_count=1, expected_dimensions=512) + print(f"✓ Custom dimensions embedding: provider={provider} dims=512") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_07_multiple_embedding_types(self, test_config, provider, model): + """Single string requesting float and int8 embedding types.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=[EMBEDDINGS_SINGLE_TEXT], + input_type="search_document", + embedding_types=["float", "int8"], + ) + + assert response is not None, "Response should not be None" + assert response.embeddings is not None, "Response should have embeddings" + assert response.embeddings.float is not None, "Response should include float embeddings" + assert response.embeddings.int8 is not None, "Response should include int8 embeddings" + assert len(response.embeddings.float) == 1 + assert len(response.embeddings.int8) == 1 + print(f"✓ Multiple embedding types: provider={provider}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("embeddings")) + def test_08_truncation_embedding(self, test_config, provider, model): + """Long text with truncate=END to verify truncation is handled.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for embeddings scenario") + + long_text = " ".join(EMBEDDINGS_MULTIPLE_TEXTS) * 10 + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + texts=[long_text], + input_type="search_document", + embedding_types=["float"], + truncate="END", + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Truncation embedding: provider={provider}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("multimodal_embeddings")) + def test_09_image_embedding(self, test_config, provider, model): + """Single image data URI with input_type=image.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for multimodal_embeddings scenario") + + image_data_uri = f"data:image/png;base64,{BASE64_IMAGE}" + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + images=[image_data_uri], + input_type="image", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Image embedding: provider={provider} dims={len(response.embeddings.float[0])}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("multimodal_embeddings")) + def test_10_multimodal_mixed_inputs_embedding(self, test_config, provider, model): + """Mixed text + image content via inputs field.""" + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for multimodal_embeddings scenario") + + image_data_uri = f"data:image/png;base64,{BASE64_IMAGE}" + + mixed_input = cohere.EmbedInput( + content=[ + {"type": "text", "text": "A colorful geometric pattern"}, + {"type": "image_url", "image_url": {"url": image_data_uri}}, + ] + ) + + client = get_provider_cohere_client(provider) + response = client.embed( + model=format_provider_model(provider, model), + inputs=[mixed_input], + input_type="search_document", + embedding_types=["float"], + ) + + assert_valid_cohere_embedding_response(response, expected_count=1) + print(f"✓ Multimodal mixed inputs embedding: provider={provider} dims={len(response.embeddings.float[0])}") diff --git a/tests/integrations/python/tests/test_google.py b/tests/integrations/python/tests/test_google.py index ec3864b51d..379c40f19a 100644 --- a/tests/integrations/python/tests/test_google.py +++ b/tests/integrations/python/tests/test_google.py @@ -22,6 +22,8 @@ 12. Error handling 13. Streaming chat 14. Single text embedding +45. Multimodal embedding - text + image content (Gemini/Vertex gemini-embedding-2-preview) +46. Multimodal embedding - batch contents (Gemini/Vertex gemini-embedding-2-preview) 15. List models 16. Audio transcription 17. Audio transcription with parameters @@ -54,6 +56,7 @@ 44. Context caching (Gemini Caches API) - create, list, get, update, delete, generate with cache """ +import base64 import io import json import os @@ -3077,6 +3080,103 @@ def test_43_google_search_grounding_streaming(self, test_config): print("✓ Google Search grounding test (streaming) passed!") + # ========================================================================= + # MULTIMODAL EMBEDDING TEST CASES (gemini-embedding-2-preview) + # ========================================================================= + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("multimodal_embeddings")) + def test_45_multimodal_embedding_text_and_image(self, test_config, provider, model): + """Test Case 45: Single multimodal content embedding - text + image (gemini-embedding-2-preview). + + Sends one types.Content with a text part and an inline-base64 image part. + Expects a single embedding vector back with the requested dimensionality. + """ + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for multimodal_embeddings scenario") + + client = get_provider_google_client(provider) + image_bytes = base64.b64decode(BASE64_IMAGE) + + response = client.models.embed_content( + model=format_provider_model(provider, model), + contents=types.Content( + parts=[ + types.Part(text="A colorful geometric pattern"), + types.Part.from_bytes(data=image_bytes, mime_type="image/png"), + ] + ), + config=types.EmbedContentConfig(output_dimensionality=512), + ) + + assert response is not None, "Multimodal embedding response should not be None" + assert hasattr(response, "embeddings"), "Response should have 'embeddings' attribute" + assert len(response.embeddings) == 1, ( + f"Single content item should produce exactly one embedding, got {len(response.embeddings)}" + ) + values = response.embeddings[0].values + assert isinstance(values, list), f"Embedding values should be a list, got {type(values)}" + assert len(values) == 512, f"Expected 512-dimensional embedding, got {len(values)}" + assert all(isinstance(v, float) for v in values), "All embedding values should be floats" + + print(f"✓ Multimodal (text+image) embedding: provider={provider} dims={len(values)}") + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("multimodal_embeddings")) + def test_46_multimodal_embedding_batch(self, test_config, provider, model): + """Test Case 46: Batch multimodal embedding - multiple contents (gemini-embedding-2-preview). + + Sends three separate contents: text-only, image-only, and text+image. + Maps to batchEmbedContents on the Gemini side. + Expects one embedding vector per content item. + """ + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for multimodal_embeddings scenario") + + client = get_provider_google_client(provider) + image_bytes = base64.b64decode(BASE64_IMAGE) + dimensions = 512 + + contents = [ + # Text-only content + types.Content(parts=[types.Part(text="Artificial intelligence research paper")]), + # Image-only content + types.Content(parts=[types.Part.from_bytes(data=image_bytes, mime_type="image/png")]), + # Mixed text + image content + types.Content( + parts=[ + types.Part(text="A geometric shape shown in the image"), + types.Part.from_bytes(data=image_bytes, mime_type="image/png"), + ] + ), + ] + + response = client.models.embed_content( + model=format_provider_model(provider, model), + contents=contents, + config=types.EmbedContentConfig(output_dimensionality=dimensions), + ) + + assert response is not None, "Batch multimodal embedding response should not be None" + assert hasattr(response, "embeddings"), "Response should have 'embeddings' attribute" + assert len(response.embeddings) == 3, ( + f"Expected 3 embeddings for 3 content items, got {len(response.embeddings)}" + ) + for i, embedding in enumerate(response.embeddings): + values = embedding.values + assert isinstance(values, list), ( + f"Content item {i}: embedding values should be a list, got {type(values)}" + ) + assert len(values) == dimensions, ( + f"Content item {i}: expected {dimensions}-dimensional embedding, got {len(values)}" + ) + assert all(isinstance(v, float) for v in values), ( + f"Content item {i}: all values should be floats" + ) + + print( + f"✓ Batch multimodal embedding: provider={provider} " + f"count={len(response.embeddings)} dims={len(response.embeddings[0].values)}" + ) + # ========================================================================= # GEMINI VIDEO GENERATION TEST CASES # ========================================================================= diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 712449e4a9..aff7520e75 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -200,6 +200,24 @@ var embeddingParamsKnownFields = map[string]bool{ "fallbacks": true, "encoding_format": true, "dimensions": true, + "task_type": true, + "title": true, + "auto_truncate": true, + "truncate": true, + "max_tokens": true, +} + +var batchEmbeddingParamsKnownFields = map[string]bool{ + "model": true, + "items": true, + "fallbacks": true, + "encoding_format": true, + "dimensions": true, + "task_type": true, + "title": true, + "auto_truncate": true, + "truncate": true, + "max_tokens": true, } var rerankParamsKnownFields = map[string]bool{ @@ -516,6 +534,15 @@ type EmbeddingRequest struct { *schemas.EmbeddingParameters } +// BatchEmbeddingHTTPRequest is a bifrost batch embedding request. +// Top-level EmbeddingParameters serve as the default for all items; +// each item may carry its own Params override. +type BatchEmbeddingHTTPRequest struct { + Items []schemas.BifrostEmbeddingBatchItem `json:"items"` + BifrostParams + *schemas.EmbeddingParameters +} + // RerankRequest is a bifrost rerank request type RerankRequest struct { Query string `json:"query"` @@ -661,6 +688,7 @@ var PathToTypeMapping = map[string]schemas.RequestType{ "/v1/chat/completions": schemas.ChatCompletionRequest, "/v1/responses": schemas.ResponsesRequest, "/v1/embeddings": schemas.EmbeddingRequest, + "/v1/embeddings/batch": schemas.BatchEmbeddingRequest, "/v1/rerank": schemas.RerankRequest, "/v1/ocr": schemas.OCRRequest, "/v1/audio/speech": schemas.SpeechRequest, @@ -706,6 +734,7 @@ func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...sche r.POST("/v1/chat/completions", lib.ChainMiddlewares(h.chatCompletion, baseMiddlewares...)) r.POST("/v1/responses", lib.ChainMiddlewares(h.responses, baseMiddlewares...)) r.POST("/v1/embeddings", lib.ChainMiddlewares(h.embeddings, baseMiddlewares...)) + r.POST("/v1/embeddings/batch", lib.ChainMiddlewares(h.batchEmbeddings, baseMiddlewares...)) r.POST("/v1/rerank", lib.ChainMiddlewares(h.rerank, baseMiddlewares...)) r.POST("/v1/ocr", lib.ChainMiddlewares(h.ocr, baseMiddlewares...)) r.POST("/v1/audio/speech", lib.ChainMiddlewares(h.speech, baseMiddlewares...)) @@ -1089,9 +1118,12 @@ func prepareEmbeddingRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*Emb if err != nil { return nil, nil, err } - if req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil) { + if req.Input == nil || len(req.Input.Contents) == 0 { return nil, nil, fmt.Errorf("input is required for embeddings") } + if err := req.Input.Validate(); err != nil { + return nil, nil, err + } if req.EmbeddingParameters == nil { req.EmbeddingParameters = &schemas.EmbeddingParameters{} } @@ -1137,6 +1169,63 @@ func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { SendJSON(ctx, resp) } +// prepareBatchEmbeddingRequest prepares a BifrostBatchEmbeddingRequest from the HTTP request body +func prepareBatchEmbeddingRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*BatchEmbeddingHTTPRequest, *schemas.BifrostBatchEmbeddingRequest, error) { + req, base, err := prepareRequest[BatchEmbeddingHTTPRequest](ctx, config, batchEmbeddingParamsKnownFields) + if err != nil { + return nil, nil, err + } + if len(req.Items) == 0 { + return nil, nil, fmt.Errorf("items are required for batch embedding") + } + if req.EmbeddingParameters == nil { + req.EmbeddingParameters = &schemas.EmbeddingParameters{} + } + req.EmbeddingParameters.ExtraParams = base.ExtraParams + bifrostReq := &schemas.BifrostBatchEmbeddingRequest{ + Provider: base.Provider, + Model: base.ModelName, + Params: req.EmbeddingParameters, + Items: req.Items, + Fallbacks: base.Fallbacks, + } + if err := bifrostReq.Validate(); err != nil { + return nil, nil, err + } + return req, bifrostReq, nil +} + +// batchEmbeddings handles POST /v1/embeddings/batch - Process batch embedding requests +func (h *CompletionHandler) batchEmbeddings(ctx *fasthttp.RequestCtx) { + _, bifrostBatchReq, err := prepareBatchEmbeddingRequest(ctx, h.config) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.config) + defer cancel() + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") + return + } + + resp, bifrostErr := h.client.BatchEmbeddingRequest(bifrostCtx, bifrostBatchReq) + if bifrostErr != nil { + forwardProviderHeadersFromContext(ctx, bifrostCtx) + SendBifrostError(ctx, bifrostErr) + return + } + + if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil { + forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders) + } + if streamLargeResponseIfActive(ctx, bifrostCtx) { + return + } + SendJSON(ctx, resp) +} + // prepareRerankRequest prepares a BifrostRerankRequest from the HTTP request body func prepareRerankRequest(ctx *fasthttp.RequestCtx, config *lib.Config) (*RerankRequest, *schemas.BifrostRerankRequest, error) { req, base, err := prepareRequest[RerankRequest](ctx, config, rerankParamsKnownFields) diff --git a/transports/bifrost-http/integrations/cohere.go b/transports/bifrost-http/integrations/cohere.go index cf6b7ceaca..045057d660 100644 --- a/transports/bifrost-http/integrations/cohere.go +++ b/transports/bifrost-http/integrations/cohere.go @@ -163,7 +163,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { return resp.ExtraFields.RawResponse, nil } } - return resp, nil + return cohere.ToCohereEmbeddingResponse(resp), nil }, ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 46cb6fe659..8de58252a1 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -24,6 +24,8 @@ import ( const isGeminiEmbedContentRequestContextKey schemas.BifrostContextKey = "bifrost-is-gemini-embed-content-request" +const isGeminiBatchEmbedContentsRequestContextKey schemas.BifrostContextKey = "bifrost-is-gemini-batch-embed-contents-request" + const isGeminiVideoGenerationRequestContextKey schemas.BifrostContextKey = "bifrost-is-gemini-video-generation-request" const isGeminiBatchCreateRequestContextKey schemas.BifrostContextKey = "bifrost-is-gemini-batch-create-request" @@ -46,6 +48,8 @@ func genAIModelGetter(ctx *fasthttp.RequestCtx, req interface{}) (string, error) return r.Model, nil case *gemini.GeminiEmbeddingRequest: return r.Model, nil + case *gemini.GeminiBatchEmbeddingRequest: + return r.Model, nil case *gemini.GeminiVideoGenerationRequest: return r.Model, nil case *gemini.GeminiBatchCreateRequest: @@ -106,13 +110,17 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { return requestType }, GetRequestTypeInstance: func(ctx context.Context) interface{} { - if requestType, ok := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType); ok && requestType == schemas.EmbeddingRequest && ctx.Value(isGeminiEmbedContentRequestContextKey) != nil { + requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType) + if requestType == schemas.EmbeddingRequest && ctx.Value(isGeminiEmbedContentRequestContextKey) != nil { return &gemini.GeminiEmbeddingRequest{} } - if requestType, ok := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType); ok && requestType == schemas.VideoGenerationRequest && ctx.Value(isGeminiVideoGenerationRequestContextKey) != nil { + if requestType == schemas.BatchEmbeddingRequest && ctx.Value(isGeminiBatchEmbedContentsRequestContextKey) != nil { + return &gemini.GeminiBatchEmbeddingRequest{} + } + if requestType == schemas.VideoGenerationRequest && ctx.Value(isGeminiVideoGenerationRequestContextKey) != nil { return &gemini.GeminiVideoGenerationRequest{} } - if requestType, ok := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType); ok && requestType == schemas.BatchCreateRequest && ctx.Value(isGeminiBatchCreateRequestContextKey) != nil { + if requestType == schemas.BatchCreateRequest && ctx.Value(isGeminiBatchCreateRequestContextKey) != nil { return &gemini.GeminiBatchCreateRequest{} } return &gemini.GeminiGenerationRequest{} @@ -159,6 +167,28 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { return &schemas.BifrostRequest{ EmbeddingRequest: req.ToBifrostEmbeddingRequest(ctx), }, nil + } else if geminiReq, ok := req.(*gemini.GeminiBatchEmbeddingRequest); ok { + bifrostBatchReq, err := geminiReq.ToBifrostBatchEmbeddingRequest(ctx) + if err != nil { + return nil, err + } + // Only Gemini supports BatchEmbedding + if bifrostBatchReq.Provider == schemas.Gemini { + return &schemas.BifrostRequest{BatchEmbeddingRequest: bifrostBatchReq}, nil + } + contents := make([]schemas.EmbeddingContent, 0, len(bifrostBatchReq.Items)) + for _, item := range bifrostBatchReq.Items { + contents = append(contents, item.Content) + } + return &schemas.BifrostRequest{ + EmbeddingRequest: &schemas.BifrostEmbeddingRequest{ + Provider: bifrostBatchReq.Provider, + Model: bifrostBatchReq.Model, + Input: &schemas.EmbeddingInput{Contents: contents}, + Params: bifrostBatchReq.Params, + Fallbacks: bifrostBatchReq.Fallbacks, + }, + }, nil } else if geminiReq, ok := req.(*gemini.GeminiVideoGenerationRequest); ok { // convert to bifrost video generation request bifrostReq, err := geminiReq.ToBifrostVideoGenerationRequest(ctx) @@ -1231,6 +1261,11 @@ func extractAndSetModelAndRequestType(ctx *fasthttp.RequestCtx, bifrostCtx *sche r.Model = modelStr } return nil + case *gemini.GeminiBatchEmbeddingRequest: + if modelStr != "" { + r.Model = modelStr + } + return nil case *gemini.GeminiVideoGenerationRequest: if modelStr != "" { r.Model = modelStr @@ -1274,6 +1309,11 @@ func extractModelAndRequestType(ctx *fasthttp.RequestCtx) (string, schemas.Reque } if strings.HasSuffix(modelStr, ":embedContent") { ctx.SetUserValue(isGeminiEmbedContentRequestContextKey, true) + return modelStr, schemas.EmbeddingRequest + } + if strings.HasSuffix(modelStr, ":batchEmbedContents") { + ctx.SetUserValue(isGeminiBatchEmbedContentsRequestContextKey, true) + return modelStr, schemas.BatchEmbeddingRequest } if isEmbedding { return modelStr, schemas.EmbeddingRequest diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 5b09f9395c..83ecbf6b07 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -458,7 +458,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return resp.ExtraFields.RawResponse, nil } } - return resp, nil + return openai.ToOpenAIEmbeddingResponse(resp), nil }, SpeechResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostSpeechResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { @@ -860,7 +860,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return resp.ExtraFields.RawResponse, nil } } - return resp, nil + return openai.ToOpenAIEmbeddingResponse(resp), nil }, ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index f4bc8f685a..86af916ec6 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -1082,6 +1082,24 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf providerResponseHeaders = embeddingResponse.ExtraFields.ProviderResponseHeaders // Convert Bifrost response to integration-specific format and send response, err = config.EmbeddingResponseConverter(bifrostCtx, embeddingResponse) + case bifrostReq.BatchEmbeddingRequest != nil: + embeddingResponse, bifrostErr := g.client.BatchEmbeddingRequest(bifrostCtx, bifrostReq.BatchEmbeddingRequest) + if bifrostErr != nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) + return + } + if config.PostCallback != nil { + if err := config.PostCallback(ctx, req, embeddingResponse); err != nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to execute post-request callback")) + return + } + } + if embeddingResponse == nil { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + return + } + providerResponseHeaders = embeddingResponse.ExtraFields.ProviderResponseHeaders + response, err = config.EmbeddingResponseConverter(bifrostCtx, embeddingResponse) case bifrostReq.RerankRequest != nil: rerankResponse, bifrostErr := g.client.RerankRequest(bifrostCtx, bifrostReq.RerankRequest) if bifrostErr != nil { diff --git a/ui/app/workspace/logs/sheets/logDetailView.tsx b/ui/app/workspace/logs/sheets/logDetailView.tsx index 22ed484adf..92faa15770 100644 --- a/ui/app/workspace/logs/sheets/logDetailView.tsx +++ b/ui/app/workspace/logs/sheets/logDetailView.tsx @@ -45,7 +45,12 @@ import { RoutingEngineUsedLabels, Status } from "@/lib/constants/logs"; -import { ContentBlock, LogEntry, ResponsesMessage } from "@/lib/types/logs"; +import { + ContentBlock, + EmbeddingContent, + LogEntry, + ResponsesMessage, +} from "@/lib/types/logs"; import { cn } from "@/lib/utils"; import { downloadAsJson } from "@/lib/utils/browser-download"; import { Link } from "@tanstack/react-router"; @@ -621,6 +626,38 @@ function MessageRow({ ); } +function EmbeddingInputView({ contents }: { contents: EmbeddingContent[] }) { + const label = + contents.length === 1 ? "Input" : `Input (${contents.length} documents)`; + const json = JSON.stringify( + contents.length === 1 ? contents[0] : contents, + null, + 2, + ); + + return ( + <> +
{label}
+ json} collapsedHeight={150}> + + + + ); +} + interface LogDetailViewProps { log: LogEntry | null; resolvedSelectedPromptName?: string; // Current prompt name from prompt-repo when `selected_prompt_id` is set; falls back to stored log name @@ -2252,7 +2289,8 @@ export function LogDetailView({ {log.is_large_payload_request && !log.input_history?.length && - !log.responses_input_history?.length && ( + !log.responses_input_history?.length && + !log.embedding_input?.length && (
Large payload request — input content was streamed directly to the provider and is not available for display. @@ -2292,6 +2330,13 @@ export function LogDetailView({ />
)} + {log.object === "embedding" && + log.embedding_input && + log.embedding_input.length > 0 && ( +
+ +
+ )} {log.status !== "processing" && log.rerank_output && !log.error_details?.error.message && ( @@ -2725,11 +2770,18 @@ const copyRequestBody = async ( if (prompt) { requestBody.prompt = prompt; } + } else if ( + isEmbedding && + log.embedding_input && + log.embedding_input.length > 0 + ) { + requestBody.input = log.embedding_input; } else if ( isEmbedding && log.input_history && log.input_history.length > 0 ) { + // Fallback for logs created before embedding_input was introduced. const texts: string[] = []; for (const message of log.input_history) { const messageTexts = extractTextsFromMessage(message); diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index a49537c484..30dea670bc 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -198,6 +198,25 @@ export interface BifrostEmbedding { embedding: string | number[] | number[][]; } +export interface EmbeddingMediaPart { + data?: string; + url?: string; + mime_type?: string; + filename?: string; +} + +export interface EmbeddingContentPart { + type: "text" | "image" | "audio" | "file" | "video" | "tokens"; + text?: string; + image?: EmbeddingMediaPart; + audio?: EmbeddingMediaPart; + file?: EmbeddingMediaPart; + video?: EmbeddingMediaPart; + tokens?: number[]; +} + +export type EmbeddingContent = EmbeddingContentPart[]; + export interface RerankDocument { text: string; id?: string; @@ -524,6 +543,7 @@ export interface LogEntry { content_summary?: string; output_message?: ChatMessage; responses_output?: ResponsesMessage[]; + embedding_input?: EmbeddingContent[]; embedding_output?: BifrostEmbedding[]; rerank_output?: RerankResult[]; ocr_input?: OCRDocument;