Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/143695.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
area: ES|QL
issues:
- 139928
pr: 143695
summary: "Feat: add implicit `dense_vector` casting to coalesce"
type: enhancement
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,74 @@ id:l | v:dense_vector
4 | [0.5, 0.5, 0.5]
;

coalesceDenseVectorImplicitCastFromInts
required_capability: coalesce_dense_vector_implicit_casting

FROM dense_vector
| EVAL v = COALESCE(float_vector, [0, 0, 0])
| KEEP id, v
| SORT id
;

id:l | v:dense_vector
0 | [1.0, 2.0, 3.0]
1 | [4.0, 5.0, 6.0]
2 | [9.0, 8.0, 7.0]
3 | [0.0625, 0.25, 0.75]
4 | [0.0, 0.0, 0.0]
;

coalesceDenseVectorImplicitCastFromDoubles
required_capability: coalesce_dense_vector_implicit_casting

FROM dense_vector
| EVAL v = COALESCE(float_vector, [0.5, 0.5, 0.5])
| KEEP id, v
| SORT id
;

id:l | v:dense_vector
0 | [1.0, 2.0, 3.0]
1 | [4.0, 5.0, 6.0]
2 | [9.0, 8.0, 7.0]
3 | [0.0625, 0.25, 0.75]
4 | [0.5, 0.5, 0.5]
;

coalesceDenseVectorImplicitCastDenseVectorNotFirst
required_capability: coalesce_dense_vector_implicit_casting

FROM dense_vector
| EVAL v = COALESCE([0.5, 0.5, 0.5], float_vector)
| KEEP id, v
| SORT id
;

id:l | v:dense_vector
0 | [0.5, 0.5, 0.5]
1 | [0.5, 0.5, 0.5]
2 | [0.5, 0.5, 0.5]
3 | [0.5, 0.5, 0.5]
4 | [0.5, 0.5, 0.5]
;

caseDenseVectorImplicitCastFromInts
required_capability: coalesce_dense_vector_implicit_casting

FROM dense_vector
| EVAL v = CASE(id > 3, [0, 0, 0], float_vector)
| KEEP id, v
| SORT id
;

id:l | v:dense_vector
0 | [1.0, 2.0, 3.0]
1 | [4.0, 5.0, 6.0]
2 | [9.0, 8.0, 7.0]
3 | [0.0625, 0.25, 0.75]
4 | [0.0, 0.0, 0.0]
;

denseVectorEqualsSameVector
required_capability: dense_vector_equality

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,11 @@ public enum Cap {
*/
TS_INFO_COMMAND,

/**
* Implicit casting of numeric and keyword arguments to dense_vector in COALESCE.
*/
COALESCE_DENSE_VECTOR_IMPLICIT_CASTING,

/**
* FORK with no implicit LIMIT
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
Expand Down Expand Up @@ -2162,10 +2163,10 @@ private static Expression processScalarOrGroupingFunction(
if (arg.resolved()) {
var dataType = arg.dataType();
if (dataType == KEYWORD) {
if (i < targetDataTypes.size()) {
targetDataType = targetDataTypes.get(i);
} // else the last type applies to all elements in a possible list (variadic)
if (arg.foldable() && ((arg instanceof EsqlScalarFunction) == false)) {
if (i < targetDataTypes.size()) {
targetDataType = targetDataTypes.get(i);
} // else the last type applies to all elements in a possible list (variadic)
if (targetDataType != NULL && targetDataType != UNSUPPORTED) {
Expression e = castStringLiteral(arg, targetDataType, configuration);
if (e != arg) {
Expand All @@ -2186,9 +2187,18 @@ private static Expression processScalarOrGroupingFunction(
newChildren.add(args.get(i));
}
Expression resultF = childrenChanged ? f.replaceChildren(newChildren) : f;
return targetNumericType != null && castNumericArgs
? castMixedNumericTypes((EsqlScalarFunction) resultF, targetNumericType)
: resultF;
if (resultF instanceof EsqlScalarFunction sf) {
if (sf instanceof VectorCastable vc) {
Expression dvResult = castDenseVectorArgs(sf, vc.denseVectorCastArgIndices(), configuration);
if (dvResult != sf) {
return dvResult;
}
}
if (targetNumericType != null && castNumericArgs) {
return castMixedNumericTypes(sf, targetNumericType);
}
}
return resultF;
}

private static Expression processBinaryOperator(BinaryOperator<?, ?, ?, ?> o, Configuration configuration) {
Expand Down Expand Up @@ -2349,31 +2359,66 @@ private static Expression processVectorFunction(
EsqlFunctionRegistry registry,
Configuration configuration
) {
// Perform implicit casting for dense_vector from numeric and keyword values
List<Expression> args = vectorFunction.arguments();
List<DataType> targetDataTypes = registry.getDataTypeForStringLiteralConversion(vectorFunction.getClass());
List<Expression> newArgs = new ArrayList<>();
List<Expression> newArgs = new ArrayList<>(args.size());
boolean changed = false;
for (int i = 0; i < args.size(); i++) {
Expression arg = args.get(i);
if (targetDataTypes.get(i) == DENSE_VECTOR && arg.resolved()) {
var dataType = arg.dataType();
if (dataType == KEYWORD) {
if (arg.foldable()) {
Expression exp = castStringLiteral(arg, DENSE_VECTOR, configuration);
if (exp != arg) {
newArgs.add(exp);
continue;
}
}
} else if (dataType.isNumeric()) {
newArgs.add(new ToDenseVector(vectorFunction.source(), arg));
continue;
if (targetDataTypes.get(i) == DENSE_VECTOR) {
Expression cast = castArgToDenseVector(arg, vectorFunction.source(), configuration);
if (cast != arg) {
changed = true;
}
newArgs.add(cast);
} else {
newArgs.add(arg);
}
newArgs.add(arg);
}
return changed ? vectorFunction.replaceChildren(newArgs) : vectorFunction;
}

return vectorFunction.replaceChildren(newArgs);
/**
* Cast selected children to dense_vector using the indices provided by {@link VectorCastable}.
*/
private static Expression castDenseVectorArgs(Expression f, Set<Integer> castIndices, Configuration cfg) {
if (castIndices.isEmpty()) {
return f;
}
List<Expression> children = f.children();
List<Expression> newChildren = new ArrayList<>(children.size());
boolean changed = false;
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (castIndices.contains(i)) {
Expression cast = castArgToDenseVector(child, f.source(), cfg);
if (cast != child) {
changed = true;
}
newChildren.add(cast);
} else {
newChildren.add(child);
}
}
return changed ? f.replaceChildren(newChildren) : f;
}

