Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.InferenceResolver;
import org.elasticsearch.xpack.esql.inference.InferenceServices;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.QueryParam;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
Expand Down Expand Up @@ -164,6 +165,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public final class EsqlTestUtils {

Expand Down Expand Up @@ -401,18 +403,26 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
mock(ProjectResolver.class),
mock(IndexNameExpressionResolver.class),
null,
mockInferenceRunner()
mockInferenceServices()
);

private static InferenceServices mockInferenceServices() {
InferenceServices inferenceServices = mock(InferenceServices.class);
InferenceResolver inferenceResolver = mockInferenceRunner();
when(inferenceServices.inferenceResolver(any())).thenReturn(inferenceResolver);

return inferenceServices;
}

@SuppressWarnings("unchecked")
private static InferenceRunner mockInferenceRunner() {
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
private static InferenceResolver mockInferenceRunner() {
InferenceResolver inferenceResolver = mock(InferenceResolver.class);
doAnswer(i -> {
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
return null;
}).when(inferenceRunner).resolveInferenceIds(any(), any());
}).when(inferenceResolver).resolveInferenceIds(any(), any());

return inferenceRunner;
return inferenceResolver;
}

private EsqlTestUtils() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,11 @@ public enum Cap {
*/
AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG),

/**
* Support for the {@code EMBED_TEXT} function for generating dense vector embeddings.
*/
EMBED_TEXT_FUNCTION(Build.current().isSnapshot()),

/**
* Forbid usage of brackets in unquoted index and enrich policy names
* https://github.com/elastic/elasticsearch/issues/130378
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
Expand Down Expand Up @@ -173,9 +174,9 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
Limiter.ONCE,
new ResolveTable(),
new ResolveEnrich(),
new ResolveInference(),
new ResolveLookupTables(),
new ResolveFunctions(),
new ResolveInference(),
new DateMillisToNanosInEsRelation(IMPLICIT_CASTING_DATE_AND_DATE_NANOS.isEnabled())
),
new Batch<>(
Expand Down Expand Up @@ -398,34 +399,6 @@ private static NamedExpression createEnrichFieldExpression(
}
}

private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan<?>, AnalyzerContext> {
@Override
protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();

String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
return plan;
} else if (resolvedInference != null) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
} else {
String error = context.inferenceResolution().getError(inferenceId);
return plan.withInferenceResolutionError(inferenceId, error);
}
}
}

private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {

@Override
Expand Down Expand Up @@ -1319,6 +1292,70 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res
}
}

private static class ResolveInference extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {

@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));

}

private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable();

String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != inferenceFunction.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ " function. Only inference endpoints with the task type ["
+ inferenceFunction.taskType()
+ "] are supported.";
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

return inferenceFunction;
}

private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();

String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return plan.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != plan.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
}

return plan;
}
}

private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
@Override
public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;

import java.util.ArrayList;
import java.util.HashSet;
Expand All @@ -33,38 +32,39 @@ public static class PreAnalysis {
public final IndexMode indexMode;
public final List<IndexPattern> indices;
public final List<Enrich> enriches;
public final List<InferencePlan<?>> inferencePlans;
public final List<String> inferenceIds;
public final List<IndexPattern> lookupIndices;

public PreAnalysis(
IndexMode indexMode,
List<IndexPattern> indices,
List<Enrich> enriches,
List<InferencePlan<?>> inferencePlans,
List<String> inferenceIds,
List<IndexPattern> lookupIndices
) {
this.indexMode = indexMode;
this.indices = indices;
this.enriches = enriches;
this.inferencePlans = inferencePlans;
this.inferenceIds = inferenceIds;
this.lookupIndices = lookupIndices;
}
}

public PreAnalysis preAnalyze(LogicalPlan plan) {
public PreAnalysis preAnalyze(LogicalPlan plan, PreAnalyzerContext context) {
if (plan.analyzed()) {
return PreAnalysis.EMPTY;
}

return doPreAnalyze(plan);
return doPreAnalyze(plan, context);
}

protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
protected PreAnalysis doPreAnalyze(LogicalPlan plan, PreAnalyzerContext context) {
Set<IndexPattern> indices = new HashSet<>();

List<Enrich> unresolvedEnriches = new ArrayList<>();
List<IndexPattern> lookupIndices = new ArrayList<>();
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();
Set<String> inferenceIds = new HashSet<>();

Holder<IndexMode> indexMode = new Holder<>();
plan.forEachUp(UnresolvedRelation.class, p -> {
if (p.indexMode() == IndexMode.LOOKUP) {
Expand All @@ -78,11 +78,12 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
});

plan.forEachUp(Enrich.class, unresolvedEnriches::add);
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);

context.inferenceResolver().collectInferenceIds(plan, inferenceIds::add);
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);

return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, List.copyOf(inferenceIds), lookupIndices);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.analysis;

import org.elasticsearch.xpack.esql.inference.InferenceResolver;

public record PreAnalyzerContext(InferenceResolver inferenceResolver) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ public void esql(
verifier,
planTelemetry,
indicesExpressionGrouper,
services
services,
services.inferenceServices().inferenceResolver(functionRegistry)
);
QueryMetric clientId = QueryMetric.fromString("rest");
metrics.total(clientId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
Expand Down Expand Up @@ -119,6 +120,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
entries.addAll(inference());
return entries;
}

Expand Down Expand Up @@ -264,4 +266,11 @@ private static List<NamedWriteableRegistry.Entry> vector() {
}
return List.of();
}

private static List<NamedWriteableRegistry.Entry> inference() {
if (EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()) {
return List.of(EmbedText.ENTRY);
}
return List.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, Knn::new, "knn"),
def(EmbedText.class, EmbedText::new, "embed_text"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
Expand Down
Loading
Loading