diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8da1229a528ea..e6b27b6a641cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -36,6 +37,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -71,6 +74,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -80,6 +84,7 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; + public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512; private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, @@ -161,7 +166,8 @@ private static Map initDefaultEndpoints( new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents + elasticInferenceServiceComponents, + ChunkingSettingsBuilder.DEFAULT_SETTINGS ), MinimalServiceSettings.sparseEmbedding(NAME) ), @@ -304,12 +310,25 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - // Pass-through without actually performing chunking (result will have a single chunk per input) - ActionListener inferListener = listener.delegateFailureAndWrap( - (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) - ); + if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) { + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE, + model.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + + return; + } - doInfer(model, inputs, taskSettings, timeout, inferListener); + // Model cannot perform chunked inference + listener.onFailure(createInvalidModelException(model)); } @Override @@ -328,6 +347,13 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + ElasticInferenceServiceModel model = createModel( inferenceEntityId, taskType, @@ -336,7 +362,8 @@ public void parseRequestConfig( serviceSettingsMap, elasticInferenceServiceComponents, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), - ConfigurationParseContext.REQUEST + ConfigurationParseContext.REQUEST, + chunkingSettings ); throwIfNotEmptyMap(config, NAME); @@ -372,7 +399,8 @@ private static ElasticInferenceServiceModel createModel( @Nullable Map secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents, String failureMessage, - ConfigurationParseContext context + ConfigurationParseContext context, + ChunkingSettings chunkingSettings ) { return switch (taskType) { case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel( @@ -383,7 +411,8 @@ private static ElasticInferenceServiceModel createModel( taskSettings, secretSettings, elasticInferenceServiceComponents, - context + context, + chunkingSettings ); case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel( inferenceEntityId, @@ -420,13 +449,19 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME), + chunkingSettings ); } @@ -435,13 +470,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + parsePersistedConfigErrorMsg(inferenceEntityId, NAME), + chunkingSettings ); } @@ -456,7 +497,8 @@ private ElasticInferenceServiceModel createModelFromPersistent( Map serviceSettings, Map taskSettings, @Nullable Map secretSettings, - String failureMessage + String failureMessage, + ChunkingSettings chunkingSettings ) { return createModel( inferenceEntityId, @@ -466,7 +508,8 @@ private ElasticInferenceServiceModel createModelFromPersistent( secretSettings, elasticInferenceServiceComponents, failureMessage, - ConfigurationParseContext.PERSISTENT + ConfigurationParseContext.PERSISTENT, + chunkingSettings ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java index 3ffd426ca6b6b..4dead9850a423 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; @@ -39,7 +40,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel( Map taskSettings, Map secrets, ElasticInferenceServiceComponents elasticInferenceServiceComponents, - ConfigurationParseContext context + ConfigurationParseContext context, + ChunkingSettings chunkingSettings ) { this( inferenceEntityId, @@ -48,7 +50,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel( ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents + elasticInferenceServiceComponents, + chunkingSettings ); } @@ -67,10 +70,11 @@ public ElasticInferenceServiceSparseEmbeddingsModel( ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings, @Nullable TaskSettings taskSettings, @Nullable SecretSettings secretSettings, - ElasticInferenceServiceComponents elasticInferenceServiceComponents + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ChunkingSettings chunkingSettings ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings, elasticInferenceServiceComponents diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index 764b3b090cf0b..9c95fbfdfa996 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; @@ -28,7 +29,8 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.of(url) + ElasticInferenceServiceComponents.of(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 71a073c02e02b..9ff4a04add8b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -835,7 +835,7 @@ public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws I } } - public void testChunkedInfer_PassesThrough() throws IOException { + public void testChunkedInfer() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var elasticInferenceServiceURL = getUrl(webServer); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 3332e55cf1f5d..d8c8c9e5b7abf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; @@ -196,7 +197,8 @@ private static Map initDefaultEndpoints() { new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE + ElasticInferenceServiceComponents.EMPTY_INSTANCE, + ChunkingSettingsBuilder.DEFAULT_SETTINGS ), MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME) )