Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
211c2c7
Change the use of SqlTypeFamily.STRING to SqlTypeFamily.CHARACTER as …
yuancu Jul 22, 2025
ae72725
Merge remote-tracking branch 'origin/main' into rel-coerc
yuancu Jul 22, 2025
4667384
Implement basic argument type coercion at RelNode level
yuancu Jul 22, 2025
93a56ab
Conform type checkers with their definition in documentation
yuancu Jul 22, 2025
25d069f
Merge remote-tracking branch 'origin/main' into rel-coerc
yuancu Jul 22, 2025
6c3faa4
Implement type widening for comparator functions
yuancu Jul 23, 2025
42fd079
Update error messages of datetime functions with invalid args
yuancu Jul 23, 2025
ec7ce78
Simplify datetime-string compare logic with implict coercion
yuancu Jul 23, 2025
86e3741
Refactor resolve with coercion
yuancu Jul 23, 2025
b126b87
Move down argument cast for reduce function
yuancu Jul 23, 2025
bd9f3bb
Merge comparators and their IP variants so that coercion works for IP…
yuancu Jul 24, 2025
c539056
Refactor ip comparator to comparator
yuancu Jul 24, 2025
260fd19
Revert "Refactor ip comparator to comparator"
yuancu Jul 24, 2025
4ea73dc
Revert "Merge comparators and their IP variants so that coercion work…
yuancu Jul 24, 2025
3d32da0
Rule out ip from built-in comparator via its type checker
yuancu Jul 24, 2025
00926b5
Merge remote-tracking branch 'origin/main' into rel-coerc
yuancu Jul 28, 2025
daaee69
Restrict CompareIP's parameter type
yuancu Jul 24, 2025
c708ae1
Revert to previous implementation of CompareIpFunction to temporarily…
yuancu Jul 29, 2025
5e8858a
Test argument coercion explain
yuancu Jul 29, 2025
abc5309
Fix error msg in CalcitePPLFunctionTypeTest
yuancu Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
import static org.opensearch.sql.utils.DateTimeUtils.findCastType;
import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated;

import java.math.BigDecimal;
import java.util.ArrayList;
Expand All @@ -30,7 +28,6 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLambda;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
Expand Down Expand Up @@ -215,11 +212,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) {

@Override
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
RexNode leftCandidate = analyze(node.getLeft(), context);
RexNode rightCandidate = analyze(node.getRight(), context);
SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate);
final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget);
final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget);
RexNode left = analyze(node.getLeft(), context);
RexNode right = analyze(node.getRight(), context);
return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right);
}

Expand Down Expand Up @@ -468,19 +462,6 @@ private List<RelDataType> modifyLambdaTypeByFunction(
}
}

private List<RexNode> castArgument(
List<RexNode> originalArguments, String functionName, ExtendedRexBuilder rexBuilder) {
switch (functionName.toUpperCase(Locale.ROOT)) {
case "REDUCE":
RexLambda call = (RexLambda) originalArguments.get(2);
originalArguments.set(
1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true));
return originalArguments;
default:
return originalArguments;
}
}

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<UnresolvedExpression> args = node.getFuncArgs();
Expand All @@ -507,8 +488,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
}
}

arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder);

RexNode resolvedNode =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,121 @@ private PPLOperandTypes() {}
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
public static final UDFOperandMetadata STRING =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING);
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER);
public static final UDFOperandMetadata INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER);
public static final UDFOperandMetadata NUMERIC =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC);

public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.NUMERIC.or(
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));

public static final UDFOperandMetadata INTEGER_INTEGER =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
public static final UDFOperandMetadata STRING_STRING =
UDFOperandMetadata.wrap(OperandTypes.STRING_STRING);
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER);
public static final UDFOperandMetadata NUMERIC_NUMERIC =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC);
public static final UDFOperandMetadata STRING_INTEGER =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));

public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
UDFOperandMetadata.wrap(
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
.or(
OperandTypes.family(
SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)));

public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family()));

public static final UDFOperandMetadata DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING));
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER));
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER));
public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER));

public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.DATETIME.or(
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));

public static final UDFOperandMetadata DATETIME_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
public static final UDFOperandMetadata DATETIME_OR_STRING_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)
.or(OperandTypes.CHARACTER_CHARACTER));
public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.STRING_STRING
OperandTypes.CHARACTER_CHARACTER
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME))
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))
.or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
.or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata STRING_TIMESTAMP =
UDFOperandMetadata.wrap(
OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP));
public static final UDFOperandMetadata STRING_DATETIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME));
public static final UDFOperandMetadata DATETIME_INTERVAL =
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL);
public static final UDFOperandMetadata TIME_TIME =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME));

