Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/139812.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 139812
summary: Add VoyageAI multimodal embeddings support with new embedding task type
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"description": "The task type",
"options": [
"text_embedding",
"rerank"
"rerank",
"embedding"
]
},
"voyageai_inference_id": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.text.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.text.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.voyageai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.contextual.VoyageAIContextualEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIContextualizedEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.voyageai.response.VoyageAIContextualizedEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIContextualEmbeddingsRequestManager extends BaseRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIContextualEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createContextualEmbeddingsHandler();

private static ResponseHandler createContextualEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai contextual embedding", VoyageAIContextualizedEmbeddingsResponseEntity::fromResponse);
}

public static VoyageAIContextualEmbeddingsRequestManager of(VoyageAIContextualEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIContextualEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIContextualEmbeddingsModel model;

private VoyageAIContextualEmbeddingsRequestManager(VoyageAIContextualEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), VoyageAIRequestManager.RateLimitGrouping.of(model), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput embeddingsInput = inferenceInputs.castTo(EmbeddingsInput.class);

// Wrap all inputs as a single entry in the top-level list
// Input: List<String> ["text1", "text2", "text3"]
// Output: List<List<String>> [["text1", "text2", "text3"]]
List<List<String>> nestedInputs = List.of(embeddingsInput.getTextInputs());

VoyageAIContextualizedEmbeddingsRequest request = new VoyageAIContextualizedEmbeddingsRequest(
nestedInputs,
embeddingsInput.getInputType(),
model
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.voyageai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.text.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.voyageai.response.VoyageAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
}

public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIEmbeddingsModel model;

private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput embeddingsInput = inferenceInputs.castTo(EmbeddingsInput.class);
List<String> docsInput = embeddingsInput.getTextInputs();
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, embeddingsInput.getInputType(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.voyageai.action.VoyageAIActionVisitor;

import java.net.URI;
import java.util.Collections;
Expand All @@ -34,7 +32,9 @@ public abstract class VoyageAIModel extends RateLimitGroupingModel {
Map<String, String> tempMap = new HashMap<>();
tempMap.put("voyage-3.5", "embed_medium");
tempMap.put("voyage-3.5-lite", "embed_small");
tempMap.put("voyage-context-3", "embed_context");
tempMap.put("voyage-multimodal-3", "embed_multimodal");
tempMap.put("voyage-multimodal-3.5", "embed_multimodal");
tempMap.put("voyage-3-large", "embed_large");
tempMap.put("voyage-code-3", "embed_large");
tempMap.put("voyage-3", "embed_medium");
Expand Down Expand Up @@ -101,5 +101,4 @@ public URI uri() {
return uri;
}

public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.voyageai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InferenceStringGroup;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.multimodal.VoyageAIMultimodalEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIMultimodalEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.voyageai.response.VoyageAIEmbeddingsResponseEntity;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

/**
* Request manager for VoyageAI multimodal embeddings.
* Handles multimodal inputs including text, images (base64/url), and videos (base64/url).
*/
public class VoyageAIMultimodalEmbeddingsRequestManager extends BaseRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIMultimodalEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createMultimodalEmbeddingsHandler();

private static ResponseHandler createMultimodalEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai multimodal embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
}

public static VoyageAIMultimodalEmbeddingsRequestManager of(VoyageAIMultimodalEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIMultimodalEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIMultimodalEmbeddingsModel model;

private VoyageAIMultimodalEmbeddingsRequestManager(VoyageAIMultimodalEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), VoyageAIRequestManager.RateLimitGrouping.of(model), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
EmbeddingsInput embeddingsInput = inferenceInputs.castTo(EmbeddingsInput.class);

// Get inputs as InferenceStringGroups - supports multimodal content (text + images + videos)
List<InferenceStringGroup> inputs = embeddingsInput.getInputs();

VoyageAIMultimodalEmbeddingsRequest request = new VoyageAIMultimodalEmbeddingsRequest(
inputs,
embeddingsInput.getInputType(),
model
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.voyageai;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager;

import java.util.Map;
import java.util.Objects;

abstract class VoyageAIRequestManager extends BaseRequestManager {
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.ofEntries(
Map.entry("voyage-multimodal-3", "embed_multimodal"),
Map.entry("voyage-multimodal-3.5", "embed_multimodal"),
Map.entry("voyage-3-large", "embed_large"),
Map.entry("voyage-code-3", "embed_large"),
Map.entry("voyage-3", "embed_medium"),
Map.entry("voyage-3-lite", "embed_small"),
Map.entry("voyage-finance-2", "embed_large"),
Map.entry("voyage-law-2", "embed_large"),
Map.entry("voyage-code-2", "embed_large"),
Map.entry("rerank-2", "rerank_large"),
Map.entry("rerank-2-lite", "rerank_small")
);

protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings());
}

record RateLimitGrouping(int apiKeyHash) {
public static RateLimitGrouping of(VoyageAIModel model) {
Objects.requireNonNull(model);
String modelId = model.getServiceSettings().modelId();
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);

return new RateLimitGrouping(modelFamily.hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.voyageai;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.voyageai.request.VoyageAIRerankRequest;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import org.elasticsearch.xpack.inference.services.voyageai.response.VoyageAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;

import java.util.Objects;
import java.util.function.Supplier;

public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();

private static ResponseHandler createVoyageAIResponseHandler() {
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
}

public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final VoyageAIRerankModel model;

private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
QueryAndDocsInputs rerankInput = inferenceInputs.castTo(QueryAndDocsInputs.class);
VoyageAIRerankRequest request = new VoyageAIRerankRequest(
rerankInput.getQuery(),
rerankInput.getChunks(),
null, // returnDocuments
null, // topN
model
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Loading