diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md new file mode 100644 index 0000000000000..e2b852912c5f5 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See + +**Parameters** + +`text` +: Text to generate embeddings from + +`inference_id` +: Identifier of the inference endpoint + diff --git a/docs/reference/query-languages/esql/images/functions/text_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_embedding.svg new file mode 100644 index 0000000000000..dab58c5e5bda0 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/text_embedding.svg @@ -0,0 +1 @@ +TEXT_EMBEDDING(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json new file mode 100644 index 0000000000000..5f1f68a2b14bd --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json @@ -0,0 +1,12 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "text_embedding", + "description" : "Generates dense vector embeddings for text using a specified inference endpoint.", + "signatures" : [ ], + "examples" : [ + "ROW input=\"Who is Victor Hugo?\"\n| EVAL embedding = TEXT_EMBEDDING(\"Who is Victor Hugo?\", \"test_dense_inference\")\n;" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md new file mode 100644 index 0000000000000..f8981fb3be66a --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### TEXT EMBEDDING +Generates dense vector embeddings for text using a specified inference endpoint. + +```esql +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference") +; +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec new file mode 100644 index 0000000000000..f026800598e10 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -0,0 +1,15 @@ +placeholder +required_capability: text_embedding_function +required_capability: not_existing_capability + +// tag::embedding-eval[] +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference") +; +// end::embedding-eval[] + + +input:keyword | embedding:dense_vector +Who is Victor Hugo? | [56.0, 50.0, 48.0] +; + diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 150495b7a5087..0b8e0b7238d57 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1319,6 +1319,11 @@ public enum Cap { */ KNN_FUNCTION_V5(Build.current().isSnapshot()), + /** + * Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings. + */ + TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()), + /** * Support for the LIKE operator with a list of wildcards. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index a5cbd69fe603d..bf0f88ed962f0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; @@ -1419,7 +1420,8 @@ private static class ResolveInference extends ParameterizedRule resolveInferencePlan(p, context)); + return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context)) + .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context)); } private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext context) { @@ -1448,6 +1450,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext return plan; } + + private InferenceFunction resolveInferenceFunction(InferenceFunction inferenceFunction, AnalyzerContext context) { + if (inferenceFunction.inferenceId().resolved() + && inferenceFunction.inferenceId().foldable() + && DataType.isString(inferenceFunction.inferenceId().dataType())) { + + String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != inferenceFunction.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass()) + + " function. Only inference endpoints with the task type [" + + inferenceFunction.taskType() + + "] are supported."; + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + } + + return inferenceFunction; + } } private static class AddImplicitLimit extends ParameterizedRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index a08ba6123a794..0224edcb4cfe3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -63,6 +63,7 @@ import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -543,7 +544,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(Hamming.class, Hamming::new, "v_hamming"), def(UrlEncode.class, UrlEncode::new, "url_encode"), def(UrlEncodeComponent.class, UrlEncodeComponent::new, "url_encode_component"), - def(UrlDecode.class, UrlDecode::new, "url_decode") } }; + def(UrlDecode.class, UrlDecode::new, "url_decode"), + def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java new file mode 100644 index 0000000000000..d2d6d9b6e2af7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.List; + +/** + * Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING). + */ +public abstract class InferenceFunction> extends Function { + + public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id"; + + protected InferenceFunction(Source source, List children) { + super(source, children); + } + + /** The inference endpoint identifier expression. */ + public abstract Expression inferenceId(); + + /** The task type required by this function (e.g., TEXT_EMBEDDING). */ + public abstract TaskType taskType(); + + /** Returns a copy with inference resolution error for display to user. */ + public abstract PlanType withInferenceResolutionError(String inferenceId, String error); + + /** True if this function contains nested inference function calls. */ + public boolean hasNestedInferenceFunction() { + return anyMatch(e -> e instanceof InferenceFunction && e != this); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java new file mode 100644 index 0000000000000..ab8b8c8c3f3c9 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +/** + * TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint. + */ +public class TextEmbedding extends InferenceFunction { + + private final Expression inferenceId; + private final Expression inputText; + + @FunctionInfo( + returnType = "dense_vector", + description = "Generates dense vector embeddings for text using a specified inference endpoint.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, + preview = true, + examples = { + @Example( + description = "Generate text embeddings using the 'test_dense_inference' inference endpoint.", + file = "text-embedding", + tag = "embedding-eval" + ) } + ) + public TextEmbedding( + Source source, + @Param(name = "text", type = { "keyword" }, description = "Text to generate embeddings from") Expression inputText, + @Param( + name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME, + type = { "keyword" }, + description = "Identifier of the inference endpoint" + ) Expression inferenceId + ) { + super(source, List.of(inputText, inferenceId)); + this.inferenceId = inferenceId; + this.inputText = inputText; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + public Expression inputText() { + return inputText; + } + + @Override + public Expression inferenceId() { + return inferenceId; + } + + @Override + public boolean foldable() { + return inferenceId.foldable() && inputText.foldable(); + } + + @Override + public DataType dataType() { + return DataType.DENSE_VECTOR; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST)) + .and(isType(inputText, DataType.KEYWORD::equals, sourceText(), FIRST, "string")); + + if (textResolution.unresolved()) { + return textResolution; + } + + TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and( + isType(inferenceId, DataType.KEYWORD::equals, sourceText(), SECOND, "string") + ).and(isFoldable(inferenceId, sourceText(), SECOND)); + + if (inferenceIdResolution.unresolved()) { + return inferenceIdResolution; + } + + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + @Override + public TextEmbedding withInferenceResolutionError(String inferenceId, String error) { + return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId); + } + + @Override + public String toString() { + return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")"; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + TextEmbedding textEmbedding = (TextEmbedding) o; + return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), inferenceId, inputText); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index 637c4d3b1ad76..28ae71ca71023 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -17,6 +17,11 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition; +import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; @@ -31,7 +36,7 @@ public class InferenceResolver { private final Client client; - + private final EsqlFunctionRegistry functionRegistry; private final ThreadPool threadPool; /** @@ -39,8 +44,9 @@ public class InferenceResolver { * * @param client The Elasticsearch client for executing inference deployment lookups */ - public InferenceResolver(Client client, ThreadPool threadPool) { + public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry, ThreadPool threadPool) { this.client = client; + this.functionRegistry = functionRegistry; this.threadPool = threadPool; } @@ -57,9 +63,8 @@ public InferenceResolver(Client client, ThreadPool threadPool) { * @param listener Callback to receive the resolution results */ public void resolveInferenceIds(LogicalPlan plan, ActionListener listener) { - List inferenceIds = new ArrayList<>(); - collectInferenceIds(plan, inferenceIds::add); - resolveInferenceIds(inferenceIds, listener); + + resolveInferenceIds(collectInferenceIds(plan), listener); } /** @@ -69,13 +74,17 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener *
  • {@link InferencePlan} objects (Completion, etc.)
  • + *
  • {@link InferenceFunction} objects (TextEmbedding, etc.)
  • * * * @param plan The logical plan to scan for inference operations - * @param c Consumer function to receive each discovered inference ID */ - void collectInferenceIds(LogicalPlan plan, Consumer c) { - collectInferenceIdsFromInferencePlans(plan, c); + List collectInferenceIds(LogicalPlan plan) { + List inferenceIds = new ArrayList<>(); + collectInferenceIdsFromInferencePlans(plan, inferenceIds::add); + collectInferenceIdsFromInferenceFunctions(plan, inferenceIds::add); + + return inferenceIds; } /** @@ -135,6 +144,28 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c.accept(inferenceId(inferencePlan))); } + /** + * Collects inference IDs from function expressions within the logical plan. + * + * @param plan The logical plan to scan for function expressions + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) { + EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry(); + plan.forEachExpressionUp(UnresolvedFunction.class, f -> { + String functionName = snapshotRegistry.resolveAlias(f.name()); + if (snapshotRegistry.functionExists(functionName)) { + FunctionDefinition def = snapshotRegistry.resolveFunction(functionName); + if (InferenceFunction.class.isAssignableFrom(def.clazz())) { + String inferenceId = inferenceId(f, def); + if (inferenceId != null) { + c.accept(inferenceId); + } + } + } + }); + } + /** * Extracts the inference ID from an InferencePlan object. * @@ -142,11 +173,31 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer plan) { - return inferenceId(plan.inferenceId()); + return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); } - private static String inferenceId(Expression e) { - return BytesRefs.toString(e.fold(FoldContext.small())); + /** + * Extracts the inference ID from an InferenceFunction expression that is not yet resolved. + * + * @param f The UnresolvedFunction expression representing the inference function + * @param def The FunctionDefinition of the inference function + * @return The inference ID as a string, or null if not found or invalid + */ + private static String inferenceId(UnresolvedFunction f, FunctionDefinition def) { + EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def); + + for (int i = 0; i < functionDescription.args().size(); i++) { + EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i); + + if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) { + Expression inferenceId = f.arguments().get(i); + if (inferenceId != null && inferenceId.foldable() && DataType.isString(inferenceId.dataType())) { + return BytesRefs.toString(inferenceId.fold(FoldContext.small())); + } + } + } + + return null; } public static Factory factory(Client client) { @@ -162,8 +213,8 @@ private Factory(Client client, ThreadPool threadPool) { this.threadPool = threadPool; } - public InferenceResolver create() { - return new InferenceResolver(client, threadPool); + public InferenceResolver create(EsqlFunctionRegistry functionRegistry) { + return new InferenceResolver(client, functionRegistry, threadPool); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java index 37c163beaecda..630477a20f447 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; @@ -33,10 +34,12 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul /** * Creates an inference resolver for resolving inference IDs in logical plans. * + * @param functionRegistry the function registry to resolve functions + * * @return a new inference resolver instance */ - public InferenceResolver inferenceResolver() { - return inferenceResolverFactory.create(); + public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) { + return inferenceResolverFactory.create(functionRegistry); } public BulkInferenceRunner bulkInferenceRunner() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 95d28081696d1..19f72320e9178 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -454,7 +454,7 @@ public void analyzedPlan( enrichPolicyResolver.resolvePolicies(preAnalysis.enriches(), executionInfo, l.map(r::withEnrichResolution)); }) .andThen((l, r) -> { - inferenceService.inferenceResolver().resolveInferenceIds(parsed, l.map(r::withInferenceResolution)); + inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(parsed, l.map(r::withInferenceResolution)); }) .andThen((l, r) -> analyzeWithRetry(parsed, requestFilter, preAnalysis, executionInfo, r, l)) .addListener(logicalPlanListener); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 172017d9cde90..ae85670ed1714 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -333,6 +333,10 @@ public final void test() throws Throwable { "CSV tests cannot currently handle FORK", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.FORK_V9.capabilityName()) ); + assumeFalse( + "CSV tests cannot currently handle TEXT_EMBEDDING function", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.capabilityName()) + ); assumeFalse( "CSV tests cannot currently handle multi_match function that depends on Lucene", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.capabilityName()) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index fbfa18dccc477..6f6c76efaf08e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -9,6 +9,7 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -26,6 +27,7 @@ import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -195,14 +197,39 @@ public static EnrichResolution defaultEnrichResolution() { return enrichResolution; } + public static final String RERANKING_INFERENCE_ID = "reranking-inference-id"; + public static final String COMPLETION_INFERENCE_ID = "completion-inference-id"; + public static final String TEXT_EMBEDDING_INFERENCE_ID = "text-embedding-inference-id"; + public static final String CHAT_COMPLETION_INFERENCE_ID = "chat-completion-inference-id"; + public static final String SPARSE_EMBEDDING_INFERENCE_ID = "sparse-embedding-inference-id"; + public static final List VALID_INFERENCE_IDS = List.of( + RERANKING_INFERENCE_ID, + COMPLETION_INFERENCE_ID, + TEXT_EMBEDDING_INFERENCE_ID, + CHAT_COMPLETION_INFERENCE_ID, + SPARSE_EMBEDDING_INFERENCE_ID + ); + public static final String ERROR_INFERENCE_ID = "error-inference-id"; + public static InferenceResolution defaultInferenceResolution() { return InferenceResolution.builder() - .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK)) - .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION)) - .withError("error-inference-id", "error with inference resolution") + .withResolvedInference(new ResolvedInference(RERANKING_INFERENCE_ID, TaskType.RERANK)) + .withResolvedInference(new ResolvedInference(COMPLETION_INFERENCE_ID, TaskType.COMPLETION)) + .withResolvedInference(new ResolvedInference(TEXT_EMBEDDING_INFERENCE_ID, TaskType.TEXT_EMBEDDING)) + .withResolvedInference(new ResolvedInference(CHAT_COMPLETION_INFERENCE_ID, TaskType.CHAT_COMPLETION)) + .withResolvedInference(new ResolvedInference(SPARSE_EMBEDDING_INFERENCE_ID, TaskType.SPARSE_EMBEDDING)) + .withError(ERROR_INFERENCE_ID, "error with inference resolution") .build(); } + public static String randomInferenceId() { + return ESTestCase.randomFrom(VALID_INFERENCE_IDS); + } + + public static String randomInferenceIdOtherThan(String... excludes) { + return ESTestCase.randomValueOtherThanMany(Arrays.asList(excludes)::contains, AnalyzerTestUtils::randomInferenceId); + } + public static void loadEnrichPolicyResolution( EnrichResolution enrich, String policyType, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 911bd45edd800..896f9d7036b1f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; @@ -123,6 +124,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping; @@ -130,6 +132,7 @@ import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomInferenceIdOtherThan; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution; import static org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -3765,6 +3768,112 @@ private void assertEmptyEsRelation(LogicalPlan plan) { assertThat(esRelation.output(), equalTo(NO_FIELDS)); } + public void testTextEmbeddingResolveInferenceId() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + LogicalPlan plan = analyze( + String.format(Locale.ROOT, """ + FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID), + "mapping-books.json" + ); + + Eval eval = as(as(plan, Limit.class).child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Alias alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo("embedding")); + TextEmbedding function = as(alias.child(), TextEmbedding.class); + + assertThat(function.inputText(), equalTo(string("italian food recipe"))); + assertThat(function.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID))); + } + + public void testTextEmbeddingFunctionResolveType() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + LogicalPlan plan = analyze( + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID), + "mapping-books.json" + ); + + Eval eval = as(as(plan, Limit.class).child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Alias alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo("embedding")); + + TextEmbedding function = as(alias.child(), TextEmbedding.class); + + assertThat(function.foldable(), equalTo(true)); + assertThat(function.dataType(), equalTo(DENSE_VECTOR)); + } + + public void testTextEmbeddingFunctionMissingInferenceIdError() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + VerificationException ve = expectThrows( + VerificationException.class, + () -> analyze( + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", "unknow-inference-id"), + "mapping-books.json" + ) + ); + + assertThat(ve.getMessage(), containsString("unresolved inference [unknow-inference-id]")); + } + + public void testTextEmbeddingFunctionInvalidInferenceIdError() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + String inferenceId = randomInferenceIdOtherThan(TEXT_EMBEDDING_INFERENCE_ID); + VerificationException ve = expectThrows( + VerificationException.class, + () -> analyze( + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", inferenceId), + "mapping-books.json" + ) + ); + + assertThat( + ve.getMessage(), + containsString(String.format(Locale.ROOT, "cannot use inference endpoint [%s] with task type", inferenceId)) + ); + } + + public void testTextEmbeddingFunctionWithoutModel() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + ParsingException ve = expectThrows(ParsingException.class, () -> analyze(""" + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe")""", "mapping-books.json")); + + assertThat( + ve.getMessage(), + containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments") + ); + } + + public void testKnnFunctionWithTextEmbedding() { + assumeTrue("KNN function capability required", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + LogicalPlan plan = analyze( + String.format(Locale.ROOT, """ + from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s"))""", TEXT_EMBEDDING_INFERENCE_ID), + "mapping-dense_vector.json" + ); + + Limit limit = as(plan, Limit.class); + Filter filter = as(limit.child(), Filter.class); + Knn knn = as(filter.condition(), Knn.class); + assertThat(knn.field(), instanceOf(FieldAttribute.class)); + assertThat(((FieldAttribute) knn.field()).name(), equalTo("float_vector")); + + TextEmbedding textEmbedding = as(knn.query(), TextEmbedding.class); + assertThat(textEmbedding.inputText(), equalTo(string("italian food recipe"))); + assertThat(textEmbedding.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID))); + } + public void testResolveRerankInferenceId() { { LogicalPlan plan = analyze(""" diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 2854105c40249..79ca5bf5fd0ef 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -41,6 +41,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; @@ -2704,6 +2705,42 @@ public void testSortInTimeSeries() { and the first aggregation [STATS avg(network.connections)] is not allowed""")); } + public void testTextEmbeddingFunctionInvalidQuery() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(null, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(null, ?)] cannot be null, received [null]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(42, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(42, ?)] must be [string], found value [42] type [integer]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(last_name, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(last_name, ?)] must be a constant, received [last_name]") + ); + } + + public void testTextEmbeddingFunctionInvalidInferenceId() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, null)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, null)] cannot be null, received [null]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, 42)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, 42)] must be [string], found value [42] type [integer]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, last_name)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, last_name)] must be a constant, received [last_name]") + ); + } + private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception { query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java index 7e2dfc82a9890..8ccda8a7010d1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java @@ -23,13 +23,13 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.junit.After; import org.junit.Before; -import java.util.HashSet; import java.util.List; -import java.util.Set; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.hamcrest.Matchers.contains; @@ -44,6 +44,7 @@ public class InferenceResolverTests extends ESTestCase { private TestThreadPool threadPool; + private EsqlFunctionRegistry functionRegistry; @Before public void setThreadPool() { @@ -60,6 +61,11 @@ public void setThreadPool() { ); } + @Before + public void setUpFunctionRegistry() { + functionRegistry = new EsqlFunctionRegistry(); + } + @After public void shutdownThreadPool() { terminate(threadPool); @@ -78,6 +84,26 @@ public void testCollectInferenceIds() { List.of("completion-inference-id") ); + if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) { + // Text embedding inference plan + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(\"description\", \"text-embedding-inference-id\")", + List.of("text-embedding-inference-id") + ); + + // Test inference ID collection from an inference function + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(\"description\", \"text-embedding-inference-id\")", + List.of("text-embedding-inference-id") + ); + + // Test inference ID collection with nested functions + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(TEXT_EMBEDDING(\"nested\", \"nested-id\"), \"outer-id\")", + List.of("nested-id", "outer-id") + ); + } + // Multiple inference plans assertCollectInferenceIds(""" FROM books METADATA _score @@ -90,9 +116,8 @@ public void testCollectInferenceIds() { } private void assertCollectInferenceIds(String query, List expectedInferenceIds) { - Set inferenceIds = new HashSet<>(); InferenceResolver inferenceResolver = inferenceResolver(); - inferenceResolver.collectInferenceIds(new EsqlParser().createStatement(query, configuration(query)), inferenceIds::add); + List inferenceIds = inferenceResolver.collectInferenceIds(new EsqlParser().createStatement(query, configuration(query))); assertThat(inferenceIds, containsInAnyOrder(expectedInferenceIds.toArray(new String[0]))); } @@ -141,7 +166,7 @@ public void testResolveMultipleInferenceIds() throws Exception { public void testResolveMissingInferenceIds() throws Exception { InferenceResolver inferenceResolver = inferenceResolver(); - List inferenceIds = List.of("missing-plan"); + List inferenceIds = List.of("missing-inference-id"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); @@ -156,7 +181,7 @@ public void testResolveMissingInferenceIds() throws Exception { assertThat(inferenceResolution.resolvedInferences(), empty()); assertThat(inferenceResolution.hasError(), equalTo(true)); - assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found")); + assertThat(inferenceResolution.getError("missing-inference-id"), equalTo("inference endpoint not found")); }); } @@ -204,7 +229,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction. } private InferenceResolver inferenceResolver() { - return new InferenceResolver(mockClient(), threadPool); + return new InferenceResolver(mockClient(), functionRegistry, threadPool); } private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 510a90fec619a..b89d0636014b7 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -90,6 +90,7 @@ setup: - set: { esql.functions.to_long: functions_to_long } - set: { esql.functions.coalesce: functions_coalesce } - set: { esql.functions.categorize: functions_categorize } + - set: {esql.functions.text_embedding: functions_text_embedding} - do: esql.query: @@ -133,6 +134,7 @@ setup: - gt: { esql.functions.to_long: $functions_to_long } - match: { esql.functions.coalesce: $functions_coalesce } - gt: { esql.functions.categorize: $functions_categorize } + - match: {esql.functions.text_embedding: $functions_text_embedding} # There's one of these per function but that's a ton of things to check. So we just spot check that a few exist. - exists: esql.functions.delay