diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java new file mode 100644 index 0000000000000..d7bf43cdfbdd5 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +/** + * Helper class that reads text strings from a {@link BytesRefBlock}. + * This class is used by inference operators to extract text content from block data. + */ +public class InputTextReader implements Releasable { + private final BytesRefBlock textBlock; + private final StringBuilder strBuilder = new StringBuilder(); + private BytesRef readBuffer = new BytesRef(); + + public InputTextReader(BytesRefBlock textBlock) { + this.textBlock = textBlock; + } + + /** + * Reads the text string at the given position. + * Multiple values at the position are concatenated with newlines. + * + * @param pos the position index in the block + * @return the text string at the position, or null if the position contains a null value + */ + public String readText(int pos) { + return readText(pos, Integer.MAX_VALUE); + } + + /** + * Reads the text string at the given position. + * + * @param pos the position index in the block + * @param limit the maximum number of value to read from the position + * @return the text string at the position, or null if the position contains a null value + */ + public String readText(int pos, int limit) { + if (textBlock.isNull(pos)) { + return null; + } + + strBuilder.setLength(0); + int maxPos = Math.min(limit, textBlock.getValueCount(pos)); + for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) { + readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); + strBuilder.append(readBuffer.utf8ToString()); + if (valueIndex != maxPos - 1) { + strBuilder.append("\n"); + } + } + + return strBuilder.toString(); + } + + /** + * Returns the total number of positions (text entries) in the block. + */ + public int estimatedSize() { + return textBlock.getPositionCount(); + } + + @Override + public void close() { + textBlock.allowPassingToDifferentDriver(); + Releasables.close(textBlock); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java index 3e9106f9a1cf6..f95f9a2f451ef 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java @@ -20,12 +20,12 @@ * {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults} * into a {@link BytesRefBlock}. */ -public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final BytesRefBlock.Builder outputBlockBuilder; private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); - public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) { + CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) { this.inputPage = inputPage; this.outputBlockBuilder = outputBlockBuilder; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java index f526cd9edb077..47281378e64bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java @@ -7,12 +7,11 @@ package org.elasticsearch.xpack.esql.inference.completion; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.InputTextReader; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import java.util.List; @@ -22,9 +21,9 @@ * This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances * of type {@link TaskType#COMPLETION}. */ -public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator { +class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator { - private final PromptReader promptReader; + private final InputTextReader textReader; private final String inferenceId; private final int size; private int currentPos = 0; @@ -35,8 +34,8 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt * @param promptBlock The input block containing prompts. * @param inferenceId The ID of the inference model to invoke. */ - public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) { - this.promptReader = new PromptReader(promptBlock); + CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) { + this.textReader = new InputTextReader(promptBlock); this.size = promptBlock.getPositionCount(); this.inferenceId = inferenceId; } @@ -52,7 +51,7 @@ public InferenceAction.Request next() { throw new NoSuchElementException(); } - return inferenceRequest(promptReader.readPrompt(currentPos++)); + return inferenceRequest(textReader.readText(currentPos++)); } /** @@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) { @Override public int estimatedSize() { - return promptReader.estimatedSize(); + return textReader.estimatedSize(); } @Override public void close() { - Releasables.close(promptReader); - } - - /** - * Helper class that reads prompts from a {@link BytesRefBlock}. - */ - private static class PromptReader implements Releasable { - private final BytesRefBlock promptBlock; - private final StringBuilder strBuilder = new StringBuilder(); - private BytesRef readBuffer = new BytesRef(); - - private PromptReader(BytesRefBlock promptBlock) { - this.promptBlock = promptBlock; - } - - /** - * Reads the prompt string at the given position.. - * - * @param pos the position index in the block - */ - public String readPrompt(int pos) { - if (promptBlock.isNull(pos)) { - return null; - } - - strBuilder.setLength(0); - - for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) { - readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); - strBuilder.append(readBuffer.utf8ToString()); - if (valueIndex != promptBlock.getValueCount(pos) - 1) { - strBuilder.append("\n"); - } - } - - return strBuilder.toString(); - } - - /** - * Returns the total number of positions (prompts) in the block. - */ - public int estimatedSize() { - return promptBlock.getPositionCount(); - } - - @Override - public void close() { - promptBlock.allowPassingToDifferentDriver(); - Releasables.close(promptBlock); - } + Releasables.close(textReader); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java index bff95cf54bae9..2554001bbf36a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java @@ -24,13 +24,13 @@ * * reranked relevance scores into the specified score channel of the input page. */ -public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final DoubleBlock.Builder scoreBlockBuilder; private final int scoreChannel; - public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) { + RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) { this.inputPage = inputPage; this.scoreBlockBuilder = scoreBlockBuilder; this.scoreChannel = scoreChannel; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java index 4b1cfe5870ad7..336d5cb3fca12 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java @@ -25,14 +25,14 @@ *

This iterator reads from a {@link BytesRefBlock} containing input documents or items to be reranked. It slices the input into batches * of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#RERANK}. */ -public class RerankOperatorRequestIterator implements BulkInferenceRequestIterator { +class RerankOperatorRequestIterator implements BulkInferenceRequestIterator { private final BytesRefBlock inputBlock; private final String inferenceId; private final String queryText; private final int batchSize; private int remainingPositions; - public RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) { + RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) { this.inputBlock = inputBlock; this.inferenceId = inferenceId; this.queryText = queryText; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java new file mode 100644 index 0000000000000..7817612007614 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceService; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; + +/** + * {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference. + * It evaluates a text expression for each input row, constructs text embedding inference requests, + * and emits the dense vector embeddings as output. + */ +public class TextEmbeddingOperator extends InferenceOperator { + + private final ExpressionEvaluator textEvaluator; + + public TextEmbeddingOperator( + DriverContext driverContext, + BulkInferenceRunner bulkInferenceRunner, + String inferenceId, + ExpressionEvaluator textEvaluator, + int maxOutstandingPages + ) { + super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages); + this.textEvaluator = textEvaluator; + } + + @Override + protected void doClose() { + Releasables.close(textEvaluator); + } + + @Override + public String toString() { + return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]"; + } + + /** + * Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression. + * + * @param inputPage The input data page. + */ + @Override + protected BulkInferenceRequestIterator requests(Page inputPage) { + return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId()); + } + + /** + * Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results. + * + * @param input The input page for which results will be constructed. + */ + @Override + protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) { + FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount()); + return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input); + } + + /** + * Factory for creating {@link TextEmbeddingOperator} instances. + */ + public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory) + implements + OperatorFactory { + @Override + public String describe() { + return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new TextEmbeddingOperator( + driverContext, + inferenceService.bulkInferenceRunner(), + inferenceId, + textEvaluatorFactory.get(driverContext), + BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests() + ); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java new file mode 100644 index 0000000000000..a2b0a32e77b05 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; + +/** + * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting + * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. + */ +class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { + private final Page inputPage; + private final FloatBlock.Builder outputBlockBuilder; + + TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) { + this.inputPage = inputPage; + this.outputBlockBuilder = outputBlockBuilder; + } + + @Override + public void close() { + Releasables.close(outputBlockBuilder); + } + + /** + * Adds an inference response to the output builder. + * + *

+ * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown. + * Else, the embedding vector is added to the output block as a multi-value position. + *

+ * + *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ */ + @Override + public void addInferenceResponse(InferenceAction.Response inferenceResponse) { + if (inferenceResponse == null) { + outputBlockBuilder.appendNull(); + return; + } + + TextEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); + + var embeddings = embeddingResults.embeddings(); + if (embeddings.isEmpty()) { + outputBlockBuilder.appendNull(); + return; + } + + float[] embeddingArray = getEmbeddingAsFloatArray(embeddingResults); + + outputBlockBuilder.beginPositionEntry(); + for (float component : embeddingArray) { + outputBlockBuilder.appendFloat(component); + } + outputBlockBuilder.endPositionEntry(); + } + + /** + * Builds the final output page by appending the embedding output block to the input page. + */ + @Override + public Page buildOutput() { + Block outputBlock = outputBlockBuilder.build(); + assert outputBlock.getPositionCount() == inputPage.getPositionCount(); + return inputPage.appendBlock(outputBlock); + } + + private TextEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + } + + /** + * Extracts the embedding as a float array from the embedding result. + */ + private static float[] getEmbeddingAsFloatArray(TextEmbeddingResults embedding) { + return switch (embedding.embeddings().get(0)) { + case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); + case TextEmbeddingByteResults.Embedding byteEmbedding -> toFloatArray(byteEmbedding.values()); + default -> throw new IllegalArgumentException( + "Unsupported embedding type: " + + embedding.embeddings().get(0).getClass().getName() + + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding." + ); + }; + } + + private static float[] toFloatArray(byte[] values) { + float[] floatArray = new float[values.length]; + for (int i = 0; i < values.length; i++) { + floatArray[i] = ((Byte) values[i]).floatValue(); + } + return floatArray; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java new file mode 100644 index 0000000000000..e7118c0a0b594 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.InputTextReader; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.List; +import java.util.NoSuchElementException; + +/** + * This iterator reads text inputs from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances + * of type {@link TaskType#TEXT_EMBEDDING}. + */ +class TextEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator { + + private final InputTextReader textReader; + private final String inferenceId; + private final int size; + private int currentPos = 0; + + /** + * Constructs a new iterator from the given block of text inputs. + * + * @param textBlock The input block containing text to embed. + * @param inferenceId The ID of the inference model to invoke. + */ + TextEmbeddingOperatorRequestIterator(BytesRefBlock textBlock, String inferenceId) { + this.textReader = new InputTextReader(textBlock); + this.size = textBlock.getPositionCount(); + this.inferenceId = inferenceId; + } + + @Override + public boolean hasNext() { + return currentPos < size; + } + + @Override + public InferenceAction.Request next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + /* + * Keep only the first value in case of multi-valued fields. + * TODO: check if it is consistent with how the query vector builder is working. + */ + return inferenceRequest(textReader.readText(currentPos++, 1)); + } + + /** + * Wraps a single text string into an {@link InferenceAction.Request} for text embedding. + */ + private InferenceAction.Request inferenceRequest(String text) { + if (text == null) { + return null; + } + + return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(text)).build(); + } + + @Override + public int estimatedSize() { + return textReader.estimatedSize(); + } + + @Override + public void close() { + Releasables.close(textReader); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java new file mode 100644 index 0000000000000..64dd6a02928e8 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java @@ -0,0 +1,231 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class InputTextReaderTests extends ComputeTestCase { + + public void testReadSingleValuePositions() throws Exception { + String[] texts = { "hello", "world", "test" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(texts.length)); + + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadMultiValuePositions() throws Exception { + BytesRefBlock block = createMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(2)); + + // First position has multiple values that should be concatenated with newlines + assertThat(reader.readText(0), equalTo("first\nsecond\nthird")); + + // Second position has a single value + assertThat(reader.readText(1), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testReadMultiValuePositionsWithLimit() throws Exception { + BytesRefBlock block = createMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Test limiting to first 2 values out of 3 + assertThat(reader.readText(0, 2), equalTo("first\nsecond")); + + // Test limiting to first 1 value out of 3 + assertThat(reader.readText(0, 1), equalTo("first")); + + // Test limit larger than available values + assertThat(reader.readText(0, 10), equalTo("first\nsecond\nthird")); + + // Test limit of 0 + assertThat(reader.readText(0, 0), equalTo("")); + + // Test single value position with limit + assertThat(reader.readText(1, 1), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testReadNullValues() throws Exception { + BytesRefBlock block = createBlockWithNulls(); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(3)); + + assertThat(reader.readText(0), equalTo("before")); + assertThat(reader.readText(1), nullValue()); + assertThat(reader.readText(2), equalTo("after")); + } + + allBreakersEmpty(); + } + + public void testReadNullValuesWithLimit() throws Exception { + BytesRefBlock block = createBlockWithNulls(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Null values should return null regardless of limit + assertThat(reader.readText(0, 1), equalTo("before")); + assertThat(reader.readText(1, 1), nullValue()); + assertThat(reader.readText(1, 10), nullValue()); + assertThat(reader.readText(2, 1), equalTo("after")); + } + + allBreakersEmpty(); + } + + public void testReadEmptyStrings() throws Exception { + String[] texts = { "", "non-empty", "" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadLargeInput() throws Exception { + int size = between(1000, 5000); + String[] texts = new String[size]; + for (int i = 0; i < size; i++) { + texts[i] = "text_" + i + "_" + randomAlphaOfLength(10); + } + + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(size)); + + for (int i = 0; i < size; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadUnicodeText() throws Exception { + String[] texts = { "café", "naïve", "résumé", "🚀 rocket", "多语言支持" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadMultipleTimesFromSamePosition() throws Exception { + String[] texts = { "consistent" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + // Reading the same position multiple times should return the same result + assertThat(reader.readText(0), equalTo("consistent")); + assertThat(reader.readText(0), equalTo("consistent")); + assertThat(reader.readText(0, 1), equalTo("consistent")); + assertThat(reader.readText(0, 10), equalTo("consistent")); + } + + allBreakersEmpty(); + } + + public void testLimitBoundaryConditions() throws Exception { + BytesRefBlock block = createLargeMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Test various limit values on a position with 5 values + assertThat(reader.readText(0, 0), equalTo("")); + assertThat(reader.readText(0, 1), equalTo("value0")); + assertThat(reader.readText(0, 2), equalTo("value0\nvalue1")); + assertThat(reader.readText(0, 3), equalTo("value0\nvalue1\nvalue2")); + assertThat(reader.readText(0, 4), equalTo("value0\nvalue1\nvalue2\nvalue3")); + assertThat(reader.readText(0, 5), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + + // Test limit beyond available values + assertThat(reader.readText(0, 10), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + assertThat(reader.readText(0, Integer.MAX_VALUE), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + } + + allBreakersEmpty(); + } + + private BytesRefBlock createSingleValueBlock(String[] texts) { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(texts.length)) { + for (String text : texts) { + builder.appendBytesRef(new BytesRef(text)); + } + return builder.build(); + } + } + + private BytesRefBlock createMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(2)) { + // First position: multiple values + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("first")); + builder.appendBytesRef(new BytesRef("second")); + builder.appendBytesRef(new BytesRef("third")); + builder.endPositionEntry(); + + // Second position: single value + builder.appendBytesRef(new BytesRef("single")); + + return builder.build(); + } + } + + private BytesRefBlock createBlockWithNulls() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(3)) { + builder.appendBytesRef(new BytesRef("before")); + builder.appendNull(); + builder.appendBytesRef(new BytesRef("after")); + return builder.build(); + } + } + + private BytesRefBlock createLargeMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(1)) { + // Single position with 5 values for testing limits + builder.beginPositionEntry(); + for (int i = 0; i < 5; i++) { + builder.appendBytesRef(new BytesRef("value" + i)); + } + builder.endPositionEntry(); + + return builder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java new file mode 100644 index 0000000000000..ea77c6bed3c38 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java @@ -0,0 +1,247 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.RandomBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class TextEmbeddingOperatorOutputBuilderTests extends ComputeTestCase { + + public void testBuildSmallOutputWithFloatEmbeddings() throws Exception { + assertBuildOutputWithFloatEmbeddings(between(1, 100)); + } + + public void testBuildLargeOutputWithFloatEmbeddings() throws Exception { + assertBuildOutputWithFloatEmbeddings(between(1_000, 10_000)); + } + + public void testBuildSmallOutputWithByteEmbeddings() throws Exception { + assertBuildOutputWithByteEmbeddings(between(1, 100)); + } + + public void testBuildLargeOutputWithByteEmbeddings() throws Exception { + assertBuildOutputWithByteEmbeddings(between(1_000, 10_000)); + } + + public void testHandleNullResponses() throws Exception { + final int size = between(10, 100); + final Page inputPage = randomInputPage(size, between(1, 5)); + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + // Add some null responses + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + if (randomBoolean()) { + outputBuilder.addInferenceResponse(null); + } else { + float[] embedding = randomFloatEmbedding(randomIntBetween(50, 200)); + outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding)); + } + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1); + assertThat(outputBlock.getPositionCount(), equalTo(size)); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + public void testHandleEmptyEmbeddings() throws Exception { + final int size = between(5, 50); + final Page inputPage = randomInputPage(size, between(1, 3)); + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + // Add responses with empty embeddings + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + outputBuilder.addInferenceResponse(createEmptyFloatEmbeddingResponse()); + } + + final Page outputPage = outputBuilder.buildOutput(); + FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1); + + // All positions should be null due to empty embeddings + for (int pos = 0; pos < outputBlock.getPositionCount(); pos++) { + assertThat(outputBlock.isNull(pos), equalTo(true)); + } + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertBuildOutputWithFloatEmbeddings(int size) throws Exception { + final Page inputPage = randomInputPage(size, between(1, 10)); + final int embeddingDim = randomIntBetween(50, 1536); // Common embedding dimensions + final float[][] expectedEmbeddings = new float[size][]; + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + float[] embedding = randomFloatEmbedding(embeddingDim); + expectedEmbeddings[currentPos] = embedding; + outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding)); + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + assertFloatEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedEmbeddings); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertBuildOutputWithByteEmbeddings(int size) throws Exception { + final Page inputPage = randomInputPage(size, between(1, 10)); + final int embeddingDim = randomIntBetween(50, 1536); + final byte[][] expectedByteEmbeddings = new byte[size][]; + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + byte[] embedding = randomByteEmbedding(embeddingDim); + expectedByteEmbeddings[currentPos] = embedding; + outputBuilder.addInferenceResponse(createByteEmbeddingResponse(embedding)); + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + assertByteEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedByteEmbeddings); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertFloatEmbeddingContent(FloatBlock block, float[][] expectedEmbeddings) { + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(expectedEmbeddings[currentPos].length)); + + int firstValueIndex = block.getFirstValueIndex(currentPos); + for (int i = 0; i < expectedEmbeddings[currentPos].length; i++) { + float actualValue = block.getFloat(firstValueIndex + i); + float expectedValue = expectedEmbeddings[currentPos][i]; + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + + private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteEmbeddings) { + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(expectedByteEmbeddings[currentPos].length)); + + int firstValueIndex = block.getFirstValueIndex(currentPos); + for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) { + float actualValue = block.getFloat(firstValueIndex + i); + // Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray() + float expectedValue = expectedByteEmbeddings[currentPos][i]; + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + + private float[] randomFloatEmbedding(int dimension) { + float[] embedding = new float[dimension]; + for (int i = 0; i < dimension; i++) { + embedding[i] = randomFloat(); + } + return embedding; + } + + private byte[] randomByteEmbedding(int dimension) { + byte[] embedding = new byte[dimension]; + for (int i = 0; i < dimension; i++) { + embedding[i] = randomByte(); + } + return embedding; + } + + private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) { + var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding); + var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult)); + return new InferenceAction.Response(textEmbeddingResults); + } + + private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) { + var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding); + var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult)); + return new InferenceAction.Response(textEmbeddingResults); + } + + private static InferenceAction.Response createEmptyFloatEmbeddingResponse() { + var textEmbeddingResults = new TextEmbeddingFloatResults(List.of()); + return new InferenceAction.Response(textEmbeddingResults); + } + + private Page randomInputPage(int positionCount, int columnCount) { + final Block[] blocks = new Block[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + blocks[i] = RandomBlock.randomBlock( + blockFactory(), + RandomBlock.randomElementExcluding(List.of(ElementType.AGGREGATE_METRIC_DOUBLE)), + positionCount, + randomBoolean(), + 0, + 0, + randomInt(10), + randomInt(10) + ).block(); + } + + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java new file mode 100644 index 0000000000000..fadd1c0d69e00 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class TextEmbeddingOperatorRequestIteratorTests extends ComputeTestCase { + + public void testIterateSmallInput() throws Exception { + assertIterate(between(1, 100)); + } + + public void testIterateLargeInput() throws Exception { + assertIterate(between(10_000, 100_000)); + } + + public void testIterateWithNullValues() throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = createBlockWithNulls(); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + BytesRef scratch = new BytesRef(); + + // First position: "before" + InferenceAction.Request request1 = requestIterator.next(); + assertThat(request1.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request1.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(0), scratch); + assertThat(request1.getInput().get(0), equalTo(scratch.utf8ToString())); + + // Second position: null + InferenceAction.Request request2 = requestIterator.next(); + assertThat(request2, nullValue()); + + // Third position: "after" + InferenceAction.Request request3 = requestIterator.next(); + assertThat(request3.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request3.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(2), scratch); + assertThat(request3.getInput().get(0), equalTo(scratch.utf8ToString())); + } + + allBreakersEmpty(); + } + + public void testIterateWithMultiValuePositions() throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = createMultiValueBlock(); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + // First position: multi-value, keep only the first value + InferenceAction.Request request1 = requestIterator.next(); + assertThat(request1.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request1.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + assertThat(request1.getInput().get(0), equalTo("first")); + + // Second position: single value + InferenceAction.Request request2 = requestIterator.next(); + assertThat(request2.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request2.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + assertThat(request2.getInput().get(0), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testEstimatedSize() throws Exception { + final String inferenceId = randomIdentifier(); + final int size = randomIntBetween(10, 1000); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + assertThat(requestIterator.estimatedSize(), equalTo(size)); + } + + allBreakersEmpty(); + } + + public void testHasNextAndIteration() throws Exception { + final String inferenceId = randomIdentifier(); + final int size = randomIntBetween(5, 50); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + int count = 0; + while (requestIterator.hasNext()) { + requestIterator.next(); + count++; + } + assertThat(count, equalTo(size)); + + // Verify hasNext returns false after iteration is complete + assertThat(requestIterator.hasNext(), equalTo(false)); + } + + allBreakersEmpty(); + } + + private void assertIterate(int size) throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; requestIterator.hasNext(); currentPos++) { + InferenceAction.Request request = requestIterator.next(); + + assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + + // Verify the input text matches what's in the block + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); + assertThat(request.getInput().get(0), equalTo(scratch.utf8ToString())); + } + } + + allBreakersEmpty(); + } + + private BytesRefBlock randomInputBlock(int size) { + try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) { + for (int i = 0; i < size; i++) { + blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + + return blockBuilder.build(); + } + } + + private BytesRefBlock createBlockWithNulls() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(3)) { + builder.appendBytesRef(new BytesRef("before")); + builder.appendNull(); + builder.appendBytesRef(new BytesRef("after")); + return builder.build(); + } + } + + private BytesRefBlock createMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(2)) { + // First position: multiple values + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("first")); + builder.appendBytesRef(new BytesRef("second")); + builder.appendBytesRef(new BytesRef("third")); + builder.endPositionEntry(); + + // Second position: single value + builder.appendBytesRef(new BytesRef("single")); + + return builder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java new file mode 100644 index 0000000000000..6ff9a90b70b16 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { + private static final String SIMPLE_INFERENCE_ID = "test_text_embedding"; + private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension + + private int inputChannel; + + @Before + public void initTextEmbeddingChannels() { + inputChannel = between(0, inputsCount - 1); + } + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new TextEmbeddingOperator.Factory(mockedInferenceService(), SIMPLE_INFERENCE_ID, evaluatorFactory(inputChannel)); + } + + @Override + protected void assertSimpleOutput(List input, List results) { + assertThat(results, hasSize(input.size())); + + for (int curPage = 0; curPage < input.size(); curPage++) { + Page inputPage = input.get(curPage); + Page resultPage = results.get(curPage); + + assertEquals(inputPage.getPositionCount(), resultPage.getPositionCount()); + assertEquals(inputPage.getBlockCount() + 1, resultPage.getBlockCount()); + + for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { + Block inputBlock = inputPage.getBlock(channel); + Block resultBlock = resultPage.getBlock(channel); + assertBlockContentEquals(inputBlock, resultBlock); + } + + assertTextEmbeddingResults(inputPage, resultPage); + } + } + + private void assertTextEmbeddingResults(Page inputPage, Page resultPage) { + BytesRefBlock inputBlock = resultPage.getBlock(inputChannel); + FloatBlock resultBlock = (FloatBlock) resultPage.getBlock(inputPage.getBlockCount()); + + BlockStringReader blockReader = new InferenceOperatorTestCase.BlockStringReader(); + + for (int curPos = 0; curPos < inputPage.getPositionCount(); curPos++) { + if (inputBlock.isNull(curPos)) { + assertThat(resultBlock.isNull(curPos), equalTo(true)); + } else { + // Verify that we have an embedding vector at this position + assertThat(resultBlock.isNull(curPos), equalTo(false)); + assertThat(resultBlock.getValueCount(curPos), equalTo(EMBEDDING_DIMENSION)); + + // Get the input text to verify our mock embedding generation + String inputText = blockReader.readString(inputBlock, curPos); + + // Verify the embedding values match our mock generation pattern + int firstValueIndex = resultBlock.getFirstValueIndex(curPos); + for (int i = 0; i < EMBEDDING_DIMENSION; i++) { + float expectedValue = generateMockEmbeddingValue(inputText, i); + float actualValue = resultBlock.getFloat(firstValueIndex + i); + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + } + + @Override + protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { + // For text embedding, we expect one input text per request + String inputText = request.getInput().get(0); + + // Generate a deterministic mock embedding based on the input text + float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION); + + var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding); + return new TextEmbeddingFloatResults(List.of(embeddingResult)); + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return expectedToStringOfSimple(); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo("TextEmbeddingOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "]]"); + } + + /** + * Generates a deterministic mock embedding vector based on the input text. + * This ensures our tests are repeatable and verifiable. + */ + private float[] generateMockEmbedding(String inputText, int dimension) { + float[] embedding = new float[dimension]; + int textHash = inputText.hashCode(); + + for (int i = 0; i < dimension; i++) { + embedding[i] = generateMockEmbeddingValue(inputText, i); + } + + return embedding; + } + + /** + * Generates a single embedding value for a specific dimension based on input text. + * Uses a deterministic function so tests are repeatable. + */ + private float generateMockEmbeddingValue(String inputText, int dimension) { + // Create a deterministic value based on input text and dimension + int hash = (inputText.hashCode() + dimension * 31) % 10000; + return hash / 10000.0f; // Normalize to [0, 1) range + } +}