/**
* Cast a single argument to dense_vector if it is a foldable keyword (hex string) or a numeric value.
*/
private static Expression castArgToDenseVector(Expression arg, Source source, Configuration configuration) {
if (arg.resolved()) {
var dataType = arg.dataType();
if (dataType == KEYWORD && arg.foldable()) {
Expression exp = castStringLiteral(arg, DENSE_VECTOR, configuration);
if (exp != arg) {
return exp;
}
} else if (dataType.isNumeric()) {
return new ToDenseVector(source, arg);
}
}
return arg;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.esql.core.type.DataType.NULL;

public final class Case extends EsqlScalarFunction {
public final class Case extends EsqlScalarFunction implements VectorCastable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Case", Case::new);

record Condition(Expression condition, Expression value) {
Expand Down Expand Up @@ -176,6 +179,31 @@ private boolean elseValueIsExplicit() {
return children().size() % 2 == 1;
}

@Override
public Set<Integer> denseVectorCastArgIndices() {
boolean hasDenseVector = false;
for (Condition c : conditions) {
if (c.value().resolved() && c.value().dataType() == DataType.DENSE_VECTOR) {
hasDenseVector = true;
break;
}
}
if (hasDenseVector == false && elseValue.resolved() && elseValue.dataType() == DataType.DENSE_VECTOR) {
hasDenseVector = true;
}
if (hasDenseVector == false) {
return Set.of();
}
Set<Integer> indices = new HashSet<>();
for (int i = 1; i < children().size(); i += 2) {
indices.add(i);
}
if (children().size() % 2 == 1) {
indices.add(children().size() - 1);
}
return indices;
}

@Override
public DataType dataType() {
if (dataType == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorCastable;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.esql.core.type.DataType.NULL;
Expand All @@ -37,7 +40,7 @@
* Function returning the first non-null value. {@code COALESCE} runs as though
* it were lazily evaluating each position in each incoming {@link Block}.
*/
public class Coalesce extends EsqlScalarFunction implements OptionalArgument {
public class Coalesce extends EsqlScalarFunction implements OptionalArgument, VectorCastable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Coalesce", Coalesce::new);

private DataType dataType;
Expand Down Expand Up @@ -143,6 +146,20 @@ public String getWriteableName() {
return ENTRY.name;
}

@Override
public Set<Integer> denseVectorCastArgIndices() {
for (Expression child : children()) {
if (child.resolved() && child.dataType() == DataType.DENSE_VECTOR) {
Set<Integer> indices = new HashSet<>();
for (int i = 0; i < children().size(); i++) {
indices.add(i);
}
return indices;
}
}
return Set.of();
}

@Override
public DataType dataType() {
if (dataType == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.vector;

import java.util.Set;

/**
* Implemented by functions whose arguments may need implicit casting to dense_vector
* during analysis. Each implementation declares which child positions are eligible for
* the cast; the Analyzer uses those indices to apply {@code castArgToDenseVector}
* selectively, leaving other children (e.g. boolean conditions) untouched.
*/
public interface VectorCastable {
/**
* Returns the indices into {@code children()} that should be implicitly
* cast to dense_vector when their current type is keyword (foldable hex
* string) or numeric. Called during analysis before type resolution.
* Implementations may examine their current children to make dynamic
* decisions (e.g. only cast when a sibling is already dense_vector).
*/
Set<Integer> denseVectorCastArgIndices();
}
Loading
Loading