diff --git a/docs/reference/query-languages/esql/images/functions/embed_text.svg b/docs/reference/query-languages/esql/images/functions/embed_text.svg
new file mode 100644
index 0000000000000..9bb6cab692c4e
--- /dev/null
+++ b/docs/reference/query-languages/esql/images/functions/embed_text.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json
new file mode 100644
index 0000000000000..edfafb213e16d
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json
@@ -0,0 +1,9 @@
+{
+ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+ "type" : "scalar",
+ "name" : "embed_text",
+ "description" : "Generates dense vector embeddings for text using a specified inference deployment.",
+ "signatures" : [ ],
+ "preview" : true,
+ "snapshot_only" : true
+}
diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md
new file mode 100644
index 0000000000000..fc6d8d0772d64
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md
@@ -0,0 +1,4 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### EMBED TEXT
+Generates dense vector embeddings for text using a specified inference deployment.
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
index e8acabe71ab41..ff1d078e022a5 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
@@ -76,7 +76,8 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceResolver;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.QueryParam;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -164,6 +165,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public final class EsqlTestUtils {
@@ -401,18 +403,26 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
mock(ProjectResolver.class),
mock(IndexNameExpressionResolver.class),
null,
- mockInferenceRunner()
+ mockInferenceServices()
);
+ private static InferenceServices mockInferenceServices() {
+ InferenceServices inferenceServices = mock(InferenceServices.class);
+ InferenceResolver inferenceResolver = mockInferenceRunner();
+ when(inferenceServices.inferenceResolver(any())).thenReturn(inferenceResolver);
+
+ return inferenceServices;
+ }
+
@SuppressWarnings("unchecked")
- private static InferenceRunner mockInferenceRunner() {
- InferenceRunner inferenceRunner = mock(InferenceRunner.class);
+ private static InferenceResolver mockInferenceRunner() {
+ InferenceResolver inferenceResolver = mock(InferenceResolver.class);
doAnswer(i -> {
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
return null;
- }).when(inferenceRunner).resolveInferenceIds(any(), any());
+ }).when(inferenceResolver).resolveInferenceIds(any(), any());
- return inferenceRunner;
+ return inferenceResolver;
}
private EsqlTestUtils() {}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 0b51241beebbd..153abbfbf2bb6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1245,6 +1245,11 @@ public enum Cap {
*/
AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG),
+ /**
+ * Support for the {@code EMBED_TEXT} function for generating dense vector embeddings.
+ */
+ EMBED_TEXT_FUNCTION(Build.current().isSnapshot()),
+
/**
* Forbid usage of brackets in unquoted index and enrich policy names
* https://github.com/elastic/elasticsearch/issues/130378
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index e4b8949af5bdb..f91305d471727 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -54,6 +54,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -173,9 +174,9 @@ public class Analyzer extends ParameterizedRuleExecutor(
@@ -398,34 +399,6 @@ private static NamedExpression createEnrichFieldExpression(
}
}
- private static class ResolveInference extends ParameterizedAnalyzerRule, AnalyzerContext> {
- @Override
- protected LogicalPlan rule(InferencePlan> plan, AnalyzerContext context) {
- assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
-
- String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
- ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
-
- if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
- return plan;
- } else if (resolvedInference != null) {
- String error = "cannot use inference endpoint ["
- + inferenceId
- + "] with task type ["
- + resolvedInference.taskType()
- + "] within a "
- + plan.nodeName()
- + " command. Only inference endpoints with the task type ["
- + plan.taskType()
- + "] are supported.";
- return plan.withInferenceResolutionError(inferenceId, error);
- } else {
- String error = context.inferenceResolution().getError(inferenceId);
- return plan.withInferenceResolutionError(inferenceId, error);
- }
- }
- }
-
private static class ResolveLookupTables extends ParameterizedAnalyzerRule {
@Override
@@ -1319,6 +1292,70 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res
}
}
+ private static class ResolveInference extends ParameterizedRule {
+
+ @Override
+ public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
+ return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
+ .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
+
+ }
+
+ private InferenceFunction> resolveInferenceFunction(InferenceFunction> inferenceFunction, AnalyzerContext context) {
+ assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable();
+
+ String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
+ ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
+
+ if (resolvedInference == null) {
+ String error = context.inferenceResolution().getError(inferenceId);
+ return inferenceFunction.withInferenceResolutionError(inferenceId, error);
+ }
+
+ if (resolvedInference.taskType() != inferenceFunction.taskType()) {
+ String error = "cannot use inference endpoint ["
+ + inferenceId
+ + "] with task type ["
+ + resolvedInference.taskType()
+ + "] within a "
+ + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ + " function. Only inference endpoints with the task type ["
+ + inferenceFunction.taskType()
+ + "] are supported.";
+ return inferenceFunction.withInferenceResolutionError(inferenceId, error);
+ }
+
+ return inferenceFunction;
+ }
+
+ private LogicalPlan resolveInferencePlan(InferencePlan> plan, AnalyzerContext context) {
+ assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
+
+ String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
+ ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
+
+ if (resolvedInference == null) {
+ String error = context.inferenceResolution().getError(inferenceId);
+ return plan.withInferenceResolutionError(inferenceId, error);
+ }
+
+ if (resolvedInference.taskType() != plan.taskType()) {
+ String error = "cannot use inference endpoint ["
+ + inferenceId
+ + "] with task type ["
+ + resolvedInference.taskType()
+ + "] within a "
+ + plan.nodeName()
+ + " command. Only inference endpoints with the task type ["
+ + plan.taskType()
+ + "] are supported.";
+ return plan.withInferenceResolutionError(inferenceId, error);
+ }
+
+ return plan;
+ }
+ }
+
private static class AddImplicitLimit extends ParameterizedRule {
@Override
public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
index 5b9f41876d6e1..958f3cc9d946f 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
@@ -13,7 +13,6 @@
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
-import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import java.util.ArrayList;
import java.util.HashSet;
@@ -33,38 +32,39 @@ public static class PreAnalysis {
public final IndexMode indexMode;
public final List indices;
public final List enriches;
- public final List> inferencePlans;
+ public final List inferenceIds;
public final List lookupIndices;
public PreAnalysis(
IndexMode indexMode,
List indices,
List enriches,
- List> inferencePlans,
+ List inferenceIds,
List lookupIndices
) {
this.indexMode = indexMode;
this.indices = indices;
this.enriches = enriches;
- this.inferencePlans = inferencePlans;
+ this.inferenceIds = inferenceIds;
this.lookupIndices = lookupIndices;
}
}
- public PreAnalysis preAnalyze(LogicalPlan plan) {
+ public PreAnalysis preAnalyze(LogicalPlan plan, PreAnalyzerContext context) {
if (plan.analyzed()) {
return PreAnalysis.EMPTY;
}
- return doPreAnalyze(plan);
+ return doPreAnalyze(plan, context);
}
- protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
+ protected PreAnalysis doPreAnalyze(LogicalPlan plan, PreAnalyzerContext context) {
Set indices = new HashSet<>();
List unresolvedEnriches = new ArrayList<>();
List lookupIndices = new ArrayList<>();
- List> unresolvedInferencePlans = new ArrayList<>();
+ Set inferenceIds = new HashSet<>();
+
Holder indexMode = new Holder<>();
plan.forEachUp(UnresolvedRelation.class, p -> {
if (p.indexMode() == IndexMode.LOOKUP) {
@@ -78,11 +78,12 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
});
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
- plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
+ context.inferenceResolver().collectInferenceIds(plan, inferenceIds::add);
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);
- return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
+ return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, List.copyOf(inferenceIds), lookupIndices);
}
+
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java
new file mode 100644
index 0000000000000..a2571d5384f89
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.analysis;
+
+import org.elasticsearch.xpack.esql.inference.InferenceResolver;
+
+public record PreAnalyzerContext(InferenceResolver inferenceResolver) {
+
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
index 620d86650abf4..1e66ce413206d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
@@ -91,7 +91,8 @@ public void esql(
verifier,
planTelemetry,
indicesExpressionGrouper,
- services
+ services,
+ services.inferenceServices().inferenceResolver(functionRegistry)
);
QueryMetric clientId = QueryMetric.fromString("rest");
metrics.total(clientId);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
index a3f6d3a089d49..59b64cccd929e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
@@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
+import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -119,6 +120,7 @@ public static List getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
+ entries.addAll(inference());
return entries;
}
@@ -264,4 +266,11 @@ private static List vector() {
}
return List.of();
}
+
+ private static List inference() {
+ if (EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()) {
+ return List.of(EmbedText.ENTRY);
+ }
+ return List.of();
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index 630c9c2008a13..bffd90362943f 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -52,6 +52,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
+import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, Knn::new, "knn"),
+ def(EmbedText.class, EmbedText::new, "embed_text"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java
new file mode 100644
index 0000000000000..cc63f7e465f48
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java
@@ -0,0 +1,163 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.*;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+
+/**
+ * EMBED_TEXT function that generates dense vector embeddings for text using a specified inference deployment.
+ */
+public class EmbedText extends InferenceFunction {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ Expression.class,
+ "EmbedText",
+ EmbedText::new
+ );
+
+ private final Expression inferenceId;
+ private final Expression inputText;
+
+ @FunctionInfo(
+ returnType = "dense_vector",
+ description = "Generates dense vector embeddings for text using a specified inference deployment.",
+ appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) },
+ preview = true
+ )
+ public EmbedText(
+ Source source,
+ @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText,
+ @Param(
+ name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME,
+ type = { "keyword", "text" },
+ description = "Inference deployment ID"
+ ) Expression inferenceId
+ ) {
+ super(source, List.of(inputText, inferenceId));
+ this.inferenceId = inferenceId;
+ this.inputText = inputText;
+ }
+
+ private EmbedText(StreamInput in) throws IOException {
+ this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(inputText);
+ out.writeNamedWriteable(inferenceId);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ public Expression inputText() {
+ return inputText;
+ }
+
+ @Override
+ public Expression inferenceId() {
+ return inferenceId;
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataType.DENSE_VECTOR;
+ }
+
+ @Override
+ protected TypeResolution resolveType() {
+ if (childrenResolved() == false) {
+ return new TypeResolution("Unresolved children");
+ }
+
+ TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST))
+ .and(isString(inputText, sourceText(), FIRST));
+
+ if (textResolution.unresolved()) {
+ return textResolution;
+ }
+
+ TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND))
+ .and(isFoldable(inferenceId, sourceText(), SECOND));
+
+ if (inferenceIdResolution.unresolved()) {
+ return inferenceIdResolution;
+ }
+
+ return TypeResolution.TYPE_RESOLVED;
+ }
+
+ @Override
+ public boolean foldable() {
+ // The function is foldable only if both arguments are foldable
+ return inputText.foldable() && inferenceId.foldable();
+ }
+
+ @Override
+ public TaskType taskType() {
+ return TaskType.TEXT_EMBEDDING;
+ }
+
+ @Override
+ public EmbedText withInferenceResolutionError(String inferenceId, String error) {
+ return new EmbedText(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
+ }
+
+ @Override
+ public Expression replaceChildren(List newChildren) {
+ return new EmbedText(source(), newChildren.get(0), newChildren.get(1));
+ }
+
+ @Override
+ protected NodeInfo extends Expression> info() {
+ return NodeInfo.create(this, EmbedText::new, inputText, inferenceId);
+ }
+
+ @Override
+ public String toString() {
+ return "EMBED_TEXT(" + inputText + ", " + inferenceId + ")";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ EmbedText embedText = (EmbedText) o;
+ return Objects.equals(inferenceId, embedText.inferenceId) && Objects.equals(inputText, embedText.inputText);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), inferenceId, inputText);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
new file mode 100644
index 0000000000000..d32587dc71058
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.Function;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+
+import java.util.List;
+
+/**
+ * A function is a function using an inference model.
+ */
+public abstract class InferenceFunction> extends Function {
+
+ public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";
+
+ protected InferenceFunction(Source source, List children) {
+ super(source, children);
+ }
+
+ /**
+ * Returns the inference model ID expression.
+ */
+ public abstract Expression inferenceId();
+
+ /**
+ * Returns the task type of the inference model.
+ */
+ public abstract TaskType taskType();
+
+ public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
index 63026fb9d7201..d66c8fdc86680 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
@@ -56,7 +56,6 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
-import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
@@ -190,8 +189,8 @@ private TypeResolution resolveField() {
private TypeResolution resolveQuery() {
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
- isNotNullAndFoldable(query(), sourceText(), SECOND)
- );
+ isNotNull(query(), sourceText(), SECOND)
+ ).and(isFoldable(query(), sourceText(), SECOND));
}
private TypeResolution resolveK() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
index fe6ab6e9a998c..54613e31212e4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
@@ -15,7 +16,6 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
@@ -26,7 +26,7 @@
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
/**
- * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}.
+ * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceResolver}.
*
* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It
* transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}.
@@ -41,21 +41,21 @@ public abstract class InferenceOperator extends AsyncOperator
+ * This method traverses the logical plan tree and identifies all inference operations,
+ * extracting their deployment IDs for subsequent validation. Currently, supports:
+ *
+ *
+ * @param plan The logical plan to scan for inference operations
+ * @param c Consumer function to receive each discovered inference ID
+ */
+ public void collectInferenceIds(LogicalPlan plan, Consumer c) {
+ collectInferenceIdsFromInferencePlans(plan, c);
+ collectInferenceIdsFromInferenceFunctions(plan, c);
+ }
+
+ /**
+ * Resolves a list of inference deployment IDs to their metadata.
+ *
+ * For each inference ID, this method:
+ *
+ *
Queries the inference service to verify the deployment exists
+ *
Retrieves the deployment's task type and configuration
+ *
Builds an {@link InferenceResolution} containing resolved metadata or errors
+ *
+ *
+ * @param inferenceIds List of inference deployment IDs to resolve
+ * @param listener Callback to receive the resolution results
+ */
+ public void resolveInferenceIds(List inferenceIds, ActionListener listener) {
+ resolveInferenceIds(Set.copyOf(inferenceIds), listener);
+ }
+
+ private void resolveInferenceIds(Set inferenceIds, ActionListener listener) {
+
+ if (inferenceIds.isEmpty()) {
+ listener.onResponse(InferenceResolution.EMPTY);
+ return;
+ }
+
+ final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
+
+ final CountDownActionListener countdownListener = new CountDownActionListener(
+ inferenceIds.size(),
+ ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
+ );
+
+ for (var inferenceId : inferenceIds) {
+ client.execute(
+ GetInferenceModelAction.INSTANCE,
+ new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
+ ActionListener.wrap(r -> {
+ ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
+ inferenceResolutionBuilder.withResolvedInference(resolvedInference);
+ countdownListener.onResponse(null);
+ }, e -> {
+ inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
+ countdownListener.onResponse(null);
+ })
+ );
+ }
+ }
+
+ /**
+ * Collects inference IDs from inference function calls within the logical plan.
+ *
+ * This method scans the logical plan for {@link UnresolvedFunction} instances that represent
+ * inference functions (e.g., EMBED_TEXT). For each inference function found:
+ *
+ *
Resolves the function definition through the function registry and checks if the function implements {@link InferenceFunction}
+ *
Extracts the inference deployment ID from the function arguments
+ *
+ *
+ * This operates during pre-analysis when functions are still unresolved, allowing early
+ * validation of inference deployments before query optimization.
+ *
+ * @param plan The logical plan to scan for inference function calls
+ * @param c Consumer function to receive each discovered inference ID
+ */
+ private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) {
+ EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry();
+ plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
+ String functionName = snapshotRegistry.resolveAlias(f.name());
+ if (snapshotRegistry.functionExists(functionName)) {
+ FunctionDefinition def = snapshotRegistry.resolveFunction(functionName);
+ if (InferenceFunction.class.isAssignableFrom(def.clazz())) {
+ String inferenceId = inferenceId(f, def);
+ if (inferenceId != null) {
+ c.accept(inferenceId);
+ }
+ }
+ }
+ });
+ }
+
+ /**
+ * Collects inference IDs from InferencePlan objects within the logical plan.
+ *
+ * @param plan The logical plan to scan for InferencePlan objects
+ * @param c Consumer function to receive each discovered inference ID
+ */
+ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c) {
+ plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan)));
+ }
+
+ /**
+ * Extracts the inference ID from an InferencePlan object.
+ *
+ * @param plan The InferencePlan object to extract the ID from
+ * @return The inference ID as a string
+ */
+ private static String inferenceId(InferencePlan> plan) {
+ return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
+ }
+
+ private static String inferenceId(Expression e) {
+ return BytesRefs.toString(e.fold(FoldContext.small()));
+ }
+
+ /**
+ * Extracts the inference ID from an {@link UnresolvedFunction} instance.
+ *
+ * This method inspects the function's arguments to find the inference ID.
+ * Currently, it only supports positional parameters named "inference_id".
+ *
+ * @param f The unresolved function to extract the ID from
+ * @param def The function definition
+ * @return The inference ID as a string, or null if not found
+ */
+ public String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
+ EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
+
+ for (int i = 0; i < functionDescription.args().size(); i++) {
+ EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
+
+ if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
+ // Found a positional parameter named "inference_id", so use its value
+ Expression argValue = f.arguments().get(i);
+ if (argValue != null && argValue.foldable()) {
+ return inferenceId(argValue);
+ }
+ }
+
+ // TODO: support inference ID as an optional named parameter
+ }
+
+ return null;
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java
deleted file mode 100644
index d67d6817742c0..0000000000000
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.esql.inference;
-
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.support.CountDownActionListener;
-import org.elasticsearch.client.internal.Client;
-import org.elasticsearch.common.lucene.BytesRefs;
-import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.esql.core.expression.FoldContext;
-import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
-
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
-import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
-
-public class InferenceRunner {
-
- private final Client client;
- private final ThreadPool threadPool;
-
- public InferenceRunner(Client client, ThreadPool threadPool) {
- this.client = client;
- this.threadPool = threadPool;
- }
-
- public ThreadPool threadPool() {
- return threadPool;
- }
-
- public void resolveInferenceIds(List> plans, ActionListener listener) {
- resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
-
- }
-
- private void resolveInferenceIds(Set inferenceIds, ActionListener listener) {
-
- if (inferenceIds.isEmpty()) {
- listener.onResponse(InferenceResolution.EMPTY);
- return;
- }
-
- final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
-
- final CountDownActionListener countdownListener = new CountDownActionListener(
- inferenceIds.size(),
- ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
- );
-
- for (var inferenceId : inferenceIds) {
- client.execute(
- GetInferenceModelAction.INSTANCE,
- new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
- ActionListener.wrap(r -> {
- ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
- inferenceResolutionBuilder.withResolvedInference(resolvedInference);
- countdownListener.onResponse(null);
- }, e -> {
- inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
- countdownListener.onResponse(null);
- })
- );
- }
- }
-
- private static String planInferenceId(InferencePlan> plan) {
- return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
- }
-
- public void doInference(InferenceAction.Request request, ActionListener listener) {
- executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, request, listener);
- }
-}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java
new file mode 100644
index 0000000000000..34236f2fb57fe
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
+
+public class InferenceServices {
+ private final Client client;
+ private final BulkInferenceExecutor.Factory bulkInferenceExecutorFactory;
+
+ public InferenceServices(Client client, ThreadPool threadPool) {
+ this.client = client;
+ this.bulkInferenceExecutorFactory = new BulkInferenceExecutor.Factory(client, threadPool);
+ }
+
+ public BulkInferenceExecutor bulkInferenceExecutor(BulkInferenceExecutionConfig bulkExecutionConfig) {
+ return bulkInferenceExecutorFactory.create(bulkExecutionConfig);
+ }
+
+ public BulkInferenceExecutor.Factory bulkInferenceExecutorFactory() {
+ return bulkInferenceExecutorFactory;
+ }
+
+ public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) {
+ return new InferenceResolver(functionRegistry, client);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java
index 257799962dda7..cd382df3ece9b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java
@@ -8,10 +8,10 @@
package org.elasticsearch.xpack.esql.inference.bulk;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import java.util.ArrayList;
@@ -22,6 +22,9 @@
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
+import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+
/**
* Executes a sequence of inference requests in bulk with throttling and concurrency control.
*/
@@ -32,12 +35,11 @@ public class BulkInferenceExecutor {
/**
* Constructs a new {@code BulkInferenceExecutor}.
*
- * @param inferenceRunner The inference runner used to execute individual inference requests.
- * @param threadPool The thread pool for executing inference tasks.
- * @param bulkExecutionConfig Configuration options (throttling and concurrency limits).
+ * @param throttledInferenceRunner The throttled inference runner used to execute individual inference requests.
+ * @param bulkExecutionConfig Configuration options (throttling and concurrency limits).
*/
- public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) {
- this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig);
+ private BulkInferenceExecutor(ThrottledInferenceRunner throttledInferenceRunner, BulkInferenceExecutionConfig bulkExecutionConfig) {
+ this.throttledInferenceRunner = throttledInferenceRunner;
this.bulkExecutionConfig = bulkExecutionConfig;
}
@@ -149,33 +151,18 @@ private void sendResponseOnCompletion() {
* Manages throttled inference tasks execution.
*/
private static class ThrottledInferenceRunner {
- private final InferenceRunner inferenceRunner;
+ private final Client client;
private final ExecutorService executorService;
private final BlockingQueue pendingRequestsQueue;
private final Semaphore permits;
- private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) {
+ private ThrottledInferenceRunner(Client client, ExecutorService executorService, int maxRunningTasks) {
this.executorService = executorService;
this.permits = new Semaphore(maxRunningTasks);
- this.inferenceRunner = inferenceRunner;
+ this.client = client;
this.pendingRequestsQueue = new ArrayBlockingQueue<>(maxRunningTasks);
}
- /**
- * Creates a new {@code ThrottledInferenceRunner} with the specified configuration.
- *
- * @param inferenceRunner TThe inference runner used to execute individual inference requests.
- * @param executorService The executor used for asynchronous execution.
- * @param bulkExecutionConfig Configuration options (throttling and concurrency limits).
- */
- public static ThrottledInferenceRunner create(
- InferenceRunner inferenceRunner,
- ExecutorService executorService,
- BulkInferenceExecutionConfig bulkExecutionConfig
- ) {
- return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests());
- }
-
/**
* Schedules the inference task for execution. If a permit is available, the task runs immediately; otherwise, it is queued.
*
@@ -212,7 +199,7 @@ private void executePendingRequests() {
* Add an inference task to the queue.
*
* @param request The inference request.
- * * @param listener The listener to notify on response or failure.
+ * @param listener The listener to notify on response or failure.
*/
private void enqueueTask(InferenceAction.Request request, ActionListener listener) {
try {
@@ -240,7 +227,7 @@ private AbstractRunnable createTask(InferenceAction.Request request, ActionListe
@Override
protected void doRun() {
try {
- inferenceRunner.doInference(request, completionListener);
+ executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, request, listener);
} catch (Throwable e) {
listener.onFailure(new RuntimeException("Unexpected failure while running inference", e));
}
@@ -257,4 +244,13 @@ public void onFailure(Exception e) {
private static ExecutorService executorService(ThreadPool threadPool) {
return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME);
}
+
+ public record Factory(Client client, ThreadPool threadPool) {
+ public BulkInferenceExecutor create(BulkInferenceExecutionConfig bulkExecutionConfig) {
+ return new BulkInferenceExecutor(
+ new ThrottledInferenceRunner(client, executorService(threadPool), bulkExecutionConfig.maxOutstandingRequests()),
+ bulkExecutionConfig
+ );
+ }
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
index e53fda90c88b3..6c686ed02b233 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java
@@ -7,16 +7,16 @@
package org.elasticsearch.xpack.esql.inference.completion;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasables;
-import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import java.util.stream.IntStream;
@@ -31,12 +31,12 @@ public class CompletionOperator extends InferenceOperator {
public CompletionOperator(
DriverContext driverContext,
- InferenceRunner inferenceRunner,
- ThreadPool threadPool,
+ ThreadContext threadContext,
+ BulkInferenceExecutor.Factory bulkInferenceExecutorFactory,
String inferenceId,
ExpressionEvaluator promptEvaluator
) {
- super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId);
+ super(driverContext, threadContext, bulkInferenceExecutorFactory, BulkInferenceExecutionConfig.DEFAULT, inferenceId);
this.promptEvaluator = promptEvaluator;
}
@@ -88,9 +88,11 @@ protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
/**
* Factory for creating {@link CompletionOperator} instances.
*/
- public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
- implements
- OperatorFactory {
+ public record Factory(
+ BulkInferenceExecutor.Factory inferenceExecutorFactory,
+ String inferenceId,
+ ExpressionEvaluator.Factory promptEvaluatorFactory
+ ) implements OperatorFactory {
@Override
public String describe() {
return "CompletionOperator[inference_id=[" + inferenceId + "]]";
@@ -100,8 +102,8 @@ public String describe() {
public Operator get(DriverContext driverContext) {
return new CompletionOperator(
driverContext,
- inferenceRunner,
- inferenceRunner.threadPool(),
+ inferenceExecutorFactory.threadPool().getThreadContext(),
+ inferenceExecutorFactory,
inferenceId,
promptEvaluatorFactory.get(driverContext)
);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java
index ca628fdba8a8f..cbc71d61fe6ed 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.inference.rerank;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
@@ -14,10 +15,9 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasables;
-import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
import java.util.stream.IntStream;
@@ -40,14 +40,14 @@ public class RerankOperator extends InferenceOperator {
public RerankOperator(
DriverContext driverContext,
- InferenceRunner inferenceRunner,
- ThreadPool threadPool,
+ ThreadContext threadContext,
+ BulkInferenceExecutor.Factory bulkInferenceExecutorFactory,
String inferenceId,
String queryText,
ExpressionEvaluator rowEncoder,
int scoreChannel
) {
- super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId);
+ super(driverContext, threadContext, bulkInferenceExecutorFactory, BulkInferenceExecutionConfig.DEFAULT, inferenceId);
this.queryText = queryText;
this.rowEncoder = rowEncoder;
this.scoreChannel = scoreChannel;
@@ -100,7 +100,7 @@ protected RerankOperatorOutputBuilder outputBuilder(Page input) {
* Factory for creating {@link RerankOperator} instances
*/
public record Factory(
- InferenceRunner inferenceRunner,
+ BulkInferenceExecutor.Factory inferenceExecutorFactory,
String inferenceId,
String queryText,
ExpressionEvaluator.Factory rowEncoderFactory,
@@ -116,8 +116,8 @@ public String describe() {
public Operator get(DriverContext driverContext) {
return new RerankOperator(
driverContext,
- inferenceRunner,
- inferenceRunner.threadPool(),
+ inferenceExecutorFactory.threadPool().getThreadContext(),
+ inferenceExecutorFactory,
inferenceId,
queryText,
rowEncoderFactory().get(driverContext),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index ad6cb42f7f835..7c1e7372ff883 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -86,7 +86,7 @@
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
import org.elasticsearch.xpack.esql.expression.Order;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator;
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator;
@@ -159,7 +159,7 @@ public class LocalExecutionPlanner {
private final Supplier exchangeSinkSupplier;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
- private final InferenceRunner inferenceRunner;
+ private final InferenceServices inferenceServices;
private final PhysicalOperationProviders physicalOperationProviders;
private final List shardContexts;
@@ -175,7 +175,7 @@ public LocalExecutionPlanner(
Supplier exchangeSinkSupplier,
EnrichLookupService enrichLookupService,
LookupFromIndexService lookupFromIndexService,
- InferenceRunner inferenceRunner,
+ InferenceServices inferenceServices,
PhysicalOperationProviders physicalOperationProviders,
List shardContexts
) {
@@ -191,7 +191,7 @@ public LocalExecutionPlanner(
this.exchangeSinkSupplier = exchangeSinkSupplier;
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
- this.inferenceRunner = inferenceRunner;
+ this.inferenceServices = inferenceServices;
this.physicalOperationProviders = physicalOperationProviders;
this.shardContexts = shardContexts;
}
@@ -318,7 +318,10 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti
source.layout
);
- return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout);
+ return source.with(
+ new CompletionOperator.Factory(inferenceServices.bulkInferenceExecutorFactory(), inferenceId, promptEvaluatorFactory),
+ outputLayout
+ );
}
private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) {
@@ -465,7 +468,7 @@ private PhysicalOperation planParallelNode(ParallelExec parallelExec, LocalExecu
statusInterval,
settings
),
- DriverParallelism.SINGLE
+ context.driverParallelism().get()
)
);
context.driverParallelism.set(DriverParallelism.SINGLE);
@@ -654,7 +657,13 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
return source.with(
- new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel),
+ new RerankOperator.Factory(
+ inferenceServices.bulkInferenceExecutorFactory(),
+ inferenceId,
+ queryText,
+ rowEncoderFactory,
+ scoreChannel
+ ),
outputLayout
);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
index 39e3503b5fdd9..d0b176635d435 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
@@ -52,7 +52,7 @@
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
@@ -130,7 +130,7 @@ public class ComputeService {
private final DriverTaskRunner driverRunner;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
- private final InferenceRunner inferenceRunner;
+ private final InferenceServices inferenceServices;
private final ClusterService clusterService;
private final ProjectResolver projectResolver;
private final AtomicLong childSessionIdGenerator = new AtomicLong();
@@ -158,7 +158,7 @@ public ComputeService(
this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor);
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
- this.inferenceRunner = transportActionServices.inferenceRunner();
+ this.inferenceServices = transportActionServices.inferenceServices();
this.clusterService = transportActionServices.clusterService();
this.projectResolver = transportActionServices.projectResolver();
this.dataNodeComputeHandler = new DataNodeComputeHandler(
@@ -576,7 +576,7 @@ public SourceProvider createSourceProvider() {
context.exchangeSinkSupplier(),
enrichLookupService,
lookupFromIndexService,
- inferenceRunner,
+ inferenceServices,
physicalOperationProviders,
contexts
);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
index ccabe09fd466c..3aa466740b689 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java
@@ -14,7 +14,7 @@
import org.elasticsearch.search.SearchService;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.usage.UsageService;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
public record TransportActionServices(
TransportService transportService,
@@ -24,5 +24,5 @@ public record TransportActionServices(
ProjectResolver projectResolver,
IndexNameExpressionResolver indexNameExpressionResolver,
UsageService usageService,
- InferenceRunner inferenceRunner
+ InferenceServices inferenceServices
) {}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
index 4be7b31bb96c0..4ca54b4d9b789 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
@@ -51,7 +51,7 @@
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner;
import org.elasticsearch.xpack.esql.session.Result;
@@ -166,7 +166,7 @@ public TransportEsqlQueryAction(
projectResolver,
indexNameExpressionResolver,
usageService,
- new InferenceRunner(client, threadPool)
+ new InferenceServices(client, threadPool)
);
this.computeService = new ComputeService(
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
index 40a859e3f5b58..6b9c823a53707 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
@@ -38,6 +38,7 @@
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
import org.elasticsearch.xpack.esql.analysis.PreAnalyzer;
+import org.elasticsearch.xpack.esql.analysis.PreAnalyzerContext;
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
@@ -61,7 +62,7 @@
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.index.MappingException;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceResolver;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
@@ -88,7 +89,6 @@
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
-import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@@ -148,7 +148,7 @@ public interface PlanRunner {
private final PlanTelemetry planTelemetry;
private final IndicesExpressionGrouper indicesExpressionGrouper;
private Set configuredClusters;
- private final InferenceRunner inferenceRunner;
+ private final InferenceResolver inferenceResolver;
private final RemoteClusterService remoteClusterService;
private boolean explainMode;
@@ -167,7 +167,8 @@ public EsqlSession(
Verifier verifier,
PlanTelemetry planTelemetry,
IndicesExpressionGrouper indicesExpressionGrouper,
- TransportActionServices services
+ TransportActionServices services,
+ InferenceResolver inferenceResolver
) {
this.sessionId = sessionId;
this.configuration = configuration;
@@ -181,7 +182,7 @@ public EsqlSession(
this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
this.planTelemetry = planTelemetry;
this.indicesExpressionGrouper = indicesExpressionGrouper;
- this.inferenceRunner = services.inferenceRunner();
+ this.inferenceResolver = inferenceResolver;
this.preMapper = new PreMapper(services);
this.remoteClusterService = services.transportService().getRemoteClusterService();
}
@@ -354,7 +355,7 @@ public void analyzedPlan(
// Capture configured remotes list to ensure consistency throughout the session
configuredClusters = Set.copyOf(indicesExpressionGrouper.getConfiguredClusters());
- PreAnalyzer.PreAnalysis preAnalysis = preAnalyzer.preAnalyze(parsed);
+ PreAnalyzer.PreAnalysis preAnalysis = preAnalyzer.preAnalyze(parsed, new PreAnalyzerContext(inferenceResolver));
var unresolvedPolicies = preAnalysis.enriches.stream()
.map(
e -> new EnrichPolicyResolver.UnresolvedPolicy(
@@ -372,7 +373,7 @@ public void analyzedPlan(
l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l)
)
.andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
- .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
+ .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferenceIds, preAnalysisResult, l));
// first resolve the lookup indices, then the main indices
for (var index : preAnalysis.lookupIndices) {
listener = listener.andThen((l, preAnalysisResult) -> preAnalyzeLookupIndex(index, preAnalysisResult, executionInfo, l));
@@ -768,12 +769,8 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric
}
}
- private void resolveInferences(
- List> inferencePlans,
- PreAnalysisResult preAnalysisResult,
- ActionListener l
- ) {
- inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
+ private void resolveInferences(List inferenceIds, PreAnalysisResult preAnalysisResult, ActionListener l) {
+ inferenceResolver.resolveInferenceIds(inferenceIds, l.map(preAnalysisResult::withInferenceResolution));
}
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set enrichPolicyMatchFields, PreAnalysisResult result) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index bdf2ba39edc66..ed95abe3a3059 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -56,6 +56,7 @@
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
import org.elasticsearch.xpack.esql.analysis.PreAnalyzer;
+import org.elasticsearch.xpack.esql.analysis.PreAnalyzerContext;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -67,7 +68,8 @@
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
-import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.InferenceResolver;
+import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
@@ -98,7 +100,6 @@
import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry;
import org.junit.After;
import org.junit.Before;
-import org.mockito.Mockito;
import java.io.IOException;
import java.net.URL;
@@ -131,6 +132,7 @@
import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
+import static org.mockito.Mockito.mock;
/**
* CSV-based unit testing.
@@ -517,7 +519,7 @@ private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.MultiInd
}
private static CsvTestsDataLoader.MultiIndexTestDataset testDatasets(LogicalPlan parsed) {
- var preAnalysis = new PreAnalyzer().preAnalyze(parsed);
+ var preAnalysis = new PreAnalyzer().preAnalyze(parsed, new PreAnalyzerContext(mock(InferenceResolver.class)));
var indices = preAnalysis.indices;
if (indices.isEmpty()) {
/*
@@ -584,7 +586,8 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
TEST_VERIFIER,
new PlanTelemetry(functionRegistry),
null,
- EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES
+ EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES,
+ EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES.inferenceServices().inferenceResolver(functionRegistry)
);
TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(foldCtx, testDatasets);
@@ -695,9 +698,9 @@ void executeSubPlan(
configuration,
exchangeSource::createExchangeSource,
() -> exchangeSink.createExchangeSink(() -> {}),
- Mockito.mock(EnrichLookupService.class),
- Mockito.mock(LookupFromIndexService.class),
- Mockito.mock(InferenceRunner.class),
+ mock(EnrichLookupService.class),
+ mock(LookupFromIndexService.class),
+ mock(InferenceServices.class),
physicalOperationProviders,
List.of()
);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
index cbb825ca9581b..6009fa5b725dc 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
@@ -200,6 +200,7 @@ public static InferenceResolution defaultInferenceResolution() {
return InferenceResolution.builder()
.withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
.withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
+ .withResolvedInference(new ResolvedInference("text-embedding-inference-id", TaskType.TEXT_EMBEDDING))
.withError("error-inference-id", "error with inference resolution")
.build();
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index b2521bddfb47b..2a5ae50a25b69 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -50,6 +50,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch;
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
+import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -3458,7 +3459,11 @@ private void assertProjectionWithMapping(String query, String mapping, QueryPara
}
private void assertError(String query, String mapping, QueryParams params, String error) {
- Throwable e = expectThrows(VerificationException.class, () -> analyze(query, mapping, params));
+ assertError(query, mapping, params, error, VerificationException.class);
+ }
+
+ private void assertError(String query, String mapping, QueryParams params, String error, Class extends Throwable> clazz) {
+ Throwable e = expectThrows(clazz, () -> analyze(query, mapping, params));
assertThat(e.getMessage(), containsString(error));
}
@@ -3802,6 +3807,129 @@ public void testResolveCompletionOutputField() {
assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField())));
}
+ public void testResolveEmbedTextInferenceId() {
+ LogicalPlan plan = analyze("""
+ FROM books METADATA _score
+ | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
+ """, "mapping-books.json");
+
+ var limit = as(plan, Limit.class);
+ var eval = as(limit.child(), Eval.class);
+ var embedTextAlias = as(eval.fields().get(0), Alias.class);
+ var embedText = as(embedTextAlias.child(), EmbedText.class);
+
+ assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
+ assertThat(embedText.inputText(), equalTo(string("description")));
+ }
+
+ public void testResolveEmbedTextInferenceIdInvalidTaskType() {
+ assertError(
+ """
+ FROM books METADATA _score
+ | EVAL embedding = EMBED_TEXT(description, "completion-inference-id")
+ """,
+ "mapping-books.json",
+ new QueryParams(),
+ "cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function."
+ + " Only inference endpoints with the task type [text_embedding] are supported"
+ );
+ }
+
+ public void testResolveEmbedTextInferenceMissingInferenceId() {
+ assertError("""
+ FROM books METADATA _score
+ | EVAL embedding = EMBED_TEXT(description, "unknown-inference-id")
+ """, "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]");
+ }
+
+ public void testResolveEmbedTextInferenceIdResolutionError() {
+ assertError("""
+ FROM books METADATA _score
+ | EVAL embedding = EMBED_TEXT(description, "error-inference-id")
+ """, "mapping-books.json", new QueryParams(), "error with inference resolution");
+ }
+
+ public void testResolveEmbedTextInNestedExpression() {
+ LogicalPlan plan = analyze("""
+ FROM colors METADATA _score
+ | WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10)
+ """, "mapping-colors.json");
+
+ var limit = as(plan, Limit.class);
+ var filter = as(limit.child(), Filter.class);
+
+ // Navigate to the EMBED_TEXT function within the KNN function
+ filter.condition().forEachDown(EmbedText.class, embedText -> {
+ assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
+ assertThat(embedText.inputText(), equalTo(string("blue")));
+ });
+ }
+
+ public void testResolveEmbedTextDataType() {
+ LogicalPlan plan = analyze("""
+ FROM books METADATA _score
+ | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
+ """, "mapping-books.json");
+
+ var limit = as(plan, Limit.class);
+ var eval = as(limit.child(), Eval.class);
+ var embedTextAlias = as(eval.fields().get(0), Alias.class);
+ var embedText = as(embedTextAlias.child(), EmbedText.class);
+
+ assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR));
+ }
+
+ public void testResolveEmbedTextInvalidParameters() {
+ assertError(
+ "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")",
+ "mapping-books.json",
+ new QueryParams(),
+ "first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]"
+ );
+
+ assertError(
+ "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)",
+ "mapping-books.json",
+ new QueryParams(),
+ "error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1",
+ ParsingException.class
+ );
+ }
+
+ public void testResolveEmbedTextWithPositionalQueryParams() {
+ LogicalPlan plan = analyze(
+ "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)",
+ "mapping-books.json",
+ new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id")))
+ );
+
+ var limit = as(plan, Limit.class);
+ var eval = as(limit.child(), Eval.class);
+ var embedTextAlias = as(eval.fields().get(0), Alias.class);
+ var embedText = as(embedTextAlias.child(), EmbedText.class);
+
+ assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
+ assertThat(embedText.inputText(), equalTo(string("description")));
+ }
+
+ public void testResolveEmbedTextWithNamedlQueryParams() {
+ LogicalPlan plan = analyze(
+ "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)",
+ "mapping-books.json",
+ new QueryParams(
+ List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id"))
+ )
+ );
+
+ var limit = as(plan, Limit.class);
+ var eval = as(limit.child(), Eval.class);
+ var embedTextAlias = as(eval.fields().get(0), Alias.class);
+ var embedText = as(embedTextAlias.child(), EmbedText.class);
+
+ assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
+ assertThat(embedText.inputText(), equalTo(string("description")));
+ }
+
public void testResolveGroupingsBeforeResolvingImplicitReferencesToGroupings() {
var plan = analyze("""
FROM test
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java
new file mode 100644
index 0000000000000..4b490a86b7587
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class EmbedTextErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
+
+ @Before
+ public void checkCapability() {
+ assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled());
+ }
+
+ @Override
+ protected List cases() {
+ return paramsToSuppliers(EmbedTextTests.parameters());
+ }
+
+ @Override
+ protected Expression build(Source source, List args) {
+ return new EmbedText(source, args.get(0), args.get(1));
+ }
+
+ @Override
+ protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) {
+ return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string"));
+ }
+
+ protected static String typeErrorMessage(
+ boolean includeOrdinal,
+ List> validPerPosition,
+ List signature,
+ AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
+ ) {
+ for (int i = 0; i < signature.size(); i++) {
+ if (signature.get(i) == DataType.NULL) {
+ String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : "";
+ return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []";
+ }
+
+ if (validPerPosition.get(i).contains(signature.get(i)) == false) {
+ break;
+ }
+ }
+
+ return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage(
+ includeOrdinal,
+ validPerPosition,
+ signature,
+ positionalErrorMessageSupplier
+ );
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java
new file mode 100644
index 0000000000000..59d5377e36f76
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;
+import org.junit.Before;
+
+import java.io.IOException;
+
+public class EmbedTextSerializationTests extends AbstractExpressionSerializationTests {
+
+ @Before
+ public void checkCapability() {
+ assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled());
+ }
+
+ @Override
+ protected EmbedText createTestInstance() {
+ Source source = randomSource();
+ Expression inputText = randomChild();
+ Expression inferenceId = randomChild();
+ return new EmbedText(source, inputText, inferenceId);
+ }
+
+ @Override
+ protected EmbedText mutateInstance(EmbedText instance) throws IOException {
+ Source source = instance.source();
+ Expression inputText = instance.inputText();
+ Expression inferenceId = instance.inferenceId();
+ if (randomBoolean()) {
+ inputText = randomValueOtherThan(inputText, AbstractExpressionSerializationTests::randomChild);
+ } else {
+ inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild);
+ }
+ return new EmbedText(source, inputText, inferenceId);
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java
new file mode 100644
index 0000000000000..c342ee143e6af
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
+import static org.hamcrest.Matchers.equalTo;
+
+@FunctionName("embed_text")
+public class EmbedTextTests extends AbstractFunctionTestCase {
+ @Before
+ public void checkCapability() {
+ assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled());
+ }
+
+ public EmbedTextTests(@Name("TestCase") Supplier testCaseSupplier) {
+ this.testCase = testCaseSupplier.get();
+ }
+
+ @ParametersFactory
+ public static Iterable