Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

/**
* Helper class that reads text strings from a {@link BytesRefBlock}.
* This class is used by inference operators to extract text content from block data.
*/
public class InputTextReader implements Releasable {
private final BytesRefBlock textBlock;
private final StringBuilder strBuilder = new StringBuilder();
private BytesRef readBuffer = new BytesRef();

public InputTextReader(BytesRefBlock textBlock) {
this.textBlock = textBlock;
}

/**
* Reads the text string at the given position.
* Multiple values at the position are concatenated with newlines.
*
* @param pos the position index in the block
* @return the text string at the position, or null if the position contains a null value
*/
public String readText(int pos) {
return readText(pos, Integer.MAX_VALUE);
}

/**
* Reads the text string at the given position.
*
* @param pos the position index in the block
* @param limit the maximum number of value to read from the position
* @return the text string at the position, or null if the position contains a null value
*/
public String readText(int pos, int limit) {
if (textBlock.isNull(pos)) {
return null;
}

strBuilder.setLength(0);
int maxPos = Math.min(limit, textBlock.getValueCount(pos));
for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) {
readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
strBuilder.append(readBuffer.utf8ToString());
if (valueIndex != maxPos - 1) {
strBuilder.append("\n");
}
}

return strBuilder.toString();
}

/**
* Returns the total number of positions (text entries) in the block.
*/
public int estimatedSize() {
return textBlock.getPositionCount();
}

@Override
public void close() {
textBlock.allowPassingToDifferentDriver();
Releasables.close(textBlock);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
* {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults}
* into a {@link BytesRefBlock}.
*/
public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
private final Page inputPage;
private final BytesRefBlock.Builder outputBlockBuilder;
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();

public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) {
CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, Page inputPage) {
this.inputPage = inputPage;
this.outputBlockBuilder = outputBlockBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

package org.elasticsearch.xpack.esql.inference.completion;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.inference.InputTextReader;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;

import java.util.List;
Expand All @@ -22,9 +21,9 @@
* This iterator reads prompts from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
* of type {@link TaskType#COMPLETION}.
*/
public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {

private final PromptReader promptReader;
private final InputTextReader textReader;
private final String inferenceId;
private final int size;
private int currentPos = 0;
Expand All @@ -35,8 +34,8 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
* @param promptBlock The input block containing prompts.
* @param inferenceId The ID of the inference model to invoke.
*/
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
this.promptReader = new PromptReader(promptBlock);
CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
this.textReader = new InputTextReader(promptBlock);
this.size = promptBlock.getPositionCount();
this.inferenceId = inferenceId;
}
Expand All @@ -52,7 +51,7 @@ public InferenceAction.Request next() {
throw new NoSuchElementException();
}

return inferenceRequest(promptReader.readPrompt(currentPos++));
return inferenceRequest(textReader.readText(currentPos++));
}

/**
Expand All @@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) {

@Override
public int estimatedSize() {
return promptReader.estimatedSize();
return textReader.estimatedSize();
}

@Override
public void close() {
Releasables.close(promptReader);
}

/**
* Helper class that reads prompts from a {@link BytesRefBlock}.
*/
private static class PromptReader implements Releasable {
private final BytesRefBlock promptBlock;
private final StringBuilder strBuilder = new StringBuilder();
private BytesRef readBuffer = new BytesRef();

private PromptReader(BytesRefBlock promptBlock) {
this.promptBlock = promptBlock;
}

/**
* Reads the prompt string at the given position..
*
* @param pos the position index in the block
*/
public String readPrompt(int pos) {
if (promptBlock.isNull(pos)) {
return null;
}

strBuilder.setLength(0);

for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
strBuilder.append(readBuffer.utf8ToString());
if (valueIndex != promptBlock.getValueCount(pos) - 1) {
strBuilder.append("\n");
}
}

return strBuilder.toString();
}

/**
* Returns the total number of positions (prompts) in the block.
*/
public int estimatedSize() {
return promptBlock.getPositionCount();
}

@Override
public void close() {
promptBlock.allowPassingToDifferentDriver();
Releasables.close(promptBlock);
}
Releasables.close(textReader);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
* * reranked relevance scores into the specified score channel of the input page.
*/

public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder {

private final Page inputPage;
private final DoubleBlock.Builder scoreBlockBuilder;
private final int scoreChannel;

public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) {
RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page inputPage, int scoreChannel) {
this.inputPage = inputPage;
this.scoreBlockBuilder = scoreBlockBuilder;
this.scoreChannel = scoreChannel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
* <p>This iterator reads from a {@link BytesRefBlock} containing input documents or items to be reranked. It slices the input into batches
* of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#RERANK}.
*/
public class RerankOperatorRequestIterator implements BulkInferenceRequestIterator {
class RerankOperatorRequestIterator implements BulkInferenceRequestIterator {
private final BytesRefBlock inputBlock;
private final String inferenceId;
private final String queryText;
private final int batchSize;
private int remainingPositions;

public RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) {
RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) {
this.inputBlock = inputBlock;
this.inferenceId = inferenceId;
this.queryText = queryText;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.inference.textembedding;

import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;

/**
* {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference.
* It evaluates a text expression for each input row, constructs text embedding inference requests,
* and emits the dense vector embeddings as output.
*/
public class TextEmbeddingOperator extends InferenceOperator {

private final ExpressionEvaluator textEvaluator;

public TextEmbeddingOperator(
DriverContext driverContext,
BulkInferenceRunner bulkInferenceRunner,
String inferenceId,
ExpressionEvaluator textEvaluator,
int maxOutstandingPages
) {
super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
this.textEvaluator = textEvaluator;
}

@Override
protected void doClose() {
Releasables.close(textEvaluator);
}

@Override
public String toString() {
return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
}

/**
* Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression.
*
* @param inputPage The input data page.
*/
@Override
protected BulkInferenceRequestIterator requests(Page inputPage) {
return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId());
}

/**
* Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results.
*
* @param input The input page for which results will be constructed.
*/
@Override
protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount());
return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input);
}

/**
* Factory for creating {@link TextEmbeddingOperator} instances.
*/
public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory)
implements
OperatorFactory {
@Override
public String describe() {
return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]";
}

@Override
public Operator get(DriverContext driverContext) {
return new TextEmbeddingOperator(
driverContext,
inferenceService.bulkInferenceRunner(),
inferenceId,
textEvaluatorFactory.get(driverContext),
BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
);
}
}
}
Loading