Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
987a709
Adding the TEXT_EMBEDDING_FUNCTION capability.
Sep 19, 2025
516a0b6
Add InferenceFunction and TextEmbedding classes for TEXT_EMBEDDING fu…
Sep 19, 2025
36df7cf
Adding tests for the TextEmbedding function.
Sep 19, 2025
918bdb7
Update ESQL usage tests
Sep 19, 2025
4418a32
Add text_embedding to the EsqlFunctionRegistry
Sep 19, 2025
4bb147d
Add text_embedding tests generated doc
Sep 19, 2025
ddf3db5
InferenceResolver can now resolve inference ids used in a logical pla…
Sep 19, 2025
aadf880
Analyzer now resolve inference endpoints for inference function.
Sep 19, 2025
6fb48b0
[CI] Auto commit changes from spotless
Sep 19, 2025
2c423fb
Apply suggestions from code review
afoucret Sep 19, 2025
9406d37
Apply suggestion from review.
Sep 19, 2025
774986f
TextEmbedding accepts only keyword parameters.
Sep 19, 2025
39b4323
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
Sep 19, 2025
8d4a832
Update TextEmbeddingTests supported parameters data types.
Sep 19, 2025
71bae69
Fix text embedding type validation.
Sep 19, 2025
1275582
Add a dummy example (waiting for real CSV tests to be implemented)
Sep 19, 2025
89dfcec
Add a dummy example (waiting for real CSV tests to be implemented)
Sep 19, 2025
39919fa
Fix breaking release tests.
Sep 19, 2025
b7e821d
Made the code more readable.
Sep 19, 2025
97ecf80
Filter out TEXT_EMBEDDING FROM CSV TESTS
Sep 19, 2025
d5cf81c
Fixing CSV tests
Sep 19, 2025
778d5e7
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
Sep 24, 2025
0d01b5e
Fix typo
Sep 24, 2025
e8ca515
Make TextEmbedding not serializable.
Sep 24, 2025
5b56232
[CI] Auto commit changes from spotless
Sep 24, 2025
128688b
Remove failing tests as it is failing.
Sep 24, 2025
a8b99c4
Remove start import
Sep 24, 2025
a0cc965
Remove test
Sep 24, 2025
ca85b86
Merge branch 'main' into esql_text_embedding_function_definition
afoucret Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
placeholder
required_capability: text_embedding_function
required_capability: not_existing_capability

// tag::embedding-eval[]
ROW input="Who is Victor Hugo?"
| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
;
// end::embedding-eval[]


input:keyword | embedding:dense_vector
Who is Victor Hugo? | [56.0, 50.0, 48.0]
;

Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,11 @@ public enum Cap {
*/
KNN_FUNCTION_V5(Build.current().isSnapshot()),

/**
* Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
*/
TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),

/**
* Support for the LIKE operator with a list of wildcards.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
Expand Down Expand Up @@ -1419,7 +1420,8 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log

@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking question for my ES|QL education: Can you help me understand why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))

will transform all the children in the plan that have the class InferencePlan with the result of the resolveInferencePlan(p, context) call. Inference plans are typically command using inference: RERANK and COMPLETION

.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));

will transform do the same but for InferenceFunction instead of plan. Because text embedding is our first inference function this was not yet done, so I added it.

}

private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
Expand Down Expand Up @@ -1448,6 +1450,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext

return plan;
}

private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
if (inferenceFunction.inferenceId().resolved()
&& inferenceFunction.inferenceId().foldable()
&& DataType.isString(inferenceFunction.inferenceId().dataType())) {

String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != inferenceFunction.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ " function. Only inference endpoints with the task type ["
+ inferenceFunction.taskType()
+ "] are supported.";
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}
}

return inferenceFunction;
}
}

private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -543,7 +544,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(Hamming.class, Hamming::new, "v_hamming"),
def(UrlEncode.class, UrlEncode::new, "url_encode"),
def(UrlEncodeComponent.class, UrlEncodeComponent::new, "url_encode_component"),
def(UrlDecode.class, UrlDecode::new, "url_decode") } };
def(UrlDecode.class, UrlDecode::new, "url_decode"),
def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;

import java.util.List;

/**
* Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING).
*/
public abstract class InferenceFunction<PlanType extends InferenceFunction<PlanType>> extends Function {

public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";

protected InferenceFunction(Source source, List<Expression> children) {
super(source, children);
}

/** The inference endpoint identifier expression. */
public abstract Expression inferenceId();

/** The task type required by this function (e.g., TEXT_EMBEDDING). */
public abstract TaskType taskType();

/** Returns a copy with inference resolution error for display to user. */
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);

