From 9f4e2adba021588d1e5d3ecfce9771c7d3c74c15 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 09:18:30 -0500 Subject: [PATCH 01/86] test --- test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test.txt diff --git a/test.txt b/test.txt new file mode 100644 index 0000000000000..e69de29bb2d1d From a5610179f10fb87fa44d9b1a4d4049f4f03f0a35 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 09:53:45 -0500 Subject: [PATCH 02/86] Revert "test" This reverts commit 9f4e2adba021588d1e5d3ecfce9771c7d3c74c15. --- test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test.txt diff --git a/test.txt b/test.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 From acb14a22a6d3bd383b739e91b282f248face7fe6 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 24 Jan 2025 10:39:30 -0500 Subject: [PATCH 03/86] Refactor InferenceService to allow passing in chunking settings --- .../inference/InferenceService.java | 39 +++++++++++++++---- .../TestDenseInferenceServiceExtension.java | 2 + .../mock/TestRerankingServiceExtension.java | 2 + .../TestSparseInferenceServiceExtension.java | 2 + ...stStreamingCompletionServiceExtension.java | 2 + .../ShardBulkInferenceActionFilter.java | 11 +++++- .../inference/services/SenderService.java | 5 ++- .../AlibabaCloudSearchService.java | 3 +- .../amazonbedrock/AmazonBedrockService.java | 3 +- .../services/anthropic/AnthropicService.java | 2 + .../azureaistudio/AzureAiStudioService.java | 3 +- .../azureopenai/AzureOpenAiService.java | 3 +- .../services/cohere/CohereService.java | 3 +- .../elastic/ElasticInferenceService.java | 2 + .../ElasticsearchInternalService.java | 5 ++- .../googleaistudio/GoogleAiStudioService.java | 3 +- .../googlevertexai/GoogleVertexAiService.java | 3 +- .../huggingface/HuggingFaceService.java | 3 +- .../elser/HuggingFaceElserService.java | 1 + .../ibmwatsonx/IbmWatsonxService.java | 3 +- .../services/jinaai/JinaAIService.java | 3 +- .../services/mistral/MistralService.java | 3 +- .../services/openai/OpenAiService.java | 3 +- .../services/SenderServiceTests.java | 2 + 24 files changed, 88 insertions(+), 23 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e1ebd8bb81ff4..d2cdf225479a3 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -126,22 +126,47 @@ void unifiedCompletionInfer( ActionListener listener ); + /** + * Chunk long text. + * + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param input Inference input + * @param taskSettings Settings in the request to override the model's defaults + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Chunked Inference result listener + */ + default void chunkedInfer( + Model model, + @Nullable String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + chunkedInfer(model, query, input, taskSettings, null, inputType, timeout, listener); + } + /** * Chunk long text. * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Chunked Inference result listener + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param input Inference input + * @param taskSettings Settings in the request to override the model's defaults + * @param chunkingSettings Chunking settings + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Chunked Inference result listener */ void chunkedInfer( Model model, @Nullable String query, List input, Map taskSettings, + @Nullable ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 1f17e335462a7..9e994a31446c7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -146,6 +147,7 @@ public void chunkedInfer( @Nullable String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 765c69e28a9ad..5f13ccf808da7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -135,6 +136,7 @@ public void chunkedInfer( @Nullable String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index f700f6672fd63..3dc2c3195bae3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -136,6 +137,7 @@ public void chunkedInfer( @Nullable String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 9355fa7d0ad48..5e63487148903 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -222,6 +223,7 @@ public void chunkedInfer( String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3933260664b7c..bd2f33926ccd8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -362,7 +362,16 @@ private void onFinish() { } }; inferenceProvider.service() - .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); + .chunkedInfer( + inferenceProvider.model(), + null, + inputs, + Map.of(), + null, + InputType.INGEST, + TimeValue.MAX_VALUE, + completionListener + ); } private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 56bf6c1359a56..fa2ccc5e9818b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -100,13 +101,14 @@ public void chunkedInfer( @Nullable String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener ) { init(); // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener); + doChunkedInfer(model, new DocumentsOnlyInput(input), taskSettings, chunkingSettings, inputType, timeout, listener); } protected abstract void doInfer( @@ -129,6 +131,7 @@ protected abstract void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 589ca1e033f06..19cdd21baaba9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -293,6 +293,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -309,7 +310,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()), - alibabaCloudSearchModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : alibabaCloudSearchModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 493acd3c0cd1a..4dfce3b07f4ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -121,6 +121,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -133,7 +134,7 @@ protected void doChunkedInfer( inputs.getInputs(), maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT, - baseAmazonBedrockModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : baseAmazonBedrockModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 64fe42fbbc171..6146ed39edfee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -228,6 +229,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 34a5c2b4cc1e9..10c0c2a14cf91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -114,6 +114,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -125,7 +126,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - baseAzureAiStudioModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : baseAzureAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 9a77b63337978..c7dd0f6717142 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -270,6 +270,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -285,7 +286,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - azureOpenAiModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : azureOpenAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 6c2d3bb96d74d..ea95285947729 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -271,6 +271,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -287,7 +288,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()), - cohereModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : cohereModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index fee66a9f84ac9..6c7aee58b3cd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -395,6 +396,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ddc5e3e1aa36c..69511ad407c95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -703,7 +703,7 @@ public void chunkedInfer( TimeValue timeout, ActionListener> listener ) { - chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); + chunkedInfer(model, null, input, taskSettings, null, inputType, timeout, listener); } @Override @@ -712,6 +712,7 @@ public void chunkedInfer( @Nullable String query, List input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -729,7 +730,7 @@ public void chunkedInfer( input, EMBEDDING_MAX_BATCH_SIZE, embeddingTypeFromTaskTypeAndSettings(model.getTaskType(), esModel.internalServiceSettings), - esModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : esModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); if (batchedRequests.isEmpty()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 205cc545a23f0..49bb992cf23a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -322,6 +322,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -332,7 +333,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - googleAiStudioModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : googleAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3e921f669e864..80761286cc71a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -221,6 +221,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -232,7 +233,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - googleVertexAiModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : googleVertexAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 73c1446b9bb26..a13591c518798 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -115,6 +115,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -131,7 +132,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - huggingFaceModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : huggingFaceModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 79001f17a4e96..88dbbdf952574 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -98,6 +98,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 3fa423c2dae19..80dbe44a5685f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -301,6 +301,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput input, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -311,7 +312,7 @@ protected void doChunkedInfer( input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - model.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : model.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings, inputType); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 37add1e264704..2e18a850868b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -253,6 +253,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -269,7 +270,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()), - jinaaiModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 3e40575e42faf..c07402de83845 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -103,6 +103,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -114,7 +115,7 @@ protected void doChunkedInfer( inputs.getInputs(), MistralConstants.MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - mistralEmbeddingsModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : mistralEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 94312a39882fd..7ec42249329c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -307,6 +307,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -323,7 +324,7 @@ protected void doChunkedInfer( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - openAiModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 69b26585d8d49..e0d8eec9d3f74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -130,6 +131,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener From 933c74d3c61381a97931b063d19c49d06d182773 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 24 Jan 2025 14:20:33 -0500 Subject: [PATCH 04/86] Add chunking config to inference field metadata and store in semantic_text field --- .../org/elasticsearch/TransportVersions.java | 1 + .../metadata/InferenceFieldMetadata.java | 46 +++++++++++++--- .../inference/ChunkingSettings.java | 4 ++ .../ShardBulkInferenceActionFilter.java | 6 ++- .../SentenceBoundaryChunkingSettings.java | 12 +++++ .../WordBoundaryChunkingSettings.java | 12 +++++ .../inference/mapper/SemanticTextField.java | 54 ++++++++++++++++++- .../mapper/SemanticTextFieldMapper.java | 43 +++++++++++++-- 8 files changed, 165 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 37473c565189b..db7823b7fce0a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -185,6 +185,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00); public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00); public static final TransportVersion ESQL_PROFILE_ASYNC_NANOS = def(9_007_00_0); + public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_008_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 8917d5a9cbbb5..f69c1632e324c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -21,9 +21,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG; + /** * Contains inference field data for fields. * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need @@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable chunkingSettings; - public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { - this(name, inferenceId, inferenceId, sourceFields); + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map chunkingSettings) { + this(name, inferenceId, inferenceId, sourceFields, chunkingSettings); } - public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) { + public InferenceFieldMetadata( + String name, + String inferenceId, + String searchInferenceId, + String[] sourceFields, + Map chunkingSettings + ) { this.name = Objects.requireNonNull(name); this.inferenceId = Objects.requireNonNull(inferenceId); this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); + this.chunkingSettings = chunkingSettings; } public InferenceFieldMetadata(StreamInput input) throws IOException { @@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { this.searchInferenceId = this.inferenceId; } this.sourceFields = input.readStringArray(); + if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { + this.chunkingSettings = input.readGenericMap(); + } else { + this.chunkingSettings = null; + } } @Override @@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(searchInferenceId); } out.writeStringArray(sourceFields); + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { + out.writeGenericMap(chunkingSettings); + } } @Override @@ -81,12 +102,13 @@ public boolean equals(Object o) { return Objects.equals(name, that.name) && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(searchInferenceId, that.searchInferenceId) - && Arrays.equals(sourceFields, that.sourceFields); + && Arrays.equals(sourceFields, that.sourceFields) + && Objects.equals(chunkingSettings, that.chunkingSettings); } @Override public int hashCode() { - int result = Objects.hash(name, inferenceId, searchInferenceId); + int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings); result = 31 * result + Arrays.hashCode(sourceFields); return result; } @@ -107,6 +129,10 @@ public String[] getSourceFields() { return sourceFields; } + public Map getChunkingSettings() { + return chunkingSettings; + } + public static Diff readDiffFrom(StreamInput in) throws IOException { return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); } @@ -119,6 +145,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId); } builder.array(SOURCE_FIELDS_FIELD, sourceFields); + if (chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS_FIELD, chunkingSettings); + } return builder.endObject(); } @@ -131,6 +160,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws String currentFieldName = null; String inferenceId = null; String searchInferenceId = null; + Map chunkingSettings = null; List inputFields = new ArrayList<>(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -151,6 +181,9 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } } } + } else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) { + chunkingSettings = parser.map(); + } else { parser.skipChildren(); } @@ -159,7 +192,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws name, inferenceId, searchInferenceId == null ? inferenceId : searchInferenceId, - inputFields.toArray(String[]::new) + inputFields.toArray(String[]::new), + chunkingSettings ); } } diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java index 2e9072626b0a8..34b3e5a6d58ee 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; +import java.util.Map; + public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable { ChunkingStrategy getChunkingStrategy(); + + Map asMap(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index bd2f33926ccd8..c0279548f4c7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -42,6 +42,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; @@ -367,7 +368,7 @@ private void onFinish() { null, inputs, Map.of(), - null, + inferenceProvider.model().getConfigurations().getChunkingSettings(), // TODO Override here InputType.INGEST, TimeValue.MAX_VALUE, completionListener @@ -451,7 +452,8 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons model != null ? new MinimalServiceSettings(model) : null, chunkMap ), - indexRequest.getContentType() + indexRequest.getContentType(), + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings()) ); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 9d6f5bb89218f..0dbf9d60b5d32 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -54,6 +55,17 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { } } + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), + sentenceOverlap + ); + } + public static SentenceBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 7e0378d5b0cd1..1987ffb8278b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -48,6 +49,17 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { overlap = in.readInt(); } + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.OVERLAP.toString(), + overlap + ); + } + public static WordBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 489951a206149..6072867acd986 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -15,6 +15,10 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; @@ -55,7 +59,8 @@ public record SemanticTextField( String fieldName, @Nullable List originalValues, InferenceResult inference, - XContentType contentType + XContentType contentType, + @Nullable ChunkingSettings chunkingSettings ) implements ToXContentObject { static final String TEXT_FIELD = "text"; @@ -69,6 +74,16 @@ public record SemanticTextField( static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; static final String MODEL_SETTINGS_FIELD = "model_settings"; + static final String TASK_TYPE_FIELD = "task_type"; + static final String DIMENSIONS_FIELD = "dimensions"; + static final String SIMILARITY_FIELD = "similarity"; + static final String ELEMENT_TYPE_FIELD = "element_type"; + // Chunking settings + static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; + static final String STRATEGY_FIELD = "strategy"; + static final String MAX_CHUNK_SIZE_FIELD = "max_chunk_size"; + static final String OVERLAP_FIELD = "overlap"; + static final String SENTENCE_OVERLAP_FIELD = "sentence_overlap"; public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map> chunks) {} @@ -135,6 +150,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); + if (chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS_FIELD, chunkingSettings); + } + if (useLegacyFormat) { builder.startArray(CHUNKS_FIELD); } else { @@ -189,7 +208,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws context.fieldName(), originalValues, (InferenceResult) args[1], - context.xContentType() + context.xContentType(), + (ChunkingSettings) args[2] ); }); @@ -212,9 +232,39 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } ); + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( + MODEL_SETTINGS_FIELD, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + DenseVectorFieldMapper.ElementType elementType = args[3] == null + ? null + : DenseVectorFieldMapper.ElementType.fromString((String) args[3]); + return new ModelSettings(taskType, dimensions, similarity, elementType); + } + ); + + private static final ConstructingObjectParser, Void> CHUNKING_SETTINGS_PARSER = new ConstructingObjectParser<>( + CHUNKING_SETTINGS_FIELD, + true, + args -> { + @SuppressWarnings("unchecked") + Map map = (Map) args[0]; + return map; + } + ); + static { SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD)); SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD)); + SEMANTIC_TEXT_FIELD_PARSER.declareObjectOrNull( + optionalConstructorArg(), + (p, c) -> CHUNKING_SETTINGS_PARSER.parse(p, null), + null, + new ParseField(CHUNKING_SETTINGS_FIELD) + ); INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); INFERENCE_RESULT_PARSER.declareObjectOrNull( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3bebd8086d792..3efd4870cca87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -58,6 +58,7 @@ import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -73,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import java.io.IOException; @@ -94,6 +96,7 @@ import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_OFFSET_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKING_SETTINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; @@ -175,6 +178,17 @@ public static class Builder extends FieldMapper.Builder { Objects::toString ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + @SuppressWarnings("unchecked") + private final Parameter chunkingSettings = new Parameter<>( + CHUNKING_SETTINGS_FIELD, + true, + () -> null, + (n, c, o) -> ChunkingSettingsBuilder.fromMap((Map) o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, + XContentBuilder::field, + Objects::toString + ).acceptsNull(); + private final Parameter> meta = Parameter.metaParam(); private Function inferenceFieldBuilder; @@ -217,9 +231,14 @@ public Builder setModelSettings(MinimalServiceSettings value) { return this; } + public Builder setChunkingSettings(ChunkingSettings value) { + this.chunkingSettings.setValue(value); + return this; + } + @Override protected Parameter[] getParameters() { - return new Parameter[] { inferenceId, searchInferenceId, modelSettings, meta }; + return new Parameter[] { inferenceId, searchInferenceId, modelSettings, chunkingSettings, meta }; } @Override @@ -235,6 +254,7 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont inferenceFieldBuilder = c -> mergedInferenceField; } + @SuppressWarnings("unchecked") @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && copyTo.copyToFields().isEmpty() == false) { @@ -261,6 +281,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { inferenceId.getValue(), searchInferenceId.getValue(), modelSettings.getValue(), + chunkingSettings.getValue(), inferenceField, useLegacyFormat, meta.getValue() @@ -294,11 +315,13 @@ private void validateServiceSettings(MinimalServiceSettings settings) { * @param mapper The mapper * @return A mapper with the copied settings applied */ + @SuppressWarnings("unchecked") private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { SemanticTextFieldMapper returnedMapper = mapper; if (mapper.fieldType().getModelSettings() == null) { Builder builder = from(mapper); builder.setModelSettings(modelSettings.getValue()); + builder.setChunkingSettings(chunkingSettings.getValue()); returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); } @@ -519,7 +542,13 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); - return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); + return new InferenceFieldMetadata( + fullPath(), + fieldType().getInferenceId(), + fieldType().getSearchInferenceId(), + copyFields, + fieldType().getChunkingSettings() != null ? fieldType().getChunkingSettings().asMap() : null + ); } @Override @@ -548,6 +577,7 @@ public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final String searchInferenceId; private final MinimalServiceSettings modelSettings; + private final ChunkingSettings chunkingSettings; private final ObjectMapper inferenceField; private final boolean useLegacyFormat; @@ -556,6 +586,7 @@ public SemanticTextFieldType( String inferenceId, String searchInferenceId, MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, ObjectMapper inferenceField, boolean useLegacyFormat, Map meta @@ -564,6 +595,7 @@ public SemanticTextFieldType( this.inferenceId = inferenceId; this.searchInferenceId = searchInferenceId; this.modelSettings = modelSettings; + this.chunkingSettings = chunkingSettings; this.inferenceField = inferenceField; this.useLegacyFormat = useLegacyFormat; } @@ -599,6 +631,10 @@ public MinimalServiceSettings getModelSettings() { return modelSettings; } + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + public ObjectMapper getInferenceField() { return inferenceField; } @@ -869,7 +905,8 @@ public List fetchValues(Source source, int doc, List ignoredValu name(), null, new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkMap), - source.sourceContentType() + source.sourceContentType(), + chunkingSettings ) ); } From 6b11f01efb23c434be9e978ded72401b84b5a178 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 27 Jan 2025 13:08:56 -0500 Subject: [PATCH 05/86] Fix test compilation errors --- .../cluster/metadata/IndexMetadataTests.java | 3 +- .../metadata/InferenceFieldMetadataTests.java | 44 ++++++++++++++++--- ...appingLookupInferenceFieldMapperTests.java | 9 +++- ...KnnVectorQueryRewriteInterceptorTests.java | 4 +- ...nticMatchQueryRewriteInterceptorTests.java | 2 +- ...rseVectorQueryRewriteInterceptorTests.java | 4 +- .../ShardBulkInferenceActionFilterTests.java | 14 +++--- .../mapper/SemanticTextFieldMapperTests.java | 8 ++-- .../mapper/SemanticTextFieldTests.java | 15 ++++++- .../queries/SemanticQueryBuilderTests.java | 4 +- 10 files changed, 83 insertions(+), 24 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 4abd0c4a9d469..a976e37ee2cf1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -694,7 +694,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) name, randomIdentifier(), randomIdentifier(), - randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new) + randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new), + InferenceFieldMetadataTests.generateRandomChunkingSettings() ); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 2d5805696320d..b9ff79dddc894 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -11,10 +11,12 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; import java.util.function.Predicate; import static org.hamcrest.Matchers.equalTo; @@ -63,13 +65,45 @@ private static InferenceFieldMetadata createTestItem() { String inferenceId = randomIdentifier(); String searchInferenceId = randomIdentifier(); String[] inputFields = generateRandomStringArray(5, 10, false, false); - return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields); + Map chunkingSettings = generateRandomChunkingSettings(); + return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings); + } + + public static Map generateRandomChunkingSettings() { + if (randomBoolean()) { + return null; // Defaults to model chunking settings + } + return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings(); + } + + private static Map generateRandomWordBoundaryChunkingSettings() { + return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(0, 50)); + } + + private static Map generateRandomSentenceBoundaryChunkingSettings() { + return Map.of( + "strategy", + "sentence_boundary", + "max_chunk_size", + randomIntBetween(20, 100), + "sentence_overlap", + randomIntBetween(0, 1) + ); } public void testNullCtorArgsThrowException() { - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null)); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of()) + ); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of()) + ); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of())); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of()) + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java index 755b83e8eb7ad..93ac31c9ba582 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; @@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf @Override public InferenceFieldMetadata getMetadata(Set sourcePaths) { - return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0])); + return new InferenceFieldMetadata( + fullPath(), + INFERENCE_ID, + SEARCH_INFERENCE_ID, + sourcePaths.toArray(new String[0]), + InferenceFieldMetadataTests.generateRandomChunkingSettings() + ); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 073bf8f5afb9a..270cdba6d3469 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -56,7 +56,7 @@ public void cleanup() { public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); @@ -67,7 +67,7 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 47705c14d5941..6987ef33ed63d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -52,7 +52,7 @@ public void cleanup() { public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = createTestQueryBuilder(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 1adad1df7b29b..075955766a0a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -54,7 +54,7 @@ public void cleanup() { public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); @@ -78,7 +78,7 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 1fca17f77ad9a..a32fbb4ffd793 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -130,7 +130,7 @@ public void testFilterNoop() throws Exception { new BulkItemRequest[0] ); request.setInferenceFieldMap( - Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false))) + Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false), null)) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -165,11 +165,11 @@ public void testInferenceNotFound() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null), "field2", - new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }, null), "field3", - new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }, null) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -233,7 +233,7 @@ public void testItemFailures() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); @@ -301,7 +301,7 @@ public void testExplicitNull() throws Exception { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); Map sourceWithNull = new HashMap<>(); sourceWithNull.put("field1", null); @@ -332,7 +332,7 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, null)); } int numRequests = atLeast(100); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index e837e1b0db989..f0d767cd2262e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -183,7 +183,7 @@ protected IngestScriptSupport ingestScriptSupport() { @Override public MappedFieldType getMappedFieldType() { - return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, false, Map.of()); + return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, null, false, Map.of()); } @Override @@ -857,7 +857,8 @@ public void testModelSettingsRequiredWithChunks() throws IOException { null, randomSemanticText.inference().chunks() ), - randomSemanticText.contentType() + randomSemanticText.contentType(), + SemanticTextFieldTests.generateRandomChunkingSettings() ); MapperService mapperService = createMapperService( @@ -901,7 +902,8 @@ private MapperService mapperServiceForFieldWithModelSettings( fieldName, List.of(), new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()), - XContentType.JSON + XContentType.JSON, + SemanticTextFieldTests.generateRandomChunkingSettings() ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 404713581eddd..4441d67fa7eef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -26,6 +27,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; +import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; @@ -273,10 +276,20 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( new MinimalServiceSettings(model), Map.of(fieldName, chunks) ), - contentType + contentType, + generateRandomChunkingSettings() ); } + public static ChunkingSettings generateRandomChunkingSettings() { + if (randomBoolean()) { + return null; // Use model defaults + } + return randomBoolean() + ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 50)) + : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); + } + /** * Returns a randomly generated object for Semantic Text tests purpose. */ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 9a3b4eff1958a..d34244bd5ce5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests; import org.junit.Before; import org.junit.BeforeClass; @@ -370,7 +371,8 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults( SEMANTIC_TEXT_FIELD, null, new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, Map.of(SEMANTIC_TEXT_FIELD, List.of())), - XContentType.JSON + XContentType.JSON, + SemanticTextFieldTests.generateRandomChunkingSettings() ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); From 7c4ba49aaeb9e42941008444a3fb20d9adb95457 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 27 Jan 2025 15:47:37 -0500 Subject: [PATCH 06/86] Hacking around trying to get ingest to work --- .../metadata/InferenceFieldMetadata.java | 2 +- .../ShardBulkInferenceActionFilter.java | 33 ++++++++++++------- .../SentenceBoundaryChunkingSettings.java | 14 ++++---- .../WordBoundaryChunkingSettings.java | 14 ++++---- .../mapper/SemanticTextFieldMapper.java | 5 +++ .../inference/services/SenderService.java | 10 +++++- 6 files changed, 48 insertions(+), 30 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index f69c1632e324c..caa2217266b2f 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -183,6 +182,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } } else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) { chunkingSettings = parser.map(); + System.out.println("foo"); } else { parser.skipChildren(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index c0279548f4c7c..3cb3174697f01 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -32,6 +32,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -134,7 +135,7 @@ private void processBulkShardRequest( new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run(); } - private record InferenceProvider(InferenceService service, Model model) {} + private record InferenceProvider(InferenceService service, Model model, ChunkingSettings chunkingSettings) {} /** * A field inference request on a single input. @@ -242,16 +243,22 @@ private void executeShardBulkInferenceAsync( public void onResponse(UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { - var provider = new InferenceProvider( - service.get(), - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ) + InferenceService inferenceService = service.get(); + Model model = inferenceService.parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() ); + // This assumes that all fields will have the same chunking settings - supporting per field chunking settings + // seems like a pretty big refactor + Map overrideChunkingSettings = fieldInferenceMap.get(requests.getFirst().field()) + .getChunkingSettings(); + ChunkingSettings chunkingSettings = overrideChunkingSettings != null + ? ChunkingSettingsBuilder.fromMap(overrideChunkingSettings) + : model.getConfigurations().getChunkingSettings(); + + var provider = new InferenceProvider(inferenceService, model, chunkingSettings); executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { @@ -368,7 +375,7 @@ private void onFinish() { null, inputs, Map.of(), - inferenceProvider.model().getConfigurations().getChunkingSettings(), // TODO Override here + inferenceProvider.chunkingSettings(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener @@ -453,7 +460,9 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons chunkMap ), indexRequest.getContentType(), - ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings()) + inferenceFieldMetadata.getChunkingSettings() != null + ? ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings()) + : null ); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 0dbf9d60b5d32..74568e8f3690a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -56,14 +57,11 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { } public Map asMap() { - return Map.of( - ChunkingSettingsOptions.STRATEGY.toString(), - STRATEGY.toString().toLowerCase(Locale.ROOT), - ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), - maxChunkSize, - ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), - sentenceOverlap - ); + Map map = new HashMap<>(); + map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); + map.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + map.put(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), sentenceOverlap); + return map; } public static SentenceBoundaryChunkingSettings fromMap(Map map) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 1987ffb8278b6..e4d5732d23b3f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -50,14 +51,11 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { } public Map asMap() { - return Map.of( - ChunkingSettingsOptions.STRATEGY.toString(), - STRATEGY.toString().toLowerCase(Locale.ROOT), - ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), - maxChunkSize, - ChunkingSettingsOptions.OVERLAP.toString(), - overlap - ); + Map map = new HashMap<>(); + map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); + map.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + map.put(ChunkingSettingsOptions.OVERLAP.toString(), overlap); + return map; } public static WordBoundaryChunkingSettings fromMap(Map map) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3efd4870cca87..d1dbc64cbd578 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -542,6 +542,11 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); + + SemanticTextFieldType fieldType = fieldType(); + ChunkingSettings fieldTypeChunkingSettings = fieldType.getChunkingSettings(); + Map asMap = fieldTypeChunkingSettings != null ? fieldTypeChunkingSettings.asMap() : null; + return new InferenceFieldMetadata( fullPath(), fieldType().getInferenceId(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index fa2ccc5e9818b..eaf5872a720ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -108,7 +108,15 @@ public void chunkedInfer( ) { init(); // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new DocumentsOnlyInput(input), taskSettings, chunkingSettings, inputType, timeout, listener); + doChunkedInfer( + model, + new DocumentsOnlyInput(input), + taskSettings, + chunkingSettings != null ? chunkingSettings : model.getConfigurations().getChunkingSettings(), + inputType, + timeout, + listener + ); } protected abstract void doInfer( From edbfde3545d6eaca7897571f7f5b98cf6aa3f208 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 27 Jan 2025 16:22:16 -0500 Subject: [PATCH 07/86] Debugging --- .../cluster/metadata/InferenceFieldMetadata.java | 6 +++++- .../action/filter/ShardBulkInferenceActionFilter.java | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index caa2217266b2f..e50ca90c644ee 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -62,6 +62,11 @@ public InferenceFieldMetadata( this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); this.chunkingSettings = chunkingSettings; + + // TODO remove this, trying to get stack traces where this called + if (chunkingSettings != null && chunkingSettings.size() != 3) { + throw new IllegalArgumentException("Chunking settings must contain exactly 3 settings"); + } } public InferenceFieldMetadata(StreamInput input) throws IOException { @@ -182,7 +187,6 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } } else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) { chunkingSettings = parser.map(); - System.out.println("foo"); } else { parser.skipChildren(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3cb3174697f01..0b7c1baafef30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -448,6 +448,9 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons .map(r -> r.input) .collect(Collectors.toList()); + // TODO remove this, used for easy debugging comparisons + Map inferenceChunkingSettings = inferenceFieldMetadata.getChunkingSettings(); + // The model can be null if we are only processing update requests that clear inference results. This is ok because we will // merge in the field's existing model settings on the data node. var result = new SemanticTextField( From c255745a883dfe4d01d27b559e7a2f60355bdac5 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 28 Jan 2025 14:21:37 +0000 Subject: [PATCH 08/86] [CI] Auto commit changes from spotless --- .../cluster/metadata/InferenceFieldMetadataTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index b9ff79dddc894..55ecf373810a2 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xcontent.XContentParser; From 889875f53ebf910677b6befd9cf555049db59659 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 28 Jan 2025 09:55:14 -0500 Subject: [PATCH 09/86] POC works and update TODO to fix this --- .../elasticsearch/cluster/metadata/InferenceFieldMetadata.java | 1 + .../action/filter/ShardBulkInferenceActionFilter.java | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index e50ca90c644ee..ffaeb6228cd36 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -46,6 +46,7 @@ public final class InferenceFieldMetadata implements SimpleDiffable chunkingSettings; + // TODO can this be ChunkingSettings instead of Map? public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map chunkingSettings) { this(name, inferenceId, inferenceId, sourceFields, chunkingSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 0b7c1baafef30..a1065a0b542cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -145,6 +145,7 @@ private record InferenceProvider(InferenceService service, Model model, Chunking * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. + * TODO Add chunking settings here instead of provider so we can chunk based on individual field settings */ private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} @@ -255,7 +256,7 @@ public void onResponse(UnparsedModel unparsedModel) { Map overrideChunkingSettings = fieldInferenceMap.get(requests.getFirst().field()) .getChunkingSettings(); ChunkingSettings chunkingSettings = overrideChunkingSettings != null - ? ChunkingSettingsBuilder.fromMap(overrideChunkingSettings) + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(overrideChunkingSettings)) : model.getConfigurations().getChunkingSettings(); var provider = new InferenceProvider(inferenceService, model, chunkingSettings); From c87753ee58676ce3780988823359aee04893be9e Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 28 Jan 2025 15:05:17 +0000 Subject: [PATCH 10/86] [CI] Auto commit changes from spotless --- .../inference/action/filter/ShardBulkInferenceActionFilter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index a1065a0b542cb..31f39092be3a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -145,7 +145,7 @@ private record InferenceProvider(InferenceService service, Model model, Chunking * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. - * TODO Add chunking settings here instead of provider so we can chunk based on individual field settings + * TODO Add chunking settings here instead of provider so we can chunk based on individual field settings */ private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} From c2c9a5294a1840d624e8829929854f5e04609a7f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 14:35:42 -0500 Subject: [PATCH 11/86] Refactor chunking settings from model settings to field inference request --- .../ShardBulkInferenceActionFilter.java | 82 ++++++++++++------- .../inference/mapper/SemanticTextField.java | 17 ---- .../elastic/ElasticInferenceService.java | 2 +- 3 files changed, 54 insertions(+), 47 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 31f39092be3a8..73204d89f3e26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -58,6 +58,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -135,7 +136,9 @@ private void processBulkShardRequest( new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run(); } - private record InferenceProvider(InferenceService service, Model model, ChunkingSettings chunkingSettings) {} + private record InferenceProvider(InferenceService service, Model model) {} + + private record ChunkedInputs(ChunkingSettings chunkingSettings, List inputs) {} /** * A field inference request on a single input. @@ -145,9 +148,17 @@ private record InferenceProvider(InferenceService service, Model model, Chunking * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. - * TODO Add chunking settings here instead of provider so we can chunk based on individual field settings + * @param chunkingSettings Additional explicitly specified chunking settings, or null to use model defaults */ - private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {} + private record FieldInferenceRequest( + int index, + String field, + String sourceField, + String input, + int inputOrder, + int offsetAdjustment, + ChunkingSettings chunkingSettings + ) {} /** * The field inference response. @@ -251,15 +262,7 @@ public void onResponse(UnparsedModel unparsedModel) { unparsedModel.settings(), unparsedModel.secrets() ); - // This assumes that all fields will have the same chunking settings - supporting per field chunking settings - // seems like a pretty big refactor - Map overrideChunkingSettings = fieldInferenceMap.get(requests.getFirst().field()) - .getChunkingSettings(); - ChunkingSettings chunkingSettings = overrideChunkingSettings != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(overrideChunkingSettings)) - : model.getConfigurations().getChunkingSettings(); - - var provider = new InferenceProvider(inferenceService, model, chunkingSettings); + var provider = new InferenceProvider(inferenceService, model); executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { @@ -306,7 +309,23 @@ public void onFailure(Exception exc) { int currentBatchSize = Math.min(requests.size(), batchSize); final List currentBatch = requests.subList(0, currentBatchSize); final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + // + // List chunkedInputs = currentBatch.stream() + // .map(request -> new ChunkedInputs(request.chunkingSettings(), List.of(request.input()))) + // .toList(); + + List chunkedInputs = currentBatch.stream() + .collect(Collectors.groupingBy(request -> Optional.ofNullable(request.chunkingSettings()))) + .entrySet() + .stream() + .map( + entry -> new ChunkedInputs( + entry.getKey().orElse(null), + entry.getValue().stream().map(FieldInferenceRequest::input).collect(Collectors.toList()) + ) + ) + .toList(); + ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { @@ -370,17 +389,20 @@ private void onFinish() { } } }; - inferenceProvider.service() - .chunkedInfer( - inferenceProvider.model(), - null, - inputs, - Map.of(), - inferenceProvider.chunkingSettings(), - InputType.INGEST, - TimeValue.MAX_VALUE, - completionListener - ); + + for (ChunkedInputs chunkedInput : chunkedInputs) { + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + null, + chunkedInput.inputs(), + Map.of(), + chunkedInput.chunkingSettings(), + InputType.INGEST, + TimeValue.MAX_VALUE, + completionListener + ); + } } private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { @@ -449,9 +471,6 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons .map(r -> r.input) .collect(Collectors.toList()); - // TODO remove this, used for easy debugging comparisons - Map inferenceChunkingSettings = inferenceFieldMetadata.getChunkingSettings(); - // The model can be null if we are only processing update requests that clear inference results. This is ok because we will // merge in the field's existing model settings on the data node. var result = new SemanticTextField( @@ -465,7 +484,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ), indexRequest.getContentType(), inferenceFieldMetadata.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings()) + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(inferenceFieldMetadata.getChunkingSettings())) : null ); @@ -525,6 +544,9 @@ private Map> createFieldInferenceRequests(Bu for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); + ChunkingSettings chunkingSettings = entry.getChunkingSettings() != null + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(entry.getChunkingSettings())) + : null; if (useLegacyFormat) { var originalFieldValue = XContentMapValues.extractValue(field, docMap); @@ -588,7 +610,9 @@ private Map> createFieldInferenceRequests(Bu List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { - fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); + fieldRequests.add( + new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) + ); // When using the inference metadata fields format, all the input values are concatenated so that the // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 6072867acd986..94100215b1a8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -16,9 +16,6 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; @@ -232,20 +229,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } ); - private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( - MODEL_SETTINGS_FIELD, - true, - args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - Integer dimensions = (Integer) args[1]; - SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); - DenseVectorFieldMapper.ElementType elementType = args[3] == null - ? null - : DenseVectorFieldMapper.ElementType.fromString((String) args[3]); - return new ModelSettings(taskType, dimensions, similarity, elementType); - } - ); - private static final ConstructingObjectParser, Void> CHUNKING_SETTINGS_PARSER = new ConstructingObjectParser<>( CHUNKING_SETTINGS_FIELD, true, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 6c7aee58b3cd4..a0f6ca260541f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -18,9 +18,9 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; From c9bfa32b68e4cd7b7dd3f35a080dc5e2d9de3f9e Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 15:36:03 -0500 Subject: [PATCH 12/86] A bit of cleanup --- .../ShardBulkInferenceActionFilter.java | 61 +++++++++---------- .../mapper/SemanticTextFieldTests.java | 1 + 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 73204d89f3e26..617f45a3f7802 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -58,7 +58,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.stream.Collectors; /** @@ -307,24 +306,26 @@ public void onFailure(Exception exc) { return; } int currentBatchSize = Math.min(requests.size(), batchSize); - final List currentBatch = requests.subList(0, currentBatchSize); + ChunkingSettings chunkingSettings = requests.get(0).chunkingSettings; + List currentBatch = new ArrayList<>(); + List others = new ArrayList<>(); + for (int i = 0; i < currentBatchSize; i++) { + FieldInferenceRequest request = requests.get(i); + if ((chunkingSettings == null && request.chunkingSettings == null) || request.chunkingSettings.equals(chunkingSettings)) { + currentBatch.add(request); + } else { + others.add(request); + } + } + final List nextBatch = requests.subList(currentBatchSize, requests.size()); - // - // List chunkedInputs = currentBatch.stream() - // .map(request -> new ChunkedInputs(request.chunkingSettings(), List.of(request.input()))) - // .toList(); - - List chunkedInputs = currentBatch.stream() - .collect(Collectors.groupingBy(request -> Optional.ofNullable(request.chunkingSettings()))) - .entrySet() - .stream() - .map( - entry -> new ChunkedInputs( - entry.getKey().orElse(null), - entry.getValue().stream().map(FieldInferenceRequest::input).collect(Collectors.toList()) - ) - ) - .toList(); + nextBatch.addAll(others); + + // We can assume current batch has all the same chunking settings + ChunkedInputs chunkedInputs = new ChunkedInputs( + chunkingSettings, + currentBatch.stream().map(r -> r.input).collect(Collectors.toList()) + ); ActionListener> completionListener = new ActionListener<>() { @Override @@ -390,19 +391,17 @@ private void onFinish() { } }; - for (ChunkedInputs chunkedInput : chunkedInputs) { - inferenceProvider.service() - .chunkedInfer( - inferenceProvider.model(), - null, - chunkedInput.inputs(), - Map.of(), - chunkedInput.chunkingSettings(), - InputType.INGEST, - TimeValue.MAX_VALUE, - completionListener - ); - } + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + null, + chunkedInputs.inputs(), + Map.of(), + chunkedInputs.chunkingSettings(), + InputType.INGEST, + TimeValue.MAX_VALUE, + completionListener + ); } private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 4441d67fa7eef..0320159d347cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -70,6 +70,7 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + assertThat(newInstance.chunkingSettings(), equalTo(expectedInstance.chunkingSettings())); MinimalServiceSettings modelSettings = newInstance.inference().modelSettings(); for (var entry : newInstance.inference().chunks().entrySet()) { var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey()); From 122aeee6bf9a9187d3b4e7ddbf16456b57e3dee9 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 14 Feb 2025 11:40:13 -0500 Subject: [PATCH 13/86] Revert a bunch of changes to try to narrow down what broke CI --- .../metadata/InferenceFieldMetadata.java | 6 ----- .../ShardBulkInferenceActionFilter.java | 25 ++++--------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index ffaeb6228cd36..129e72b4e6358 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -46,7 +46,6 @@ public final class InferenceFieldMetadata implements SimpleDiffable chunkingSettings; - // TODO can this be ChunkingSettings instead of Map? public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map chunkingSettings) { this(name, inferenceId, inferenceId, sourceFields, chunkingSettings); } @@ -63,11 +62,6 @@ public InferenceFieldMetadata( this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); this.chunkingSettings = chunkingSettings; - - // TODO remove this, trying to get stack traces where this called - if (chunkingSettings != null && chunkingSettings.size() != 3) { - throw new IllegalArgumentException("Chunking settings must contain exactly 3 settings"); - } } public InferenceFieldMetadata(StreamInput input) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 617f45a3f7802..58b32768999d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -307,25 +307,10 @@ public void onFailure(Exception exc) { } int currentBatchSize = Math.min(requests.size(), batchSize); ChunkingSettings chunkingSettings = requests.get(0).chunkingSettings; - List currentBatch = new ArrayList<>(); - List others = new ArrayList<>(); - for (int i = 0; i < currentBatchSize; i++) { - FieldInferenceRequest request = requests.get(i); - if ((chunkingSettings == null && request.chunkingSettings == null) || request.chunkingSettings.equals(chunkingSettings)) { - currentBatch.add(request); - } else { - others.add(request); - } - } - + final List currentBatch = requests.subList(0, currentBatchSize); final List nextBatch = requests.subList(currentBatchSize, requests.size()); - nextBatch.addAll(others); - - // We can assume current batch has all the same chunking settings - ChunkedInputs chunkedInputs = new ChunkedInputs( - chunkingSettings, - currentBatch.stream().map(r -> r.input).collect(Collectors.toList()) - ); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).toList(); + // TODO create ChunkedInputs here, so we send chunkingSettings in ActionListener> completionListener = new ActionListener<>() { @Override @@ -395,9 +380,9 @@ private void onFinish() { .chunkedInfer( inferenceProvider.model(), null, - chunkedInputs.inputs(), + inputs, Map.of(), - chunkedInputs.chunkingSettings(), + null, // TODO add chunking settings InputType.INGEST, TimeValue.MAX_VALUE, completionListener From b42f2622d2bd668140053aa346d9e29b63f6b5e6 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 09:18:30 -0500 Subject: [PATCH 14/86] test --- test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test.txt diff --git a/test.txt b/test.txt new file mode 100644 index 0000000000000..e69de29bb2d1d From 75404b513814e25f058b8b764f9f95bab816248d Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 12 Feb 2025 09:53:45 -0500 Subject: [PATCH 15/86] Revert "test" This reverts commit 9f4e2adba021588d1e5d3ecfce9771c7d3c74c15. --- test.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test.txt diff --git a/test.txt b/test.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 From a54804ee9a56bb18678c97733d2136a848deb1c8 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 14 Feb 2025 15:26:46 -0500 Subject: [PATCH 16/86] Fix InferenceFieldMetadataTest --- .../cluster/metadata/InferenceFieldMetadata.java | 14 ++++++++++++-- .../metadata/InferenceFieldMetadataTests.java | 7 +------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 129e72b4e6358..eb9a3131a162f 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.SimpleDiffable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ToXContentFragment; @@ -38,7 +39,8 @@ public final class InferenceFieldMetadata implements SimpleDiffable getRandomFieldsExcludeFilter() { - return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field - } - @Override protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { if (parser.nextToken() == XContentParser.Token.START_OBJECT) { @@ -56,7 +51,7 @@ protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws I @Override protected boolean supportsUnknownFields() { - return true; + return false; } private static InferenceFieldMetadata createTestItem() { From 675819c3c279b5b96f4c36de2302cb63919534fe Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 14 Feb 2025 20:32:55 +0000 Subject: [PATCH 17/86] [CI] Auto commit changes from spotless --- .../cluster/metadata/InferenceFieldMetadataTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 1b4eff6b5bbaa..8a450b11c1f15 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -16,7 +16,6 @@ import java.io.IOException; import java.util.Map; -import java.util.function.Predicate; import static org.hamcrest.Matchers.equalTo; From ccef5cc8f4563164fc20c61c9d3b82d7e0479d5a Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 14 Feb 2025 16:18:16 -0500 Subject: [PATCH 18/86] Add chunking settings back in --- .../ShardBulkInferenceActionFilter.java | 21 ++++++++++++------- .../SentenceBoundaryChunkingSettings.java | 5 +++++ .../WordBoundaryChunkingSettings.java | 5 +++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 58b32768999d8..33956019d6b23 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -58,6 +58,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; /** @@ -137,8 +138,6 @@ private void processBulkShardRequest( private record InferenceProvider(InferenceService service, Model model) {} - private record ChunkedInputs(ChunkingSettings chunkingSettings, List inputs) {} - /** * A field inference request on a single input. * @param index The index of the request in the original bulk request. @@ -306,11 +305,17 @@ public void onFailure(Exception exc) { return; } int currentBatchSize = Math.min(requests.size(), batchSize); - ChunkingSettings chunkingSettings = requests.get(0).chunkingSettings; - final List currentBatch = requests.subList(0, currentBatchSize); - final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).toList(); - // TODO create ChunkedInputs here, so we send chunkingSettings in + + final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; + final List nextBatch = new ArrayList<>(); + final List inputs = new ArrayList<>(); + for (FieldInferenceRequest request : requests) { + if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { + inputs.add(request.input); + } else { + nextBatch.add(request); + } + } ActionListener> completionListener = new ActionListener<>() { @Override @@ -382,7 +387,7 @@ private void onFinish() { null, inputs, Map.of(), - null, // TODO add chunking settings + chunkingSettings, InputType.INGEST, TimeValue.MAX_VALUE, completionListener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 74568e8f3690a..e468ed171c3b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -151,4 +151,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index e4d5732d23b3f..1efd02bb0ab99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -140,4 +140,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize, overlap); } + + @Override + public String toString() { + return Strings.toString(this); + } } From 70ac065b8308cdc7c7bd2752cd7196bf4850b86b Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 08:42:09 -0500 Subject: [PATCH 19/86] Update builder to use new map --- .../inference/action/filter/ShardBulkInferenceActionFilter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 33956019d6b23..6978091c48c9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -473,7 +473,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ), indexRequest.getContentType(), inferenceFieldMetadata.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(inferenceFieldMetadata.getChunkingSettings())) + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(new HashMap<>(inferenceFieldMetadata.getChunkingSettings()))) : null ); From 20d596c55b2e848d165c4917be2880b37b47b8fd Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 09:10:21 -0500 Subject: [PATCH 20/86] Fix compilation errors after merge --- .../xpack/inference/services/voyageai/VoyageAIService.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 16659f075c564..e24522527a24c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -276,6 +276,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -291,7 +292,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker( inputs.getInputs(), getBatchSize(voyageaiModel), - voyageaiModel.getConfigurations().getChunkingSettings() + chunkingSettings != null ? chunkingSettings : voyageaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { From 178a9dbf85ab68a66930dc9322087444574780dc Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 09:26:06 -0500 Subject: [PATCH 21/86] Debugging tests --- .../ShardBulkInferenceActionFilter.java | 53 +++++++++++-------- .../ShardBulkInferenceActionFilterTests.java | 2 +- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 9242508a79d0f..015b4960bc506 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -62,7 +62,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; @@ -264,14 +263,24 @@ private void executeShardBulkInferenceAsync( public void onResponse(UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { - InferenceService inferenceService = service.get(); - Model model = inferenceService.parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() + // InferenceService inferenceService = service.get(); + // Model model = inferenceService.parsePersistedConfigWithSecrets( + // inferenceId, + // unparsedModel.taskType(), + // unparsedModel.settings(), + // unparsedModel.secrets() + // ); + // var provider = new InferenceProvider(inferenceService, model); + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) ); - var provider = new InferenceProvider(inferenceService, model); executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); } else { try (onFinish) { @@ -317,17 +326,20 @@ public void onFailure(Exception exc) { } int currentBatchSize = Math.min(requests.size(), batchSize); - final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; - final List nextBatch = new ArrayList<>(); - final List inputs = new ArrayList<>(); - for (FieldInferenceRequest request : requests) { - if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { - inputs.add(request.input); - } else { - nextBatch.add(request); - } - } - + // TODO KD adjust here + // final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; + // final List nextBatch = new ArrayList<>(); + // final List inputs = new ArrayList<>(); + // for (FieldInferenceRequest request : requests) { + // if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { + // inputs.add(request.input); + // } else { + // nextBatch.add(request); + // } + // } + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { @@ -391,14 +403,13 @@ private void onFinish() { } } }; - inferenceProvider.service() .chunkedInfer( inferenceProvider.model(), null, inputs, Map.of(), - chunkingSettings, + null, // TODO pass in chunkingSettings InputType.INGEST, TimeValue.MAX_VALUE, completionListener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 58c836bdb1f50..8d92efc7d5ff9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -170,7 +170,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[1]; items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); From 7f1e99d0f43a1a8206824aa15f9a9247facddb6e Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 09:33:07 -0500 Subject: [PATCH 22/86] debugging --- .../ShardBulkInferenceActionFilter.java | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 015b4960bc506..525202956f728 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -62,6 +62,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; @@ -326,20 +327,16 @@ public void onFailure(Exception exc) { } int currentBatchSize = Math.min(requests.size(), batchSize); - // TODO KD adjust here - // final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; - // final List nextBatch = new ArrayList<>(); - // final List inputs = new ArrayList<>(); - // for (FieldInferenceRequest request : requests) { - // if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { - // inputs.add(request.input); - // } else { - // nextBatch.add(request); - // } - // } - final List currentBatch = requests.subList(0, currentBatchSize); - final List nextBatch = requests.subList(currentBatchSize, requests.size()); - final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; + final List nextBatch = new ArrayList<>(); + final List inputs = new ArrayList<>(); + for (FieldInferenceRequest request : requests) { + if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { + inputs.add(request.input); + } else { + nextBatch.add(request); + } + } ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { @@ -409,7 +406,7 @@ private void onFinish() { null, inputs, Map.of(), - null, // TODO pass in chunkingSettings + chunkingSettings, InputType.INGEST, TimeValue.MAX_VALUE, completionListener From 65acf8f1513be8d5a954315ea3dac773c09e987c Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 09:53:57 -0500 Subject: [PATCH 23/86] Cleanup --- .../action/filter/ShardBulkInferenceActionFilter.java | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 525202956f728..3bddc0ec44a7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -264,14 +264,6 @@ private void executeShardBulkInferenceAsync( public void onResponse(UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { - // InferenceService inferenceService = service.get(); - // Model model = inferenceService.parsePersistedConfigWithSecrets( - // inferenceId, - // unparsedModel.taskType(), - // unparsedModel.settings(), - // unparsedModel.secrets() - // ); - // var provider = new InferenceProvider(inferenceService, model); var provider = new InferenceProvider( service.get(), service.get() @@ -492,7 +484,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ), indexRequest.getContentType(), inferenceFieldMetadata.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(new HashMap<>(inferenceFieldMetadata.getChunkingSettings()))) + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(inferenceFieldMetadata.getChunkingSettings())) : null ); From 51d9aae8cfb6e88a2ce2806b47fbcbe9bf41d802 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 13:55:50 -0500 Subject: [PATCH 24/86] Add yaml test --- .../xpack/inference/InferenceFeatures.java | 4 +- .../mapper/SemanticTextFieldMapper.java | 1 + ...5_semantic_text_field_mapping_chunking.yml | 142 ++++++++++++++++++ 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 032817ce758a2..cfb7cea745ff9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -15,6 +15,7 @@ import java.util.Set; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED; @@ -49,7 +50,8 @@ public Set getTestFeatures() { SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT, SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT, SEMANTIC_KNN_FILTER_FIX, - TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE + TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE, + SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 8eec9cf8896e2..37f385e101f20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -121,6 +121,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie "semantic_text.always_emit_inference_id_fix" ); public static final NodeFeature SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS = new NodeFeature("semantic_text.skip_inference_fields"); + public static final NodeFeature SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG = new NodeFeature("semantic_text.support_chunking_config"); public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml new file mode 100644 index 0000000000000..8329482f4f87d --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -0,0 +1,142 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: custom-chunking + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 5 + + - do: + index: + index: default-chunking + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking + + - is_false: default-chunking.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking + + - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.overlap": 5 } + +--- +"We return different chunks based on configured chunking overrides or model defaults": + + - do: + search: + index: default-chunking + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all" } + - match: { hits.hits.0.highlight.inference_field.1: " Lucene internally and enjoys all the features it provides." } + From 7aaaee851e052628e3c941ee9643586d4b09d735 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 14:35:29 -0500 Subject: [PATCH 25/86] Update tests --- .../action/filter/ShardBulkInferenceActionFilterTests.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 8d92efc7d5ff9..4958f71780e7d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -77,6 +78,7 @@ import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; @@ -380,7 +382,10 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, null)); + Map chunkingSettingsMap = Optional.ofNullable(generateRandomChunkingSettings()) + .map(ChunkingSettings::asMap) + .orElse(null); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, chunkingSettingsMap)); } int numRequests = atLeast(100); From d306eaeadf613c346668e040e5ceb1a181b271ba Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 4 Mar 2025 15:53:56 -0500 Subject: [PATCH 26/86] Add chunking to test inference service --- .../mock/AbstractTestInferenceService.java | 33 +++++++++++++++++++ .../TestDenseInferenceServiceExtension.java | 7 +++- .../TestSparseInferenceServiceExtension.java | 8 ++++- .../inference/src/main/java/module-info.java | 1 + .../WordBoundaryChunkingSettings.java | 8 +++++ ...5_semantic_text_field_mapping_chunking.yml | 32 ++++++++++++------ 6 files changed, 77 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 3c29cef47d628..024866ee718a5 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -13,6 +13,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -22,9 +24,13 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunker; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Random; @@ -105,6 +111,33 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} + protected List chunkInputs(List input, ChunkingSettings chunkingSettings) { + if (chunkingSettings == null) { + return input; + } + List chunkedInputs = new ArrayList<>(); + ChunkingStrategy chunkingStrategy = chunkingSettings.getChunkingStrategy(); + if (chunkingStrategy == ChunkingStrategy.WORD) { + WordBoundaryChunker chunker = new WordBoundaryChunker(); + for (String inputString : input) { + WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; + List offsets = chunker.chunk( + inputString, + wordBoundaryChunkingSettings.maxChunkSize(), + wordBoundaryChunkingSettings.overlap() + ); + for (WordBoundaryChunker.ChunkOffset offset : offsets) { + chunkedInputs.add(inputString.substring(offset.start(), offset.end())); + } + } + } else { + // Won't implement till we need it + throw new UnsupportedOperationException("Test inference service only supports word chunking strategies"); + } + + return chunkedInputs; + } + public static class TestServiceModel extends Model { public TestServiceModel( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 1db55288002ba..328426eb4a86f 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -155,7 +155,7 @@ public void chunkedInfer( switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { ServiceSettings modelServiceSettings = model.getServiceSettings(); - listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions())); + listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions(), chunkingSettings)); } default -> listener.onFailure( new ElasticsearchStatusException( @@ -175,6 +175,11 @@ private TextEmbeddingFloatResults makeResults(List input, int dimensions return new TextEmbeddingFloatResults(embeddings); } + private List makeChunkedResults(List input, int dimensions, ChunkingSettings chunkingSettings) { + List chunkedInputs = chunkInputs(input, chunkingSettings); + return makeChunkedResults(chunkedInputs, dimensions); + } + private List makeChunkedResults(List input, int dimensions) { TextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index a2475fe98bfb5..378cf1d9994b1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -45,6 +45,7 @@ import java.util.Map; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { + @Override public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); @@ -143,7 +144,7 @@ public void chunkedInfer( ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input, chunkingSettings)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -165,6 +166,11 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } + private List makeChunkedResults(List input, ChunkingSettings chunkingSettings) { + List chunkedInputs = chunkInputs(input, chunkingSettings); + return makeChunkedResults(chunkedInputs); + } + private List makeChunkedResults(List input) { List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 78f30e7da0670..d41aa654f59e6 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -42,6 +42,7 @@ exports org.elasticsearch.xpack.inference.services; exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; + exports org.elasticsearch.xpack.inference.chunking; exports org.elasticsearch.xpack.inference.telemetry; provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.inference.InferenceFeatures; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 1efd02bb0ab99..0673b21bfb9af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -58,6 +58,14 @@ public Map asMap() { return map; } + public int maxChunkSize() { + return maxChunkSize; + } + + public int overlap() { + return overlap; + } + public static WordBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 8329482f4f87d..ce1274fc1c5aa 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -37,7 +37,7 @@ setup: - do: indices.create: - index: default-chunking + index: default-chunking-sparse body: mappings: properties: @@ -49,7 +49,7 @@ setup: - do: indices.create: - index: custom-chunking + index: default-chunking-dense body: mappings: properties: @@ -58,14 +58,26 @@ setup: inference_field: type: semantic_text inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id chunking_settings: strategy: word max_chunk_size: 10 - overlap: 5 + overlap: 1 - do: index: - index: default-chunking + index: default-chunking-sparse id: doc_1 body: keyword_field: "default sentence chunking" @@ -74,7 +86,7 @@ setup: - do: index: - index: custom-chunking + index: custom-chunking-sparse id: doc_2 body: keyword_field: "custom word chunking" @@ -86,13 +98,13 @@ setup: - do: indices.get_mapping: - index: default-chunking + index: default-chunking-sparse - is_false: default-chunking.mappings.properties.inference_field.chunking_settings - do: indices.get_mapping: - index: custom-chunking + index: custom-chunking-sparse - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.strategy": "word" } - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } @@ -103,7 +115,7 @@ setup: - do: search: - index: default-chunking + index: default-chunking-sparse body: query: semantic: @@ -122,7 +134,7 @@ setup: - do: search: - index: custom-chunking + index: custom-chunking-sparse body: query: semantic: @@ -138,5 +150,5 @@ setup: - match: { hits.hits.0._id: "doc_2" } - length: { hits.hits.0.highlight.inference_field: 2 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all" } - - match: { hits.hits.0.highlight.inference_field.1: " Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.inference_field.1: " the features it provides." } From 7cf7589125d7317af40e5ad0e0716851a2550d7b Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 5 Mar 2025 09:17:55 -0500 Subject: [PATCH 27/86] Trying to get tests to work --- .../inference/mapper/SemanticTextField.java | 18 +++++------- .../ShardBulkInferenceActionFilterTests.java | 14 ++++++++- ...cInferenceMetadataFieldsRecoveryTests.java | 29 +++++++++++++++++-- .../mapper/SemanticTextFieldMapperTests.java | 28 ++++++++++++++---- .../mapper/SemanticTextFieldTests.java | 15 ++++++++-- 5 files changed, 82 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 94100215b1a8c..1313cf665eb42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -28,6 +28,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import java.io.IOException; import java.util.ArrayList; @@ -50,6 +51,7 @@ * {@link IndexVersions#INFERENCE_METADATA_FIELDS}, null otherwise. * @param inference The inference result. * @param contentType The {@link XContentType} used to store the embeddings chunks. + * @param chunkingSettings The {@link ChunkingSettings} used to override model chunking defaults */ public record SemanticTextField( boolean useLegacyFormat, @@ -71,16 +73,7 @@ public record SemanticTextField( static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; static final String MODEL_SETTINGS_FIELD = "model_settings"; - static final String TASK_TYPE_FIELD = "task_type"; - static final String DIMENSIONS_FIELD = "dimensions"; - static final String SIMILARITY_FIELD = "similarity"; - static final String ELEMENT_TYPE_FIELD = "element_type"; - // Chunking settings static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; - static final String STRATEGY_FIELD = "strategy"; - static final String MAX_CHUNK_SIZE_FIELD = "max_chunk_size"; - static final String OVERLAP_FIELD = "overlap"; - static final String SENTENCE_OVERLAP_FIELD = "sentence_overlap"; public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map> chunks) {} @@ -194,6 +187,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser SEMANTIC_TEXT_FIELD_PARSER = new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { List originalValues = (List) args[0]; + InferenceResult inferenceResult = (InferenceResult) args[1]; + Map chunkingSettingsMap = (Map) args[2]; + ChunkingSettings chunkingSettings = chunkingSettingsMap != null ? ChunkingSettingsBuilder.fromMap(chunkingSettingsMap) : null; if (context.useLegacyFormat() == false) { if (originalValues != null && originalValues.isEmpty() == false) { throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); @@ -204,9 +200,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws context.useLegacyFormat(), context.fieldName(), originalValues, - (InferenceResult) args[1], + inferenceResult, context.xContentType(), - (ChunkingSettings) args[2] + chunkingSettings ); }); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4958f71780e7d..e5f5aea837fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -543,6 +544,9 @@ private static BulkItemRequest[] randomBulkItemRequest( for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); var model = modelMap.get(entry.getInferenceId()); + ChunkingSettings chunkingSettings = entry.getChunkingSettings() != null + ? ChunkingSettingsBuilder.fromMap(new HashMap<>(entry.getChunkingSettings())) + : null; Object inputObject = randomSemanticTextInput(); String inputText = inputObject.toString(); docMap.put(field, inputObject); @@ -562,13 +566,21 @@ private static BulkItemRequest[] randomBulkItemRequest( useLegacyFormat, field, model, + chunkingSettings, List.of(inputText), results, requestContentType ); } else { Map> inputTextMap = Map.of(field, List.of(inputText)); - semanticTextField = randomSemanticText(useLegacyFormat, field, model, List.of(inputText), requestContentType); + semanticTextField = randomSemanticText( + useLegacyFormat, + field, + model, + chunkingSettings, + List.of(inputText), + requestContentType + ); model.putResult(inputText, toChunkedResult(useLegacyFormat, inputTextMap, semanticTextField)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 8bf824fb2ed59..16901f78c0829 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -43,6 +44,7 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults; @@ -51,12 +53,14 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase { private final Model model1; private final Model model2; + private final ChunkingSettings chunkingSettings; private final boolean useSynthetic; private final boolean useIncludesExcludes; public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) { this.model1 = randomModel(TaskType.TEXT_EMBEDDING); this.model2 = randomModel(TaskType.SPARSE_EMBEDDING); + this.chunkingSettings = generateRandomChunkingSettings(); this.useSynthetic = useSynthetic; this.useIncludesExcludes = useIncludesExcludes; } @@ -105,6 +109,11 @@ protected String defaultMapping() { builder.field("similarity", model1.getServiceSettings().similarity().name()); builder.field("element_type", model1.getServiceSettings().elementType().name()); builder.endObject(); + if (chunkingSettings != null) { + builder.startObject("chunking_settings"); + chunkingSettings.toXContent(builder, null); + builder.endObject(); + } builder.endObject(); builder.startObject("semantic_2"); @@ -113,6 +122,11 @@ protected String defaultMapping() { builder.startObject("model_settings"); builder.field("task_type", model2.getTaskType().name()); builder.endObject(); + if (chunkingSettings != null) { + builder.startObject("chunking_settings"); + chunkingSettings.toXContent(builder, null); + builder.endObject(); + } builder.endObject(); builder.endObject(); @@ -244,8 +258,8 @@ private BytesReference randomSource() throws IOException { false, builder, List.of( - randomSemanticText(false, "semantic_2", model2, randomInputs(), XContentType.JSON), - randomSemanticText(false, "semantic_1", model1, randomInputs(), XContentType.JSON) + randomSemanticText(false, "semantic_2", model2, chunkingSettings, randomInputs(), XContentType.JSON), + randomSemanticText(false, "semantic_1", model1, chunkingSettings, randomInputs(), XContentType.JSON) ) ); builder.endObject(); @@ -256,6 +270,7 @@ private static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -267,7 +282,15 @@ private static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } private static List randomInputs() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index f40c496e0fabe..ccf71fd05904d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -54,6 +54,7 @@ import org.elasticsearch.index.mapper.vectors.XFeatureField; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -90,6 +91,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -642,6 +644,7 @@ public void testSuccessfulParse() throws IOException { Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(); XContentBuilder mapping = mapping(b -> { addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); @@ -670,8 +673,15 @@ public void testSuccessfulParse() throws IOException { useLegacyFormat, b, List.of( - randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON), - randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON) + randomSemanticText( + useLegacyFormat, + fieldName1, + model1, + chunkingSettings, + List.of("a b", "c"), + XContentType.JSON + ), + randomSemanticText(useLegacyFormat, fieldName2, model2, chunkingSettings, List.of("d e f"), XContentType.JSON) ) ) ) @@ -842,7 +852,15 @@ public void testDenseVectorElementType() throws IOException { public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - SemanticTextField randomSemanticText = randomSemanticText(useLegacyFormat, "field", model, List.of("a"), XContentType.JSON); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(); + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + "field", + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); SemanticTextField inferenceResults = new SemanticTextField( randomSemanticText.useLegacyFormat(), randomSemanticText.fieldName(), @@ -853,7 +871,7 @@ public void testModelSettingsRequiredWithChunks() throws IOException { randomSemanticText.inference().chunks() ), randomSemanticText.contentType(), - SemanticTextFieldTests.generateRandomChunkingSettings() + chunkingSettings ); MapperService mapperService = createMapperService( @@ -898,7 +916,7 @@ private MapperService mapperServiceForFieldWithModelSettings( List.of(), new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()), XContentType.JSON, - SemanticTextFieldTests.generateRandomChunkingSettings() + generateRandomChunkingSettings() ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index cbecea37a7fbc..d33e191b748be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -118,6 +118,7 @@ protected SemanticTextField createTestInstance() { useLegacyFormat, NAME, TestModel.createRandomInstance(), + generateRandomChunkingSettings(), rawValues, randomFrom(XContentType.values()) ); @@ -218,6 +219,7 @@ public static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -229,13 +231,22 @@ public static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, ChunkedInference results, XContentType contentType @@ -273,7 +284,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( Map.of(fieldName, chunks) ), contentType, - generateRandomChunkingSettings() + chunkingSettings ); } From 89e040da3b88494d755802a894e94adf7842ba92 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 5 Mar 2025 11:14:49 -0500 Subject: [PATCH 28/86] Shard bulk inference test never specifies chunking settings --- .../ShardBulkInferenceActionFilterTests.java | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index e5f5aea837fdc..4ec24a4ce1e72 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -36,7 +36,6 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -55,7 +54,6 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -79,7 +77,6 @@ import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; @@ -383,10 +380,7 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - Map chunkingSettingsMap = Optional.ofNullable(generateRandomChunkingSettings()) - .map(ChunkingSettings::asMap) - .orElse(null); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, chunkingSettingsMap)); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, null)); } int numRequests = atLeast(100); @@ -544,9 +538,6 @@ private static BulkItemRequest[] randomBulkItemRequest( for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); var model = modelMap.get(entry.getInferenceId()); - ChunkingSettings chunkingSettings = entry.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(entry.getChunkingSettings())) - : null; Object inputObject = randomSemanticTextInput(); String inputText = inputObject.toString(); docMap.put(field, inputObject); @@ -566,21 +557,14 @@ private static BulkItemRequest[] randomBulkItemRequest( useLegacyFormat, field, model, - chunkingSettings, + null, List.of(inputText), results, requestContentType ); } else { Map> inputTextMap = Map.of(field, List.of(inputText)); - semanticTextField = randomSemanticText( - useLegacyFormat, - field, - model, - chunkingSettings, - List.of(inputText), - requestContentType - ); + semanticTextField = randomSemanticText(useLegacyFormat, field, model, null, List.of(inputText), requestContentType); model.putResult(inputText, toChunkedResult(useLegacyFormat, inputTextMap, semanticTextField)); } From 233defd5eef0555986cbe87bfd97bbb7a7e92bae Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 5 Mar 2025 14:23:22 -0500 Subject: [PATCH 29/86] Fix test --- .../SemanticInferenceMetadataFieldsRecoveryTests.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 16901f78c0829..f5e3ad8d9c704 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -110,9 +110,8 @@ protected String defaultMapping() { builder.field("element_type", model1.getServiceSettings().elementType().name()); builder.endObject(); if (chunkingSettings != null) { - builder.startObject("chunking_settings"); + builder.field("chunking_settings"); chunkingSettings.toXContent(builder, null); - builder.endObject(); } builder.endObject(); @@ -123,9 +122,8 @@ protected String defaultMapping() { builder.field("task_type", model2.getTaskType().name()); builder.endObject(); if (chunkingSettings != null) { - builder.startObject("chunking_settings"); + builder.field("chunking_settings"); chunkingSettings.toXContent(builder, null); - builder.endObject(); } builder.endObject(); From 0b2ebf6cd83ac1dc765547f26f6f180b8685dedf Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 5 Mar 2025 14:42:08 -0500 Subject: [PATCH 30/86] Always process batches in order --- .../filter/ShardBulkInferenceActionFilter.java | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3bddc0ec44a7d..9e097fc4cfb42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -317,18 +317,20 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } + // TODO More efficiently batch requests int currentBatchSize = Math.min(requests.size(), batchSize); - final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; - final List nextBatch = new ArrayList<>(); - final List inputs = new ArrayList<>(); + final List currentBatch = new ArrayList<>(); for (FieldInferenceRequest request : requests) { - if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) { - inputs.add(request.input); - } else { - nextBatch.add(request); + if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) { + break; } + currentBatch.add(request); } + + final List nextBatch = requests.subList(currentBatch.size(), requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { From a807ab597d4ce148e6f4707818bbed211e86337f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 5 Mar 2025 16:46:57 -0500 Subject: [PATCH 31/86] Fix chunking in test inference service and yaml tests --- .../mock/AbstractTestInferenceService.java | 24 ++-- .../TestDenseInferenceServiceExtension.java | 43 ++++---- .../TestSparseInferenceServiceExtension.java | 34 +++--- ...5_semantic_text_field_mapping_chunking.yml | 103 ++++++++++++++++-- 4 files changed, 146 insertions(+), 58 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 024866ee718a5..67205a5f01d02 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -111,25 +112,24 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} - protected List chunkInputs(List input, ChunkingSettings chunkingSettings) { + protected List chunkInputs(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings == null) { - return input; + return Collections.singletonList(input); } List chunkedInputs = new ArrayList<>(); ChunkingStrategy chunkingStrategy = chunkingSettings.getChunkingStrategy(); if (chunkingStrategy == ChunkingStrategy.WORD) { WordBoundaryChunker chunker = new WordBoundaryChunker(); - for (String inputString : input) { - WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; - List offsets = chunker.chunk( - inputString, - wordBoundaryChunkingSettings.maxChunkSize(), - wordBoundaryChunkingSettings.overlap() - ); - for (WordBoundaryChunker.ChunkOffset offset : offsets) { - chunkedInputs.add(inputString.substring(offset.start(), offset.end())); - } + WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; + List offsets = chunker.chunk( + input, + wordBoundaryChunkingSettings.maxChunkSize(), + wordBoundaryChunkingSettings.overlap() + ); + for (WordBoundaryChunker.ChunkOffset offset : offsets) { + chunkedInputs.add(input.substring(offset.start(), offset.end())); } + } else { // Won't implement till we need it throw new UnsupportedOperationException("Test inference service only supports word chunking strategies"); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 328426eb4a86f..41c462a4648ad 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -36,7 +36,9 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -175,27 +177,28 @@ private TextEmbeddingFloatResults makeResults(List input, int dimensions return new TextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, int dimensions, ChunkingSettings chunkingSettings) { - List chunkedInputs = chunkInputs(input, chunkingSettings); - return makeChunkedResults(chunkedInputs, dimensions); - } - - private List makeChunkedResults(List input, int dimensions) { - TextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions); - - var results = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - results.add( - new ChunkedInferenceEmbedding( - List.of( - new TextEmbeddingFloatResults.Chunk( - nonChunkedResults.embeddings().get(i).values(), - input.get(i), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) + private List makeChunkedResults(List inputs, int dimensions, ChunkingSettings chunkingSettings) { + + List results = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + String input = inputs.get(i); + TextEmbeddingFloatResults nonChunkedResults = makeResults(inputs, dimensions); + List chunkedInput = chunkInputs(input, chunkingSettings); + List chunks = new ArrayList<>(); + int offset = 0; + for (String c : chunkedInput) { + offset = input.indexOf(c, offset); + int endOffset = offset + c.length(); + chunks.add( + new TextEmbeddingFloatResults.Chunk( + nonChunkedResults.embeddings().get(i).values(), + c, + new ChunkedInference.TextOffset(offset, endOffset) ) - ) - ); + ); + } + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 378cf1d9994b1..283fdfc6c2b79 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -34,6 +34,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -166,29 +167,24 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List input, ChunkingSettings chunkingSettings) { - List chunkedInputs = chunkInputs(input, chunkingSettings); - return makeChunkedResults(chunkedInputs); - } - - private List makeChunkedResults(List input) { + private List makeChunkedResults(List inputs, ChunkingSettings chunkingSettings) { List results = new ArrayList<>(); - for (int i = 0; i < input.size(); i++) { + for (int i = 0; i < inputs.size(); i++) { + String input = inputs.get(i); var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); + tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input, j))); } - results.add( - new ChunkedInferenceEmbedding( - List.of( - new SparseEmbeddingResults.Chunk( - tokens, - input.get(i), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) - ) - ) - ); + List chunkedInput = chunkInputs(input, chunkingSettings); + List chunks = new ArrayList<>(); + int offset = 0; + for (String c : chunkedInput) { + offset = input.indexOf(c, offset); + int endOffset = offset + c.length(); + chunks.add(new SparseEmbeddingResults.Chunk(tokens, c, new ChunkedInference.TextOffset(offset, endOffset))); + } + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index ce1274fc1c5aa..51eacabf6b070 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -75,6 +75,22 @@ setup: max_chunk_size: 10 overlap: 1 + - do: + indices.create: + index: custom-chunking-dense + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + - do: index: index: default-chunking-sparse @@ -93,6 +109,24 @@ setup: inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." refresh: true + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + --- "We return chunking configurations with mappings": @@ -100,18 +134,32 @@ setup: indices.get_mapping: index: default-chunking-sparse - - is_false: default-chunking.mappings.properties.inference_field.chunking_settings + - is_false: default-chunking-sparse.mappings.properties.inference_field.chunking_settings - do: indices.get_mapping: index: custom-chunking-sparse - - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.strategy": "word" } - - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - - match: { "custom-chunking.mappings.properties.inference_field.chunking_settings.overlap": 5 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - is_false: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } --- -"We return different chunks based on configured chunking overrides or model defaults": +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": - do: search: @@ -149,6 +197,47 @@ setup: - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_2" } - length: { hits.hits.0.highlight.inference_field: 2 } - - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all" } - - match: { hits.hits.0.highlight.inference_field.1: " the features it provides." } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } From dc48c28620be9306a40a57ca933fbec7390b2fbe Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 5 Mar 2025 22:04:41 +0000 Subject: [PATCH 32/86] [CI] Auto commit changes from spotless --- .../inference/mock/TestDenseInferenceServiceExtension.java | 2 -- .../inference/mock/TestSparseInferenceServiceExtension.java | 1 - 2 files changed, 3 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 41c462a4648ad..5f8531c2678fb 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -36,9 +36,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.nio.charset.StandardCharsets; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 283fdfc6c2b79..fc6065d7173d7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -34,7 +34,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; From 3ed563113d8275f7eeee133f60a6f28346a759c5 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 09:16:01 -0500 Subject: [PATCH 33/86] Refactor - remove convenience method with default chunking settings --- .../inference/InferenceService.java | 23 ------------------- .../ShardBulkInferenceActionFilterTests.java | 2 +- .../AlibabaCloudSearchServiceTests.java | 11 ++++++++- .../AmazonBedrockServiceTests.java | 1 + .../AzureAiStudioServiceTests.java | 1 + .../azureopenai/AzureOpenAiServiceTests.java | 1 + .../services/cohere/CohereServiceTests.java | 2 ++ .../elastic/ElasticInferenceServiceTests.java | 1 + .../ElasticsearchInternalServiceTests.java | 8 +++++++ .../GoogleAiStudioServiceTests.java | 11 ++++++++- .../HuggingFaceElserServiceTests.java | 1 + .../huggingface/HuggingFaceServiceTests.java | 2 ++ .../ibmwatsonx/IbmWatsonxServiceTests.java | 11 ++++++++- .../services/jinaai/JinaAIServiceTests.java | 1 + .../services/mistral/MistralServiceTests.java | 1 + .../services/openai/OpenAiServiceTests.java | 1 + .../voyageai/VoyageAIServiceTests.java | 1 + 17 files changed, 52 insertions(+), 27 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 8b045100e082f..646e781bc30b0 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -126,29 +126,6 @@ void unifiedCompletionInfer( ActionListener listener ); - /** - * Chunk long text. - * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Chunked Inference result listener - */ - default void chunkedInfer( - Model model, - @Nullable String query, - List input, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, query, input, taskSettings, null, inputType, timeout, listener); - } - /** * Chunk long text. * diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4ec24a4ce1e72..67ddfb0f62cf6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -481,7 +481,7 @@ private static ShardBulkInferenceActionFilter createFilter( } return null; }; - doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any(), any()); Answer modelAnswer = invocationOnMock -> { String inferenceId = (String) invocationOnMock.getArguments()[0]; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index c4c6b69b117bc..f17be8b4d7a5d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -386,7 +386,16 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.chunkedInfer( + model, + null, + input, + new HashMap<>(), + null, + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var results = listener.actionGet(TIMEOUT); assertThat(results, instanceOf(List.class)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 970dab45731bd..4858efd1f3424 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -1446,6 +1446,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep null, List.of("abc", "xyz"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index cdd8494c9b343..6202457d709be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1193,6 +1193,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 7ee595cddf084..60b7e15d32245 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -1343,6 +1343,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 7d959b9bff0a0..c4a0bbea19707 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -1454,6 +1454,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -1553,6 +1554,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 2ecd39b3991b8..bd3e564cdf6f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -552,6 +552,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { null, List.of("input text"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index e7e654b599fe6..ba3298d83702a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -96,6 +96,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; @@ -925,6 +926,7 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru null, List.of("foo", "bar"), Map.of(), + chunkingSettings, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -997,6 +999,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int null, List.of("foo", "bar"), Map.of(), + chunkingSettings, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1069,6 +1072,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte null, List.of("foo", "bar"), Map.of(), + chunkingSettings, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1156,6 +1160,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { null ); var service = createService(client); + var chunkingSettings = generateRandomChunkingSettings(); var gotResults = new AtomicBoolean(); var resultsListener = ActionListener.>wrap(chunkedResponse -> { @@ -1177,6 +1182,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { null, List.of("foo", "bar", "baz"), Map.of(), + chunkingSettings, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1239,6 +1245,7 @@ public void testChunkingLargeDocument() throws InterruptedException { new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), new WordBoundaryChunkingSettings(wordsPerChunk, 0) ); + var chunkingSettings = generateRandomChunkingSettings(); var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(resultsListener, latch); @@ -1249,6 +1256,7 @@ public void testChunkingLargeDocument() throws InterruptedException { null, List.of(input), Map.of(), + chunkingSettings, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 9828a4f21ab51..9fb7c634ec8fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -871,7 +871,16 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.chunkedInfer( + model, + null, + input, + new HashMap<>(), + null, + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 1050ac137be8d..3c221c754c325 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -97,6 +97,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE null, List.of("abc"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index b9e7cda1461cc..c42afb3f63950 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -778,6 +778,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th null, List.of("abc"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -831,6 +832,7 @@ public void testChunkedInfer() throws IOException { null, List.of("abc"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 74d055d44363d..d46d56057d5d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -723,7 +723,16 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws getUrl(webServer) ); PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.chunkedInfer( + model, + null, + input, + new HashMap<>(), + null, + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 5d2ab9e6d2f57..88c3979f42da5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -1821,6 +1821,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 6acafd59272ef..6b8384c18277c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -672,6 +672,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { null, List.of("abc", "def"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index d608f4a33ff52..b05cf111810d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -1859,6 +1859,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 6a0428e962f52..6486fd194189d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -1828,6 +1828,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo null, List.of("foo", "bar"), new HashMap<>(), + null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener From c95789d1b435768fdcd72e932cbb6644bb4f7b6c Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 09:23:23 -0500 Subject: [PATCH 34/86] Fix ShardBulkInferenceActionFilterTests --- .../action/filter/ShardBulkInferenceActionFilterTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 67ddfb0f62cf6..db6202e3f7965 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -462,7 +462,7 @@ private static ShardBulkInferenceActionFilter createFilter( Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; + ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[7]; Runnable runnable = () -> { List results = new ArrayList<>(); for (String input : inputs) { From 75031e17d26698c992432f5b26b04697c11bd695 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 09:30:24 -0500 Subject: [PATCH 35/86] Fix ElasticsearchInternalServiceTests --- .../elasticsearch/ElasticsearchInternalServiceTests.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index ba3298d83702a..c7f759f3c294f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -96,7 +96,6 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; @@ -1160,7 +1159,6 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { null ); var service = createService(client); - var chunkingSettings = generateRandomChunkingSettings(); var gotResults = new AtomicBoolean(); var resultsListener = ActionListener.>wrap(chunkedResponse -> { @@ -1182,7 +1180,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { null, List.of("foo", "bar", "baz"), Map.of(), - chunkingSettings, + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1245,7 +1243,6 @@ public void testChunkingLargeDocument() throws InterruptedException { new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), new WordBoundaryChunkingSettings(wordsPerChunk, 0) ); - var chunkingSettings = generateRandomChunkingSettings(); var latch = new CountDownLatch(1); var latchedListener = new LatchedActionListener<>(resultsListener, latch); @@ -1256,7 +1253,7 @@ public void testChunkingLargeDocument() throws InterruptedException { null, List.of(input), Map.of(), - chunkingSettings, + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener From d051cd24a0cdd5dd365501a91a323fe171dc14ba Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 10:19:34 -0500 Subject: [PATCH 36/86] Fix SemanticTextFieldMapperTests --- .../inference/mapper/SemanticTextField.java | 18 ++++++++++++++++++ .../mapper/SemanticTextFieldMapper.java | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 1313cf665eb42..b05d1075dff17 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -125,6 +125,24 @@ static MinimalServiceSettings parseModelSettingsFromMap(Object node) { } } + static ChunkingSettings parseChunkingSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, CHUNKING_SETTINGS_FIELD); + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return ChunkingSettingsBuilder.fromMap(map); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + @Override public List originalValues() { return originalValues != null ? originalValues : Collections.emptyList(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 37f385e101f20..a20846b6b9ffc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -184,7 +184,7 @@ public static class Builder extends FieldMapper.Builder { CHUNKING_SETTINGS_FIELD, true, () -> null, - (n, c, o) -> ChunkingSettingsBuilder.fromMap((Map) o), + (n, c, o) -> SemanticTextField.parseChunkingSettingsFromMap(o), mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, XContentBuilder::field, Objects::toString From 8913177bb63a8a200eff38cda9665391c7611b5b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 6 Mar 2025 15:30:16 +0000 Subject: [PATCH 37/86] [CI] Auto commit changes from spotless --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index a20846b6b9ffc..3e6ffadd7898c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -74,7 +74,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; -import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import java.io.IOException; From 2ab5aec5c392506e0d42386985f24453e21925c4 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 10:24:15 -0500 Subject: [PATCH 38/86] Fix test data to fit within bounds --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 2 -- .../xpack/inference/mapper/SemanticTextFieldTests.java | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3e6ffadd7898c..11d521689651e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -254,7 +254,6 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont inferenceFieldBuilder = c -> mergedInferenceField; } - @SuppressWarnings("unchecked") @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { if (useLegacyFormat && copyTo.copyToFields().isEmpty() == false) { @@ -315,7 +314,6 @@ private void validateServiceSettings(MinimalServiceSettings settings) { * @param mapper The mapper * @return A mapper with the copied settings applied */ - @SuppressWarnings("unchecked") private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { SemanticTextFieldMapper returnedMapper = mapper; if (mapper.fieldType().getModelSettings() == null) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d33e191b748be..7e4d93b61347d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -293,7 +293,7 @@ public static ChunkingSettings generateRandomChunkingSettings() { return null; // Use model defaults } return randomBoolean() - ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 50)) + ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 10)) : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); } From 92d70dc51b350304b6822556e22574eb227f7668 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 14:13:34 -0500 Subject: [PATCH 39/86] Add additional yaml test cases --- ...5_semantic_text_field_mapping_chunking.yml | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 51eacabf6b070..a1bb73ae29a51 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -241,3 +241,136 @@ setup: - length: { hits.hits.0.highlight.inference_field: 2 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - semantic: + field: "default_chunked_inference_field" + query: "What is Elasticsearch?" + - semantic: + field: "customized_chunked_inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 2 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 2 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 2 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } From 11066718738b77dfafb4599296580723b853eb74 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 6 Mar 2025 16:20:11 -0500 Subject: [PATCH 40/86] Playing with xcontent parsing --- .../xpack/inference/mapper/SemanticTextField.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index b05d1075dff17..a6e68f7d77189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -159,7 +159,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); if (chunkingSettings != null) { - builder.field(CHUNKING_SETTINGS_FIELD, chunkingSettings); + builder.startObject(CHUNKING_SETTINGS_FIELD); + builder.mapContents(chunkingSettings.asMap()); + builder.endObject(); } if (useLegacyFormat) { @@ -206,8 +208,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { List originalValues = (List) args[0]; InferenceResult inferenceResult = (InferenceResult) args[1]; - Map chunkingSettingsMap = (Map) args[2]; - ChunkingSettings chunkingSettings = chunkingSettingsMap != null ? ChunkingSettingsBuilder.fromMap(chunkingSettingsMap) : null; + ChunkingSettings chunkingSettings = (ChunkingSettings) args[2]; if (context.useLegacyFormat() == false) { if (originalValues != null && originalValues.isEmpty() == false) { throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); @@ -243,13 +244,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } ); - private static final ConstructingObjectParser, Void> CHUNKING_SETTINGS_PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser CHUNKING_SETTINGS_PARSER = new ConstructingObjectParser<>( CHUNKING_SETTINGS_FIELD, true, args -> { @SuppressWarnings("unchecked") Map map = (Map) args[0]; - return map; + return map != null && map.isEmpty() == false ? ChunkingSettingsBuilder.fromMap(map) : null; } ); @@ -258,7 +259,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD)); SEMANTIC_TEXT_FIELD_PARSER.declareObjectOrNull( optionalConstructorArg(), - (p, c) -> CHUNKING_SETTINGS_PARSER.parse(p, null), + (p, c) -> p.map(), null, new ParseField(CHUNKING_SETTINGS_FIELD) ); From fb2cc28d0aeb03145c78d557524396ceda57534a Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 7 Mar 2025 16:49:07 -0500 Subject: [PATCH 41/86] A little cleanup --- .../metadata/InferenceFieldMetadata.java | 4 -- .../ShardBulkInferenceActionFilter.java | 1 - .../mapper/SemanticTextFieldMapper.java | 1 - ...5_semantic_text_field_mapping_chunking.yml | 51 +++++++++++++++++++ 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index eb9a3131a162f..0729739144cc7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -64,10 +64,6 @@ public InferenceFieldMetadata( this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); this.chunkingSettings = chunkingSettings; - - if (chunkingSettings != null && chunkingSettings.size() != EXPECTED_CHUNKING_SETTINGS_SIZE) { - throw new IllegalArgumentException("Chunking settings did not contain expected number of entries, was: " + chunkingSettings); - } } public InferenceFieldMetadata(StreamInput input) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fee6f3eb572f1..8bdb0e8164793 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -317,7 +317,6 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - // TODO More efficiently batch requests int currentBatchSize = Math.min(requests.size(), batchSize); final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; final List currentBatch = new ArrayList<>(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 11d521689651e..cd5dfcec3436d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -188,7 +188,6 @@ public static class Builder extends FieldMapper.Builder { XContentBuilder::field, Objects::toString ).acceptsNull(); - private final Parameter> meta = Parameter.metaParam(); private Function inferenceFieldBuilder; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index a1bb73ae29a51..b217b518ddabb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -374,3 +374,54 @@ setup: - match: { hits.hits.2._id: "doc_2" } - length: { hits.hits.2.highlight.inference_field: 1 } - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + From 1745dc91de20d0707439f9c98fb3b1a23e38fcfe Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 7 Mar 2025 16:50:34 -0500 Subject: [PATCH 42/86] Update docs/changelog/121041.yaml --- docs/changelog/121041.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/121041.yaml diff --git a/docs/changelog/121041.yaml b/docs/changelog/121041.yaml new file mode 100644 index 0000000000000..44a51a966c0a1 --- /dev/null +++ b/docs/changelog/121041.yaml @@ -0,0 +1,5 @@ +pr: 121041 +summary: Support configurable chunking in `semantic_text` fields +area: Relevance +type: enhancement +issues: [] From 525eed28d53915bc47f93bfd1621ac6d9fe07dfc Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 10 Mar 2025 13:56:14 -0400 Subject: [PATCH 43/86] Fix failures introduced by merge --- .../action/filter/ShardBulkInferenceActionFilter.java | 8 +++++--- .../filter/ShardBulkInferenceActionFilterTests.java | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index dc2b78ce5566a..8e0f25e3c895d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -46,8 +46,8 @@ import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; -import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.InferenceException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; @@ -318,7 +318,7 @@ public void onFailure(Exception exc) { return; } int currentBatchSize = Math.min(requests.size(), batchSize); - final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings; + final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; final List currentBatch = new ArrayList<>(); for (FieldInferenceRequest request : requests) { if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) { @@ -621,7 +621,9 @@ private Map> createFieldInferenceRequests(Bu new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) ); } else { - fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)); + fieldRequests.add( + new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) + ); } // When using the inference metadata fields format, all the input values are concatenated so that the diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 8c58c1ce302d0..ea3a1e43df9d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -408,7 +408,7 @@ public void testHandleEmptyInput() throws Exception { Task task = mock(Task.class); Map inferenceFieldMap = Map.of( "semantic_text_field", - new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }) + new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; From 45ab0eb3b851a76147f6eed75468e88ba13d8991 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 10 Mar 2025 18:16:13 +0000 Subject: [PATCH 44/86] [CI] Auto commit changes from spotless --- .../inference/mock/TestDenseInferenceServiceExtension.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 9c559592433a2..8e0bacf167434 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -175,7 +175,11 @@ private TextEmbeddingFloatResults makeResults(List input, ServiceSetting return new TextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List inputs, ServiceSettings serviceSettings, ChunkingSettings chunkingSettings) { + private List makeChunkedResults( + List inputs, + ServiceSettings serviceSettings, + ChunkingSettings chunkingSettings + ) { TextEmbeddingFloatResults nonChunkedResults = makeResults(inputs, serviceSettings); var results = new ArrayList(); for (int i = 0; i < inputs.size(); i++) { From 42c84494f841d3c20fb07e109cefb43d615fd7dc Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 13 Mar 2025 13:06:26 -0400 Subject: [PATCH 45/86] Address PR feedback --- .../cluster/metadata/InferenceFieldMetadata.java | 5 +++-- .../metadata/InferenceFieldMetadataTests.java | 9 ++++++++- .../mock/AbstractTestInferenceService.java | 2 +- .../mock/TestDenseInferenceServiceExtension.java | 3 +-- .../mock/TestSparseInferenceServiceExtension.java | 11 +++++------ .../filter/ShardBulkInferenceActionFilter.java | 4 ++++ .../chunking/SentenceBoundaryChunkingSettings.java | 1 + .../chunking/WordBoundaryChunkingSettings.java | 1 + .../inference/mapper/SemanticTextFieldMapper.java | 8 +------- .../ElasticsearchInternalService.java | 11 ----------- .../ElasticsearchInternalServiceTests.java | 10 +++++++--- .../25_semantic_text_field_mapping_chunking.yml | 14 ++++++-------- 12 files changed, 38 insertions(+), 41 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 0729739144cc7..63c61caf08824 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -22,6 +22,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -40,7 +42,6 @@ public final class InferenceFieldMetadata implements SimpleDiffable { @@ -50,7 +52,12 @@ protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws I @Override protected boolean supportsUnknownFields() { - return false; + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> field.equals(CHUNKING_SETTINGS_FIELD) || field.startsWith(CHUNKING_SETTINGS_FIELD + "."); } private static InferenceFieldMetadata createTestItem() { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 67205a5f01d02..28a9ca61d23dc 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -114,7 +114,7 @@ public void close() throws IOException {} protected List chunkInputs(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings == null) { - return Collections.singletonList(input); + return List.of(input); } List chunkedInputs = new ArrayList<>(); ChunkingStrategy chunkingStrategy = chunkingSettings.getChunkingStrategy(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 8e0bacf167434..6bc9d8a203ca6 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -180,7 +180,6 @@ private List makeChunkedResults( ServiceSettings serviceSettings, ChunkingSettings chunkingSettings ) { - TextEmbeddingFloatResults nonChunkedResults = makeResults(inputs, serviceSettings); var results = new ArrayList(); for (int i = 0; i < inputs.size(); i++) { String input = inputs.get(i); @@ -192,7 +191,7 @@ private List makeChunkedResults( int endOffset = offset + c.length(); chunks.add( new TextEmbeddingFloatResults.Chunk( - nonChunkedResults.embeddings().get(i).values(), + makeResults(List.of(c), serviceSettings).embeddings().getFirst().values(), new ChunkedInference.TextOffset(offset, endOffset) ) ); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 5c8b5441c92b1..c2ff064e98a82 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -168,16 +168,15 @@ private SparseEmbeddingResults makeResults(List input) { private List makeChunkedResults(List inputs, ChunkingSettings chunkingSettings) { List results = new ArrayList<>(); - for (int i = 0; i < inputs.size(); i++) { - String input = inputs.get(i); - var tokens = new ArrayList(); - for (int j = 0; j < 5; j++) { - tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input, j))); - } + for (String input : inputs) { List chunkedInput = chunkInputs(input, chunkingSettings); List chunks = new ArrayList<>(); int offset = 0; for (String c : chunkedInput) { + var tokens = new ArrayList(); + for (int i = 0; i < 5; i++) { + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(input, i))); + } offset = input.indexOf(c, offset); int endOffset = offset + c.length(); chunks.add(new SparseEmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(offset, endOffset))); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 8e0f25e3c895d..5e016de8570f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -317,6 +317,10 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } + + // Create a batch of requests that come in. The number of requests in the batch must be <= the configured batch size, + // and they must all have the same configured chunking settings. As requests must be processed in order, this means + // that some batches with several different chunking settings may result in more, smaller batches. int currentBatchSize = Math.min(requests.size(), batchSize); final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; final List currentBatch = new ArrayList<>(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index e468ed171c3b8..b48e8e78ce4ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -56,6 +56,7 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { } } + @Override public Map asMap() { Map map = new HashMap<>(); map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 0673b21bfb9af..46fa35b649702 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -50,6 +50,7 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { overlap = in.readInt(); } + @Override public Map asMap() { Map map = new HashMap<>(); map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 0a225e1c9244d..4a8da81a2e81e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -546,13 +546,7 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { ChunkingSettings fieldTypeChunkingSettings = fieldType.getChunkingSettings(); Map asMap = fieldTypeChunkingSettings != null ? fieldTypeChunkingSettings.asMap() : null; - return new InferenceFieldMetadata( - fullPath(), - fieldType().getInferenceId(), - fieldType().getSearchInferenceId(), - copyFields, - fieldType().getChunkingSettings() != null ? fieldType().getChunkingSettings().asMap() : null - ); + return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields, asMap); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 36a8ec1b6ded3..087383c898671 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -695,17 +695,6 @@ public void inferRerank( client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, null, inputType, timeout, listener); - } - @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 86203ab19f0d7..44e130c05d3d2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -925,7 +925,7 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru null, List.of("a", "bb"), Map.of(), - chunkingSettings, + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -998,7 +998,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int null, List.of("a", "bb"), Map.of(), - chunkingSettings, + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1071,7 +1071,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte null, List.of("a", "bb"), Map.of(), - chunkingSettings, + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1115,8 +1115,10 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(null); service.chunkedInfer( model, + null, List.of("foo", "bar"), Map.of(), + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) @@ -1126,8 +1128,10 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(256); service.chunkedInfer( model, + null, List.of("foo", "bar"), Map.of(), + null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index b217b518ddabb..2d4b79b7264fc 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -134,7 +134,7 @@ setup: indices.get_mapping: index: default-chunking-sparse - - is_false: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings - do: indices.get_mapping: @@ -148,7 +148,7 @@ setup: indices.get_mapping: index: default-chunking-dense - - is_false: default-chunking-dense.mappings.properties.inference_field.chunking_settings + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings - do: indices.get_mapping: @@ -280,12 +280,10 @@ setup: query: bool: should: - - semantic: - field: "default_chunked_inference_field" - query: "What is Elasticsearch?" - - semantic: - field: "customized_chunked_inference_field" - query: "What is Elasticsearch?" + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" highlight: fields: default_chunked_inference_field: From 71edec2c99cae11b135d97f7a3b3d7cb6cc2baae Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 13 Mar 2025 18:37:45 +0000 Subject: [PATCH 46/86] [CI] Auto commit changes from spotless --- .../elasticsearch/cluster/metadata/InferenceFieldMetadata.java | 2 -- .../xpack/inference/mock/AbstractTestInferenceService.java | 1 - 2 files changed, 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 63c61caf08824..70db5517180fe 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -22,8 +22,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 28a9ca61d23dc..3bdede4bff389 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; From 6a374494dee965ddfcb5e296067f96248ac06aa6 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 13 Mar 2025 14:40:00 -0400 Subject: [PATCH 47/86] Fix predicate in updated test --- .../cluster/metadata/InferenceFieldMetadataTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 2a7fece8eed73..28bebec1c9d47 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -57,7 +57,8 @@ protected boolean supportsUnknownFields() { @Override protected Predicate getRandomFieldsExcludeFilter() { - return field -> field.equals(CHUNKING_SETTINGS_FIELD) || field.startsWith(CHUNKING_SETTINGS_FIELD + "."); + // do not add elements at the top-level as any element at this level is parsed as a new inference field + return field -> field.equals("") || field.equals(CHUNKING_SETTINGS_FIELD) || field.startsWith(CHUNKING_SETTINGS_FIELD + "."); } private static InferenceFieldMetadata createTestItem() { From c076f92f8d3aa3419bd09d8e636c2e9e6f40e4aa Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 13 Mar 2025 15:13:59 -0400 Subject: [PATCH 48/86] Better handling of null/empty ChunkingSettings --- .../ShardBulkInferenceActionFilter.java | 8 +--- .../chunking/ChunkingSettingsBuilder.java | 26 +++++++++---- .../inference/mapper/SemanticTextField.java | 4 +- .../ChunkingSettingsBuilderTests.java | 8 +++- ...5_semantic_text_field_mapping_chunking.yml | 39 +++++++++++++++++++ 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 5e016de8570f0..c71c35e899285 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -488,9 +488,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons chunkMap ), indexRequest.getContentType(), - inferenceFieldMetadata.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(inferenceFieldMetadata.getChunkingSettings())) - : null + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false) ); if (useLegacyFormat) { @@ -549,9 +547,7 @@ private Map> createFieldInferenceRequests(Bu for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); - ChunkingSettings chunkingSettings = entry.getChunkingSettings() != null - ? ChunkingSettingsBuilder.fromMap(new HashMap<>(entry.getChunkingSettings())) - : null; + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(entry.getChunkingSettings(), false); if (useLegacyFormat) { var originalFieldValue = XContentMapValues.extractValue(field, docMap); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 2ede1684e315b..25553a4c760f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -10,6 +10,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; +import java.util.HashMap; import java.util.Map; public class ChunkingSettingsBuilder { @@ -18,13 +19,24 @@ public class ChunkingSettingsBuilder { public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); public static ChunkingSettings fromMap(Map settings) { - if (settings == null) { - return OLD_DEFAULT_SETTINGS; - } + return fromMap(settings, true); + } - if (settings.isEmpty()) { - return DEFAULT_SETTINGS; + public static ChunkingSettings fromMap(Map settings, boolean returnDefaultValues) { + + if (returnDefaultValues) { + if (settings == null) { + return OLD_DEFAULT_SETTINGS; + } + if (settings.isEmpty()) { + return DEFAULT_SETTINGS; + } + } else { + if (settings == null || settings.isEmpty()) { + return null; + } } + if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) { throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided"); } @@ -33,8 +45,8 @@ public static ChunkingSettings fromMap(Map settings) { settings.get(ChunkingSettingsOptions.STRATEGY.toString()).toString() ); return switch (chunkingStrategy) { - case WORD -> WordBoundaryChunkingSettings.fromMap(settings); - case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(settings); + case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); + case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 95a8e5c72b260..e1b0056dd6a30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -137,7 +137,7 @@ static ChunkingSettings parseChunkingSettingsFromMap(Object node) { map, XContentType.JSON ); - return ChunkingSettingsBuilder.fromMap(map); + return ChunkingSettingsBuilder.fromMap(map, false); } catch (Exception exc) { throw new ElasticsearchException(exc); } @@ -250,7 +250,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws args -> { @SuppressWarnings("unchecked") Map map = (Map) args[0]; - return map != null && map.isEmpty() == false ? ChunkingSettingsBuilder.fromMap(map) : null; + return ChunkingSettingsBuilder.fromMap(map, false); } ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 4a284e0a84ff5..9e6dde60bc641 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -21,14 +21,18 @@ public class ChunkingSettingsBuilderTests extends ESTestCase { public void testNullChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null); - assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(null, false); + assertNull(chunkingSettingsOrNull); } public void testEmptyChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(Collections.emptyMap()); - assertEquals(DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(Map.of(), false); + assertNull(chunkingSettingsOrNull); } public void testChunkingStrategyNotProvided() { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 2d4b79b7264fc..0181312451c5a 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -158,6 +158,45 @@ setup: - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + --- "We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": From 2ef235cf57185cc316e076e0f01b8b8289013dbe Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 13 Mar 2025 16:06:59 -0400 Subject: [PATCH 49/86] Update parsing settings --- .../elasticsearch/xpack/inference/mapper/SemanticTextField.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index e1b0056dd6a30..fc1d497347e76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -259,7 +259,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD)); SEMANTIC_TEXT_FIELD_PARSER.declareObjectOrNull( optionalConstructorArg(), - (p, c) -> p.map(), + (p, c) -> CHUNKING_SETTINGS_PARSER.parse(p, null), null, new ParseField(CHUNKING_SETTINGS_FIELD) ); From 311d8406e538220da5e39bb3d77e0541b2a3abc5 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 13 Mar 2025 16:34:26 -0400 Subject: [PATCH 50/86] Fix errors post merge --- .../inference/mock/TestSparseInferenceServiceExtension.java | 4 ++-- .../xpack/inference/services/deepseek/DeepSeekService.java | 2 ++ .../inference/services/deepseek/DeepSeekServiceTests.java | 2 +- .../services/elastic/ElasticInferenceServiceTests.java | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 63cf8f763cc84..21aeb01a7848c 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -34,7 +34,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -180,7 +179,8 @@ private List makeChunkedResults(List inputs, ChunkingS } offset = input.indexOf(c, offset); int endOffset = offset + c.length(); - chunks.add(new EmbeddingResults.Chunk(tokens, new ChunkedInference.TextOffset(offset, endOffset))); + var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); + chunks.add(new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(offset, endOffset))); } ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 6338cee473cbd..2404109da0912 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -107,6 +108,7 @@ protected void doChunkedInfer( Model model, DocumentsOnlyInput inputs, Map taskSettings, + ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 277eba9e7dbfc..d80971ebd1088 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -371,7 +371,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { public void testDoChunkedInferAlwaysFails() throws IOException { try (var service = createService()) { - service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { + service.doChunkedInfer(mock(), mock(), Map.of(), null, InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { assertThat(e, isA(UnsupportedOperationException.class)); assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion")); })); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index c6ef1581a750f..a93f59a98e2b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -631,6 +631,7 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException null, List.of("input text"), new HashMap<>(), + null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener From e3e15d259be50da8d9d8eb8667310ad01a743e57 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 19 Mar 2025 11:14:01 -0400 Subject: [PATCH 51/86] PR feedback --- .../mock/TestDenseInferenceServiceExtension.java | 5 +++-- .../mock/TestSparseInferenceServiceExtension.java | 3 ++- .../filter/ShardBulkInferenceActionFilter.java | 5 ++--- .../chunking/SentenceBoundaryChunkingSettings.java | 14 ++++++++------ .../chunking/WordBoundaryChunkingSettings.java | 13 ++++++++----- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 99d433b402ce8..3aad6c334789c 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -191,11 +191,12 @@ private List makeChunkedResults( offset = input.indexOf(c, offset); int endOffset = offset + c.length(); chunks.add( - new EmbeddingResults.Chunk( - makeResults(List.of(c), serviceSettings).embeddings().getFirst(), + new TextEmbeddingFloatResults.Chunk( + makeResults(List.of(c), serviceSettings).embeddings().get(0), new ChunkedInference.TextOffset(offset, endOffset) ) ); + offset = endOffset; } ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 21aeb01a7848c..4cb82caac6cee 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -175,12 +175,13 @@ private List makeChunkedResults(List inputs, ChunkingS for (String c : chunkedInput) { var tokens = new ArrayList(); for (int i = 0; i < 5; i++) { - tokens.add(new WeightedToken("feature_" + i, generateEmbedding(input, i))); + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c, i))); } offset = input.indexOf(c, offset); int endOffset = offset + c.length(); var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); chunks.add(new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(offset, endOffset))); + offset = endOffset; } ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index c71c35e899285..b77348ffba606 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -318,9 +318,8 @@ public void onFailure(Exception exc) { return; } - // Create a batch of requests that come in. The number of requests in the batch must be <= the configured batch size, - // and they must all have the same configured chunking settings. As requests must be processed in order, this means - // that some batches with several different chunking settings may result in more, smaller batches. + // Batch requests in the order they are specified, grouping by field and chunking settings. + // As each field may have different chunking settings specified, the size of the batch will be <= the configured batchSize. int currentBatchSize = Math.min(requests.size(), batchSize); final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; final List currentBatch = new ArrayList<>(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index b48e8e78ce4ba..6eb16d00748f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -22,7 +22,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -58,11 +57,14 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { @Override public Map asMap() { - Map map = new HashMap<>(); - map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); - map.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); - map.put(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), sentenceOverlap); - return map; + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), + sentenceOverlap + ); } public static SentenceBoundaryChunkingSettings fromMap(Map map) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 46fa35b649702..1f2391f9c4dc0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -52,11 +52,14 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { @Override public Map asMap() { - Map map = new HashMap<>(); - map.put(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY.toString().toLowerCase(Locale.ROOT)); - map.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); - map.put(ChunkingSettingsOptions.OVERLAP.toString(), overlap); - return map; + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.OVERLAP.toString(), + overlap + ); } public int maxChunkSize() { From 4224159ac6b7c5a292f4858ed1a9c136b9a5748b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 19 Mar 2025 15:28:34 +0000 Subject: [PATCH 52/86] [CI] Auto commit changes from spotless --- .../xpack/inference/mock/TestDenseInferenceServiceExtension.java | 1 - .../xpack/inference/chunking/WordBoundaryChunkingSettings.java | 1 - 2 files changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 3aad6c334789c..0561b0b8c5bf6 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -36,7 +36,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 1f2391f9c4dc0..97f8aa49ef4d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; From ad090d7faa73c43a51691a09153ef9e7306bcff0 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 19 Mar 2025 14:13:31 -0400 Subject: [PATCH 53/86] PR feedback and fix Xcontent parsing for SemanticTextField --- .../metadata/InferenceFieldMetadataTests.java | 2 +- .../ShardBulkInferenceActionFilter.java | 4 +- .../inference/mapper/SemanticTextField.java | 50 ++++++----------- .../mapper/SemanticTextFieldMapper.java | 19 +++++-- .../mapper/SemanticTextFieldMapperTests.java | 9 ++- .../mapper/SemanticTextFieldTests.java | 6 +- .../queries/SemanticQueryBuilderTests.java | 10 +++- ...5_semantic_text_field_mapping_chunking.yml | 55 +++++++++++++++++++ 8 files changed, 105 insertions(+), 50 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 28bebec1c9d47..fd6f706c9e392 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -58,7 +58,7 @@ protected boolean supportsUnknownFields() { @Override protected Predicate getRandomFieldsExcludeFilter() { // do not add elements at the top-level as any element at this level is parsed as a new inference field - return field -> field.equals("") || field.equals(CHUNKING_SETTINGS_FIELD) || field.startsWith(CHUNKING_SETTINGS_FIELD + "."); + return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD); } private static InferenceFieldMetadata createTestItem() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index b77348ffba606..ac3acfe7114cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -484,10 +484,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), - indexRequest.getContentType(), - ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false) + indexRequest.getContentType() ); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index fc1d497347e76..0408857af127c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -51,15 +51,13 @@ * {@link IndexVersions#INFERENCE_METADATA_FIELDS}, null otherwise. * @param inference The inference result. * @param contentType The {@link XContentType} used to store the embeddings chunks. - * @param chunkingSettings The {@link ChunkingSettings} used to override model chunking defaults */ public record SemanticTextField( boolean useLegacyFormat, String fieldName, @Nullable List originalValues, InferenceResult inference, - XContentType contentType, - @Nullable ChunkingSettings chunkingSettings + XContentType contentType ) implements ToXContentObject { static final String TEXT_FIELD = "text"; @@ -75,7 +73,12 @@ public record SemanticTextField( static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; - public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map> chunks) {} + public record InferenceResult( + String inferenceId, + MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, + Map> chunks + ) {} public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {} @@ -131,12 +134,6 @@ static ChunkingSettings parseChunkingSettingsFromMap(Object node) { } try { Map map = XContentMapValues.nodeMapValue(node, CHUNKING_SETTINGS_FIELD); - XContentParser parser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); return ChunkingSettingsBuilder.fromMap(map, false); } catch (Exception exc) { throw new ElasticsearchException(exc); @@ -158,9 +155,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); - if (chunkingSettings != null) { + if (inference.chunkingSettings != null) { builder.startObject(CHUNKING_SETTINGS_FIELD); - builder.mapContents(chunkingSettings.asMap()); + builder.mapContents(inference.chunkingSettings.asMap()); builder.endObject(); } @@ -208,7 +205,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { List originalValues = (List) args[0]; InferenceResult inferenceResult = (InferenceResult) args[1]; - ChunkingSettings chunkingSettings = (ChunkingSettings) args[2]; if (context.useLegacyFormat() == false) { if (originalValues != null && originalValues.isEmpty() == false) { throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); @@ -220,8 +216,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws context.fieldName(), originalValues, inferenceResult, - context.xContentType(), - chunkingSettings + context.xContentType() ); }); @@ -229,7 +224,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( INFERENCE_FIELD, true, - args -> new InferenceResult((String) args[0], (MinimalServiceSettings) args[1], (Map>) args[2]) + args -> { + String inferenceId = (String) args[0]; + MinimalServiceSettings modelSettings = (MinimalServiceSettings) args[1]; + Map chunkingSettings = (Map) args[2]; + Map> chunks = (Map>) args[3]; + return new InferenceResult(inferenceId, modelSettings, ChunkingSettingsBuilder.fromMap(chunkingSettings, false), chunks); + } ); private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( @@ -244,25 +245,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } ); - private static final ConstructingObjectParser CHUNKING_SETTINGS_PARSER = new ConstructingObjectParser<>( - CHUNKING_SETTINGS_FIELD, - true, - args -> { - @SuppressWarnings("unchecked") - Map map = (Map) args[0]; - return ChunkingSettingsBuilder.fromMap(map, false); - } - ); - static { SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD)); SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD)); - SEMANTIC_TEXT_FIELD_PARSER.declareObjectOrNull( - optionalConstructorArg(), - (p, c) -> CHUNKING_SETTINGS_PARSER.parse(p, null), - null, - new ParseField(CHUNKING_SETTINGS_FIELD) - ); INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); INFERENCE_RESULT_PARSER.declareObjectOrNull( @@ -271,6 +256,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws null, new ParseField(MODEL_SETTINGS_FIELD) ); + INFERENCE_RESULT_PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(), new ParseField(CHUNKING_SETTINGS_FIELD)); INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { if (c.useLegacyFormat()) { return Map.of(c.fieldName, parseChunksArrayLegacy(p, c)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 4a8da81a2e81e..b1ba36babc263 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -189,7 +189,8 @@ public static class Builder extends FieldMapper.Builder { mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, XContentBuilder::field, Objects::toString - ).acceptsNull(); + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeChunkingSettings); + private final Parameter> meta = Parameter.metaParam(); private Function inferenceFieldBuilder; @@ -905,9 +906,8 @@ public List fetchValues(Source source, int doc, List ignoredValu useLegacyFormat, name(), null, - new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkMap), - source.sourceContentType(), - chunkingSettings + new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkingSettings, chunkMap), + source.sourceContentType() ) ); } @@ -1025,4 +1025,15 @@ private static boolean canMergeModelSettings(MinimalServiceSettings previous, Mi conflicts.addConflict("model_settings", ""); return false; } + + private static boolean canMergeChunkingSettings(ChunkingSettings previous, ChunkingSettings current, Conflicts conflicts) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null || current == null) { + return true; + } + conflicts.addConflict("chunking_settings", ""); + return false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index ccf71fd05904d..385d7315bd39a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -868,10 +868,10 @@ public void testModelSettingsRequiredWithChunks() throws IOException { new SemanticTextField.InferenceResult( randomSemanticText.inference().inferenceId(), null, + chunkingSettings, randomSemanticText.inference().chunks() ), - randomSemanticText.contentType(), - chunkingSettings + randomSemanticText.contentType() ); MapperService mapperService = createMapperService( @@ -914,9 +914,8 @@ private MapperService mapperServiceForFieldWithModelSettings( useLegacyFormat, fieldName, List.of(), - new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()), - XContentType.JSON, - generateRandomChunkingSettings() + new SemanticTextField.InferenceResult(inferenceId, modelSettings, generateRandomChunkingSettings(), Map.of()), + XContentType.JSON ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); if (useLegacyFormat) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 744009f344efb..0cdf92e0e3a68 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -73,7 +73,7 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); - assertThat(newInstance.chunkingSettings(), equalTo(expectedInstance.chunkingSettings())); + assertThat(newInstance.inference().chunkingSettings(), equalTo(expectedInstance.inference().chunkingSettings())); MinimalServiceSettings modelSettings = newInstance.inference().modelSettings(); for (var entry : newInstance.inference().chunks().entrySet()) { var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey()); @@ -307,10 +307,10 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( new SemanticTextField.InferenceResult( model.getInferenceEntityId(), new MinimalServiceSettings(model), + chunkingSettings, Map.of(fieldName, chunks) ), - contentType, - chunkingSettings + contentType ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 925c1ecb7f9bb..04d4269af163c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -369,9 +369,13 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults( useLegacyFormat, SEMANTIC_TEXT_FIELD, null, - new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, Map.of(SEMANTIC_TEXT_FIELD, List.of())), - XContentType.JSON, - SemanticTextFieldTests.generateRandomChunkingSettings() + new SemanticTextField.InferenceResult( + INFERENCE_ID, + modelSettings, + SemanticTextFieldTests.generateRandomChunkingSettings(), + Map.of(SEMANTIC_TEXT_FIELD, List.of()) + ), + XContentType.JSON ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 0181312451c5a..d61587c3a8a79 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -462,3 +462,58 @@ setup: chunking_settings: strategy: invalid +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + catch: /chunking_settings/ + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 20 + overlap: 5 + From a8ef9a11116ac50afba09c190c453d2709aa196e Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 19 Mar 2025 16:47:23 -0400 Subject: [PATCH 54/86] Remove chunking settings check to use what's passed in from sender service --- .../alibabacloudsearch/AlibabaCloudSearchService.java | 2 +- .../services/amazonbedrock/AmazonBedrockService.java | 2 +- .../services/azureaistudio/AzureAiStudioService.java | 2 +- .../inference/services/azureopenai/AzureOpenAiService.java | 2 +- .../xpack/inference/services/cohere/CohereService.java | 2 +- .../services/googleaistudio/GoogleAiStudioService.java | 2 +- .../services/googlevertexai/GoogleVertexAiService.java | 2 +- .../inference/services/huggingface/HuggingFaceService.java | 2 +- .../inference/services/ibmwatsonx/IbmWatsonxService.java | 7 ++----- .../xpack/inference/services/jinaai/JinaAIService.java | 2 +- .../xpack/inference/services/mistral/MistralService.java | 2 +- .../xpack/inference/services/openai/OpenAiService.java | 2 +- .../xpack/inference/services/voyageai/VoyageAIService.java | 2 +- 13 files changed, 14 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index aa77bc49fbef8..cf73d0c67488e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -309,7 +309,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : alibabaCloudSearchModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 9898f08dd3c01..40a930608e0c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -133,7 +133,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), maxBatchSize, - chunkingSettings != null ? chunkingSettings : baseAmazonBedrockModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 71a44edadcd09..ba57eef3eb3d3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -125,7 +125,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : baseAzureAiStudioModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 38d96dfed9448..6f993d58e0fef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -285,7 +285,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : azureOpenAiModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b34199faa1698..c0f3ab0e73e6d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -288,7 +288,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : cohereModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index cfdf8cd7fa3a9..b191081b257ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -329,7 +329,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : googleAiStudioModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index cce03441e253e..3c9aeacb9fc21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -232,7 +232,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : googleVertexAiModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d61de3d038dd9..df3d93b15667a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -131,7 +131,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : huggingFaceModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 83fba3534413a..0e2458473caf1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -308,11 +308,8 @@ protected void doChunkedInfer( ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; - var batchedRequests = new EmbeddingRequestChunker<>( - input.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : model.getConfigurations().getChunkingSettings() - ).batchRequestsWithListeners(listener); + var batchedRequests = new EmbeddingRequestChunker<>(input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, chunkingSettings) + .batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings, inputType); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 490698fd604b7..39cf55871ce30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -270,7 +270,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : jinaaiModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 79d5dbcaa17ce..c08488f3ba709 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -114,7 +114,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), MistralConstants.MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : mistralEmbeddingsModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index cda47f8a06669..c9fe286228b28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -339,7 +339,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : openAiModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index a1a9f83d9fe3c..c99a953c1e39e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -292,7 +292,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), getBatchSize(voyageaiModel), - chunkingSettings != null ? chunkingSettings : voyageaiModel.getConfigurations().getChunkingSettings() + chunkingSettings ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { From db06fd954e3afbb746c03859109501e06c0b9ecc Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 20 Mar 2025 08:54:35 -0400 Subject: [PATCH 55/86] Fix some tests --- .../xpack/inference/mapper/SemanticTextFieldMapperTests.java | 2 +- .../xpack/inference/mapper/SemanticTextFieldTests.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 385d7315bd39a..28e796715a1d4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -644,7 +644,7 @@ public void testSuccessfulParse() throws IOException { Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - ChunkingSettings chunkingSettings = generateRandomChunkingSettings(); + ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts XContentBuilder mapping = mapping(b -> { addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 0cdf92e0e3a68..ca25319003e37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -319,7 +319,7 @@ public static ChunkingSettings generateRandomChunkingSettings() { return null; // Use model defaults } return randomBoolean() - ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 10)) + ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(1, 10)) : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); } From b5f292922a80f901fe5024fae1b3cb39add2243b Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 20 Mar 2025 13:03:28 -0400 Subject: [PATCH 56/86] Cleanup --- .../cluster/metadata/InferenceFieldMetadata.java | 4 +++- .../cluster/metadata/InferenceFieldMetadataTests.java | 5 +++-- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 4 +--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 70db5517180fe..495403e963e45 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -151,7 +151,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.array(SOURCE_FIELDS_FIELD, sourceFields); if (chunkingSettings != null) { - builder.field(CHUNKING_SETTINGS_FIELD, chunkingSettings); + builder.startObject(CHUNKING_SETTINGS_FIELD); + builder.mapContents(chunkingSettings); + builder.endObject(); } return builder.endObject(); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index fd6f706c9e392..f0c61b68226e1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -57,7 +57,8 @@ protected boolean supportsUnknownFields() { @Override protected Predicate getRandomFieldsExcludeFilter() { - // do not add elements at the top-level as any element at this level is parsed as a new inference field + // do not add elements at the top-level as any element at this level is parsed as a new inference field, + // and do not add additional elements to chunking maps as they will fail parsing with extra data return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD); } @@ -78,7 +79,7 @@ public static Map generateRandomChunkingSettings() { } private static Map generateRandomWordBoundaryChunkingSettings() { - return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(0, 50)); + return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50)); } private static Map generateRandomSentenceBoundaryChunkingSettings() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index b1ba36babc263..c17cbb4689aba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -542,9 +542,7 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); - - SemanticTextFieldType fieldType = fieldType(); - ChunkingSettings fieldTypeChunkingSettings = fieldType.getChunkingSettings(); + ChunkingSettings fieldTypeChunkingSettings = fieldType().getChunkingSettings(); Map asMap = fieldTypeChunkingSettings != null ? fieldTypeChunkingSettings.asMap() : null; return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields, asMap); From fccbd61781c941c05b4e909bba508883bb30c79f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 20 Mar 2025 15:44:44 -0400 Subject: [PATCH 57/86] Test failure whack-a-mole --- .../TestDenseInferenceServiceExtension.java | 5 +-- .../TestSparseInferenceServiceExtension.java | 7 ++-- ...5_semantic_text_field_mapping_chunking.yml | 41 +++++++++---------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 0561b0b8c5bf6..fcf2fc2a41f2b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -185,9 +185,9 @@ private List makeChunkedResults( String input = inputs.get(i); List chunkedInput = chunkInputs(input, chunkingSettings); List chunks = new ArrayList<>(); - int offset = 0; for (String c : chunkedInput) { - offset = input.indexOf(c, offset); + // Note: We have to start with an offset of 0 to account for overlaps + int offset = input.indexOf(c); int endOffset = offset + c.length(); chunks.add( new TextEmbeddingFloatResults.Chunk( @@ -195,7 +195,6 @@ private List makeChunkedResults( new ChunkedInference.TextOffset(offset, endOffset) ) ); - offset = endOffset; } ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 4cb82caac6cee..60ac28e11c3b8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -171,17 +171,16 @@ private List makeChunkedResults(List inputs, ChunkingS for (String input : inputs) { List chunkedInput = chunkInputs(input, chunkingSettings); List chunks = new ArrayList<>(); - int offset = 0; for (String c : chunkedInput) { var tokens = new ArrayList(); for (int i = 0; i < 5; i++) { - tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c, i))); + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(input, i))); } - offset = input.indexOf(c, offset); + // Note: We have to start with an offset of 0 to account for overlaps + int offset = input.indexOf(c); int endOffset = offset + c.length(); var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); chunks.add(new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(offset, endOffset))); - offset = endOffset; } ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index d61587c3a8a79..869331d97448a 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -475,45 +475,44 @@ setup: type: semantic_text inference_id: sparse-inference-id - - do: indices.get_mapping: index: chunking-update - - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings - do: indices.put_mapping: index: chunking-update body: - properties: - inference_field: - type: semantic_text - inference_id: sparse-inference-id - chunking_settings: - strategy: word - max_chunk_size: 10 - overlap: 1 + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 - do: indices.get_mapping: index: chunking-update - - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } - - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } - do: catch: /chunking_settings/ indices.put_mapping: index: chunking-update body: - properties: - inference_field: - type: semantic_text - inference_id: sparse-inference-id - chunking_settings: - strategy: word - max_chunk_size: 20 - overlap: 5 + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 20 + overlap: 5 From aab16e79b631c353938cb1e08be52f71657998ad Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 20 Mar 2025 16:08:49 -0400 Subject: [PATCH 58/86] Cleanup --- .../xpack/inference/mapper/SemanticTextFieldMapperTests.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 28e796715a1d4..697f8fcd5a10f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -852,12 +852,11 @@ public void testDenseVectorElementType() throws IOException { public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - ChunkingSettings chunkingSettings = generateRandomChunkingSettings(); SemanticTextField randomSemanticText = randomSemanticText( useLegacyFormat, "field", model, - chunkingSettings, + generateRandomChunkingSettings(), List.of("a"), XContentType.JSON ); @@ -868,7 +867,7 @@ public void testModelSettingsRequiredWithChunks() throws IOException { new SemanticTextField.InferenceResult( randomSemanticText.inference().inferenceId(), null, - chunkingSettings, + randomSemanticText.inference().chunkingSettings(), randomSemanticText.inference().chunks() ), randomSemanticText.contentType() From 023c227cd4efadf85adf201d27a9c162e7536d91 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 21 Mar 2025 13:31:18 -0400 Subject: [PATCH 59/86] Refactor to handle memory optimized bulk shard inference actions - this is ugly but at least it compiles --- .../inference/ChunkInferenceInput.java | 25 ++ .../inference/InferenceService.java | 4 +- .../mock/AbstractTestInferenceService.java | 14 +- .../TestDenseInferenceServiceExtension.java | 18 +- .../mock/TestRerankingServiceExtension.java | 4 +- .../TestSparseInferenceServiceExtension.java | 14 +- ...stStreamingCompletionServiceExtension.java | 4 +- .../ShardBulkInferenceActionFilter.java | 42 +-- .../chunking/EmbeddingRequestChunker.java | 40 ++- .../voyageai/VoyageAIActionCreator.java | 2 +- ...baCloudSearchEmbeddingsRequestManager.java | 2 +- ...libabaCloudSearchSparseRequestManager.java | 2 +- ...AmazonBedrockEmbeddingsRequestManager.java | 2 +- ...AzureAiStudioEmbeddingsRequestManager.java | 3 +- .../AzureOpenAiEmbeddingsRequestManager.java | 2 +- .../CohereEmbeddingsRequestManager.java | 2 +- ...ServiceSparseEmbeddingsRequestManager.java | 2 +- .../external/http/sender/EmbeddingsInput.java | 19 +- ...oogleAiStudioEmbeddingsRequestManager.java | 2 +- ...oogleVertexAiEmbeddingsRequestManager.java | 2 +- .../sender/HuggingFaceRequestManager.java | 3 +- .../IbmWatsonxEmbeddingsRequestManager.java | 2 +- .../JinaAIEmbeddingsRequestManager.java | 2 +- .../MistralEmbeddingsRequestManager.java | 2 +- .../http/sender/TruncatingRequestManager.java | 2 +- .../inference/services/SenderService.java | 18 +- .../AlibabaCloudSearchService.java | 3 +- .../amazonbedrock/AmazonBedrockService.java | 3 +- .../services/anthropic/AnthropicService.java | 2 - .../azureaistudio/AzureAiStudioService.java | 3 +- .../azureopenai/AzureOpenAiService.java | 3 +- .../services/cohere/CohereService.java | 3 +- .../services/deepseek/DeepSeekService.java | 2 - .../elastic/ElasticInferenceService.java | 4 +- .../ElasticsearchInternalService.java | 8 +- .../googleaistudio/GoogleAiStudioService.java | 3 +- .../googlevertexai/GoogleVertexAiService.java | 3 +- .../huggingface/HuggingFaceService.java | 3 +- .../elser/HuggingFaceElserService.java | 11 +- .../ibmwatsonx/IbmWatsonxService.java | 8 +- .../services/jinaai/JinaAIService.java | 3 +- .../services/mistral/MistralService.java | 3 +- .../services/openai/OpenAiService.java | 3 +- .../services/voyageai/VoyageAIService.java | 3 +- .../ShardBulkInferenceActionFilterTests.java | 11 +- .../EmbeddingRequestChunkerTests.java | 260 +++++++++++------- ...ingleInputSenderExecutableActionTests.java | 3 +- ...ibabaCloudSearchCompletionActionTests.java | 2 +- .../AmazonBedrockActionCreatorTests.java | 6 +- .../AzureAiStudioActionAndCreatorTests.java | 2 +- .../AzureOpenAiActionCreatorTests.java | 16 +- .../AzureOpenAiEmbeddingsActionTests.java | 31 ++- .../cohere/CohereActionCreatorTests.java | 2 +- .../cohere/CohereEmbeddingsActionTests.java | 14 +- ...ticInferenceServiceActionCreatorTests.java | 8 +- .../GoogleAiStudioEmbeddingsActionTests.java | 8 +- .../GoogleVertexAiEmbeddingsActionTests.java | 6 +- .../GoogleVertexAiRerankActionTests.java | 6 +- .../HuggingFaceActionCreatorTests.java | 12 +- .../huggingface/HuggingFaceActionTests.java | 6 +- .../IbmWatsonxEmbeddingsActionTests.java | 8 +- .../openai/OpenAiActionCreatorTests.java | 14 +- .../openai/OpenAiEmbeddingsActionTests.java | 13 +- .../voyageai/VoyageAIActionCreatorTests.java | 7 +- .../VoyageAIEmbeddingsActionTests.java | 27 +- .../AmazonBedrockMockRequestSender.java | 3 +- .../AmazonBedrockRequestSenderTests.java | 3 +- .../http/sender/HttpRequestSenderTests.java | 3 +- .../http/sender/RequestTaskTests.java | 11 +- .../services/SenderServiceTests.java | 2 - .../AlibabaCloudSearchServiceTests.java | 4 +- .../AmazonBedrockServiceTests.java | 4 +- .../AzureAiStudioServiceTests.java | 4 +- .../azureopenai/AzureOpenAiServiceTests.java | 4 +- .../services/cohere/CohereServiceTests.java | 7 +- .../deepseek/DeepSeekServiceTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 7 +- .../ElasticsearchInternalServiceTests.java | 22 +- .../GoogleAiStudioServiceTests.java | 12 +- .../HuggingFaceElserServiceTests.java | 4 +- .../huggingface/HuggingFaceServiceTests.java | 7 +- .../ibmwatsonx/IbmWatsonxServiceTests.java | 8 +- .../services/jinaai/JinaAIServiceTests.java | 4 +- .../services/mistral/MistralServiceTests.java | 4 +- .../services/openai/OpenAiServiceTests.java | 4 +- .../voyageai/VoyageAIServiceTests.java | 4 +- 86 files changed, 524 insertions(+), 366 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java new file mode 100644 index 0000000000000..678a46bb3d29b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.core.Nullable; + +import java.util.List; + +public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { + + public ChunkInferenceInput(String input) { + this(input, null); + } + + public static List convertToStrings(List chunkInferenceInputs) { + return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 646e781bc30b0..a23a1f0ad1e40 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -133,7 +133,6 @@ void unifiedCompletionInfer( * @param query Inference query, mainly for re-ranking * @param input Inference input * @param taskSettings Settings in the request to override the model's defaults - * @param chunkingSettings Chunking settings * @param inputType For search, ingest etc * @param timeout The timeout for the request * @param listener Chunked Inference result listener @@ -141,9 +140,8 @@ void unifiedCompletionInfer( void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - @Nullable ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 3bdede4bff389..b70e854cbf626 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceService; @@ -111,22 +112,23 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} - protected List chunkInputs(String input, ChunkingSettings chunkingSettings) { + protected List chunkInputs(ChunkInferenceInput input) { + ChunkingSettings chunkingSettings = input.chunkingSettings(); if (chunkingSettings == null) { - return List.of(input); + return List.of(input.input()); } + List chunkedInputs = new ArrayList<>(); - ChunkingStrategy chunkingStrategy = chunkingSettings.getChunkingStrategy(); - if (chunkingStrategy == ChunkingStrategy.WORD) { + if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.WORD) { WordBoundaryChunker chunker = new WordBoundaryChunker(); WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; List offsets = chunker.chunk( - input, + input.input(), wordBoundaryChunkingSettings.maxChunkSize(), wordBoundaryChunkingSettings.overlap() ); for (WordBoundaryChunker.ChunkOffset offset : offsets) { - chunkedInputs.add(input.substring(offset.start(), offset.end())); + chunkedInputs.add(input.input().substring(offset.start(), offset.end())); } } else { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 391d537c07818..b4bb45c09a29e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -145,9 +146,8 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -155,7 +155,7 @@ public void chunkedInfer( switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { ServiceSettings modelServiceSettings = model.getServiceSettings(); - listener.onResponse(makeChunkedResults(input, modelServiceSettings, chunkingSettings)); + listener.onResponse(makeChunkedResults(input, modelServiceSettings)); } default -> listener.onFailure( new ElasticsearchStatusException( @@ -175,19 +175,15 @@ private TextEmbeddingFloatResults makeResults(List input, ServiceSetting return new TextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults( - List inputs, - ServiceSettings serviceSettings, - ChunkingSettings chunkingSettings - ) { + private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); for (int i = 0; i < inputs.size(); i++) { - String input = inputs.get(i); - List chunkedInput = chunkInputs(input, chunkingSettings); + ChunkInferenceInput input = inputs.get(i); + List chunkedInput = chunkInputs(input); List chunks = new ArrayList<>(); for (String c : chunkedInput) { // Note: We have to start with an offset of 0 to account for overlaps - int offset = input.indexOf(c); + int offset = input.input().indexOf(c); int endOffset = offset + c.length(); chunks.add( new TextEmbeddingFloatResults.Chunk( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index c564aac890c50..e0cc69041d30c 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -135,9 +136,8 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 60ac28e11c3b8..0a2dcfd1ff697 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,8 +16,8 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; @@ -136,15 +136,14 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input, chunkingSettings)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -166,10 +165,11 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List inputs, ChunkingSettings chunkingSettings) { + private List makeChunkedResults(List inputs) { List results = new ArrayList<>(); - for (String input : inputs) { - List chunkedInput = chunkInputs(input, chunkingSettings); + for (ChunkInferenceInput chunkInferenceInput : inputs) { + String input = chunkInferenceInput.input(); + List chunkedInput = chunkInputs(chunkInferenceInput); List chunks = new ArrayList<>(); for (String c : chunkedInput) { var tokens = new ArrayList(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6ac893d932922..269cc6ef91cc0 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -258,9 +259,8 @@ public Iterator toXContentChunked(ToXContent.Params params public void chunkedInfer( Model model, String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 9767f0723e98a..5575ef0659891 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -35,6 +35,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; @@ -70,7 +71,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; @@ -360,23 +360,26 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + // final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final List inputs = requests.stream() + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .collect(Collectors.toList()); // TODO reconcile -// // Batch requests in the order they are specified, grouping by field and chunking settings. -// // As each field may have different chunking settings specified, the size of the batch will be <= the configured batchSize. -// int currentBatchSize = Math.min(requests.size(), batchSize); -// final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; -// final List currentBatch = new ArrayList<>(); -// for (FieldInferenceRequest request : requests) { -// if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) { -// break; -// } -// currentBatch.add(request); -// } -// -// final List nextBatch = requests.subList(currentBatch.size(), requests.size()); -// final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + // // Batch requests in the order they are specified, grouping by field and chunking settings. + // // As each field may have different chunking settings specified, the size of the batch will be <= the configured batchSize. + // int currentBatchSize = Math.min(requests.size(), batchSize); + // final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; + // final List currentBatch = new ArrayList<>(); + // for (FieldInferenceRequest request : requests) { + // if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) { + // break; + // } + // currentBatch.add(request); + // } + // + // final List nextBatch = requests.subList(currentBatch.size(), requests.size()); + // final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override @@ -430,7 +433,7 @@ public void onFailure(Exception exc) { } }; inferenceProvider.service() - .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), chunkingSettings, InputType.INGEST, TimeValue.MAX_VALUE, completionListener); + .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); } /** @@ -547,7 +550,9 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) ); } else { - requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)); + requests.add( + new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) + ); } // When using the inference metadata fields format, all the input values are concatenated so that the @@ -628,6 +633,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), indexRequest.getContentType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 13bb406ed481b..256f3d260cd05 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; @@ -39,15 +40,24 @@ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { public String chunkText() { - return inputs.get(inputIndex).substring(chunk.start(), chunk.end()); + return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); + } + + public ChunkInferenceInput chunkInput() { + ChunkInferenceInput chunkInferenceInput = inputs.get(inputIndex); + String chunkText = chunkInferenceInput.input().substring(chunk.start(), chunk.end()); + return new ChunkInferenceInput(chunkText, chunkInferenceInput.chunkingSettings()); } } public record BatchRequest(List requests) { - public List inputs() { - return requests.stream().map(Request::chunkText).collect(Collectors.toList()); + public List inputs() { + // return requests.stream().map(Request::chunkText).collect(Collectors.toList()); + return requests.stream() + .map(r -> new ChunkInferenceInput(r.chunkText(), r.inputs().getFirst().chunkingSettings())) + .collect(Collectors.toList()); } } @@ -71,28 +81,36 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener resultsErrors; private ActionListener> finalListener; - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { this(inputs, maxNumberOfInputsPerBatch, null); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap)); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) { + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + ChunkingSettings defaultChunkingSettings + ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size()); this.resultOffsetEnds = new ArrayList<>(inputs.size()); this.resultsErrors = new AtomicArray<>(inputs.size()); - if (chunkingSettings == null) { - chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); + if (defaultChunkingSettings == null) { + defaultChunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); } - Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { - List chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings); + ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); + if (chunkingSettings == null) { + chunkingSettings = defaultChunkingSettings; + } + Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + List chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java index e87c7f6eb014a..69f4078b6beb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getInputs(), + embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java index 42fba017b1000..752270f7d0d11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -71,7 +71,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java index a07c42e9cee82..7ba1d3ca30758 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java @@ -71,7 +71,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java index fd04950ed6459..33a83de2486c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index 32a29490d7d24..53be191b2a7f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.threadpool.ThreadPool; @@ -48,7 +49,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = ChunkInferenceInput.convertToStrings(input.getInputs()); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 5704129180c9a..166504a510006 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -62,7 +62,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 4a0ea36da0e5a..d9a9e6ccfb6c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index 8acfc05fb79eb..17772e179c5ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -67,7 +67,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index d1d017f1c61c5..42f1310e14140 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,10 +8,13 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { @@ -23,24 +26,32 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { return (EmbeddingsInput) inferenceInputs; } - private final List input; + private final List input; private final InputType inputType; - public EmbeddingsInput(List input, @Nullable InputType inputType) { + public EmbeddingsInput(List input, @Nullable InputType inputType) { this(input, inputType, false); } - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { + this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false); + } + + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { super(stream); this.input = Objects.requireNonNull(input); this.inputType = inputType; } - public List getInputs() { + public List getInputs() { return this.input; } + public List getStringInputs() { + return this.input.stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + } + public InputType getInputType() { return this.inputType; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index d23b26a034d15..6ef00d7d2b305 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java index e9916cc700cb3..b2d342bb53591 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java @@ -63,7 +63,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 3954f9322cca7..32246464d2f83 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.Truncator; @@ -60,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = ChunkInferenceInput.convertToStrings(EmbeddingsInput.of(inferenceInputs).getInputs()); var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java index 1e9fec7dcba86..14c62f6237844 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java @@ -53,7 +53,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java index 10a837afba18c..f83519b42ecb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index b224d5ceb14ae..6f77fa3f6e985 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index 4a485f87858aa..c39387d647f77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 97ae81985913f..54e27b93a6952 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -14,8 +14,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -69,21 +69,23 @@ public void infer( ActionListener listener ) { init(); - var inferenceInput = createInput(this, model, input, inputType, query, stream); + var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); + var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } private static InferenceInputs createInput( SenderService service, Model model, - List input, + List input, InputType inputType, @Nullable String query, boolean stream ) { + List textInput = ChunkInferenceInput.convertToStrings(input); return switch (model.getTaskType()) { - case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); - case RERANK -> new QueryAndDocsInputs(query, input, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); + case RERANK -> new QueryAndDocsInputs(query, textInput, stream); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); service.validateInputType(inputType, model, validationException); @@ -114,9 +116,8 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -130,7 +131,7 @@ public void chunkedInfer( } // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, chunkingSettings != null ? chunkingSettings : model.getConfigurations().getChunkingSettings(), inputType, timeout, listener); + doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener); } protected abstract void doInfer( @@ -154,7 +155,6 @@ protected abstract void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 50707f6fb7064..fe844bbe0c1a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -305,7 +305,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -321,7 +320,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + alibabaCloudSearchModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 49c59a0e17dd9..4d1f9044087be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -153,7 +153,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -165,7 +164,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), maxBatchSize, - chunkingSettings + baseAmazonBedrockModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index a3f58e86a698d..ce356f52d4e07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -16,7 +16,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -232,7 +231,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 87baa6d2f6b29..a25fb7a3e4d82 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -126,7 +126,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -137,7 +136,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + baseAzureAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 809ee469a4e58..3843db7140b30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -275,7 +275,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -290,7 +289,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + azureOpenAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 04ceb12ce976e..e22852f0e78e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -286,7 +286,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -302,7 +301,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + cohereModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 0109f1a14f505..74ad6f6f55a06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -15,7 +15,6 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -111,7 +110,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index cb385bf9e28bb..8e00b758417a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -17,7 +17,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -268,7 +267,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -437,7 +435,7 @@ public void checkModelConfig(Model model, ActionListener listener) { private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 75838ecd1766e..7cf8fad47d6d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; @@ -723,9 +724,8 @@ public void inferRerank( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -742,7 +742,7 @@ public void chunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( input, EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings != null ? chunkingSettings : esModel.getConfigurations().getChunkingSettings() + esModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); if (batchedRequests.isEmpty()) { @@ -1118,7 +1118,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - batch.batch().inputs(), + ChunkInferenceInput.convertToStrings(batch.batch().inputs()), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 083612ba5d0a6..1be510330fa85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -351,7 +351,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -361,7 +360,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + googleAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 60527f2a3fadb..726fb0b5da02f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -235,7 +235,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -246,7 +245,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + googleVertexAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 7f111d8dcaeeb..ca167a15a65f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -115,7 +115,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -131,7 +130,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + huggingFaceModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 7296b90e5ddb5..3eaacc8d25458 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -98,7 +99,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -113,7 +113,10 @@ protected void doChunkedInfer( private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings( + ChunkInferenceInput.convertToStrings(inputs.getInputs()), + textEmbeddingResults.embeddings().size() + ); var results = new ArrayList(inputs.getInputs().size()); @@ -123,7 +126,7 @@ private static List translateToChunkedResults(EmbeddingsInput List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) ) ) ) @@ -131,7 +134,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = ChunkInferenceInput.convertToStrings(EmbeddingsInput.of(inputs).getInputs()); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 6e1705ef70820..65a8ef469a2f7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -306,15 +306,17 @@ protected void doChunkedInfer( Model model, EmbeddingsInput input, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; - var batchedRequests = new EmbeddingRequestChunker<>(input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, chunkingSettings) - .batchRequestsWithListeners(listener); + var batchedRequests = new EmbeddingRequestChunker<>( + input.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + ibmWatsonxModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 6abbd5e658e77..4975d74e9984a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -268,7 +268,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -284,7 +283,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 69a8107bce64e..9ed0bb85cbeb8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -108,7 +108,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -119,7 +118,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), MistralConstants.MAX_BATCH_SIZE, - chunkingSettings + mistralEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 7d13d7ad46b56..977a025eed2f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -328,7 +328,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -344,7 +343,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - chunkingSettings + openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index dfe2952e64c51..6b8af209ea4b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -288,7 +288,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener @@ -304,7 +303,7 @@ protected void doChunkedInfer( List batchedRequests = new EmbeddingRequestChunker<>( inputs.getInputs(), getBatchSize(voyageaiModel), - chunkingSettings + voyageaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index b32308f98d38c..905acb7363f7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -508,12 +509,12 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[7]; + List inputs = (List) invocationOnMock.getArguments()[2]; + ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { List results = new ArrayList<>(); - for (String input : inputs) { - results.add(model.getResults(input)); + for (ChunkInferenceInput input : inputs) { + results.add(model.getResults(input.input())); } listener.onResponse(results); }; @@ -528,7 +529,7 @@ private static ShardBulkInferenceActionFilter createFilter( } return null; }; - doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any(), any()); + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); Answer modelAnswer = invocationOnMock -> { String inferenceId = (String) invocationOnMock.getArguments()[0]; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 39d729e752c25..0c941ed3fdc64 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -46,59 +47,73 @@ public void testEmptyInput_SentenceChunker() { } public void testWhitespaceInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput(" ")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" ")); + assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is(" ")); } public void testBlankInput_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("")); } public void testBlankInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)) + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("")); } public void testInputThatDoesNotChunk_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("ABBAABBA")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput("ABBAABBA")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { - String input = "one chunk"; + ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), contains(input)); } public void testMultipleShortInputsAreSingleBatch() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); assertEquals(batch.inputs(), inputs); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertEquals(i, request.inputIndex()); assertEquals(0, request.chunkIndex()); } @@ -107,10 +122,10 @@ public void testMultipleShortInputsAreSingleBatch() { public void testManyInputsMakeManyBatches() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); @@ -120,20 +135,20 @@ public void testManyInputsMakeManyBatches() { assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); + assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - batches.get(1).batch().inputs(), + ChunkInferenceInput.convertToStrings(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get(9)); - assertThat(batches.get(3).batch().inputs(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); + assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); + assertThat(ChunkInferenceInput.convertToStrings(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -142,10 +157,10 @@ public void testManyInputsMakeManyBatches() { public void testChunkingSettingsProvided() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) @@ -156,20 +171,20 @@ public void testChunkingSettingsProvided() { assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); + assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - batches.get(1).batch().inputs(), + ChunkInferenceInput.convertToStrings(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get(9)); - assertThat(batches.get(3).batch().inputs(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); + assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); + assertThat(ChunkInferenceInput.convertToStrings(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -188,7 +203,12 @@ public void testLongInputChunkedOverMultipleBatches() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); @@ -244,7 +264,11 @@ public void testVeryLongInput_Sparse() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -278,7 +302,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); @@ -294,16 +318,19 @@ public void testVeryLongInput_Sparse() { // The first merged chunk consists of 20 small chunks (so 400 words) and the max // weight is the weight of the 20th small chunk (so 21/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); // The last merged chunk consists of 19 small chunks (so 380 words) and the max // weight is the weight of the 10000th small chunk (so 10001/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); @@ -313,7 +340,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); @@ -329,7 +356,11 @@ public void testVeryLongInput_Float() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -362,7 +393,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); @@ -378,16 +409,19 @@ public void testVeryLongInput_Float() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2/16384 ... 21/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983/16384 ... 10001/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); @@ -397,7 +431,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); @@ -413,7 +447,11 @@ public void testVeryLongInput_Byte() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -446,7 +484,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); @@ -462,8 +500,8 @@ public void testVeryLongInput_Byte() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); @@ -471,8 +509,11 @@ public void testVeryLongInput_Byte() { // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so // the average of -1, 0, 1, ... , 17, so 8. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); @@ -482,7 +523,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); @@ -500,7 +541,12 @@ public void testMergingListener_Float() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -529,7 +575,7 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -537,26 +583,29 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -572,7 +621,12 @@ public void testMergingListener_Byte() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -601,7 +655,7 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -609,26 +663,26 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -644,7 +698,12 @@ public void testMergingListener_Bit() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -673,7 +732,7 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -681,26 +740,26 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -716,7 +775,12 @@ public void testMergingListener_Sparse() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString()); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small"), + new ChunkInferenceInput(passageBuilder.toString()) + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -752,21 +816,21 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -774,14 +838,24 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(1).offset()), + startsWith(" passage_input10 ") + ); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(8).offset()), + startsWith(" passage_input80 ") + ); } } public void testListenerErrorsWithWrongNumberOfResponses() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var failureMessage = new AtomicReference(); var listener = new ActionListener>() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 2061174813041..1bead6b72ca92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -63,7 +64,7 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of("one", "two"); + var input = List.of(new ChunkInferenceInput("one", null), new ChunkInferenceInput("two", null)); when(badInput.getInputs()).thenReturn(input); when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionActionTests.java index 1e6217518e930..09e58ba7c7e54 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionActionTests.java @@ -122,7 +122,7 @@ public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatComplet PlainActionFuture listener = new PlainActionFuture<>(); assertThrows(IllegalArgumentException.class, () -> { action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST), + new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index 587687b6ee66a..8ecda6fa3fcb9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 4b5a053a74614..7c257a8c98c43 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -115,7 +115,7 @@ public void testEmbeddingsRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 159402e664648..ea53bb42ca24a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,7 +433,11 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of("super long input"), null, inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java index 6c4c05feb1361..a737b38e17762 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -116,7 +117,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -144,7 +149,11 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -165,7 +174,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,7 +199,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +218,11 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 5385cfd7e45fb..c72281b3ff7ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 226fa3368fd01..196e03ddfeb6c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -127,7 +127,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -216,7 +216,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -271,7 +271,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -291,7 +291,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -311,7 +311,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -325,7 +325,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -339,7 +339,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java index 179880b58eab1..2bd1c280100a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java @@ -96,7 +96,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -157,7 +157,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -278,7 +278,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java index a9240a9a3795a..e4c8a07efb265 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java index 10ad38f17382f..6efceed309850 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiRerankActionTests.java index 42dcdfdb5b341..7bcfcd1ec7cb1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index 0fa48f46983be..35284409043e2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -95,7 +95,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -218,7 +218,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -280,7 +280,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -339,7 +339,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -401,7 +401,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionTests.java index d05f16916d0a7..afa1718f6349a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java index 85b1d7c38d3c6..2b0d70ef387ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index af18e5ab6e720..45869a57bc810 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index dae652f2a5d6a..e01a36d564eba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -114,7 +115,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -154,7 +155,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -178,7 +179,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -202,7 +203,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -220,7 +221,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -238,7 +239,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java index 0f9532849e146..6f7643beb682d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreatorTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -110,7 +111,11 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java index c58a7ab86a4a9..46a219318d644 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -120,7 +121,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -217,7 +222,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -314,7 +323,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -383,7 +396,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -407,7 +420,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -425,7 +438,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -448,7 +461,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index 12736c64d37b1..db938af0c1912 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -83,7 +84,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(docsInput.getInputs()); + inputs.add(ChunkInferenceInput.convertToStrings(docsInput.getInputs())); if (docsInput.getInputType() != null) { inputTypes.add(docsInput.getInputType()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index 33992eb6edeb0..f62f7429aeca8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -83,7 +84,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 8289bdb9aee97..2de9bf178e24f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -119,7 +120,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of("abc"), null), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index a232b7724ca98..d40ee517a1c51 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -61,7 +62,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -81,7 +82,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -105,7 +106,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -134,7 +135,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -161,7 +162,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index defdadb1e32de..5d7a6a149f941 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -134,7 +133,6 @@ protected void doChunkedInfer( Model model, EmbeddingsInput inputs, Map taskSettings, - ChunkingSettings chunkingSettings, InputType inputType, TimeValue timeout, ActionListener> listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 5098a88d9d540..57c07d06cd53e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -463,7 +464,7 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx } private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("foo", "bar"); + var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -476,7 +477,6 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin null, input, new HashMap<>(), - null, InputTypeTests.randomWithIngestAndSearch(), InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 42981cfe70b06..f6ffd5025ccf8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1490,9 +1491,8 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index dffc09bdadee8..940ee4367fe8a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1231,9 +1232,8 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 535e32398fda9..a6213d0112bd6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1362,9 +1363,8 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f3f1c98a00760..f92cfe0e62fec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1452,9 +1453,8 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -1552,9 +1552,8 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 05407a14e3298..aa1313793274b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -350,7 +350,7 @@ public void testUnifiedCompletionMalformedError() throws Exception { public void testDoChunkedInferAlwaysFails() throws IOException { try (var service = createService()) { - service.doChunkedInfer(mock(), mock(), Map.of(), null, InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { + service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { assertThat(e, isA(UnsupportedOperationException.class)); assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion")); })); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index fcef35ca03a5c..bd11df0987c87 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -622,9 +623,8 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException service.chunkedInfer( model, null, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), - null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -751,9 +751,8 @@ public void testChunkedInfer_PassesThrough() throws IOException { service.chunkedInfer( model, null, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), - null, InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index c96e508378007..7067577b30189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -950,9 +951,8 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1023,9 +1023,8 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1096,9 +1095,8 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1143,9 +1141,8 @@ public void testChunkInferSetsTokenization() { service.chunkedInfer( model, null, - List.of("foo", "bar"), + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) @@ -1156,9 +1153,8 @@ public void testChunkInferSetsTokenization() { service.chunkedInfer( model, null, - List.of("foo", "bar"), + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) @@ -1209,9 +1205,8 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { service.chunkedInfer( model, null, - List.of("foo", "bar", "baz"), + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"), new ChunkInferenceInput("baz")), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener @@ -1236,7 +1231,7 @@ public void testChunkingLargeDocument() throws InterruptedException { // build a doc with enough words to make numChunks of chunks int wordsPerChunk = 10; int numWords = numChunks * wordsPerChunk; - var input = "word ".repeat(numWords); + var input = new ChunkInferenceInput("word ".repeat(numWords), null); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); @@ -1284,7 +1279,6 @@ public void testChunkingLargeDocument() throws InterruptedException { null, List.of(input), Map.of(), - null, InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, latchedListener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index b39c9ee98fcff..07066a36b922b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -884,7 +885,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -916,7 +917,6 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed null, input, new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -930,7 +930,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -945,7 +945,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -969,7 +969,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(0)))), + Map.of("parts", List.of(Map.of("text", input.get(0).input()))), "taskType", "RETRIEVAL_DOCUMENT" ), @@ -977,7 +977,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(1)))), + Map.of("parts", List.of(Map.of("text", input.get(1).input()))), "taskType", "RETRIEVAL_DOCUMENT" ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index e03e4b84fcb64..3ef0de5f30b0d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -96,9 +97,8 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), - null, InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 4f79bfbcfee19..37be2ef91dd16 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -804,9 +805,8 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -857,9 +857,8 @@ public void testChunkedInfer() throws IOException { service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index c847a43cbf640..0d3efd682d329 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -725,7 +726,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -767,7 +768,6 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws null, input, new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener @@ -781,7 +781,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -796,7 +796,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index e64eb0f5e4522..edff51571a4b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1819,9 +1820,8 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index e759ffcc56a8c..2ed7436005c8e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -709,9 +710,8 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("abc", "def"), + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 1a9a715784938..a8fbad507cc4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1895,9 +1896,8 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 4bc8db99735ed..13cf07b030dd1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1771,9 +1772,8 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), - null, InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, listener From b26b9a21fb7c597fbee9ce7e997463a68226bc5b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 21 Mar 2025 20:24:32 +0000 Subject: [PATCH 60/86] [CI] Auto commit changes from spotless --- .../xpack/inference/mock/TestDenseInferenceServiceExtension.java | 1 - .../xpack/inference/mock/TestRerankingServiceExtension.java | 1 - .../inference/mock/TestStreamingCompletionServiceExtension.java | 1 - 3 files changed, 3 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index b4bb45c09a29e..a1cef21bf4301 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -19,7 +19,6 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index e0cc69041d30c..9b313ca3e3737 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -18,7 +18,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 269cc6ef91cc0..ab8340b1a387c 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -18,7 +18,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; From 1c84cc29c2cb65c51336376f3c3d61d067431840 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 09:35:19 -0400 Subject: [PATCH 61/86] Minor cleanup --- .../inference/ChunkInferenceInput.java | 2 +- .../ShardBulkInferenceActionFilter.java | 17 ----------- ...AzureAiStudioEmbeddingsRequestManager.java | 2 +- .../sender/HuggingFaceRequestManager.java | 3 +- .../inference/services/SenderService.java | 2 +- .../ElasticsearchInternalService.java | 2 +- .../elser/HuggingFaceElserService.java | 7 ++--- .../EmbeddingRequestChunkerTests.java | 8 +++--- .../AmazonBedrockMockRequestSender.java | 2 +- .../60_semantic_text_inference_update.yml | 28 +++++++++---------- 10 files changed, 26 insertions(+), 47 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java index 678a46bb3d29b..bb29d58dc145c 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -19,7 +19,7 @@ public ChunkInferenceInput(String input) { this(input, null); } - public static List convertToStrings(List chunkInferenceInputs) { + public static List asStrings(List chunkInferenceInputs) { return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 5575ef0659891..dff611592dfae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -360,27 +360,10 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - // final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); final List inputs = requests.stream() .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) .collect(Collectors.toList()); - // TODO reconcile - // // Batch requests in the order they are specified, grouping by field and chunking settings. - // // As each field may have different chunking settings specified, the size of the batch will be <= the configured batchSize. - // int currentBatchSize = Math.min(requests.size(), batchSize); - // final ChunkingSettings chunkingSettings = requests.isEmpty() == false ? requests.getFirst().chunkingSettings : null; - // final List currentBatch = new ArrayList<>(); - // for (FieldInferenceRequest request : requests) { - // if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) { - // break; - // } - // currentBatch.add(request); - // } - // - // final List nextBatch = requests.subList(currentBatch.size(), requests.size()); - // final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); - ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index 53be191b2a7f6..698c2a90b71df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -49,7 +49,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = ChunkInferenceInput.convertToStrings(input.getInputs()); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 32246464d2f83..72042b184f635 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.Truncator; @@ -61,7 +60,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = ChunkInferenceInput.convertToStrings(EmbeddingsInput.of(inferenceInputs).getInputs()); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 54e27b93a6952..bdeb3a485cc24 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -82,7 +82,7 @@ private static InferenceInputs createInput( @Nullable String query, boolean stream ) { - List textInput = ChunkInferenceInput.convertToStrings(input); + List textInput = ChunkInferenceInput.asStrings(input); return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); case RERANK -> new QueryAndDocsInputs(query, textInput, stream); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 7cf8fad47d6d9..56912c2f9de1e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -1118,7 +1118,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - ChunkInferenceInput.convertToStrings(batch.batch().inputs()), + ChunkInferenceInput.asStrings(batch.batch().inputs()), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 3eaacc8d25458..a55a00be1909e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -113,10 +113,7 @@ protected void doChunkedInfer( private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings( - ChunkInferenceInput.convertToStrings(inputs.getInputs()), - textEmbeddingResults.embeddings().size() - ); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.asStrings(inputs.getInputs()), textEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.getInputs().size()); @@ -134,7 +131,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = ChunkInferenceInput.convertToStrings(EmbeddingsInput.of(inputs).getInputs()); + var inputsAsList = ChunkInferenceInput.asStrings(EmbeddingsInput.of(inputs).getInputs()); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 0c941ed3fdc64..9bca2b8d9cc19 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -138,12 +138,12 @@ public void testManyInputsMakeManyBatches() { assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - ChunkInferenceInput.convertToStrings(batches.get(1).batch().inputs()), + ChunkInferenceInput.asStrings(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.convertToStrings(batches.get(3).batch().inputs()), contains("input 30")); + assertThat(ChunkInferenceInput.asStrings(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { @@ -174,12 +174,12 @@ public void testChunkingSettingsProvided() { assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - ChunkInferenceInput.convertToStrings(batches.get(1).batch().inputs()), + ChunkInferenceInput.asStrings(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.convertToStrings(batches.get(3).batch().inputs()), contains("input 30")); + assertThat(ChunkInferenceInput.asStrings(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index db938af0c1912..f1b6aa4b8fc42 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -84,7 +84,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(ChunkInferenceInput.convertToStrings(docsInput.getInputs())); + inputs.add(ChunkInferenceInput.asStrings(docsInput.getInputs())); if (docsInput.getInputType() != null) { inputTypes.add(docsInput.getInputType()); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml index 27c405f6c23bf..35e472e72b06d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml @@ -79,7 +79,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -104,7 +104,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -129,7 +129,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -152,7 +152,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } # We can't directly check that the embeddings are different since there isn't a "does not match" assertion in the # YAML test framework. Check that the start and end offsets change as expected as a proxy. @@ -179,7 +179,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -202,7 +202,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -254,7 +254,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -283,7 +283,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -320,7 +320,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -367,7 +367,7 @@ setup: index: test-index id: doc_1 body: - doc: { "sparse_field": [{"key": "value"}], "dense_field": [{"key": "value"}] } + doc: { "sparse_field": [ { "key": "value" } ], "dense_field": [ { "key": "value" } ] } - match: { error.type: "status_exception" } - match: { error.reason: "/Invalid\\ format\\ for\\ field\\ \\[(dense|sparse)_field\\].+/" } @@ -415,7 +415,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -448,7 +448,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -509,7 +509,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -540,7 +540,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } From 165c19eb3fafc66710f7232d614c04bcee4f0eec Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 09:46:50 -0400 Subject: [PATCH 62/86] A bit more cleanup --- .../xpack/inference/chunking/EmbeddingRequestChunker.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 256f3d260cd05..5d01f13619db1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -54,10 +54,7 @@ public ChunkInferenceInput chunkInput() { public record BatchRequest(List requests) { public List inputs() { - // return requests.stream().map(Request::chunkText).collect(Collectors.toList()); - return requests.stream() - .map(r -> new ChunkInferenceInput(r.chunkText(), r.inputs().getFirst().chunkingSettings())) - .collect(Collectors.toList()); + return requests.stream().flatMap(r -> r.inputs().stream()).collect(Collectors.toList()); } } From 735982a28a6810ad933972ea56afdd26d25b05d3 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 09:55:06 -0400 Subject: [PATCH 63/86] Spotless --- .../http/sender/AzureAiStudioEmbeddingsRequestManager.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index 698c2a90b71df..afef9ca87ad63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.threadpool.ThreadPool; From b2839fc7aa122e425df2c4ea57a5ef635e28c88f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 10:51:05 -0400 Subject: [PATCH 64/86] Revert change --- .../xpack/inference/chunking/EmbeddingRequestChunker.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 5d01f13619db1..a3b6121f3f6ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -54,7 +54,9 @@ public ChunkInferenceInput chunkInput() { public record BatchRequest(List requests) { public List inputs() { - return requests.stream().flatMap(r -> r.inputs().stream()).collect(Collectors.toList()); + return requests.stream() + .map(r -> new ChunkInferenceInput(r.chunkText(), r.inputs().getFirst().chunkingSettings())) + .collect(Collectors.toList()); } } From ecc9bb3096daae723334297c7ed7e3769db57045 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 17:05:02 -0400 Subject: [PATCH 65/86] Update chunking setting update logic --- .../mapper/SemanticTextFieldMapper.java | 22 +++++-------------- ...5_semantic_text_field_mapping_chunking.yml | 14 ------------ 2 files changed, 5 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 19527f51dd22d..550ca6f6bde6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -189,7 +189,7 @@ public static class Builder extends FieldMapper.Builder { mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, XContentBuilder::field, Objects::toString - ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeChunkingSettings); + ).acceptsNull(); private final Parameter> meta = Parameter.metaParam(); @@ -317,14 +317,13 @@ private void validateServiceSettings(MinimalServiceSettings settings) { * @return A mapper with the copied settings applied */ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { - SemanticTextFieldMapper returnedMapper = mapper; + SemanticTextFieldMapper returnedMapper; + Builder builder = from(mapper); if (mapper.fieldType().getModelSettings() == null) { - Builder builder = from(mapper); builder.setModelSettings(modelSettings.getValue()); - builder.setChunkingSettings(chunkingSettings.getValue()); - returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); } - + builder.setChunkingSettings(chunkingSettings.getValue()); + returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); return returnedMapper; } } @@ -1023,15 +1022,4 @@ private static boolean canMergeModelSettings(MinimalServiceSettings previous, Mi conflicts.addConflict("model_settings", ""); return false; } - - private static boolean canMergeChunkingSettings(ChunkingSettings previous, ChunkingSettings current, Conflicts conflicts) { - if (Objects.equals(previous, current)) { - return true; - } - if (previous == null || current == null) { - return true; - } - conflicts.addConflict("chunking_settings", ""); - return false; - } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 869331d97448a..996d8d4f08513 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -502,17 +502,3 @@ setup: - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } - - do: - catch: /chunking_settings/ - indices.put_mapping: - index: chunking-update - body: - properties: - inference_field: - type: semantic_text - inference_id: sparse-inference-id - chunking_settings: - strategy: word - max_chunk_size: 20 - overlap: 5 - From 4c992d8465042b40b8dd82012f9b5810313623f1 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 17:05:56 -0400 Subject: [PATCH 66/86] Go back to serializing maps --- .../action/filter/ShardBulkInferenceActionFilter.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index dff611592dfae..40de650a2efd7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -631,10 +631,12 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } indexRequest.source(newDocMap, indexRequest.getContentType()); } else { - try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { - appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); - indexRequest.source(builder); - } + var newDocMap = indexRequest.sourceAsMap(); + newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap); + // try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { + // appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); + // indexRequest.source(builder); + // } } } } From c8071995e2c6f2298cabfaa54dfb9a808a600182 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Mon, 24 Mar 2025 17:18:49 -0400 Subject: [PATCH 67/86] Revert change to model settings - source still errors on missing model_id --- .../action/filter/ShardBulkInferenceActionFilter.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 40de650a2efd7..dff611592dfae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -631,12 +631,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } indexRequest.source(newDocMap, indexRequest.getContentType()); } else { - var newDocMap = indexRequest.sourceAsMap(); - newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap); - // try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { - // appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); - // indexRequest.source(builder); - // } + try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { + appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); + indexRequest.source(builder); + } } } } From c15228169031d478e025f887825618baf70161ca Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 25 Mar 2025 13:59:37 -0400 Subject: [PATCH 68/86] Fix updating chunking settings --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 550ca6f6bde6c..18b9cda8a131e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -322,7 +322,7 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map if (mapper.fieldType().getModelSettings() == null) { builder.setModelSettings(modelSettings.getValue()); } - builder.setChunkingSettings(chunkingSettings.getValue()); + builder.setChunkingSettings(mapper.fieldType().getChunkingSettings()); returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); return returnedMapper; } From 9907a64dbaf5f0b28a08e2e174bdf07562c20ad1 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 26 Mar 2025 09:41:37 -0400 Subject: [PATCH 69/86] Look up model if null --- .../action/filter/ShardBulkInferenceActionFilter.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index dff611592dfae..512b5baa11fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -615,7 +615,9 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons inputs, new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), - model != null ? new MinimalServiceSettings(model) : null, + model != null + ? new MinimalServiceSettings(model) + : modelRegistry.getMinimalServiceSettings(inferenceFieldMetadata.getInferenceId()), ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), From 688c6379b2f902a4a52a81ed4952259d10d78d13 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 26 Mar 2025 10:07:16 -0400 Subject: [PATCH 70/86] Fix test --- .../inference/mapper/SemanticTextFieldMapperTests.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a5871e46b1f52..29f1720c34ae3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -64,6 +64,7 @@ import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.XPackClientPlugin; @@ -893,12 +894,12 @@ public void testModelSettingsRequiredWithChunks() throws IOException { useLegacyFormat ); SourceToParse source = source(b -> addSemanticTextInferenceResults(useLegacyFormat, b, List.of(inferenceResults))); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, + Exception ex = expectThrows( DocumentParsingException.class, + XContentParseException.class, () -> mapperService.documentMapper().parse(source) ); - assertThat(ex.getMessage(), containsString("[model_settings] must be set for field [field] when chunks are provided")); + assertThat(ex.getCause().getMessage(), containsString("Required [model_settings]")); } private MapperService mapperServiceForFieldWithModelSettings(String fieldName, String inferenceId, MinimalServiceSettings modelSettings) From fa9247abe8897062d4065b3bcd8fc62c718cf7cb Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 27 Mar 2025 10:53:27 -0400 Subject: [PATCH 71/86] Work around https://github.com/elastic/elasticsearch/issues/125723 in semantic text field serialization --- .../filter/ShardBulkInferenceActionFilter.java | 4 +--- .../xpack/inference/mapper/SemanticTextField.java | 13 +++++++------ .../mapper/SemanticTextFieldMapperTests.java | 7 +++---- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 512b5baa11fb1..dff611592dfae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -615,9 +615,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons inputs, new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), - model != null - ? new MinimalServiceSettings(model) - : modelRegistry.getMinimalServiceSettings(inferenceFieldMetadata.getInferenceId()), + model != null ? new MinimalServiceSettings(model) : null, ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index bc6089905dd1d..1993d42fd0729 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -155,11 +155,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); - if (inference.chunkingSettings != null) { - builder.startObject(CHUNKING_SETTINGS_FIELD); - builder.mapContents(inference.chunkingSettings.asMap()); - builder.endObject(); - } + builder.field(CHUNKING_SETTINGS_FIELD, inference.chunkingSettings); if (useLegacyFormat) { builder.startArray(CHUNKS_FIELD); @@ -256,7 +252,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws null, new ParseField(MODEL_SETTINGS_FIELD) ); - INFERENCE_RESULT_PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(), new ParseField(CHUNKING_SETTINGS_FIELD)); + INFERENCE_RESULT_PARSER.declareObjectOrNull( + optionalConstructorArg(), + (p, c) -> p.map(), + null, + new ParseField(CHUNKING_SETTINGS_FIELD) + ); INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { if (c.useLegacyFormat()) { return Map.of(c.fieldName, parseChunksArrayLegacy(p, c)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 29f1720c34ae3..a5871e46b1f52 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -64,7 +64,6 @@ import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.XPackClientPlugin; @@ -894,12 +893,12 @@ public void testModelSettingsRequiredWithChunks() throws IOException { useLegacyFormat ); SourceToParse source = source(b -> addSemanticTextInferenceResults(useLegacyFormat, b, List.of(inferenceResults))); - Exception ex = expectThrows( + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, DocumentParsingException.class, - XContentParseException.class, () -> mapperService.documentMapper().parse(source) ); - assertThat(ex.getCause().getMessage(), containsString("Required [model_settings]")); + assertThat(ex.getMessage(), containsString("[model_settings] must be set for field [field] when chunks are provided")); } private MapperService mapperServiceForFieldWithModelSettings(String fieldName, String inferenceId, MinimalServiceSettings modelSettings) From 1d8931c8eadd8e8c78ef2bb11964a4182ea0e327 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 27 Mar 2025 13:19:39 -0400 Subject: [PATCH 72/86] Add BWC tests --- ...mantic_text_field_mapping_chunking_bwc.yml | 530 ++++++++++++++++++ 1 file changed, 530 insertions(+) create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml new file mode 100644 index 0000000000000..07a2331edad17 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -0,0 +1,530 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: default-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: custom-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: default-chunking-sparse + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-sparse + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking-sparse + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-sparse + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + +--- +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": + + - do: + search: + index: default-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 2 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 2 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 2 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + From ab7752e6d0fccc81c7a962fbcc1417b4d3db588f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 27 Mar 2025 13:34:39 -0400 Subject: [PATCH 73/86] Add chunking_settings to docs --- .../mapping-reference/semantic-text.md | 160 +++++++++++++----- 1 file changed, 119 insertions(+), 41 deletions(-) diff --git a/docs/reference/elasticsearch/mapping-reference/semantic-text.md b/docs/reference/elasticsearch/mapping-reference/semantic-text.md index abd44df2a139a..22a19c84e3cee 100644 --- a/docs/reference/elasticsearch/mapping-reference/semantic-text.md +++ b/docs/reference/elasticsearch/mapping-reference/semantic-text.md @@ -6,15 +6,32 @@ mapped_pages: # Semantic text field type [semantic-text] -The `semantic_text` field type automatically generates embeddings for text content using an inference endpoint. Long passages are [automatically chunked](#auto-text-chunking) to smaller sections to enable the processing of larger corpuses of text. - -The `semantic_text` field type specifies an inference endpoint identifier that will be used to generate embeddings. You can create the inference endpoint by using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). This field type and the [`semantic` query](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) type make it simpler to perform semantic search on your data. The `semantic_text` field type may also be queried with [match](/reference/query-languages/query-dsl/query-dsl-match-query.md), [sparse_vector](/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md) or [knn](/reference/query-languages/query-dsl/query-dsl-knn-query.md) queries. - -If you don’t specify an inference endpoint, the `inference_id` field defaults to `.elser-2-elasticsearch`, a preconfigured endpoint for the elasticsearch service. - -Using `semantic_text`, you won’t need to specify how to generate embeddings for your data, or how to index it. The {{infer}} endpoint automatically determines the embedding generation, indexing, and query to use. - -If you use the preconfigured `.elser-2-elasticsearch` endpoint, you can set up `semantic_text` with the following API request: +The `semantic_text` field type automatically generates embeddings for text +content using an inference endpoint. Long passages +are [automatically chunked](#auto-text-chunking) to smaller sections to enable +the processing of larger corpuses of text. + +The `semantic_text` field type specifies an inference endpoint identifier that +will be used to generate embeddings. You can create the inference endpoint by +using +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). +This field type and the [ +`semantic` query](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) +type make it simpler to perform semantic search on your data. The +`semantic_text` field type may also be queried +with [match](/reference/query-languages/query-dsl/query-dsl-match-query.md), [sparse_vector](/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md) +or [knn](/reference/query-languages/query-dsl/query-dsl-knn-query.md) queries. + +If you don’t specify an inference endpoint, the `inference_id` field defaults to +`.elser-2-elasticsearch`, a preconfigured endpoint for the elasticsearch +service. + +Using `semantic_text`, you won’t need to specify how to generate embeddings for +your data, or how to index it. The {{infer}} endpoint automatically determines +the embedding generation, indexing, and query to use. + +If you use the preconfigured `.elser-2-elasticsearch` endpoint, you can set up +`semantic_text` with the following API request: ```console PUT my-index-000001 @@ -29,7 +46,10 @@ PUT my-index-000001 } ``` -To use a custom {{infer}} endpoint instead of the default `.elser-2-elasticsearch`, you must [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) and specify its `inference_id` when setting up the `semantic_text` field type. +To use a custom {{infer}} endpoint instead of the default +`.elser-2-elasticsearch`, you +must [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +and specify its `inference_id` when setting up the `semantic_text` field type. ```console PUT my-index-000002 @@ -47,8 +67,12 @@ PUT my-index-000002 1. The `inference_id` of the {{infer}} endpoint to use to generate embeddings. - -The recommended way to use `semantic_text` is by having dedicated {{infer}} endpoints for ingestion and search. This ensures that search speed remains unaffected by ingestion workloads, and vice versa. After creating dedicated {{infer}} endpoints for both, you can reference them using the `inference_id` and `search_inference_id` parameters when setting up the index mapping for an index that uses the `semantic_text` field. +The recommended way to use `semantic_text` is by having dedicated {{infer}} +endpoints for ingestion and search. This ensures that search speed remains +unaffected by ingestion workloads, and vice versa. After creating dedicated +{{infer}} endpoints for both, you can reference them using the `inference_id` +and `search_inference_id` parameters when setting up the index mapping for an +index that uses the `semantic_text` field. ```console PUT my-index-000003 @@ -65,40 +89,71 @@ PUT my-index-000003 } ``` - ## Parameters for `semantic_text` fields [semantic-text-params] `inference_id` -: (Required, string) {{infer-cap}} endpoint that will be used to generate embeddings for the field. By default, `.elser-2-elasticsearch` is used. This parameter cannot be updated. Use the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) to create the endpoint. If `search_inference_id` is specified, the {{infer}} endpoint will only be used at index time. +: (Required, string) {{infer-cap}} endpoint that will be used to generate +embeddings for the field. By default, `.elser-2-elasticsearch` is used. This +parameter cannot be updated. Use +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +to create the endpoint. If `search_inference_id` is specified, the {{infer}} +endpoint will only be used at index time. `search_inference_id` -: (Optional, string) {{infer-cap}} endpoint that will be used to generate embeddings at query time. You can update this parameter by using the [Update mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping). Use the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) to create the endpoint. If not specified, the {{infer}} endpoint defined by `inference_id` will be used at both index and query time. - +: (Optional, string) {{infer-cap}} endpoint that will be used to generate +embeddings at query time. You can update this parameter by using +the [Update mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping). +Use +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +to create the endpoint. If not specified, the {{infer}} endpoint defined by +`inference_id` will be used at both index and query time. + +`chunking_settings` +: (Optional, object) Sets chunking settings that will override the settings +configured by the `inference_id` endpoint. +See [chunking settings attributes](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +in the {{infer}} API documentation for a complete list of available options. ## {{infer-cap}} endpoint validation [infer-endpoint-validation] -The `inference_id` will not be validated when the mapping is created, but when documents are ingested into the index. When the first document is indexed, the `inference_id` will be used to generate underlying indexing structures for the field. +The `inference_id` will not be validated when the mapping is created, but when +documents are ingested into the index. When the first document is indexed, the +`inference_id` will be used to generate underlying indexing structures for the +field. ::::{warning} -Removing an {{infer}} endpoint will cause ingestion of documents and semantic queries to fail on indices that define `semantic_text` fields with that {{infer}} endpoint as their `inference_id`. Trying to [delete an {{infer}} endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-delete) that is used on a `semantic_text` field will result in an error. +Removing an {{infer}} endpoint will cause ingestion of documents and semantic +queries to fail on indices that define `semantic_text` fields with that +{{infer}} endpoint as their `inference_id`. Trying +to [delete an {{infer}} endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-delete) +that is used on a `semantic_text` field will result in an error. :::: - - ## Text chunking [auto-text-chunking] -{{infer-cap}} endpoints have a limit on the amount of text they can process. To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called *chunks*. +{{infer-cap}} endpoints have a limit on the amount of text they can process. To +allow for large amounts of text to be used in semantic search, `semantic_text` +automatically generates smaller passages if needed, called *chunks*. -Each chunk refers to a passage of the text and the corresponding embedding generated from it. When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score. +Each chunk refers to a passage of the text and the corresponding embedding +generated from it. When querying, the individual passages will be automatically +searched for each document, and the most relevant passage will be used to +compute a score. -For more details on chunking and how to configure chunking settings, see [Configuring chunking](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-inference) in the Inference API documentation. - -Refer to [this tutorial](docs-content://solutions/search/semantic-search/semantic-search-semantic-text.md) to learn more about semantic search using `semantic_text` and the `semantic` query. +For more details on chunking and how to configure chunking settings, +see [Configuring chunking](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-inference) +in the Inference API documentation. +Refer +to [this tutorial](docs-content://solutions/search/semantic-search/semantic-search-semantic-text.md) +to learn more about semantic search using `semantic_text` and the `semantic` +query. ## Extracting Relevant Fragments from Semantic Text [semantic-text-highlighting] -You can extract the most relevant fragments from a semantic text field by using the [highlight parameter](/reference/elasticsearch/rest-apis/highlighting.md) in the [Search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). +You can extract the most relevant fragments from a semantic text field by using +the [highlight parameter](/reference/elasticsearch/rest-apis/highlighting.md) in +the [Search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). ```console POST test-index/_search @@ -120,10 +175,13 @@ POST test-index/_search ``` 1. Specifies the maximum number of fragments to return. -2. Sorts highlighted fragments by score when set to `score`. By default, fragments will be output in the order they appear in the field (order: none). - +2. Sorts highlighted fragments by score when set to `score`. By default, + fragments will be output in the order they appear in the field (order: none). -Highlighting is supported on fields other than semantic_text. However, if you want to restrict highlighting to the semantic highlighter and return no fragments when the field is not of type semantic_text, you can explicitly enforce the `semantic` highlighter in the query: +Highlighting is supported on fields other than semantic_text. However, if you +want to restrict highlighting to the semantic highlighter and return no +fragments when the field is not of type semantic_text, you can explicitly +enforce the `semantic` highlighter in the query: ```console PUT test-index @@ -147,23 +205,42 @@ PUT test-index 1. Ensures that highlighting is applied exclusively to semantic_text fields. - - ## Customizing `semantic_text` indexing [custom-indexing] -`semantic_text` uses defaults for indexing data based on the {{infer}} endpoint specified. It enables you to quickstart your semantic search by providing automatic {{infer}} and a dedicated query so you don’t need to provide further details. - -In case you want to customize data indexing, use the [`sparse_vector`](/reference/elasticsearch/mapping-reference/sparse-vector.md) or [`dense_vector`](/reference/elasticsearch/mapping-reference/dense-vector.md) field types and create an ingest pipeline with an [{{infer}} processor](/reference/enrich-processor/inference-processor.md) to generate the embeddings. [This tutorial](docs-content://solutions/search/semantic-search/semantic-search-inference.md) walks you through the process. In these cases - when you use `sparse_vector` or `dense_vector` field types instead of the `semantic_text` field type to customize indexing - using the [`semantic_query`](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) is not supported for querying the field data. - +`semantic_text` uses defaults for indexing data based on the {{infer}} endpoint +specified. It enables you to quickstart your semantic search by providing +automatic {{infer}} and a dedicated query so you don’t need to provide further +details. + +In case you want to customize data indexing, use the [ +`sparse_vector`](/reference/elasticsearch/mapping-reference/sparse-vector.md) +or [`dense_vector`](/reference/elasticsearch/mapping-reference/dense-vector.md) +field types and create an ingest pipeline with +an [{{infer}} processor](/reference/enrich-processor/inference-processor.md) to +generate the +embeddings. [This tutorial](docs-content://solutions/search/semantic-search/semantic-search-inference.md) +walks you through the process. In these cases - when you use `sparse_vector` or +`dense_vector` field types instead of the `semantic_text` field type to +customize indexing - using the [ +`semantic_query`](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) +is not supported for querying the field data. ## Updates to `semantic_text` fields [update-script] -Updates that use scripts are not supported for an index contains a `semantic_text` field. Even if the script targets non-`semantic_text` fields, the update will fail when the index contains a `semantic_text` field. - +Updates that use scripts are not supported for an index contains a +`semantic_text` field. Even if the script targets non-`semantic_text` fields, +the update will fail when the index contains a `semantic_text` field. ## `copy_to` and multi-fields support [copy-to-support] -The semantic_text field type can serve as the target of [copy_to fields](/reference/elasticsearch/mapping-reference/copy-to.md), be part of a [multi-field](/reference/elasticsearch/mapping-reference/multi-fields.md) structure, or contain [multi-fields](/reference/elasticsearch/mapping-reference/multi-fields.md) internally. This means you can use a single field to collect the values of other fields for semantic search. +The semantic_text field type can serve as the target +of [copy_to fields](/reference/elasticsearch/mapping-reference/copy-to.md), be +part of +a [multi-field](/reference/elasticsearch/mapping-reference/multi-fields.md) +structure, or +contain [multi-fields](/reference/elasticsearch/mapping-reference/multi-fields.md) +internally. This means you can use a single field to collect the values of other +fields for semantic search. For example, the following mapping: @@ -206,11 +283,12 @@ PUT test-index } ``` - ## Limitations [limitations] `semantic_text` field types have the following limitations: -* `semantic_text` fields are not currently supported as elements of [nested fields](/reference/elasticsearch/mapping-reference/nested.md). -* `semantic_text` fields can’t currently be set as part of [Dynamic templates](docs-content://manage-data/data-store/mapping/dynamic-templates.md). +* `semantic_text` fields are not currently supported as elements + of [nested fields](/reference/elasticsearch/mapping-reference/nested.md). +* `semantic_text` fields can’t currently be set as part + of [Dynamic templates](docs-content://manage-data/data-store/mapping/dynamic-templates.md). From cd4d32b441f6a04220ea100b661626660030250f Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 08:40:27 -0400 Subject: [PATCH 74/86] Refactor/rename --- .../org/elasticsearch/inference/ChunkInferenceInput.java | 2 +- .../xpack/inference/services/SenderService.java | 2 +- .../elasticsearch/ElasticsearchInternalService.java | 2 +- .../huggingface/elser/HuggingFaceElserService.java | 4 ++-- .../inference/chunking/EmbeddingRequestChunkerTests.java | 8 ++++---- .../amazonbedrock/AmazonBedrockMockRequestSender.java | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java index bb29d58dc145c..8e25e0e55f08c 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -19,7 +19,7 @@ public ChunkInferenceInput(String input) { this(input, null); } - public static List asStrings(List chunkInferenceInputs) { + public static List inputs(List chunkInferenceInputs) { return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 23fb5dc762a71..ff8ae6fd5aac3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -86,7 +86,7 @@ private static InferenceInputs createInput( @Nullable Integer topN, boolean stream ) { - List textInput = ChunkInferenceInput.asStrings(input); + List textInput = ChunkInferenceInput.inputs(input); return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); case RERANK -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 7b20b972767ff..1fc647764790f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -1127,7 +1127,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - ChunkInferenceInput.asStrings(batch.batch().inputs()), + ChunkInferenceInput.inputs(batch.batch().inputs()), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index a55a00be1909e..8116eaf86e74a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -113,7 +113,7 @@ protected void doChunkedInfer( private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.asStrings(inputs.getInputs()), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.getInputs().size()); @@ -131,7 +131,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = ChunkInferenceInput.asStrings(EmbeddingsInput.of(inputs).getInputs()); + var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 9bca2b8d9cc19..4809a41e84189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -138,12 +138,12 @@ public void testManyInputsMakeManyBatches() { assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - ChunkInferenceInput.asStrings(batches.get(1).batch().inputs()), + ChunkInferenceInput.inputs(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.asStrings(batches.get(3).batch().inputs()), contains("input 30")); + assertThat(ChunkInferenceInput.inputs(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { @@ -174,12 +174,12 @@ public void testChunkingSettingsProvided() { assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); assertThat( - ChunkInferenceInput.asStrings(batches.get(1).batch().inputs()), + ChunkInferenceInput.inputs(batches.get(1).batch().inputs()), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.asStrings(batches.get(3).batch().inputs()), contains("input 30")); + assertThat(ChunkInferenceInput.inputs(batches.get(3).batch().inputs()), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index f1b6aa4b8fc42..48b04a38928ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -84,7 +84,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(ChunkInferenceInput.asStrings(docsInput.getInputs())); + inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); if (docsInput.getInputType() != null) { inputTypes.add(docsInput.getInputType()); } From 5db7ed48e7b9ab96fda7e706079948a42215eb73 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 08:49:53 -0400 Subject: [PATCH 75/86] Address minor PR feedback --- .../inference/chunking/EmbeddingRequestChunker.java | 11 ++--------- .../inference/mapper/SemanticTextFieldMapper.java | 3 +-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index a3b6121f3f6ce..f470ac444938b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -44,12 +44,6 @@ record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List requests) { @@ -62,8 +56,7 @@ public List inputs() { public record BatchRequestAndListener(BatchRequest batch, ActionListener listener) {} - private static final int DEFAULT_WORDS_PER_CHUNK = 250; - private static final int DEFAULT_CHUNK_OVERLAP = 100; + private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100); // The maximum number of chunks that is stored for any input text. // If the configured chunker chunks the text into more chunks, each @@ -99,7 +92,7 @@ public EmbeddingRequestChunker( this.resultsErrors = new AtomicArray<>(inputs.size()); if (defaultChunkingSettings == null) { - defaultChunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); + defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS; } List allRequests = new ArrayList<>(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 18b9cda8a131e..c2e93e0b24c52 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -323,8 +323,7 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map builder.setModelSettings(modelSettings.getValue()); } builder.setChunkingSettings(mapper.fieldType().getChunkingSettings()); - returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); - return returnedMapper; + return builder.build(mapperMergeContext.getMapperBuilderContext()); } } From e4776390d89ab2aa4cec355c68b1acceb689794b Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 10:20:19 -0400 Subject: [PATCH 76/86] Add test case for null update --- .../25_semantic_text_field_mapping_chunking.yml | 16 ++++++++++++++++ ..._semantic_text_field_mapping_chunking_bwc.yml | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 996d8d4f08513..ce67a2340aa33 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -502,3 +502,19 @@ setup: - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml index 07a2331edad17..6fc376eb1874c 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -528,3 +528,19 @@ setup: - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + From 7d85fd3133fb55e2a7be932e354c74dbc298c19e Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 10:53:33 -0400 Subject: [PATCH 77/86] PR feedback - adjust refactor of chunked inputs --- .../chunking/EmbeddingRequestChunker.java | 3 +- .../external/http/sender/EmbeddingsInput.java | 4 ++ .../AlibabaCloudSearchService.java | 2 +- .../amazonbedrock/AmazonBedrockService.java | 2 +- .../azureaistudio/AzureAiStudioService.java | 2 +- .../azureopenai/AzureOpenAiService.java | 2 +- .../services/cohere/CohereService.java | 2 +- .../ElasticsearchInternalService.java | 2 +- .../googleaistudio/GoogleAiStudioService.java | 2 +- .../googlevertexai/GoogleVertexAiService.java | 2 +- .../huggingface/HuggingFaceService.java | 2 +- .../ibmwatsonx/IbmWatsonxService.java | 2 +- .../services/jinaai/JinaAIService.java | 2 +- .../services/mistral/MistralService.java | 2 +- .../services/openai/OpenAiService.java | 2 +- .../services/voyageai/VoyageAIService.java | 2 +- .../EmbeddingRequestChunkerTests.java | 38 +++++++++---------- 17 files changed, 39 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index f470ac444938b..5ac322c63f5f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -47,9 +47,10 @@ public String chunkText() { } public record BatchRequest(List requests) { - public List inputs() { + public List inputs() { return requests.stream() .map(r -> new ChunkInferenceInput(r.chunkText(), r.inputs().getFirst().chunkingSettings())) + .map(ChunkInferenceInput::input) .collect(Collectors.toList()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 42f1310e14140..ea90e23fe69b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -44,6 +44,10 @@ public EmbeddingsInput(List input, @Nullable InputType inpu this.inputType = inputType; } + public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { + return new EmbeddingsInput(input, null, inputType); + } + public List getInputs() { return this.input; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index bf1fbda2b826b..ebc3cdc9100ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -344,7 +344,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index c9ab8f287da38..eeb90aa80e923 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -170,7 +170,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index a25fb7a3e4d82..d2e077bfdfaf9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -141,7 +141,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 3843db7140b30..747480d2bd54a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -294,7 +294,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 9f1181a5d4382..2920513ef52bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -307,7 +307,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 1fc647764790f..d4539099345fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -1127,7 +1127,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - ChunkInferenceInput.inputs(batch.batch().inputs()), + batch.batch().inputs(), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 1be510330fa85..9439b9c230b3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -364,7 +364,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); + doInfer(model, EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 726fb0b5da02f..831a06c9ab246 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -250,7 +250,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index ca167a15a65f3..2c7bdba8de8a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -135,7 +135,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 65a8ef469a2f7..c95cc6113749c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -319,7 +319,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 4975d74e9984a..95b00c8804d35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -288,7 +288,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 9ed0bb85cbeb8..cb94bfd7776e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -123,7 +123,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 977a025eed2f6..1d615876f7386 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -348,7 +348,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 6b8af209ea4b1..fade09997aca5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -308,7 +308,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 4809a41e84189..6667aa61a5802 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -54,7 +54,7 @@ public void testWhitespaceInput_SentenceChunker() { ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is(" ")); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" ")); } public void testBlankInput_WordChunker() { @@ -63,7 +63,7 @@ public void testBlankInput_WordChunker() { ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { @@ -71,7 +71,7 @@ public void testBlankInput_SentenceChunker() { .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); } public void testInputThatDoesNotChunk_WordChunker() { @@ -80,7 +80,7 @@ public void testInputThatDoesNotChunk_WordChunker() { ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { @@ -91,14 +91,14 @@ public void testInputThatDoesNotChunk_SentenceChunker() { ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0).input(), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), contains(input)); + assertThat(batches.get(0).batch().inputs(), contains(input.input())); } public void testMultipleShortInputsAreSingleBatch() { @@ -110,7 +110,7 @@ public void testMultipleShortInputsAreSingleBatch() { var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); - assertEquals(batch.inputs(), inputs); + assertEquals(batch.inputs(), ChunkInferenceInput.inputs(inputs)); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); assertThat(request.chunkText(), equalTo(inputs.get(i).input())); @@ -135,15 +135,15 @@ public void testManyInputsMakeManyBatches() { assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); - assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); + assertEquals("input 0", batches.get(0).batch().inputs().get(0)); + assertEquals("input 9", batches.get(0).batch().inputs().get(9)); assertThat( - ChunkInferenceInput.inputs(batches.get(1).batch().inputs()), + batches.get(1).batch().inputs(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); - assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.inputs(batches.get(3).batch().inputs()), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get(0)); + assertEquals("input 29", batches.get(2).batch().inputs().get(9)); + assertThat(batches.get(3).batch().inputs(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { @@ -171,15 +171,15 @@ public void testChunkingSettingsProvided() { assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); assertThat(batches.get(3).batch().inputs(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0).input()); - assertEquals("input 9", batches.get(0).batch().inputs().get(9).input()); + assertEquals("input 0", batches.get(0).batch().inputs().get(0)); + assertEquals("input 9", batches.get(0).batch().inputs().get(9)); assertThat( - ChunkInferenceInput.inputs(batches.get(1).batch().inputs()), + batches.get(1).batch().inputs(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0).input()); - assertEquals("input 29", batches.get(2).batch().inputs().get(9).input()); - assertThat(ChunkInferenceInput.inputs(batches.get(3).batch().inputs()), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get(0)); + assertEquals("input 29", batches.get(2).batch().inputs().get(9)); + assertThat(batches.get(3).batch().inputs(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { From 845a732a82d2c58f5406845f27cf4bbf410941c3 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 13:09:01 -0400 Subject: [PATCH 78/86] Refactored AbstractTestInferenceService to return offsets instead of just Strings --- .../mock/AbstractTestInferenceService.java | 13 ++++++---- .../TestDenseInferenceServiceExtension.java | 26 +++++++++---------- .../TestSparseInferenceServiceExtension.java | 13 ++++------ 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index b70e854cbf626..7d4a120668a8b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -37,6 +37,8 @@ public abstract class AbstractTestInferenceService implements InferenceService { + protected record ChunkedInput(String input, int startOffset, int endOffset) {} + protected static final Random random = new Random( System.getProperty("tests.seed") == null ? System.currentTimeMillis() @@ -112,23 +114,24 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} - protected List chunkInputs(ChunkInferenceInput input) { + protected List chunkInputs(ChunkInferenceInput input) { ChunkingSettings chunkingSettings = input.chunkingSettings(); + String inputText = input.input(); if (chunkingSettings == null) { - return List.of(input.input()); + return List.of(new ChunkedInput(inputText, 0, inputText.length())); } - List chunkedInputs = new ArrayList<>(); + List chunkedInputs = new ArrayList<>(); if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.WORD) { WordBoundaryChunker chunker = new WordBoundaryChunker(); WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; List offsets = chunker.chunk( - input.input(), + inputText, wordBoundaryChunkingSettings.maxChunkSize(), wordBoundaryChunkingSettings.overlap() ); for (WordBoundaryChunker.ChunkOffset offset : offsets) { - chunkedInputs.add(input.input().substring(offset.start(), offset.end())); + chunkedInputs.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); } } else { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index e6cccfdd55546..5a27a7acfcde0 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -45,6 +45,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; public class TestDenseInferenceServiceExtension implements InferenceServiceExtension { @Override @@ -178,21 +179,18 @@ private TextEmbeddingFloatResults makeResults(List input, ServiceSetting private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); - for (int i = 0; i < inputs.size(); i++) { - ChunkInferenceInput input = inputs.get(i); - List chunkedInput = chunkInputs(input); - List chunks = new ArrayList<>(); - for (String c : chunkedInput) { - // Note: We have to start with an offset of 0 to account for overlaps - int offset = input.input().indexOf(c); - int endOffset = offset + c.length(); - chunks.add( - new TextEmbeddingFloatResults.Chunk( - makeResults(List.of(c), serviceSettings).embeddings().get(0), - new ChunkedInference.TextOffset(offset, endOffset) + for (ChunkInferenceInput input : inputs) { + List chunkedInput = chunkInputs(input); + List chunks = new ArrayList<>( + chunkedInput.stream() + .map( + c -> new TextEmbeddingFloatResults.Chunk( + makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), + new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) + ) ) - ); - } + .toList() + ); ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 09dd63978b349..6b4ebdd5c674b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -43,6 +43,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { @@ -171,19 +172,15 @@ private List makeChunkedResults(List inpu List results = new ArrayList<>(); for (ChunkInferenceInput chunkInferenceInput : inputs) { String input = chunkInferenceInput.input(); - List chunkedInput = chunkInputs(chunkInferenceInput); - List chunks = new ArrayList<>(); - for (String c : chunkedInput) { + List chunkedInput = chunkInputs(chunkInferenceInput); + List chunks = new ArrayList<>(chunkedInput.stream().map(c -> { var tokens = new ArrayList(); for (int i = 0; i < 5; i++) { tokens.add(new WeightedToken("feature_" + i, generateEmbedding(input, i))); } - // Note: We have to start with an offset of 0 to account for overlaps - int offset = input.indexOf(c); - int endOffset = offset + c.length(); var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); - chunks.add(new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(offset, endOffset))); - } + return new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())); + }).toList()); ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); } From 2ed86d45cd50380a36c1c214dfe982ff6a1fa2b2 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 28 Mar 2025 17:16:50 +0000 Subject: [PATCH 79/86] [CI] Auto commit changes from spotless --- .../xpack/inference/mock/TestDenseInferenceServiceExtension.java | 1 - .../inference/mock/TestSparseInferenceServiceExtension.java | 1 - 2 files changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 5a27a7acfcde0..6c6db8616b248 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -45,7 +45,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class TestDenseInferenceServiceExtension implements InferenceServiceExtension { @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 6b4ebdd5c674b..93e60e7f7db17 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -43,7 +43,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { From 0a54972a04b6517c0da56fa3bb239e0b45d3c6c3 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 13:56:48 -0400 Subject: [PATCH 80/86] Fix tests where chunk output was of size 3 --- .../TestSparseInferenceServiceExtension.java | 2 +- ...5_semantic_text_field_mapping_chunking.yml | 19 ++++++++++-------- ...mantic_text_field_mapping_chunking_bwc.yml | 20 +++++++++++-------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 93e60e7f7db17..c401217bbf2ec 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -175,7 +175,7 @@ private List makeChunkedResults(List inpu List chunks = new ArrayList<>(chunkedInput.stream().map(c -> { var tokens = new ArrayList(); for (int i = 0; i < 5; i++) { - tokens.add(new WeightedToken("feature_" + i, generateEmbedding(input, i))); + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c.input(), i))); } var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); return new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index ce67a2340aa33..fc52f0890c84a 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -212,7 +212,7 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_1" } @@ -231,13 +231,14 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_2" } - - length: { hits.hits.0.highlight.inference_field: 2 } + - length: { hits.hits.0.highlight.inference_field: 3 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } --- "We return different chunks based on configured chunking overrides or model defaults for dense embeddings": @@ -327,18 +328,19 @@ setup: fields: default_chunked_inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 customized_chunked_inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - length: { hits.hits.0.highlight.customized_chunked_inference_field: 2 } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } --- "Bulk requests are handled appropriately": @@ -394,7 +396,7 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 3 } @@ -404,9 +406,10 @@ setup: - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } - match: { hits.hits.1._id: "doc_1" } - - length: { hits.hits.1.highlight.inference_field: 2 } + - length: { hits.hits.1.highlight.inference_field: 3 } - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } - match: { hits.hits.2._id: "doc_2" } - length: { hits.hits.2.highlight.inference_field: 1 } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml index 6fc376eb1874c..50221151d77fb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -224,7 +224,7 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_1" } @@ -243,13 +243,14 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_2" } - - length: { hits.hits.0.highlight.inference_field: 2 } + - length: { hits.hits.0.highlight.inference_field: 3 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } --- "We return different chunks based on configured chunking overrides or model defaults for dense embeddings": @@ -341,18 +342,20 @@ setup: fields: default_chunked_inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 customized_chunked_inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } - - length: { hits.hits.0.highlight.customized_chunked_inference_field: 2 } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } + --- "Bulk requests are handled appropriately": @@ -412,7 +415,7 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 3 } @@ -422,9 +425,10 @@ setup: - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } - match: { hits.hits.1._id: "doc_1" } - - length: { hits.hits.1.highlight.inference_field: 2 } + - length: { hits.hits.1.highlight.inference_field: 3 } - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } - match: { hits.hits.2._id: "doc_2" } - length: { hits.hits.2.highlight.inference_field: 1 } From 8cf287b1f1450a2b81ddf8e8690ff9a4ad6e3dad Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Fri, 28 Mar 2025 15:58:53 -0400 Subject: [PATCH 81/86] Update mappings per PR feedback --- .../inference/mapper/SemanticTextField.java | 8 +- .../mapper/SemanticTextFieldMapperTests.java | 88 ++++++++++++++++--- .../mapper/SemanticTextFieldTests.java | 10 ++- ...5_semantic_text_field_mapping_chunking.yml | 1 - ...mantic_text_field_mapping_chunking_bwc.yml | 1 - 5 files changed, 92 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 1993d42fd0729..b6652e499b9fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -155,7 +155,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); - builder.field(CHUNKING_SETTINGS_FIELD, inference.chunkingSettings); + if (inference.chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS_FIELD, inference.chunkingSettings); + } if (useLegacyFormat) { builder.startArray(CHUNKS_FIELD); @@ -247,7 +249,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); INFERENCE_RESULT_PARSER.declareObjectOrNull( - constructorArg(), + optionalConstructorArg(), (p, c) -> MinimalServiceSettings.parse(p), null, new ParseField(MODEL_SETTINGS_FIELD) @@ -332,7 +334,7 @@ public static List toSemanticTextFieldChunksLegacy(String input, ChunkedI return chunks; } - public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) { + public static Chunk toSemanticTextFieldChunkLegacy(String input, org.elasticsearch.inference.ChunkedInference.Chunk chunk) { var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); return new Chunk(text, -1, -1, chunk.bytesReference()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a5871e46b1f52..4d2a76f915af3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -92,6 +92,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettingsOtherThan; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -568,6 +569,15 @@ public void testUpdateSearchInferenceId() throws IOException { } private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null); + } + + private static void assertSemanticTextField( + MapperService mapperService, + String fieldName, + boolean expectedModelSettings, + ChunkingSettings expectedChunkingSettings + ) { Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); @@ -619,6 +629,13 @@ private static void assertSemanticTextField(MapperService mapperService, String } else { assertNull(semanticFieldMapper.fieldType().getModelSettings()); } + + if (expectedChunkingSettings != null) { + assertNotNull(semanticFieldMapper.fieldType().getChunkingSettings()); + assertEquals(expectedChunkingSettings, semanticFieldMapper.fieldType().getChunkingSettings()); + } else { + assertNull(semanticFieldMapper.fieldType().getChunkingSettings()); + } } private static void assertInferenceEndpoints( @@ -646,8 +663,20 @@ public void testSuccessfulParse() throws IOException { Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + addSemanticTextMapping( + b, + fieldName1, + model1.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); + addSemanticTextMapping( + b, + fieldName2, + model2.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); }); MapperService mapperService = createMapperService(mapping, useLegacyFormat); @@ -762,7 +791,7 @@ public void testSuccessfulParse() throws IOException { public void testMissingInferenceId() throws IOException { final MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), useLegacyFormat ); @@ -788,8 +817,11 @@ public void testMissingInferenceId() throws IOException { assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); } - public void testMissingModelSettings() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + public void testMissingModelSettingsAndChunks() throws IOException { + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -801,11 +833,15 @@ public void testMissingModelSettings() throws IOException { ) ) ); - assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + // Model settings may be null here so we only error on chunks + assertThat(ex.getCause().getMessage(), containsString("Required [chunks]")); } public void testMissingTaskType() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -864,14 +900,40 @@ public void testDenseVectorElementType() throws IOException { assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE); } + public void testSettingAndUpdatingChunkingSettings() throws IOException { + Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + String fieldName = "field"; + + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); + + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings)), + useLegacyFormat + ); + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings); + + ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); + merge(mapperService, mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings))); + assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings); + } + public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); SemanticTextField randomSemanticText = randomSemanticText( useLegacyFormat, "field", model, - generateRandomChunkingSettings(), + chunkingSettings, List.of("a"), XContentType.JSON ); @@ -889,7 +951,7 @@ public void testModelSettingsRequiredWithChunks() throws IOException { ); MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null)), + mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null, chunkingSettings)), useLegacyFormat ); SourceToParse source = source(b -> addSemanticTextInferenceResults(useLegacyFormat, b, List.of(inferenceResults))); @@ -1000,7 +1062,8 @@ private static void addSemanticTextMapping( XContentBuilder mappingBuilder, String fieldName, String inferenceId, - String searchInferenceId + String searchInferenceId, + ChunkingSettings chunkingSettings ) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); @@ -1008,6 +1071,11 @@ private static void addSemanticTextMapping( if (searchInferenceId != null) { mappingBuilder.field("search_inference_id", searchInferenceId); } + if (chunkingSettings != null) { + mappingBuilder.startObject("chunking_settings"); + mappingBuilder.mapContents(chunkingSettings.asMap()); + mappingBuilder.endObject(); + } mappingBuilder.endObject(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 41df8f650cb1f..b4ac5c475d425 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -322,7 +322,11 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( } public static ChunkingSettings generateRandomChunkingSettings() { - if (randomBoolean()) { + return generateRandomChunkingSettings(true); + } + + public static ChunkingSettings generateRandomChunkingSettings(boolean allowNull) { + if (allowNull && randomBoolean()) { return null; // Use model defaults } return randomBoolean() @@ -330,6 +334,10 @@ public static ChunkingSettings generateRandomChunkingSettings() { : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); } + public static ChunkingSettings generateRandomChunkingSettingsOtherThan(ChunkingSettings chunkingSettings) { + return randomValueOtherThan(chunkingSettings, () -> generateRandomChunkingSettings(false)); + } + /** * Returns a randomly generated object for Semantic Text tests purpose. */ diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index fc52f0890c84a..0eb871313a1ce 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -513,7 +513,6 @@ setup: inference_field: type: semantic_text inference_id: sparse-inference-id - chunking_settings: null - do: indices.get_mapping: diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml index 50221151d77fb..596b99ef70305 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -540,7 +540,6 @@ setup: inference_field: type: semantic_text inference_id: sparse-inference-id - chunking_settings: null - do: indices.get_mapping: From a9c751260ac1de86fa96629e1f6e167921a50f55 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Tue, 1 Apr 2025 16:03:24 -0400 Subject: [PATCH 82/86] PR Feedback --- .../mock/TestDenseInferenceServiceExtension.java | 16 +++++++--------- .../TestSparseInferenceServiceExtension.java | 5 ++--- .../chunking/EmbeddingRequestChunker.java | 5 +---- .../mapper/SemanticTextFieldMapper.java | 9 +++++---- .../SingleInputSenderExecutableActionTests.java | 2 +- .../queries/SemanticQueryBuilderTests.java | 8 +------- .../25_semantic_text_field_mapping_chunking.yml | 5 +++-- ..._semantic_text_field_mapping_chunking_bwc.yml | 5 +++-- 8 files changed, 23 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 6c6db8616b248..044af0ab1d37d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -180,16 +180,14 @@ private List makeChunkedResults(List inpu var results = new ArrayList(); for (ChunkInferenceInput input : inputs) { List chunkedInput = chunkInputs(input); - List chunks = new ArrayList<>( - chunkedInput.stream() - .map( - c -> new TextEmbeddingFloatResults.Chunk( - makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), - new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) - ) + List chunks = chunkedInput.stream() + .map( + c -> new TextEmbeddingFloatResults.Chunk( + makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), + new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) ) - .toList() - ); + ) + .toList(); ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index c401217bbf2ec..03c5c6201ce33 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -170,16 +170,15 @@ private SparseEmbeddingResults makeResults(List input) { private List makeChunkedResults(List inputs) { List results = new ArrayList<>(); for (ChunkInferenceInput chunkInferenceInput : inputs) { - String input = chunkInferenceInput.input(); List chunkedInput = chunkInputs(chunkInferenceInput); - List chunks = new ArrayList<>(chunkedInput.stream().map(c -> { + List chunks = chunkedInput.stream().map(c -> { var tokens = new ArrayList(); for (int i = 0; i < 5; i++) { tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c.input(), i))); } var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); return new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())); - }).toList()); + }).toList(); ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); results.add(chunkedInferenceEmbedding); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 5ac322c63f5f3..3750bf8a1a950 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -48,10 +48,7 @@ public String chunkText() { public record BatchRequest(List requests) { public List inputs() { - return requests.stream() - .map(r -> new ChunkInferenceInput(r.chunkText(), r.inputs().getFirst().chunkingSettings())) - .map(ChunkInferenceInput::input) - .collect(Collectors.toList()); + return requests.stream().map(Request::chunkText).collect(Collectors.toList()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index c2e93e0b24c52..3a942a8e73537 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -317,13 +317,14 @@ private void validateServiceSettings(MinimalServiceSettings settings) { * @return A mapper with the copied settings applied */ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { - SemanticTextFieldMapper returnedMapper; - Builder builder = from(mapper); + SemanticTextFieldMapper returnedMapper = mapper; if (mapper.fieldType().getModelSettings() == null) { + Builder builder = from(mapper); builder.setModelSettings(modelSettings.getValue()); + returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); } - builder.setChunkingSettings(mapper.fieldType().getChunkingSettings()); - return builder.build(mapperMergeContext.getMapperBuilderContext()); + + return returnedMapper; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 1bead6b72ca92..440eef69ed7a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -64,7 +64,7 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of(new ChunkInferenceInput("one", null), new ChunkInferenceInput("two", null)); + var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); when(badInput.getInputs()).thenReturn(input); when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 53d6e58492b80..c4a6b92ac033c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -60,7 +60,6 @@ import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests; import org.junit.Before; import org.junit.BeforeClass; @@ -370,12 +369,7 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults( useLegacyFormat, SEMANTIC_TEXT_FIELD, null, - new SemanticTextField.InferenceResult( - INFERENCE_ID, - modelSettings, - SemanticTextFieldTests.generateRandomChunkingSettings(), - Map.of(SEMANTIC_TEXT_FIELD, List.of()) - ), + new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, null, Map.of(SEMANTIC_TEXT_FIELD, List.of())), XContentType.JSON ); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml index 0eb871313a1ce..a6ff307f0ef4a 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -274,13 +274,14 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_4" } - - length: { hits.hits.0.highlight.inference_field: 2 } + - length: { hits.hits.0.highlight.inference_field: 3 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } --- "We respect multiple semantic_text fields with different chunking configurations": diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml index 596b99ef70305..f189d5535bb77 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -286,13 +286,14 @@ setup: fields: inference_field: type: "semantic" - number_of_fragments: 2 + number_of_fragments: 3 - match: { hits.total.value: 1 } - match: { hits.hits.0._id: "doc_4" } - - length: { hits.hits.0.highlight.inference_field: 2 } + - length: { hits.hits.0.highlight.inference_field: 3 } - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } --- "We respect multiple semantic_text fields with different chunking configurations": From 546e3337cf779ce37e1b6b22339e5c9ab9bd30a6 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 3 Apr 2025 14:15:19 -0400 Subject: [PATCH 83/86] Fix problems related to merge --- .../alibabacloudsearch/AlibabaCloudSearchService.java | 2 +- .../services/amazonbedrock/AmazonBedrockService.java | 2 +- .../services/azureaistudio/AzureAiStudioService.java | 2 +- .../services/azureopenai/AzureOpenAiService.java | 2 +- .../xpack/inference/services/cohere/CohereService.java | 2 +- .../services/googleaistudio/GoogleAiStudioService.java | 8 +++++++- .../services/googlevertexai/GoogleVertexAiService.java | 2 +- .../services/huggingface/HuggingFaceService.java | 2 +- .../inference/services/ibmwatsonx/IbmWatsonxService.java | 2 +- .../xpack/inference/services/jinaai/JinaAIService.java | 2 +- .../xpack/inference/services/mistral/MistralService.java | 2 +- .../xpack/inference/services/openai/OpenAiService.java | 2 +- .../inference/services/voyageai/VoyageAIService.java | 2 +- 13 files changed, 19 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 4312e43903595..ac0d0df06b48d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -344,7 +344,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index c27ee0a8dd0a3..38d8d61873ce5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -170,7 +170,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 8d808350d9d27..a70f44b91f9f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -141,7 +141,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 828903e638bb1..03778e4471042 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -294,7 +294,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 246ea2e9a2676..66dc7a1de9a75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -307,7 +307,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9439b9c230b3b..87e0a1ba67a90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -364,7 +364,13 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer(model, EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); + doInfer( + model, + EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), + taskSettings, + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 8d42021dedc71..8526e8abbad4d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -250,7 +250,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 9cb9ec5677a65..612748f6ede12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -135,7 +135,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index c016c014c19df..c01d4d142fe16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -319,7 +319,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index f1a97e31390f2..afd1d5db213bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -288,7 +288,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 4dd84a9f15708..5c6488bfbbda2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -123,7 +123,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index ca0458ef2ecca..094b6b27e158b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -347,7 +347,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index f135da2d7f21d..229266a5e51ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -308,7 +308,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } From 612d2fba65daf6f9b9f57a451a67183766cb0082 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 3 Apr 2025 14:23:12 -0400 Subject: [PATCH 84/86] PR optimization --- .../inference/chunking/EmbeddingRequestChunker.java | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 0fe0dc31e2cc0..818f901e19d85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -22,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Supplier; @@ -94,13 +97,20 @@ public EmbeddingRequestChunker( defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS; } + Map chunkers = inputs.stream() + .map(ChunkInferenceInput::chunkingSettings) + .filter(Objects::nonNull) + .map(ChunkingSettings::getChunkingStrategy) + .distinct() + .collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy)); + List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); if (chunkingSettings == null) { chunkingSettings = defaultChunkingSettings; } - Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + Chunker chunker = chunkers.get(chunkingSettings.getChunkingStrategy()); List chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); From 9fb17a675d15e94139cf34789b31cb5eba9be364 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 3 Apr 2025 14:48:08 -0400 Subject: [PATCH 85/86] Fix test --- .../search/rank/feature/RerankSnippet.java | 30 +++++++++++++++++++ .../chunking/EmbeddingRequestChunker.java | 3 +- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java new file mode 100644 index 0000000000000..416a74ba9bd2d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; + +public record RerankSnippet(String snippet, float score) implements Writeable { + + public RerankSnippet(StreamInput in) throws IOException { + this(in.readString(), in.readFloat()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(snippet); + out.writeFloat(score); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 818f901e19d85..2df2f1e62f89a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -103,6 +103,7 @@ public EmbeddingRequestChunker( .map(ChunkingSettings::getChunkingStrategy) .distinct() .collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy)); + Chunker defaultChunker = ChunkerBuilder.fromChunkingStrategy(defaultChunkingSettings.getChunkingStrategy()); List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { @@ -110,7 +111,7 @@ public EmbeddingRequestChunker( if (chunkingSettings == null) { chunkingSettings = defaultChunkingSettings; } - Chunker chunker = chunkers.get(chunkingSettings.getChunkingStrategy()); + Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); List chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); From bf8c46066268fdda482f132ea4b376166d256d74 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 3 Apr 2025 14:50:46 -0400 Subject: [PATCH 86/86] Delete extra file --- .../search/rank/feature/RerankSnippet.java | 30 ------------------- 1 file changed, 30 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java deleted file mode 100644 index 416a74ba9bd2d..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippet.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.rank.feature; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; - -import java.io.IOException; - -public record RerankSnippet(String snippet, float score) implements Writeable { - - public RerankSnippet(StreamInput in) throws IOException { - this(in.readString(), in.readFloat()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(snippet); - out.writeFloat(score); - } - -}