diff --git a/docs/reference/query-languages/esql/images/functions/embed_text.svg b/docs/reference/query-languages/esql/images/functions/embed_text.svg new file mode 100644 index 0000000000000..9bb6cab692c4e --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/embed_text.svg @@ -0,0 +1 @@ +EMBED_TEXT(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json new file mode 100644 index 0000000000000..edfafb213e16d --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json @@ -0,0 +1,9 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "embed_text", + "description" : "Generates dense vector embeddings for text using a specified inference deployment.", + "signatures" : [ ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md new file mode 100644 index 0000000000000..fc6d8d0772d64 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md @@ -0,0 +1,4 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### EMBED TEXT +Generates dense vector embeddings for text using a specified inference deployment. diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index e8acabe71ab41..ff1d078e022a5 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -76,7 +76,8 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.inference.InferenceResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceResolver; +import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.parser.QueryParam; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -164,6 +165,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public final class EsqlTestUtils { @@ -401,18 +403,26 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() { mock(ProjectResolver.class), mock(IndexNameExpressionResolver.class), null, - mockInferenceRunner() + mockInferenceServices() ); + private static InferenceServices mockInferenceServices() { + InferenceServices inferenceServices = mock(InferenceServices.class); + InferenceResolver inferenceResolver = mockInferenceRunner(); + when(inferenceServices.inferenceResolver(any())).thenReturn(inferenceResolver); + + return inferenceServices; + } + @SuppressWarnings("unchecked") - private static InferenceRunner mockInferenceRunner() { - InferenceRunner inferenceRunner = mock(InferenceRunner.class); + private static InferenceResolver mockInferenceRunner() { + InferenceResolver inferenceResolver = mock(InferenceResolver.class); doAnswer(i -> { i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution()); return null; - }).when(inferenceRunner).resolveInferenceIds(any(), any()); + }).when(inferenceResolver).resolveInferenceIds(any(), any()); - return inferenceRunner; + return inferenceResolver; } private EsqlTestUtils() {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 0b51241beebbd..153abbfbf2bb6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1245,6 +1245,11 @@ public enum Cap { */ AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG), + /** + * Support for the {@code EMBED_TEXT} function for generating dense vector embeddings. + */ + EMBED_TEXT_FUNCTION(Build.current().isSnapshot()), + /** * Forbid usage of brackets in unquoted index and enrich policy names * https://github.com/elastic/elasticsearch/issues/130378 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index e4b8949af5bdb..f91305d471727 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -54,6 +54,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; @@ -173,9 +174,9 @@ public class Analyzer extends ParameterizedRuleExecutor( @@ -398,34 +399,6 @@ private static NamedExpression createEnrichFieldExpression( } } - private static class ResolveInference extends ParameterizedAnalyzerRule, AnalyzerContext> { - @Override - protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) { - assert plan.inferenceId().resolved() && plan.inferenceId().foldable(); - - String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); - ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); - - if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) { - return plan; - } else if (resolvedInference != null) { - String error = "cannot use inference endpoint [" - + inferenceId - + "] with task type [" - + resolvedInference.taskType() - + "] within a " - + plan.nodeName() - + " command. Only inference endpoints with the task type [" - + plan.taskType() - + "] are supported."; - return plan.withInferenceResolutionError(inferenceId, error); - } else { - String error = context.inferenceResolution().getError(inferenceId); - return plan.withInferenceResolutionError(inferenceId, error); - } - } - } - private static class ResolveLookupTables extends ParameterizedAnalyzerRule { @Override @@ -1319,6 +1292,70 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res } } + private static class ResolveInference extends ParameterizedRule { + + @Override + public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) { + return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context)) + .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context)); + + } + + private InferenceFunction resolveInferenceFunction(InferenceFunction inferenceFunction, AnalyzerContext context) { + assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable(); + + String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != inferenceFunction.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass()) + + " function. Only inference endpoints with the task type [" + + inferenceFunction.taskType() + + "] are supported."; + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + return inferenceFunction; + } + + private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext context) { + assert plan.inferenceId().resolved() && plan.inferenceId().foldable(); + + String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return plan.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != plan.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + plan.nodeName() + + " command. Only inference endpoints with the task type [" + + plan.taskType() + + "] are supported."; + return plan.withInferenceResolutionError(inferenceId, error); + } + + return plan; + } + } + private static class AddImplicitLimit extends ParameterizedRule { @Override public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java index 5b9f41876d6e1..958f3cc9d946f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import java.util.ArrayList; import java.util.HashSet; @@ -33,38 +32,39 @@ public static class PreAnalysis { public final IndexMode indexMode; public final List indices; public final List enriches; - public final List> inferencePlans; + public final List inferenceIds; public final List lookupIndices; public PreAnalysis( IndexMode indexMode, List indices, List enriches, - List> inferencePlans, + List inferenceIds, List lookupIndices ) { this.indexMode = indexMode; this.indices = indices; this.enriches = enriches; - this.inferencePlans = inferencePlans; + this.inferenceIds = inferenceIds; this.lookupIndices = lookupIndices; } } - public PreAnalysis preAnalyze(LogicalPlan plan) { + public PreAnalysis preAnalyze(LogicalPlan plan, PreAnalyzerContext context) { if (plan.analyzed()) { return PreAnalysis.EMPTY; } - return doPreAnalyze(plan); + return doPreAnalyze(plan, context); } - protected PreAnalysis doPreAnalyze(LogicalPlan plan) { + protected PreAnalysis doPreAnalyze(LogicalPlan plan, PreAnalyzerContext context) { Set indices = new HashSet<>(); List unresolvedEnriches = new ArrayList<>(); List lookupIndices = new ArrayList<>(); - List> unresolvedInferencePlans = new ArrayList<>(); + Set inferenceIds = new HashSet<>(); + Holder indexMode = new Holder<>(); plan.forEachUp(UnresolvedRelation.class, p -> { if (p.indexMode() == IndexMode.LOOKUP) { @@ -78,11 +78,12 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) { }); plan.forEachUp(Enrich.class, unresolvedEnriches::add); - plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add); + context.inferenceResolver().collectInferenceIds(plan, inferenceIds::add); // mark plan as preAnalyzed (if it were marked, there would be no analysis) plan.forEachUp(LogicalPlan::setPreAnalyzed); - return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices); + return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, List.copyOf(inferenceIds), lookupIndices); } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java new file mode 100644 index 0000000000000..a2571d5384f89 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzerContext.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.analysis; + +import org.elasticsearch.xpack.esql.inference.InferenceResolver; + +public record PreAnalyzerContext(InferenceResolver inferenceResolver) { + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 620d86650abf4..1e66ce413206d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -91,7 +91,8 @@ public void esql( verifier, planTelemetry, indicesExpressionGrouper, - services + services, + services.inferenceServices().inferenceResolver(functionRegistry) ); QueryMetric clientId = QueryMetric.fromString("rest"); metrics.total(clientId); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index a3f6d3a089d49..59b64cccd929e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; @@ -119,6 +120,7 @@ public static List getNamedWriteables() { entries.addAll(fullText()); entries.addAll(unaryScalars()); entries.addAll(vector()); + entries.addAll(inference()); return entries; } @@ -264,4 +266,11 @@ private static List vector() { } return List.of(); } + + private static List inference() { + if (EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()) { + return List.of(EmbedText.ENTRY); + } + return List.of(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 630c9c2008a13..bffd90362943f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Term.class, bi(Term::new), "term"), def(Knn.class, Knn::new, "knn"), + def(EmbedText.class, EmbedText::new, "embed_text"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java new file mode 100644 index 0000000000000..cc63f7e465f48 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.*; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; + +/** + * EMBED_TEXT function that generates dense vector embeddings for text using a specified inference deployment. + */ +public class EmbedText extends InferenceFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "EmbedText", + EmbedText::new + ); + + private final Expression inferenceId; + private final Expression inputText; + + @FunctionInfo( + returnType = "dense_vector", + description = "Generates dense vector embeddings for text using a specified inference deployment.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, + preview = true + ) + public EmbedText( + Source source, + @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText, + @Param( + name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME, + type = { "keyword", "text" }, + description = "Inference deployment ID" + ) Expression inferenceId + ) { + super(source, List.of(inputText, inferenceId)); + this.inferenceId = inferenceId; + this.inputText = inputText; + } + + private EmbedText(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(inputText); + out.writeNamedWriteable(inferenceId); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public Expression inputText() { + return inputText; + } + + @Override + public Expression inferenceId() { + return inferenceId; + } + + @Override + public DataType dataType() { + return DataType.DENSE_VECTOR; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST)) + .and(isString(inputText, sourceText(), FIRST)); + + if (textResolution.unresolved()) { + return textResolution; + } + + TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND)) + .and(isFoldable(inferenceId, sourceText(), SECOND)); + + if (inferenceIdResolution.unresolved()) { + return inferenceIdResolution; + } + + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public boolean foldable() { + // The function is foldable only if both arguments are foldable + return inputText.foldable() && inferenceId.foldable(); + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + @Override + public EmbedText withInferenceResolutionError(String inferenceId, String error) { + return new EmbedText(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new EmbedText(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, EmbedText::new, inputText, inferenceId); + } + + @Override + public String toString() { + return "EMBED_TEXT(" + inputText + ", " + inferenceId + ")"; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + EmbedText embedText = (EmbedText) o; + return Objects.equals(inferenceId, embedText.inferenceId) && Objects.equals(inputText, embedText.inputText); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), inferenceId, inputText); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java new file mode 100644 index 0000000000000..d32587dc71058 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.List; + +/** + * A function is a function using an inference model. + */ +public abstract class InferenceFunction> extends Function { + + public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id"; + + protected InferenceFunction(Source source, List children) { + super(source, children); + } + + /** + * Returns the inference model ID expression. + */ + public abstract Expression inferenceId(); + + /** + * Returns the task type of the inference model. + */ + public abstract TaskType taskType(); + + public abstract PlanType withInferenceResolutionError(String inferenceId, String error); +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 63026fb9d7201..d66c8fdc86680 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -56,7 +56,6 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; @@ -190,8 +189,8 @@ private TypeResolution resolveField() { private TypeResolution resolveQuery() { return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( - isNotNullAndFoldable(query(), sourceText(), SECOND) - ); + isNotNull(query(), sourceText(), SECOND) + ).and(isFoldable(query(), sourceText(), SECOND)); } private TypeResolution resolveK() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java index fe6ab6e9a998c..54613e31212e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AsyncOperator; @@ -15,7 +16,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; @@ -26,7 +26,7 @@ import static org.elasticsearch.common.logging.LoggerMessageFormat.format; /** - * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}. + * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceResolver}. *

* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It * transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}. @@ -41,21 +41,21 @@ public abstract class InferenceOperator extends AsyncOperator + * This method traverses the logical plan tree and identifies all inference operations, + * extracting their deployment IDs for subsequent validation. Currently, supports: + *

    + *
  • {@link InferencePlan} objects (Completion, etc.)
  • + *
  • {@link InferenceFunction} objects (EmbedText, etc.)
  • + *
+ * + * @param plan The logical plan to scan for inference operations + * @param c Consumer function to receive each discovered inference ID + */ + public void collectInferenceIds(LogicalPlan plan, Consumer c) { + collectInferenceIdsFromInferencePlans(plan, c); + collectInferenceIdsFromInferenceFunctions(plan, c); + } + + /** + * Resolves a list of inference deployment IDs to their metadata. + *

+ * For each inference ID, this method: + *

    + *
  1. Queries the inference service to verify the deployment exists
  2. + *
  3. Retrieves the deployment's task type and configuration
  4. + *
  5. Builds an {@link InferenceResolution} containing resolved metadata or errors
  6. + *
+ * + * @param inferenceIds List of inference deployment IDs to resolve + * @param listener Callback to receive the resolution results + */ + public void resolveInferenceIds(List inferenceIds, ActionListener listener) { + resolveInferenceIds(Set.copyOf(inferenceIds), listener); + } + + private void resolveInferenceIds(Set inferenceIds, ActionListener listener) { + + if (inferenceIds.isEmpty()) { + listener.onResponse(InferenceResolution.EMPTY); + return; + } + + final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder(); + + final CountDownActionListener countdownListener = new CountDownActionListener( + inferenceIds.size(), + ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure) + ); + + for (var inferenceId : inferenceIds) { + client.execute( + GetInferenceModelAction.INSTANCE, + new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), + ActionListener.wrap(r -> { + ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); + inferenceResolutionBuilder.withResolvedInference(resolvedInference); + countdownListener.onResponse(null); + }, e -> { + inferenceResolutionBuilder.withError(inferenceId, e.getMessage()); + countdownListener.onResponse(null); + }) + ); + } + } + + /** + * Collects inference IDs from inference function calls within the logical plan. + *

+ * This method scans the logical plan for {@link UnresolvedFunction} instances that represent + * inference functions (e.g., EMBED_TEXT). For each inference function found: + *

    + *
  1. Resolves the function definition through the function registry and checks if the function implements {@link InferenceFunction}
  2. + *
  3. Extracts the inference deployment ID from the function arguments
  4. + *
+ *

+ * This operates during pre-analysis when functions are still unresolved, allowing early + * validation of inference deployments before query optimization. + * + * @param plan The logical plan to scan for inference function calls + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) { + EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry(); + plan.forEachExpressionUp(UnresolvedFunction.class, f -> { + String functionName = snapshotRegistry.resolveAlias(f.name()); + if (snapshotRegistry.functionExists(functionName)) { + FunctionDefinition def = snapshotRegistry.resolveFunction(functionName); + if (InferenceFunction.class.isAssignableFrom(def.clazz())) { + String inferenceId = inferenceId(f, def); + if (inferenceId != null) { + c.accept(inferenceId); + } + } + } + }); + } + + /** + * Collects inference IDs from InferencePlan objects within the logical plan. + * + * @param plan The logical plan to scan for InferencePlan objects + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c) { + plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan))); + } + + /** + * Extracts the inference ID from an InferencePlan object. + * + * @param plan The InferencePlan object to extract the ID from + * @return The inference ID as a string + */ + private static String inferenceId(InferencePlan plan) { + return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); + } + + private static String inferenceId(Expression e) { + return BytesRefs.toString(e.fold(FoldContext.small())); + } + + /** + * Extracts the inference ID from an {@link UnresolvedFunction} instance. + *

+ * This method inspects the function's arguments to find the inference ID. + * Currently, it only supports positional parameters named "inference_id". + * + * @param f The unresolved function to extract the ID from + * @param def The function definition + * @return The inference ID as a string, or null if not found + */ + public String inferenceId(UnresolvedFunction f, FunctionDefinition def) { + EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def); + + for (int i = 0; i < functionDescription.args().size(); i++) { + EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i); + + if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) { + // Found a positional parameter named "inference_id", so use its value + Expression argValue = f.arguments().get(i); + if (argValue != null && argValue.foldable()) { + return inferenceId(argValue); + } + } + + // TODO: support inference ID as an optional named parameter + } + + return null; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java deleted file mode 100644 index d67d6817742c0..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.lucene.BytesRefs; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; - -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; - -public class InferenceRunner { - - private final Client client; - private final ThreadPool threadPool; - - public InferenceRunner(Client client, ThreadPool threadPool) { - this.client = client; - this.threadPool = threadPool; - } - - public ThreadPool threadPool() { - return threadPool; - } - - public void resolveInferenceIds(List> plans, ActionListener listener) { - resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener); - - } - - private void resolveInferenceIds(Set inferenceIds, ActionListener listener) { - - if (inferenceIds.isEmpty()) { - listener.onResponse(InferenceResolution.EMPTY); - return; - } - - final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder(); - - final CountDownActionListener countdownListener = new CountDownActionListener( - inferenceIds.size(), - ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure) - ); - - for (var inferenceId : inferenceIds) { - client.execute( - GetInferenceModelAction.INSTANCE, - new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), - ActionListener.wrap(r -> { - ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); - inferenceResolutionBuilder.withResolvedInference(resolvedInference); - countdownListener.onResponse(null); - }, e -> { - inferenceResolutionBuilder.withError(inferenceId, e.getMessage()); - countdownListener.onResponse(null); - }) - ); - } - } - - private static String planInferenceId(InferencePlan plan) { - return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); - } - - public void doInference(InferenceAction.Request request, ActionListener listener) { - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, request, listener); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java new file mode 100644 index 0000000000000..34236f2fb57fe --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; + +public class InferenceServices { + private final Client client; + private final BulkInferenceExecutor.Factory bulkInferenceExecutorFactory; + + public InferenceServices(Client client, ThreadPool threadPool) { + this.client = client; + this.bulkInferenceExecutorFactory = new BulkInferenceExecutor.Factory(client, threadPool); + } + + public BulkInferenceExecutor bulkInferenceExecutor(BulkInferenceExecutionConfig bulkExecutionConfig) { + return bulkInferenceExecutorFactory.create(bulkExecutionConfig); + } + + public BulkInferenceExecutor.Factory bulkInferenceExecutorFactory() { + return bulkInferenceExecutorFactory; + } + + public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) { + return new InferenceResolver(functionRegistry, client); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java index 257799962dda7..cd382df3ece9b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java @@ -8,10 +8,10 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import java.util.ArrayList; @@ -22,6 +22,9 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + /** * Executes a sequence of inference requests in bulk with throttling and concurrency control. */ @@ -32,12 +35,11 @@ public class BulkInferenceExecutor { /** * Constructs a new {@code BulkInferenceExecutor}. * - * @param inferenceRunner The inference runner used to execute individual inference requests. - * @param threadPool The thread pool for executing inference tasks. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). + * @param throttledInferenceRunner The throttled inference runner used to execute individual inference requests. + * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). */ - public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) { - this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig); + private BulkInferenceExecutor(ThrottledInferenceRunner throttledInferenceRunner, BulkInferenceExecutionConfig bulkExecutionConfig) { + this.throttledInferenceRunner = throttledInferenceRunner; this.bulkExecutionConfig = bulkExecutionConfig; } @@ -149,33 +151,18 @@ private void sendResponseOnCompletion() { * Manages throttled inference tasks execution. */ private static class ThrottledInferenceRunner { - private final InferenceRunner inferenceRunner; + private final Client client; private final ExecutorService executorService; private final BlockingQueue pendingRequestsQueue; private final Semaphore permits; - private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) { + private ThrottledInferenceRunner(Client client, ExecutorService executorService, int maxRunningTasks) { this.executorService = executorService; this.permits = new Semaphore(maxRunningTasks); - this.inferenceRunner = inferenceRunner; + this.client = client; this.pendingRequestsQueue = new ArrayBlockingQueue<>(maxRunningTasks); } - /** - * Creates a new {@code ThrottledInferenceRunner} with the specified configuration. - * - * @param inferenceRunner TThe inference runner used to execute individual inference requests. - * @param executorService The executor used for asynchronous execution. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). - */ - public static ThrottledInferenceRunner create( - InferenceRunner inferenceRunner, - ExecutorService executorService, - BulkInferenceExecutionConfig bulkExecutionConfig - ) { - return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests()); - } - /** * Schedules the inference task for execution. If a permit is available, the task runs immediately; otherwise, it is queued. * @@ -212,7 +199,7 @@ private void executePendingRequests() { * Add an inference task to the queue. * * @param request The inference request. - * * @param listener The listener to notify on response or failure. + * @param listener The listener to notify on response or failure. */ private void enqueueTask(InferenceAction.Request request, ActionListener listener) { try { @@ -240,7 +227,7 @@ private AbstractRunnable createTask(InferenceAction.Request request, ActionListe @Override protected void doRun() { try { - inferenceRunner.doInference(request, completionListener); + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, request, listener); } catch (Throwable e) { listener.onFailure(new RuntimeException("Unexpected failure while running inference", e)); } @@ -257,4 +244,13 @@ public void onFailure(Exception e) { private static ExecutorService executorService(ThreadPool threadPool) { return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME); } + + public record Factory(Client client, ThreadPool threadPool) { + public BulkInferenceExecutor create(BulkInferenceExecutionConfig bulkExecutionConfig) { + return new BulkInferenceExecutor( + new ThrottledInferenceRunner(client, executorService(threadPool), bulkExecutionConfig.maxOutstandingRequests()), + bulkExecutionConfig + ); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java index e53fda90c88b3..6c686ed02b233 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java @@ -7,16 +7,16 @@ package org.elasticsearch.xpack.esql.inference.completion; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Releasables; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.inference.InferenceOperator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import java.util.stream.IntStream; @@ -31,12 +31,12 @@ public class CompletionOperator extends InferenceOperator { public CompletionOperator( DriverContext driverContext, - InferenceRunner inferenceRunner, - ThreadPool threadPool, + ThreadContext threadContext, + BulkInferenceExecutor.Factory bulkInferenceExecutorFactory, String inferenceId, ExpressionEvaluator promptEvaluator ) { - super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + super(driverContext, threadContext, bulkInferenceExecutorFactory, BulkInferenceExecutionConfig.DEFAULT, inferenceId); this.promptEvaluator = promptEvaluator; } @@ -88,9 +88,11 @@ protected CompletionOperatorOutputBuilder outputBuilder(Page input) { /** * Factory for creating {@link CompletionOperator} instances. */ - public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory) - implements - OperatorFactory { + public record Factory( + BulkInferenceExecutor.Factory inferenceExecutorFactory, + String inferenceId, + ExpressionEvaluator.Factory promptEvaluatorFactory + ) implements OperatorFactory { @Override public String describe() { return "CompletionOperator[inference_id=[" + inferenceId + "]]"; @@ -100,8 +102,8 @@ public String describe() { public Operator get(DriverContext driverContext) { return new CompletionOperator( driverContext, - inferenceRunner, - inferenceRunner.threadPool(), + inferenceExecutorFactory.threadPool().getThreadContext(), + inferenceExecutorFactory, inferenceId, promptEvaluatorFactory.get(driverContext) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java index ca628fdba8a8f..cbc71d61fe6ed 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.inference.rerank; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; @@ -14,10 +15,9 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Releasables; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.inference.InferenceOperator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; import java.util.stream.IntStream; @@ -40,14 +40,14 @@ public class RerankOperator extends InferenceOperator { public RerankOperator( DriverContext driverContext, - InferenceRunner inferenceRunner, - ThreadPool threadPool, + ThreadContext threadContext, + BulkInferenceExecutor.Factory bulkInferenceExecutorFactory, String inferenceId, String queryText, ExpressionEvaluator rowEncoder, int scoreChannel ) { - super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + super(driverContext, threadContext, bulkInferenceExecutorFactory, BulkInferenceExecutionConfig.DEFAULT, inferenceId); this.queryText = queryText; this.rowEncoder = rowEncoder; this.scoreChannel = scoreChannel; @@ -100,7 +100,7 @@ protected RerankOperatorOutputBuilder outputBuilder(Page input) { * Factory for creating {@link RerankOperator} instances */ public record Factory( - InferenceRunner inferenceRunner, + BulkInferenceExecutor.Factory inferenceExecutorFactory, String inferenceId, String queryText, ExpressionEvaluator.Factory rowEncoderFactory, @@ -116,8 +116,8 @@ public String describe() { public Operator get(DriverContext driverContext) { return new RerankOperator( driverContext, - inferenceRunner, - inferenceRunner.threadPool(), + inferenceExecutorFactory.threadPool().getThreadContext(), + inferenceExecutorFactory, inferenceId, queryText, rowEncoderFactory().get(driverContext), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index ad6cb42f7f835..7c1e7372ff883 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -86,7 +86,7 @@ import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter; import org.elasticsearch.xpack.esql.expression.Order; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; @@ -159,7 +159,7 @@ public class LocalExecutionPlanner { private final Supplier exchangeSinkSupplier; private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; - private final InferenceRunner inferenceRunner; + private final InferenceServices inferenceServices; private final PhysicalOperationProviders physicalOperationProviders; private final List shardContexts; @@ -175,7 +175,7 @@ public LocalExecutionPlanner( Supplier exchangeSinkSupplier, EnrichLookupService enrichLookupService, LookupFromIndexService lookupFromIndexService, - InferenceRunner inferenceRunner, + InferenceServices inferenceServices, PhysicalOperationProviders physicalOperationProviders, List shardContexts ) { @@ -191,7 +191,7 @@ public LocalExecutionPlanner( this.exchangeSinkSupplier = exchangeSinkSupplier; this.enrichLookupService = enrichLookupService; this.lookupFromIndexService = lookupFromIndexService; - this.inferenceRunner = inferenceRunner; + this.inferenceServices = inferenceServices; this.physicalOperationProviders = physicalOperationProviders; this.shardContexts = shardContexts; } @@ -318,7 +318,10 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti source.layout ); - return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout); + return source.with( + new CompletionOperator.Factory(inferenceServices.bulkInferenceExecutorFactory(), inferenceId, promptEvaluatorFactory), + outputLayout + ); } private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) { @@ -465,7 +468,7 @@ private PhysicalOperation planParallelNode(ParallelExec parallelExec, LocalExecu statusInterval, settings ), - DriverParallelism.SINGLE + context.driverParallelism().get() ) ); context.driverParallelism.set(DriverParallelism.SINGLE); @@ -654,7 +657,13 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel(); return source.with( - new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel), + new RerankOperator.Factory( + inferenceServices.bulkInferenceExecutorFactory(), + inferenceId, + queryText, + rowEncoderFactory, + scoreChannel + ), outputLayout ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 39e3503b5fdd9..d0b176635d435 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -52,7 +52,7 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; import org.elasticsearch.xpack.esql.plan.physical.OutputExec; @@ -130,7 +130,7 @@ public class ComputeService { private final DriverTaskRunner driverRunner; private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; - private final InferenceRunner inferenceRunner; + private final InferenceServices inferenceServices; private final ClusterService clusterService; private final ProjectResolver projectResolver; private final AtomicLong childSessionIdGenerator = new AtomicLong(); @@ -158,7 +158,7 @@ public ComputeService( this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor); this.enrichLookupService = enrichLookupService; this.lookupFromIndexService = lookupFromIndexService; - this.inferenceRunner = transportActionServices.inferenceRunner(); + this.inferenceServices = transportActionServices.inferenceServices(); this.clusterService = transportActionServices.clusterService(); this.projectResolver = transportActionServices.projectResolver(); this.dataNodeComputeHandler = new DataNodeComputeHandler( @@ -576,7 +576,7 @@ public SourceProvider createSourceProvider() { context.exchangeSinkSupplier(), enrichLookupService, lookupFromIndexService, - inferenceRunner, + inferenceServices, physicalOperationProviders, contexts ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java index ccabe09fd466c..3aa466740b689 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java @@ -14,7 +14,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.transport.TransportService; import org.elasticsearch.usage.UsageService; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceServices; public record TransportActionServices( TransportService transportService, @@ -24,5 +24,5 @@ public record TransportActionServices( ProjectResolver projectResolver, IndexNameExpressionResolver indexNameExpressionResolver, UsageService usageService, - InferenceRunner inferenceRunner + InferenceServices inferenceServices ) {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 4be7b31bb96c0..4ca54b4d9b789 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -51,7 +51,7 @@ import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; import org.elasticsearch.xpack.esql.execution.PlanExecutor; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner; import org.elasticsearch.xpack.esql.session.Result; @@ -166,7 +166,7 @@ public TransportEsqlQueryAction( projectResolver, indexNameExpressionResolver, usageService, - new InferenceRunner(client, threadPool) + new InferenceServices(client, threadPool) ); this.computeService = new ComputeService( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 40a859e3f5b58..6b9c823a53707 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; +import org.elasticsearch.xpack.esql.analysis.PreAnalyzerContext; import org.elasticsearch.xpack.esql.analysis.Verifier; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; @@ -61,7 +62,7 @@ import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.index.MappingException; import org.elasticsearch.xpack.esql.inference.InferenceResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceResolver; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; @@ -88,7 +89,6 @@ import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; @@ -148,7 +148,7 @@ public interface PlanRunner { private final PlanTelemetry planTelemetry; private final IndicesExpressionGrouper indicesExpressionGrouper; private Set configuredClusters; - private final InferenceRunner inferenceRunner; + private final InferenceResolver inferenceResolver; private final RemoteClusterService remoteClusterService; private boolean explainMode; @@ -167,7 +167,8 @@ public EsqlSession( Verifier verifier, PlanTelemetry planTelemetry, IndicesExpressionGrouper indicesExpressionGrouper, - TransportActionServices services + TransportActionServices services, + InferenceResolver inferenceResolver ) { this.sessionId = sessionId; this.configuration = configuration; @@ -181,7 +182,7 @@ public EsqlSession( this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration)); this.planTelemetry = planTelemetry; this.indicesExpressionGrouper = indicesExpressionGrouper; - this.inferenceRunner = services.inferenceRunner(); + this.inferenceResolver = inferenceResolver; this.preMapper = new PreMapper(services); this.remoteClusterService = services.transportService().getRemoteClusterService(); } @@ -354,7 +355,7 @@ public void analyzedPlan( // Capture configured remotes list to ensure consistency throughout the session configuredClusters = Set.copyOf(indicesExpressionGrouper.getConfiguredClusters()); - PreAnalyzer.PreAnalysis preAnalysis = preAnalyzer.preAnalyze(parsed); + PreAnalyzer.PreAnalysis preAnalysis = preAnalyzer.preAnalyze(parsed, new PreAnalyzerContext(inferenceResolver)); var unresolvedPolicies = preAnalysis.enriches.stream() .map( e -> new EnrichPolicyResolver.UnresolvedPolicy( @@ -372,7 +373,7 @@ public void analyzedPlan( l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l) ) .andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l)) - .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l)); + .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferenceIds, preAnalysisResult, l)); // first resolve the lookup indices, then the main indices for (var index : preAnalysis.lookupIndices) { listener = listener.andThen((l, preAnalysisResult) -> preAnalyzeLookupIndex(index, preAnalysisResult, executionInfo, l)); @@ -768,12 +769,8 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric } } - private void resolveInferences( - List> inferencePlans, - PreAnalysisResult preAnalysisResult, - ActionListener l - ) { - inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution)); + private void resolveInferences(List inferenceIds, PreAnalysisResult preAnalysisResult, ActionListener l) { + inferenceResolver.resolveInferenceIds(inferenceIds, l.map(preAnalysisResult::withInferenceResolution)); } static PreAnalysisResult fieldNames(LogicalPlan parsed, Set enrichPolicyMatchFields, PreAnalysisResult result) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index bdf2ba39edc66..ed95abe3a3059 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; +import org.elasticsearch.xpack.esql.analysis.PreAnalyzerContext; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -67,7 +68,8 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceResolver; +import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; @@ -98,7 +100,6 @@ import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry; import org.junit.After; import org.junit.Before; -import org.mockito.Mockito; import java.io.IOException; import java.net.URL; @@ -131,6 +132,7 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; /** * CSV-based unit testing. @@ -517,7 +519,7 @@ private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.MultiInd } private static CsvTestsDataLoader.MultiIndexTestDataset testDatasets(LogicalPlan parsed) { - var preAnalysis = new PreAnalyzer().preAnalyze(parsed); + var preAnalysis = new PreAnalyzer().preAnalyze(parsed, new PreAnalyzerContext(mock(InferenceResolver.class))); var indices = preAnalysis.indices; if (indices.isEmpty()) { /* @@ -584,7 +586,8 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { TEST_VERIFIER, new PlanTelemetry(functionRegistry), null, - EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES + EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES, + EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES.inferenceServices().inferenceResolver(functionRegistry) ); TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(foldCtx, testDatasets); @@ -695,9 +698,9 @@ void executeSubPlan( configuration, exchangeSource::createExchangeSource, () -> exchangeSink.createExchangeSink(() -> {}), - Mockito.mock(EnrichLookupService.class), - Mockito.mock(LookupFromIndexService.class), - Mockito.mock(InferenceRunner.class), + mock(EnrichLookupService.class), + mock(LookupFromIndexService.class), + mock(InferenceServices.class), physicalOperationProviders, List.of() ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index cbb825ca9581b..6009fa5b725dc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -200,6 +200,7 @@ public static InferenceResolution defaultInferenceResolution() { return InferenceResolution.builder() .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK)) .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION)) + .withResolvedInference(new ResolvedInference("text-embedding-inference-id", TaskType.TEXT_EMBEDDING)) .withError("error-inference-id", "error with inference resolution") .build(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index b2521bddfb47b..2a5ae50a25b69 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -50,6 +50,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; @@ -3458,7 +3459,11 @@ private void assertProjectionWithMapping(String query, String mapping, QueryPara } private void assertError(String query, String mapping, QueryParams params, String error) { - Throwable e = expectThrows(VerificationException.class, () -> analyze(query, mapping, params)); + assertError(query, mapping, params, error, VerificationException.class); + } + + private void assertError(String query, String mapping, QueryParams params, String error, Class clazz) { + Throwable e = expectThrows(clazz, () -> analyze(query, mapping, params)); assertThat(e.getMessage(), containsString(error)); } @@ -3802,6 +3807,129 @@ public void testResolveCompletionOutputField() { assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField()))); } + public void testResolveEmbedTextInferenceId() { + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + """, "mapping-books.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + + public void testResolveEmbedTextInferenceIdInvalidTaskType() { + assertError( + """ + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "completion-inference-id") + """, + "mapping-books.json", + new QueryParams(), + "cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function." + + " Only inference endpoints with the task type [text_embedding] are supported" + ); + } + + public void testResolveEmbedTextInferenceMissingInferenceId() { + assertError(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "unknown-inference-id") + """, "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]"); + } + + public void testResolveEmbedTextInferenceIdResolutionError() { + assertError(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "error-inference-id") + """, "mapping-books.json", new QueryParams(), "error with inference resolution"); + } + + public void testResolveEmbedTextInNestedExpression() { + LogicalPlan plan = analyze(""" + FROM colors METADATA _score + | WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10) + """, "mapping-colors.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + + // Navigate to the EMBED_TEXT function within the KNN function + filter.condition().forEachDown(EmbedText.class, embedText -> { + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("blue"))); + }); + } + + public void testResolveEmbedTextDataType() { + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + """, "mapping-books.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR)); + } + + public void testResolveEmbedTextInvalidParameters() { + assertError( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")", + "mapping-books.json", + new QueryParams(), + "first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]" + ); + + assertError( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)", + "mapping-books.json", + new QueryParams(), + "error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1", + ParsingException.class + ); + } + + public void testResolveEmbedTextWithPositionalQueryParams() { + LogicalPlan plan = analyze( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)", + "mapping-books.json", + new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id"))) + ); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + + public void testResolveEmbedTextWithNamedlQueryParams() { + LogicalPlan plan = analyze( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)", + "mapping-books.json", + new QueryParams( + List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id")) + ) + ); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + public void testResolveGroupingsBeforeResolvingImplicitReferencesToGroupings() { var plan = analyze(""" FROM test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java new file mode 100644 index 0000000000000..4b490a86b7587 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; +import java.util.Locale; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public class EmbedTextErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + @Override + protected List cases() { + return paramsToSuppliers(EmbedTextTests.parameters()); + } + + @Override + protected Expression build(Source source, List args) { + return new EmbedText(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string")); + } + + protected static String typeErrorMessage( + boolean includeOrdinal, + List> validPerPosition, + List signature, + AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + for (int i = 0; i < signature.size(); i++) { + if (signature.get(i) == DataType.NULL) { + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : ""; + return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []"; + } + + if (validPerPosition.get(i).contains(signature.get(i)) == false) { + break; + } + } + + return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage( + includeOrdinal, + validPerPosition, + signature, + positionalErrorMessageSupplier + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java new file mode 100644 index 0000000000000..59d5377e36f76 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.junit.Before; + +import java.io.IOException; + +public class EmbedTextSerializationTests extends AbstractExpressionSerializationTests { + + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + @Override + protected EmbedText createTestInstance() { + Source source = randomSource(); + Expression inputText = randomChild(); + Expression inferenceId = randomChild(); + return new EmbedText(source, inputText, inferenceId); + } + + @Override + protected EmbedText mutateInstance(EmbedText instance) throws IOException { + Source source = instance.source(); + Expression inputText = instance.inputText(); + Expression inferenceId = instance.inferenceId(); + if (randomBoolean()) { + inputText = randomValueOtherThan(inputText, AbstractExpressionSerializationTests::randomChild); + } else { + inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild); + } + return new EmbedText(source, inputText, inferenceId); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java new file mode 100644 index 0000000000000..c342ee143e6af --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matchers; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("embed_text") +public class EmbedTextTests extends AbstractFunctionTestCase { + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + public EmbedTextTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + + // Valid cases with string types for input text and inference_id + for (DataType inputTextDataType : DataType.stringTypes()) { + for (DataType inferenceIdDataType : DataType.stringTypes()) { + suppliers.add( + new TestCaseSupplier( + List.of(inputTextDataType, inferenceIdDataType), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inputTextDataType, "inputText"), + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inferenceIdDataType, "inference_id") + ), + Matchers.blankOrNullString(), + DENSE_VECTOR, + equalTo(true) + ) + ) + ); + } + } + + return parameterSuppliersFromTypedData(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new EmbedText(source, args.get(0), args.get(1)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java index c49e301968aa0..4081b165fc48b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -39,6 +39,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; @@ -118,7 +119,7 @@ public void testOperatorStatus() { } @SuppressWarnings("unchecked") - protected InferenceRunner mockedSimpleInferenceRunner() { + protected BulkInferenceExecutor.Factory mockedBulkInferenceExecutorFactory() { Client client = new NoOpClient(threadPool) { @Override protected void doExecute( @@ -144,7 +145,7 @@ protected void } }; - return new InferenceRunner(client, threadPool); + return new BulkInferenceExecutor.Factory(client, threadPool); } protected abstract InferenceResultsType mockInferenceResult(InferenceAction.Request request); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java similarity index 60% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java index ef7b3984bd532..a80119705b022 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java @@ -22,16 +22,19 @@ import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; @@ -40,7 +43,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class InferenceRunnerTests extends ESTestCase { +public class InferenceResolverTests extends ESTestCase { private TestThreadPool threadPool; @Before @@ -63,12 +66,74 @@ public void shutdownThreadPool() { terminate(threadPool); } + public void testCollectInferenceIds() { + // Rerank inference plan + assertCollectInferenceIds( + "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH inferenceId=`rerank-inference-id`", + List.of("rerank-inference-id") + ); + + // Completion inference plan + assertCollectInferenceIds( + "FROM books METADATA _score | COMPLETION \"italian food recipe\" WITH `completion-inference-id`", + List.of("completion-inference-id") + ); + + // Multiple inference plans + assertCollectInferenceIds(""" + FROM books METADATA _score + | RERANK \"italian food recipe\" ON title WITH inferenceId=`rerank-inference-id` + | COMPLETION \"italian food recipe\" WITH `completion-inference-id` + """, List.of("rerank-inference-id", "completion-inference-id")); + + // From an inference function (EMBED_TEXT) + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = EMBED_TEXT(\"italian food recipe\", \"text-embedding-inference-id\")", + List.of("text-embedding-inference-id") + ); + + // From an inference function nested in another function + assertCollectInferenceIds( + "FROM books METADATA _score | WHERE KNN(field, EMBED_TEXT(\"italian food recipe\", \"text-embedding-inference-id\"))", + List.of("text-embedding-inference-id") + ); + + // Multiples functions + assertCollectInferenceIds(""" + FROM books METADATA _score + | WHERE KNN(fieldA, EMBED_TEXT("italian food recipe", "text-embedding-inference-id-a")) + | WHERE KNN(fieldB, EMBED_TEXT("italian food recipe", "text-embedding-inference-id-b")) + """, List.of("text-embedding-inference-id-a", "text-embedding-inference-id-b")); + + // All the way + assertCollectInferenceIds( + """ + FROM books METADATA _score + | RERANK "italian food recipe" ON title WITH inferenceId=`rerank-inference-id` + | COMPLETION "italian food recipe" WITH `completion-inference-id` + | WHERE KNN(fieldA, EMBED_TEXT("italian food recipe", "text-embedding-inference-id-a")) + | WHERE KNN(fieldB, EMBED_TEXT("italian food recipe", "text-embedding-inference-id-b")) + """, + List.of("rerank-inference-id", "completion-inference-id", "text-embedding-inference-id-a", "text-embedding-inference-id-b") + ); + + // No inference operations + assertCollectInferenceIds("FROM books | WHERE title:\"test\"", List.of()); + } + + private void assertCollectInferenceIds(String query, List expectedInferenceIds) { + Set inferenceIds = new HashSet<>(); + InferenceResolver inferenceResolver = inferenceResolver(); + inferenceResolver.collectInferenceIds(new EsqlParser().createStatement(query, configuration(query)), inferenceIds::add); + assertThat(inferenceIds, containsInAnyOrder(expectedInferenceIds.toArray(new String[0]))); + } + public void testResolveInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("rerank-plan")); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("rerank-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -81,15 +146,11 @@ public void testResolveInferenceIds() throws Exception { } public void testResolveMultipleInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of( - mockInferencePlan("rerank-plan"), - mockInferencePlan("rerank-plan"), - mockInferencePlan("completion-plan") - ); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("rerank-plan", "rerank-plan", "completion-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -109,12 +170,12 @@ public void testResolveMultipleInferenceIds() throws Exception { } public void testResolveMissingInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("missing-plan")); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("missing-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -175,13 +236,11 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction. return null; } - private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { - return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class)); + private InferenceResolver inferenceResolver() { + return new InferenceResolver(new EsqlFunctionRegistry(), mockClient()); } - private static InferencePlan mockInferencePlan(String inferenceId) { - InferencePlan plan = mock(InferencePlan.class); - when(plan.inferenceId()).thenReturn(Literal.keyword(Source.EMPTY, inferenceId)); - return plan; + private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { + return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java index 7e44c681c6fc4..b27b09932520d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; @@ -17,7 +18,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; @@ -33,6 +33,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -64,10 +65,10 @@ public void testSuccessfulExecution() throws Exception { List requests = randomInferenceRequestList(between(1, 1000)); List responses = randomInferenceResponseList(requests.size()); - InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { + Client client = mockClient(invocation -> { runWithRandomDelay(() -> { - ActionListener l = invocation.getArgument(1); - l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)))); + ActionListener l = invocation.getArgument(2); + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); }); return null; }); @@ -75,7 +76,7 @@ public void testSuccessfulExecution() throws Exception { AtomicReference> output = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + bulkExecutor(client).execute(requestIterator(requests), listener); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses)))); } @@ -87,7 +88,7 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception { AtomicReference> output = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); - bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, listener); + bulkExecutor(mock(Client.class)).execute(requestIterator, listener); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty()))); } @@ -95,9 +96,9 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception { public void testInferenceRunnerAlwaysFails() throws Exception { List requests = randomInferenceRequestList(between(1, 1000)); - InferenceRunner inferenceRunner = mock(invocation -> { + Client client = mockClient(invocation -> { runWithRandomDelay(() -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("inference failure")); }); return null; @@ -106,7 +107,7 @@ public void testInferenceRunnerAlwaysFails() throws Exception { AtomicReference exception = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + bulkExecutor(client).execute(requestIterator(requests), listener); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -117,10 +118,10 @@ public void testInferenceRunnerAlwaysFails() throws Exception { public void testInferenceRunnerSometimesFails() throws Exception { List requests = randomInferenceRequestList(between(1, 1000)); - InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { - ActionListener listener = invocation.getArgument(1); + Client client = mockClient(invocation -> { + ActionListener listener = invocation.getArgument(2); runWithRandomDelay(() -> { - if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) { + if ((requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)) % requests.size()) == 0) { listener.onFailure(new RuntimeException("inference failure")); } else { listener.onResponse(mockInferenceResponse()); @@ -133,7 +134,7 @@ public void testInferenceRunnerSometimesFails() throws Exception { AtomicReference exception = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + bulkExecutor(client).execute(requestIterator(requests), listener); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -141,8 +142,8 @@ public void testInferenceRunnerSometimesFails() throws Exception { }); } - private BulkInferenceExecutor bulkExecutor(InferenceRunner inferenceRunner) { - return new BulkInferenceExecutor(inferenceRunner, threadPool, randomBulkExecutionConfig()); + private BulkInferenceExecutor bulkExecutor(Client client) { + return new BulkInferenceExecutor.Factory(client, threadPool).create(randomBulkExecutionConfig()); } private InferenceAction.Request mockInferenceRequest() { @@ -185,10 +186,11 @@ private List randomInferenceResponseList(int size) { return response; } - private InferenceRunner mockInferenceRunner(Answer doInferenceAnswer) { - InferenceRunner inferenceRunner = mock(InferenceRunner.class); - doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any()); - return inferenceRunner; + private Client mockClient(Answer doInferenceAnswer) { + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(doInferenceAnswer).when(client).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + return client; } private void runWithRandomDelay(Runnable runnable) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java index add8155240ad1..c1b8c9a236b49 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java @@ -29,7 +29,7 @@ public class CompletionOperatorTests extends InferenceOperatorTestCase