/** True if this function contains nested inference function calls. */
public boolean hasNestedInferenceFunction() {
return anyMatch(e -> e instanceof InferenceFunction && e != this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.inference;

import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;

/**
* TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint.
*/
public class TextEmbedding extends InferenceFunction<TextEmbedding> {

private final Expression inferenceId;
private final Expression inputText;

@FunctionInfo(
returnType = "dense_vector",
description = "Generates dense vector embeddings for text using a specified inference endpoint.",
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) },
preview = true,
examples = {
@Example(
description = "Generate text embeddings using the 'test_dense_inference' inference endpoint.",
file = "text-embedding",
tag = "embedding-eval"
) }
)
public TextEmbedding(
Source source,
@Param(name = "text", type = { "keyword" }, description = "Text to generate embeddings from") Expression inputText,
@Param(
name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME,
type = { "keyword" },
description = "Identifier of the inference endpoint"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add a function example and then regen the docs
otherwise Kibana CI will fail when we try to bring the json specification of this function to Kibana even if it's under snapshot.
this happened recently for decay too #134705 - opened a separate PR to address this #135094 since I promised to look into it.

Copy link
Contributor Author

@afoucret afoucret Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fixed in Kibana CI so it does not break in such a case.

I mean, it is expected for a function that have appliesTo set to FunctionAppliesToLifecycle.DEVELOPMENT to be incomplete. Especially examples come at the very last when writing CSV tests.

Anyway, I added an placeholder example so I will not break anything, It will be replaced when adding more realistic CSV tests.

) Expression inferenceId
) {
super(source, List.of(inputText, inferenceId));
this.inferenceId = inferenceId;
this.inputText = inputText;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to serialize this function - since we will always resolve it on the coordinator and replace it with its result.
We have other instances in ES|QL where we don't serialize - FORK, ROW, RENAME come to mind:
But this would probably be the first function that does not require serialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep the serialization even if the expression is not supposed to be moved between node.

The first reason is because the Function are supposed to implement it. If for some reason the infrastructure of execution need to move them between node in the future, I would not like to be the black sheep that make things more difficult.

Also TextEmbeddingTests extends AbstractFunctionTestCase which expect the function to be serializable.

I consider that the implementation or serialization / deserialization is simple to implement and it is probably not beneficial to let it unimplemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first reason is because the Function are supposed to implement it.

That does not mean it should. If the text_embedding function is actually serialized and sent between nodes, that's an execution path that should never happen.

Anything that extends from LogicalPlan is also supposed to implement serialization, but as I said before there are many examples of plans where we don't serialize.
We could have made the same argument for logical plans that the serialization is simple to implement, or that it's not beneficial to let it unimplemented.
However raising an exception in the case where we attempt to serialize these plans is intentional and it captures very clearly the fact that this should never happen.

I get that making this unserializable is more work, especially for AbstractFunctionTestCase, but at least then the behaviour would be correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels to me that we are removing some code that we will ultimately reintroduce when we will support non constant input text embeddings. Anyway, I pushed a version without serialization so we can move forward and merge this PR.

throw new UnsupportedOperationException("doesn't escape the node");
}

@Override
public String getWriteableName() {
throw new UnsupportedOperationException("doesn't escape the node");
}

public Expression inputText() {
return inputText;
}

@Override
public Expression inferenceId() {
return inferenceId;
}

@Override
public boolean foldable() {
return inferenceId.foldable() && inputText.foldable();
}

@Override
public DataType dataType() {
return DataType.DENSE_VECTOR;
}

@Override
protected TypeResolution resolveType() {
if (childrenResolved() == false) {
return new TypeResolution("Unresolved children");
}

TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST))
.and(isType(inputText, DataType.KEYWORD::equals, sourceText(), FIRST, "string"));

if (textResolution.unresolved()) {
return textResolution;
}

TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(
isType(inferenceId, DataType.KEYWORD::equals, sourceText(), SECOND, "string")
).and(isFoldable(inferenceId, sourceText(), SECOND));

if (inferenceIdResolution.unresolved()) {
return inferenceIdResolution;
}

return TypeResolution.TYPE_RESOLVED;
}

@Override
public TaskType taskType() {
return TaskType.TEXT_EMBEDDING;
}

@Override
public TextEmbedding withInferenceResolutionError(String inferenceId, String error) {
return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1));
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId);
}

@Override
public String toString() {
return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")";
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
TextEmbedding textEmbedding = (TextEmbedding) o;
return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), inferenceId, inputText);
}
}
Loading