public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
OperandTypes.family(
SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER)));
public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING));
public static final UDFOperandMetadata STRING_TIMESTAMP =
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP));
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME)));
public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME =
UDFOperandMetadata.wrap(
(CompositeOperandTypeChecker)
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
.or(
OperandTypes.family(
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER,
SqlTypeFamily.CHARACTER)));
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -380,4 +381,13 @@ public static Optional<BuiltinFunctionName> ofWindowFunction(String functionName
return Optional.ofNullable(
WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
}

public static final Set<BuiltinFunctionName> COMPARATORS =
Set.of(
BuiltinFunctionName.EQUAL,
BuiltinFunctionName.NOTEQUAL,
BuiltinFunctionName.LESS,
BuiltinFunctionName.LTE,
BuiltinFunctionName.GREATER,
BuiltinFunctionName.GTE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
/** Function signature is composed by function name and arguments list. */
public record CalciteFuncSignature(FunctionName functionName, PPLTypeChecker typeChecker) {

public boolean match(FunctionName functionName, List<RelDataType> paramTypeList) {
public boolean match(FunctionName functionName, List<RelDataType> argTypes) {
if (!functionName.equals(this.functionName())) return false;
// For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED),
// the typeChecker will be null because only simple family-based type checks are currently
// supported.
if (typeChecker == null) return true;
return typeChecker.checkOperandTypes(paramTypeList);
return typeChecker.checkOperandTypes(argTypes);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.data.type.WideningTypeRule;
import org.opensearch.sql.exception.ExpressionEvaluationException;

public class CoercionUtils {

/**
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
* of parameter types matches the arguments or if casting fails.
*
* @param builder RexBuilder to create casts
* @param typeChecker PPLTypeChecker that provides the parameter types
* @param arguments List of RexNode arguments to be cast
* @return List of cast RexNode arguments or null if casting fails
*/
public static @Nullable List<RexNode> castArguments(
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();

// TODO: var args?

for (List<ExprType> paramTypes : paramTypeCombinations) {
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
if (castedArguments != null) {
return castedArguments;
}
}
return null;
}

/**
* Widen the arguments to the widest type found among them. If no widest type can be determined,
* returns null.
*
* @param builder RexBuilder to create casts
* @param arguments List of RexNode arguments to be widened
* @return List of widened RexNode arguments or null if no widest type can be determined
*/
public static @Nullable List<RexNode> widenArguments(
RexBuilder builder, List<RexNode> arguments) {
// TODO: Add test on e.g. IP
ExprType widestType = findWidestType(arguments);
if (widestType == null) {
return null; // No widest type found, return null
}
return arguments.stream().map(arg -> cast(builder, widestType, arg)).toList();
}

/**
* Casts the arguments to the types specified in paramTypes. Returns null if the number of
* parameters does not match or if casting fails.
*/
private static @Nullable List<RexNode> castArguments(
RexBuilder builder, List<ExprType> paramTypes, List<RexNode> arguments) {
if (paramTypes.size() != arguments.size()) {
return null; // Skip if the number of parameters does not match
}

List<RexNode> castedArguments = new ArrayList<>();
for (int i = 0; i < paramTypes.size(); i++) {
ExprType toType = paramTypes.get(i);
RexNode arg = arguments.get(i);

RexNode castedArg = cast(builder, toType, arg);

if (castedArg == null) {
return null;
}
castedArguments.add(castedArg);
}
return castedArguments;
}

private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
if (!argType.shouldCast(targetType)) {
return arg;
}

if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
}
return null;
}

/**
* Finds the widest type among the given arguments. The widest type is determined by applying the
* widening type rule to each pair of types in the arguments.
*
* @param arguments List of RexNode arguments to find the widest type from
* @return the widest ExprType if found, otherwise null
*/
private static @Nullable ExprType findWidestType(List<RexNode> arguments) {
if (arguments.isEmpty()) {
return null; // No arguments to process
}
ExprType widestType =
OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.getFirst().getType());
if (arguments.size() == 1) {
return widestType;
}

// Iterate pairwise through the arguments and find the widest type
for (int i = 1; i < arguments.size(); i++) {
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
try {
if (areDateAndTime(widestType, type)) {
// If one is date and the other is time, we consider timestamp as the widest type
widestType = ExprCoreType.TIMESTAMP;
} else {
widestType = WideningTypeRule.max(widestType, type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[non-blocking]Calcite has similar operation by using RelDataTypeFactory::leastRestrictive. If it's functional equivalent as this, I would prefer use that one.

We may deprecate ExprValue and ExprType in the future. But for now, it's OK to keep using this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, we could override leastRestrictive in favor of RelDataType in the future.

}
} catch (ExpressionEvaluationException e) {
// the two types are not compatible, return null
return null;
}
}
return widestType;
}

private static boolean areDateAndTime(ExprType type1, ExprType type2) {
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
}
}
Loading
Loading