diff --git a/docs/changelog/143695.yaml b/docs/changelog/143695.yaml new file mode 100644 index 0000000000000..4f80d3534a55c --- /dev/null +++ b/docs/changelog/143695.yaml @@ -0,0 +1,6 @@ +area: ES|QL +issues: + - 139928 +pr: 143695 +summary: "Feat: add implicit `dense_vector` casting to coalesce" +type: enhancement diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec index a7e40426b3b97..451ec3480eb7e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec @@ -213,6 +213,74 @@ id:l | v:dense_vector 4 | [0.5, 0.5, 0.5] ; +coalesceDenseVectorImplicitCastFromInts +required_capability: coalesce_dense_vector_implicit_casting + +FROM dense_vector +| EVAL v = COALESCE(float_vector, [0, 0, 0]) +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.0625, 0.25, 0.75] +4 | [0.0, 0.0, 0.0] +; + +coalesceDenseVectorImplicitCastFromDoubles +required_capability: coalesce_dense_vector_implicit_casting + +FROM dense_vector +| EVAL v = COALESCE(float_vector, [0.5, 0.5, 0.5]) +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.0625, 0.25, 0.75] +4 | [0.5, 0.5, 0.5] +; + +coalesceDenseVectorImplicitCastDenseVectorNotFirst +required_capability: coalesce_dense_vector_implicit_casting + +FROM dense_vector +| EVAL v = COALESCE([0.5, 0.5, 0.5], float_vector) +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [0.5, 0.5, 0.5] +1 | [0.5, 0.5, 0.5] +2 | [0.5, 0.5, 0.5] +3 | [0.5, 0.5, 0.5] +4 | [0.5, 0.5, 0.5] +; + +caseDenseVectorImplicitCastFromInts +required_capability: coalesce_dense_vector_implicit_casting + +FROM dense_vector +| EVAL v = CASE(id > 3, [0, 0, 0], float_vector) +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.0625, 0.25, 0.75] +4 | [0.0, 0.0, 0.0] +; + denseVectorEqualsSameVector required_capability: dense_vector_equality diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index f475c13936600..72bd84ffc9bd8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -2279,6 +2279,11 @@ public enum Cap { */ TS_INFO_COMMAND, + /** + * Implicit casting of numeric and keyword arguments to dense_vector in COALESCE. + */ + COALESCE_DENSE_VECTOR_IMPLICIT_CASTING, + /** * FORK with no implicit LIMIT */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index fbc1756aba8e5..6ab52570e8dde 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -108,6 +108,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable; import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction; import org.elasticsearch.xpack.esql.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation; @@ -2162,10 +2163,10 @@ private static Expression processScalarOrGroupingFunction( if (arg.resolved()) { var dataType = arg.dataType(); if (dataType == KEYWORD) { + if (i < targetDataTypes.size()) { + targetDataType = targetDataTypes.get(i); + } // else the last type applies to all elements in a possible list (variadic) if (arg.foldable() && ((arg instanceof EsqlScalarFunction) == false)) { - if (i < targetDataTypes.size()) { - targetDataType = targetDataTypes.get(i); - } // else the last type applies to all elements in a possible list (variadic) if (targetDataType != NULL && targetDataType != UNSUPPORTED) { Expression e = castStringLiteral(arg, targetDataType, configuration); if (e != arg) { @@ -2186,9 +2187,18 @@ private static Expression processScalarOrGroupingFunction( newChildren.add(args.get(i)); } Expression resultF = childrenChanged ? f.replaceChildren(newChildren) : f; - return targetNumericType != null && castNumericArgs - ? castMixedNumericTypes((EsqlScalarFunction) resultF, targetNumericType) - : resultF; + if (resultF instanceof EsqlScalarFunction sf) { + if (sf instanceof VectorCastable vc) { + Expression dvResult = castDenseVectorArgs(sf, vc.denseVectorCastArgIndices(), configuration); + if (dvResult != sf) { + return dvResult; + } + } + if (targetNumericType != null && castNumericArgs) { + return castMixedNumericTypes(sf, targetNumericType); + } + } + return resultF; } private static Expression processBinaryOperator(BinaryOperator o, Configuration configuration) { @@ -2349,31 +2359,66 @@ private static Expression processVectorFunction( EsqlFunctionRegistry registry, Configuration configuration ) { - // Perform implicit casting for dense_vector from numeric and keyword values List args = vectorFunction.arguments(); List targetDataTypes = registry.getDataTypeForStringLiteralConversion(vectorFunction.getClass()); - List newArgs = new ArrayList<>(); + List newArgs = new ArrayList<>(args.size()); + boolean changed = false; for (int i = 0; i < args.size(); i++) { Expression arg = args.get(i); - if (targetDataTypes.get(i) == DENSE_VECTOR && arg.resolved()) { - var dataType = arg.dataType(); - if (dataType == KEYWORD) { - if (arg.foldable()) { - Expression exp = castStringLiteral(arg, DENSE_VECTOR, configuration); - if (exp != arg) { - newArgs.add(exp); - continue; - } - } - } else if (dataType.isNumeric()) { - newArgs.add(new ToDenseVector(vectorFunction.source(), arg)); - continue; + if (targetDataTypes.get(i) == DENSE_VECTOR) { + Expression cast = castArgToDenseVector(arg, vectorFunction.source(), configuration); + if (cast != arg) { + changed = true; } + newArgs.add(cast); + } else { + newArgs.add(arg); } - newArgs.add(arg); } + return changed ? vectorFunction.replaceChildren(newArgs) : vectorFunction; + } - return vectorFunction.replaceChildren(newArgs); + /** + * Cast selected children to dense_vector using the indices provided by {@link VectorCastable}. + */ + private static Expression castDenseVectorArgs(Expression f, Set castIndices, Configuration cfg) { + if (castIndices.isEmpty()) { + return f; + } + List children = f.children(); + List newChildren = new ArrayList<>(children.size()); + boolean changed = false; + for (int i = 0; i < children.size(); i++) { + Expression child = children.get(i); + if (castIndices.contains(i)) { + Expression cast = castArgToDenseVector(child, f.source(), cfg); + if (cast != child) { + changed = true; + } + newChildren.add(cast); + } else { + newChildren.add(child); + } + } + return changed ? f.replaceChildren(newChildren) : f; + } + + /** + * Cast a single argument to dense_vector if it is a foldable keyword (hex string) or a numeric value. + */ + private static Expression castArgToDenseVector(Expression arg, Source source, Configuration configuration) { + if (arg.resolved()) { + var dataType = arg.dataType(); + if (dataType == KEYWORD && arg.foldable()) { + Expression exp = castStringLiteral(arg, DENSE_VECTOR, configuration); + if (exp != arg) { + return exp; + } + } else if (dataType.isNumeric()) { + return new ToDenseVector(source, arg); + } + } + return arg; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java index 67377ce093360..98038a4f527ee 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java @@ -34,12 +34,15 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -47,7 +50,7 @@ import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; -public final class Case extends EsqlScalarFunction { +public final class Case extends EsqlScalarFunction implements VectorCastable { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Case", Case::new); record Condition(Expression condition, Expression value) { @@ -176,6 +179,31 @@ private boolean elseValueIsExplicit() { return children().size() % 2 == 1; } + @Override + public Set denseVectorCastArgIndices() { + boolean hasDenseVector = false; + for (Condition c : conditions) { + if (c.value().resolved() && c.value().dataType() == DataType.DENSE_VECTOR) { + hasDenseVector = true; + break; + } + } + if (hasDenseVector == false && elseValue.resolved() && elseValue.dataType() == DataType.DENSE_VECTOR) { + hasDenseVector = true; + } + if (hasDenseVector == false) { + return Set.of(); + } + Set indices = new HashSet<>(); + for (int i = 1; i < children().size(); i += 2) { + indices.add(i); + } + if (children().size() % 2 == 1) { + indices.add(children().size() - 1); + } + return indices; + } + @Override public DataType dataType() { if (dataType == null) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java index 052d57420a523..754c876eaa728 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java @@ -25,10 +25,13 @@ import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import java.io.IOException; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Stream; import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; @@ -37,7 +40,7 @@ * Function returning the first non-null value. {@code COALESCE} runs as though * it were lazily evaluating each position in each incoming {@link Block}. */ -public class Coalesce extends EsqlScalarFunction implements OptionalArgument { +public class Coalesce extends EsqlScalarFunction implements OptionalArgument, VectorCastable { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Coalesce", Coalesce::new); private DataType dataType; @@ -143,6 +146,20 @@ public String getWriteableName() { return ENTRY.name; } + @Override + public Set denseVectorCastArgIndices() { + for (Expression child : children()) { + if (child.resolved() && child.dataType() == DataType.DENSE_VECTOR) { + Set indices = new HashSet<>(); + for (int i = 0; i < children().size(); i++) { + indices.add(i); + } + return indices; + } + } + return Set.of(); + } + @Override public DataType dataType() { if (dataType == null) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorCastable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorCastable.java new file mode 100644 index 0000000000000..15b9107a5dd86 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorCastable.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import java.util.Set; + +/** + * Implemented by functions whose arguments may need implicit casting to dense_vector + * during analysis. Each implementation declares which child positions are eligible for + * the cast; the Analyzer uses those indices to apply {@code castArgToDenseVector} + * selectively, leaving other children (e.g. boolean conditions) untouched. + */ +public interface VectorCastable { + /** + * Returns the indices into {@code children()} that should be implicitly + * cast to dense_vector when their current type is keyword (foldable hex + * string) or numeric. Called during analysis before type resolution. + * Implementations may examine their current children to make dynamic + * decisions (e.g. only cast when a sibling is already dense_vector). + */ + Set denseVectorCastArgIndices(); +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 8c5be9189ffef..aaf0a2f60857a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -65,12 +65,14 @@ import org.elasticsearch.xpack.esql.expression.function.inference.CompletionFunction; import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.Random; +import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; +import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; @@ -2683,6 +2685,105 @@ private void checkVectorFunctionHexImplicitCastingError(String clause) { ); } + public void testCoalesceDenseVectorImplicitCastingFromNumeric() { + var plan = denseVector().query(""" + from test | eval v = coalesce(float_vector, [0, 0, 0]) + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var coalesce = as(alias.child(), Coalesce.class); + var first = as(coalesce.children().get(0), FieldAttribute.class); + assertThat(first.name(), is("float_vector")); + var second = as(coalesce.children().get(1), ToDenseVector.class); + var literal = as(second.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(0, 0, 0))); + } + + public void testCoalesceDenseVectorImplicitCastingFromHexString() { + var plan = denseVector().query(""" + from test | eval v = coalesce(float_vector, "3f8000003f8000003f800000") + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var coalesce = as(alias.child(), Coalesce.class); + var first = as(coalesce.children().get(0), FieldAttribute.class); + assertThat(first.name(), is("float_vector")); + assertThat(coalesce.children().get(1), instanceOf(Literal.class)); + assertThat(coalesce.children().get(1).dataType(), is(DENSE_VECTOR)); + } + + public void testCoalesceDenseVectorImplicitCastingNullFirst() { + var plan = denseVector().query(""" + from test | eval v = coalesce(null, float_vector, [0, 0, 0]) + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var coalesce = as(alias.child(), Coalesce.class); + assertThat(coalesce.children().get(1), instanceOf(FieldAttribute.class)); + var third = as(coalesce.children().get(2), ToDenseVector.class); + var literal = as(third.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(0, 0, 0))); + } + + public void testCoalesceDenseVectorImplicitCastingDenseVectorNotFirst() { + var plan = denseVector().query(""" + from test | eval v = coalesce([0, 0, 0], float_vector) + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var coalesce = as(alias.child(), Coalesce.class); + var first = as(coalesce.children().get(0), ToDenseVector.class); + var literal = as(first.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(0, 0, 0))); + var second = as(coalesce.children().get(1), FieldAttribute.class); + assertThat(second.name(), is("float_vector")); + } + + public void testCaseDenseVectorImplicitCastingFromNumeric() { + var plan = denseVector().query(""" + from test | eval v = case(id > 3, float_vector, [0, 0, 0]) + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var caseExpr = as(alias.child(), Case.class); + var thenValue = as(caseExpr.children().get(1), FieldAttribute.class); + assertThat(thenValue.name(), is("float_vector")); + var elseValue = as(caseExpr.children().get(2), ToDenseVector.class); + var literal = as(elseValue.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(0, 0, 0))); + } + + public void testCaseDenseVectorImplicitCastingFromHexString() { + var plan = denseVector().query(""" + from test | eval v = case(id > 3, float_vector, "3f8000003f8000003f800000") + """); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("v", alias.name()); + var caseExpr = as(alias.child(), Case.class); + var thenValue = as(caseExpr.children().get(1), FieldAttribute.class); + assertThat(thenValue.name(), is("float_vector")); + assertThat(caseExpr.children().get(2), instanceOf(Literal.class)); + assertThat(caseExpr.children().get(2).dataType(), is(DENSE_VECTOR)); + } + public void testMagnitudePlanWithDenseVectorImplicitCasting() { assumeTrue("v_magnitude not available", EsqlCapabilities.Cap.MAGNITUDE_SCALAR_VECTOR_FUNCTION.isEnabled());