Skip to content

Commit ac63795

Browse files
committed
[ES|QL] COMPLETION command grammar and logical plan (#126319)
1 parent e5552ce commit ac63795

File tree

16 files changed

+331
-25
lines changed

16 files changed

+331
-25
lines changed

docs/changelog/126319.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126319
2+
summary: COMPLETION command grammar and logical plan
3+
area: ES|QL
4+
type: feature
5+
issues: []

x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ processingCommand
5454
| joinCommand
5555
| changePointCommand
5656
// in development
57+
| {this.isDevVersion()}? completionCommand
5758
| {this.isDevVersion()}? inlinestatsCommand
5859
| {this.isDevVersion()}? lookupCommand
5960
| {this.isDevVersion()}? rerankCommand
@@ -371,3 +372,7 @@ joinPredicate
371372
rerankCommand
372373
: DEV_RERANK queryText=constant ON fields WITH inferenceId=identifierOrParameter
373374
;
375+
376+
completionCommand
377+
: DEV_COMPLETION prompt=primaryExpression WITH inferenceId=identifierOrParameter (AS targetField=qualifiedName)?
378+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,11 @@ public enum Cap {
771771
*/
772772
RERANK(Build.current().isSnapshot()),
773773

774+
/**
775+
* Support for COMPLETION command
776+
*/
777+
COMPLETION(Build.current().isSnapshot()),
778+
774779
/**
775780
* Allow mixed numeric types in conditional functions - case, greatest and least
776781
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,9 @@ private static NamedExpression createEnrichFieldExpression(
385385
}
386386
}
387387

388-
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan, AnalyzerContext> {
388+
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan<?>, AnalyzerContext> {
389389
@Override
390-
protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) {
390+
protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
391391
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
392392

393393
String inferenceId = plan.inferenceId().fold(FoldContext.small()).toString();

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ public static class PreAnalysis {
2828

2929
public final List<TableInfo> indices;
3030
public final List<Enrich> enriches;
31-
public final List<InferencePlan> inferencePlans;
31+
public final List<InferencePlan<?>> inferencePlans;
3232
public final List<TableInfo> lookupIndices;
3333

34-
public PreAnalysis(List<TableInfo> indices, List<Enrich> enriches, List<InferencePlan> inferencePlans, List<TableInfo> lookupIndices) {
34+
public PreAnalysis(List<TableInfo> indices, List<Enrich> enriches, List<InferencePlan<?>> inferencePlans, List<TableInfo> lookupIndices) {
3535
this.indices = indices;
3636
this.enriches = enriches;
3737
this.inferencePlans = inferencePlans;
@@ -51,7 +51,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
5151
List<TableInfo> indices = new ArrayList<>();
5252
List<Enrich> unresolvedEnriches = new ArrayList<>();
5353
List<TableInfo> lookupIndices = new ArrayList<>();
54-
List<InferencePlan> unresolvedInferencePlans = new ArrayList<>();
54+
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();
5555

5656
plan.forEachUp(UnresolvedRelation.class, p -> {
5757
List<TableInfo> list = p.indexMode() == IndexMode.LOOKUP ? lookupIndices : indices;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public ThreadContext getThreadContext() {
3333
return client.threadPool().getThreadContext();
3434
}
3535

36-
public void resolveInferenceIds(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
36+
public void resolveInferenceIds(List<InferencePlan<?>> plans, ActionListener<InferenceResolution> listener) {
3737
resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
3838

3939
}
@@ -68,7 +68,7 @@ private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<Infere
6868
}
6969
}
7070

71-
private static String planInferenceId(InferencePlan plan) {
71+
private static String planInferenceId(InferencePlan<?> plan) {
7272
return plan.inferenceId().fold(FoldContext.small()).toString();
7373
}
7474

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.elasticsearch.xpack.esql.plan.logical.Rename;
6262
import org.elasticsearch.xpack.esql.plan.logical.Row;
6363
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
64+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
6465
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
6566
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
6667
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
@@ -643,12 +644,7 @@ public PlanFactory visitJoinCommand(EsqlBaseParser.JoinCommandContext ctx) {
643644

644645
@Override
645646
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
646-
var source = source(ctx);
647-
648-
if (false == EsqlCapabilities.Cap.RERANK.isEnabled()) {
649-
throw new ParsingException(source, "RERANK is in preview and only available in SNAPSHOT build");
650-
}
651-
647+
Source source = source(ctx);
652648
Expression queryText = expression(ctx.queryText);
653649
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
654650
if (queryTextLiteral.value() == null) {
@@ -669,6 +665,18 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
669665
return p -> new Rerank(source, p, inferenceId(ctx.inferenceId), queryText, visitFields(ctx.fields()));
670666
}
671667

668+
@Override
669+
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
670+
Source source = source(ctx);
671+
Expression prompt = expression(ctx.prompt);
672+
Literal inferenceId = inferenceId(ctx.inferenceId);
673+
Attribute targetField = ctx.targetField == null
674+
? new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME)
675+
: visitQualifiedName(ctx.targetField);
676+
677+
return p -> new Completion(source, p, inferenceId, prompt, targetField);
678+
}
679+
672680
public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
673681
if (ctx.identifier() != null) {
674682
return new Literal(source(ctx), visitIdentifier(ctx.identifier()), KEYWORD);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
2323
import org.elasticsearch.xpack.esql.plan.logical.Project;
2424
import org.elasticsearch.xpack.esql.plan.logical.TopN;
25+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
2526
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
2627
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
2728
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
@@ -65,6 +66,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
6566
public static List<NamedWriteableRegistry.Entry> logical() {
6667
return List.of(
6768
Aggregate.ENTRY,
69+
Completion.ENTRY,
6870
Dissect.ENTRY,
6971
Enrich.ENTRY,
7072
EsRelation.ENTRY,
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.logical.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
15+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
16+
import org.elasticsearch.xpack.esql.core.expression.Expression;
17+
import org.elasticsearch.xpack.esql.core.expression.NameId;
18+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
21+
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
22+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
23+
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
24+
25+
import java.io.IOException;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
30+
31+
public class Completion extends InferencePlan<Completion> implements GeneratingPlan<Completion>, SortAgnostic {
32+
33+
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
34+
35+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
36+
LogicalPlan.class,
37+
"Completion",
38+
Completion::new
39+
);
40+
private final Expression prompt;
41+
private final Attribute targetField;
42+
private List<Attribute> lazyOutput;
43+
44+
public Completion(Source source, LogicalPlan child, Expression inferenceId, Expression prompt, Attribute targetField) {
45+
super(source, child, inferenceId);
46+
this.prompt = prompt;
47+
this.targetField = targetField;
48+
}
49+
50+
public Completion(StreamInput in) throws IOException {
51+
this(
52+
Source.readFrom((PlanStreamInput) in),
53+
in.readNamedWriteable(LogicalPlan.class),
54+
in.readNamedWriteable(Expression.class),
55+
in.readNamedWriteable(Expression.class),
56+
in.readNamedWriteable(Attribute.class)
57+
);
58+
}
59+
60+
@Override
61+
public void writeTo(StreamOutput out) throws IOException {
62+
super.writeTo(out);
63+
out.writeNamedWriteable(prompt);
64+
out.writeNamedWriteable(targetField);
65+
}
66+
67+
public Expression prompt() {
68+
return prompt;
69+
}
70+
71+
public Attribute targetField() {
72+
return targetField;
73+
}
74+
75+
@Override
76+
public Completion withInferenceId(Expression newInferenceId) {
77+
return new Completion(source(), child(), newInferenceId, prompt, targetField);
78+
}
79+
80+
@Override
81+
public Completion replaceChild(LogicalPlan newChild) {
82+
return new Completion(source(), newChild, inferenceId(), prompt, targetField);
83+
}
84+
85+
@Override
86+
public TaskType taskType() {
87+
return TaskType.COMPLETION;
88+
}
89+
90+
@Override
91+
public String getWriteableName() {
92+
return ENTRY.name;
93+
}
94+
95+
@Override
96+
public List<Attribute> output() {
97+
if (lazyOutput == null) {
98+
lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
99+
}
100+
101+
return lazyOutput;
102+
}
103+
104+
@Override
105+
public List<Attribute> generatedAttributes() {
106+
return List.of(targetField);
107+
}
108+
109+
@Override
110+
public Completion withGeneratedNames(List<String> newNames) {
111+
checkNumberOfNewNames(newNames);
112+
return new Completion(source(), child(), inferenceId(), prompt, this.renameTargetField(newNames.get(0)));
113+
}
114+
115+
private Attribute renameTargetField(String newName) {
116+
if (newName.equals(targetField.name())) {
117+
return targetField;
118+
}
119+
120+
return targetField.withName(newName).withId(new NameId());
121+
}
122+
123+
@Override
124+
protected AttributeSet computeReferences() {
125+
return prompt.references();
126+
}
127+
128+
@Override
129+
public boolean expressionsResolved() {
130+
return super.expressionsResolved() && prompt.resolved();
131+
}
132+
133+
@Override
134+
protected NodeInfo<? extends LogicalPlan> info() {
135+
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);
136+
}
137+
138+
@Override
139+
public boolean equals(Object o) {
140+
if (this == o) return true;
141+
if (o == null || getClass() != o.getClass()) return false;
142+
if (super.equals(o) == false) return false;
143+
Completion completion = (Completion) o;
144+
145+
return Objects.equals(prompt, completion.prompt) && Objects.equals(targetField, completion.targetField);
146+
}
147+
148+
@Override
149+
public int hashCode() {
150+
return Objects.hash(super.hashCode(), prompt, targetField);
151+
}
152+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import java.io.IOException;
1919
import java.util.Objects;
2020

21-
public abstract class InferencePlan extends UnaryPlan {
21+
public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> extends UnaryPlan {
2222

2323
private final Expression inferenceId;
2424

@@ -48,7 +48,7 @@ public boolean equals(Object o) {
4848
if (this == o) return true;
4949
if (o == null || getClass() != o.getClass()) return false;
5050
if (super.equals(o) == false) return false;
51-
InferencePlan other = (InferencePlan) o;
51+
InferencePlan<?> other = (InferencePlan<?>) o;
5252
return Objects.equals(inferenceId(), other.inferenceId());
5353
}
5454

@@ -59,9 +59,9 @@ public int hashCode() {
5959

6060
public abstract TaskType taskType();
6161

62-
public abstract InferencePlan withInferenceId(Expression newInferenceId);
62+
public abstract PlanType withInferenceId(Expression newInferenceId);
6363

64-
public InferencePlan withInferenceResolutionError(String inferenceId, String error) {
64+
public PlanType withInferenceResolutionError(String inferenceId, String error) {
6565
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
6666
}
6767
}

0 commit comments

Comments
 (0)