diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index a9a32a707b6..f545f224d04 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -11,8 +11,6 @@ import static org.opensearch.sql.ast.expression.SpanUnit.NONE; import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; -import static org.opensearch.sql.utils.DateTimeUtils.findCastType; -import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated; import java.math.BigDecimal; import java.util.ArrayList; @@ -30,7 +28,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexLambda; import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlIntervalQualifier; @@ -215,11 +212,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) { @Override public RexNode visitCompare(Compare node, CalcitePlanContext context) { - RexNode leftCandidate = analyze(node.getLeft(), context); - RexNode rightCandidate = analyze(node.getRight(), context); - SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate); - final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget); - final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget); + RexNode left = analyze(node.getLeft(), context); + RexNode right = analyze(node.getRight(), context); return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right); } @@ -468,19 +462,6 @@ private List modifyLambdaTypeByFunction( } } - private List castArgument( - List originalArguments, String functionName, ExtendedRexBuilder rexBuilder) { - switch (functionName.toUpperCase(Locale.ROOT)) { - case "REDUCE": - RexLambda call = (RexLambda) originalArguments.get(2); - originalArguments.set( - 1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true)); - return originalArguments; - default: - return originalArguments; - } - } - @Override public RexNode visitFunction(Function node, CalcitePlanContext context) { List args = node.getFuncArgs(); @@ -507,8 +488,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { } } - arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder); - RexNode resolvedNode = PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0])); diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java index 9f818be10ec..8ae08d2c7fd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java @@ -27,40 +27,121 @@ private PPLOperandTypes() {} UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family())); public static final UDFOperandMetadata STRING = - UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING); + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER); public static final UDFOperandMetadata INTEGER = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER); public static final UDFOperandMetadata NUMERIC = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC); + + public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC.or( + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER))); + public static final UDFOperandMetadata INTEGER_INTEGER = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER); public static final UDFOperandMetadata STRING_STRING = - UDFOperandMetadata.wrap(OperandTypes.STRING_STRING); + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER); public static final UDFOperandMetadata NUMERIC_NUMERIC = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC); + public static final UDFOperandMetadata STRING_INTEGER = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)); + public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC = UDFOperandMetadata.wrap( OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)); + public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER))); + + public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family())); public static final UDFOperandMetadata DATETIME_OR_STRING = UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING)); + (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER)); + public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP)); + public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER)); + public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER)); + + public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + public static final UDFOperandMetadata DATETIME_DATETIME = UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)); + public static final UDFOperandMetadata DATETIME_OR_STRING_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER) + .or(OperandTypes.CHARACTER_CHARACTER)); public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING = UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING + OperandTypes.CHARACTER_CHARACTER .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME))); - public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING = + .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))); + public static final UDFOperandMetadata STRING_TIMESTAMP = + UDFOperandMetadata.wrap( + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)); + public static final UDFOperandMetadata STRING_DATETIME = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)); + public static final UDFOperandMetadata DATETIME_INTERVAL = + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL); + public static final UDFOperandMetadata TIME_TIME = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME)); + + public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING = UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP)); - public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING = + OperandTypes.family( + SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER))); + public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING = UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING)); - public static final UDFOperandMetadata STRING_TIMESTAMP = - UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP)); + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME))); + public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER))); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 5834d3f5942..6e278c4192e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -9,6 +9,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -380,4 +381,13 @@ public static Optional ofWindowFunction(String functionName return Optional.ofNullable( WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); } + + public static final Set COMPARATORS = + Set.of( + BuiltinFunctionName.EQUAL, + BuiltinFunctionName.NOTEQUAL, + BuiltinFunctionName.LESS, + BuiltinFunctionName.LTE, + BuiltinFunctionName.GREATER, + BuiltinFunctionName.GTE); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java index af89e98fa64..a8c4be11102 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java @@ -11,12 +11,12 @@ /** Function signature is composed by function name and arguments list. */ public record CalciteFuncSignature(FunctionName functionName, PPLTypeChecker typeChecker) { - public boolean match(FunctionName functionName, List paramTypeList) { + public boolean match(FunctionName functionName, List argTypes) { if (!functionName.equals(this.functionName())) return false; // For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED), // the typeChecker will be null because only simple family-based type checks are currently // supported. if (typeChecker == null) return true; - return typeChecker.checkOperandTypes(paramTypeList); + return typeChecker.checkOperandTypes(argTypes); } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java new file mode 100644 index 00000000000..f0c6fc84837 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.data.type.WideningTypeRule; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +public class CoercionUtils { + + /** + * Casts the arguments to the types specified in the typeChecker. Returns null if no combination + * of parameter types matches the arguments or if casting fails. + * + * @param builder RexBuilder to create casts + * @param typeChecker PPLTypeChecker that provides the parameter types + * @param arguments List of RexNode arguments to be cast + * @return List of cast RexNode arguments or null if casting fails + */ + public static @Nullable List castArguments( + RexBuilder builder, PPLTypeChecker typeChecker, List arguments) { + List> paramTypeCombinations = typeChecker.getParameterTypes(); + + // TODO: var args? + + for (List paramTypes : paramTypeCombinations) { + List castedArguments = castArguments(builder, paramTypes, arguments); + if (castedArguments != null) { + return castedArguments; + } + } + return null; + } + + /** + * Widen the arguments to the widest type found among them. If no widest type can be determined, + * returns null. + * + * @param builder RexBuilder to create casts + * @param arguments List of RexNode arguments to be widened + * @return List of widened RexNode arguments or null if no widest type can be determined + */ + public static @Nullable List widenArguments( + RexBuilder builder, List arguments) { + // TODO: Add test on e.g. IP + ExprType widestType = findWidestType(arguments); + if (widestType == null) { + return null; // No widest type found, return null + } + return arguments.stream().map(arg -> cast(builder, widestType, arg)).toList(); + } + + /** + * Casts the arguments to the types specified in paramTypes. Returns null if the number of + * parameters does not match or if casting fails. + */ + private static @Nullable List castArguments( + RexBuilder builder, List paramTypes, List arguments) { + if (paramTypes.size() != arguments.size()) { + return null; // Skip if the number of parameters does not match + } + + List castedArguments = new ArrayList<>(); + for (int i = 0; i < paramTypes.size(); i++) { + ExprType toType = paramTypes.get(i); + RexNode arg = arguments.get(i); + + RexNode castedArg = cast(builder, toType, arg); + + if (castedArg == null) { + return null; + } + castedArguments.add(castedArg); + } + return castedArguments; + } + + private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) { + ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType()); + if (!argType.shouldCast(targetType)) { + return arg; + } + + if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) { + return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg); + } + return null; + } + + /** + * Finds the widest type among the given arguments. The widest type is determined by applying the + * widening type rule to each pair of types in the arguments. + * + * @param arguments List of RexNode arguments to find the widest type from + * @return the widest ExprType if found, otherwise null + */ + private static @Nullable ExprType findWidestType(List arguments) { + if (arguments.isEmpty()) { + return null; // No arguments to process + } + ExprType widestType = + OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.getFirst().getType()); + if (arguments.size() == 1) { + return widestType; + } + + // Iterate pairwise through the arguments and find the widest type + for (int i = 1; i < arguments.size(); i++) { + var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType()); + try { + if (areDateAndTime(widestType, type)) { + // If one is date and the other is time, we consider timestamp as the widest type + widestType = ExprCoreType.TIMESTAMP; + } else { + widestType = WideningTypeRule.max(widestType, type); + } + } catch (ExpressionEvaluationException e) { + // the two types are not compatible, return null + return null; + } + } + return widestType; + } + + private static boolean areDateAndTime(ExprType type1, ExprType type2) { + return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME) + || (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 02ddfd78dee..257f0c04668 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -21,10 +21,7 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rex.RexCall; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; @@ -144,7 +141,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprAddTime", PPLReturnTypes.TIME_APPLY_RETURN_TYPE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("ADDTIME"); public static final SqlOperator SUBTIME = adaptExprMethodWithPropertiesToUDF( @@ -152,7 +149,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprSubTime", PPLReturnTypes.TIME_APPLY_RETURN_TYPE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("SUBTIME"); public static final SqlOperator ADDDATE = new AddSubDateFunction(true).toUDF("ADDDATE"); public static final SqlOperator SUBDATE = new AddSubDateFunction(false).toUDF("SUBDATE"); @@ -198,11 +195,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprConvertTZ", PPLReturnTypes.TIMESTAMP_FORCE_NULLABLE, NullPolicy.ANY, - UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING_STRING.or( - OperandTypes.family( - SqlTypeFamily.DATETIME, SqlTypeFamily.STRING, SqlTypeFamily.STRING)))) + PPLOperandTypes.TIMESTAMP_OR_STRING_STRING_STRING) .toUDF("CONVERT_TZ"); public static final SqlOperator DATEDIFF = adaptExprMethodWithPropertiesToUDF( @@ -210,7 +203,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprDateDiff", ReturnTypes.BIGINT_FORCE_NULLABLE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("DATEDIFF"); public static final SqlOperator TIMESTAMPDIFF = new TimestampDiffFunction().toUDF("TIMESTAMPDIFF"); @@ -291,7 +284,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprTimeToSec", ReturnTypes.BIGINT_FORCE_NULLABLE, NullPolicy.ARG0, - PPLOperandTypes.DATETIME_OR_STRING) + PPLOperandTypes.TIME_OR_TIMESTAMP_OR_STRING) .toUDF("TIME_TO_SEC"); public static final SqlOperator TIMEDIFF = UserDefinedFunctionUtils.adaptExprMethodToUDF( @@ -299,7 +292,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprTimeDiff", PPLReturnTypes.TIME_FORCE_NULLABLE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.TIME_TIME) .toUDF("TIME_DIFF"); public static final SqlOperator TIMESTAMPADD = new TimestampAddFunction().toUDF("TIMESTAMPADD"); public static final SqlOperator TO_DAYS = diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index f9a713fc49a..1e871169eae 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -226,8 +226,10 @@ import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLambda; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; @@ -270,7 +272,6 @@ RelBuilder.AggCall apply( public interface FunctionImp { RelDataType ANY_TYPE = TYPE_FACTORY.createSqlType(SqlTypeName.ANY); - // TODO: Support argument coercion and casting RexNode resolve(RexBuilder builder, RexNode... args); /** @@ -442,6 +443,12 @@ public RexNode resolve( if (implementList == null || implementList.isEmpty()) { throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName)); } + + // Make compulsory casts for some functions that require specific casting of arguments. + // For example, the REDUCE function requires the second argument to be cast to the + // return type of the lambda function. + compulsoryCast(builder, functionName, args); + List argTypes = Arrays.stream(args).map(RexNode::getType).toList(); try { for (Map.Entry implement : implementList) { @@ -449,6 +456,13 @@ public RexNode resolve( return implement.getValue().resolve(builder, args); } } + + // If no implementation found with exact match, try to cast arguments to match the + // signatures. + RexNode coerced = resolveWithCoercion(builder, functionName, implementList, args); + if (coerced != null) { + return coerced; + } } catch (Exception e) { throw new ExpressionEvaluationException( String.format( @@ -469,6 +483,63 @@ functionName, getActualSignature(argTypes), e.getMessage()), functionName, allowedSignatures, getActualSignature(argTypes))); } + /** + * Ad-hoc coercion for some functions that require specific casting of arguments. Now it only + * applies to the REDUCE function. + */ + private void compulsoryCast( + final RexBuilder builder, final BuiltinFunctionName functionName, RexNode... args) { + + //noinspection SwitchStatementWithTooFewBranches + switch (functionName) { + case BuiltinFunctionName.REDUCE: + // Set the second argument to the return type of the lambda function, so that + // code generated with linq4j can correctly accumulate the result. + RexLambda call = (RexLambda) args[2]; + args[1] = builder.makeCast(call.getType(), args[1], true, true); + break; + default: + break; + } + } + + private @Nullable RexNode resolveWithCoercion( + final RexBuilder builder, + final BuiltinFunctionName functionName, + List> implementList, + RexNode... args) { + if (BuiltinFunctionName.COMPARATORS.contains(functionName)) { + for (Map.Entry implement : implementList) { + var widenedArgs = CoercionUtils.widenArguments(builder, List.of(args)); + if (widenedArgs != null) { + boolean matchSignature = + implement + .getKey() + .typeChecker() + .checkOperandTypes(widenedArgs.stream().map(RexNode::getType).toList()); + if (matchSignature) { + return implement.getValue().resolve(builder, widenedArgs.toArray(new RexNode[0])); + } + } + } + } else { + for (Map.Entry implement : implementList) { + var signature = implement.getKey(); + var castedArgs = + CoercionUtils.castArguments(builder, signature.typeChecker(), List.of(args)); + if (castedArgs != null) { + // If compatible function is found, replace the original RexNode with cast node + // TODO: check - this is a return-once-found implementation, rest possible combinations + // will be skipped. + // Maybe can be improved to return the best match? E.g. convert to timestamp when date, + // time, and timestamp are all possible. + return implement.getValue().resolve(builder, castedArgs.toArray(new RexNode[0])); + } + } + } + return null; + } + private static String getActualSignature(List argTypes) { return "[" + argTypes.stream() @@ -873,7 +944,7 @@ void populate() { builder.makeFlag(Flag.BOTH), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( LTRIM, @@ -884,7 +955,7 @@ void populate() { builder.makeFlag(Flag.LEADING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( RTRIM, createFunctionImpWithTypeChecker( @@ -894,7 +965,7 @@ void populate() { builder.makeFlag(Flag.TRAILING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( ATAN, createFunctionImpWithTypeChecker( @@ -904,7 +975,7 @@ void populate() { STRCMP, createFunctionImpWithTypeChecker( (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); // SqlStdOperatorTable.SUBSTRING.getOperandTypeChecker is null. We manually create a type // checker for it. register( @@ -912,14 +983,24 @@ void populate() { wrapWithCompositeTypeChecker( SqlStdOperatorTable.SUBSTRING, (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.INTEGER, + SqlTypeFamily.INTEGER)), false)); register( SUBSTR, wrapWithCompositeTypeChecker( SqlStdOperatorTable.SUBSTRING, (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.INTEGER, + SqlTypeFamily.INTEGER)), false)); // SqlStdOperatorTable.ITEM.getOperandTypeChecker() checks only the first operand instead of // all operands. Therefore, we wrap it with a custom CompositeOperandTypeChecker to check both diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java index df2225ab9c7..adc10e63b71 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java @@ -15,6 +15,9 @@ import java.util.stream.IntStream; import lombok.RequiredArgsConstructor; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.FamilyOperandTypeChecker; import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; @@ -23,6 +26,8 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Pair; +import org.opensearch.sql.calcite.type.ExprIPType; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -53,6 +58,16 @@ public interface PPLTypeChecker { */ String getAllowedSignatures(); + /** + * Get a list of all possible parameter type combinations for the function. + * + *

This method is used to generate the allowed signatures for the function based on the + * parameter types. + * + * @return a list of lists, where each inner list represents an allowed parameter type combination + */ + List> getParameterTypes(); + private static boolean validateOperands( List funcTypeFamilies, List operandTypes) { // If the number of actual operands does not match expectation, return false @@ -95,6 +110,11 @@ public String getAllowedSignatures() { return PPLTypeChecker.getFamilySignature(families); } + @Override + public List> getParameterTypes() { + return PPLTypeChecker.getExprSignatures(families); + } + @Override public String toString() { return String.format("PPLFamilyTypeChecker[families=%s]", getAllowedSignatures()); @@ -130,12 +150,23 @@ public boolean checkOperandTypes(List types) { @Override public String getAllowedSignatures() { if (innerTypeChecker instanceof FamilyOperandTypeChecker familyOperandTypeChecker) { - var allowedSignatures = PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker); - return String.join(",", allowedSignatures); + var allowedExprSignatures = getExprSignatures(familyOperandTypeChecker); + return PPLTypeChecker.formatExprSignatures(allowedExprSignatures); } else { return ""; } } + + @Override + public List> getParameterTypes() { + if (innerTypeChecker instanceof FamilyOperandTypeChecker familyOperandTypeChecker) { + return getExprSignatures(familyOperandTypeChecker); + } else { + // If the inner type checker is not a FamilyOperandTypeChecker, we cannot provide + // parameter types. + return Collections.emptyList(); + } + } } /** @@ -149,6 +180,7 @@ public String getAllowedSignatures() { * ImplicitCastOperandTypeChecker}. */ class PPLCompositeTypeChecker implements PPLTypeChecker { + private final List allowedRules; public PPLCompositeTypeChecker(CompositeOperandTypeChecker typeChecker) { @@ -184,16 +216,33 @@ public boolean checkOperandTypes(List types) { @Override public String getAllowedSignatures() { - List allowedSignatures = new ArrayList<>(); + StringBuilder builder = new StringBuilder(); + for (SqlOperandTypeChecker rule : allowedRules) { + if (rule instanceof FamilyOperandTypeChecker familyOperandTypeChecker) { + if (!builder.isEmpty()) { + builder.append(","); + } + builder.append(PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker)); + } else { + throw new IllegalArgumentException( + "Currently only compositions of FamilyOperandTypeChecker are supported"); + } + } + return builder.toString(); + } + + @Override + public List> getParameterTypes() { + List> parameterTypes = new ArrayList<>(); for (SqlOperandTypeChecker rule : allowedRules) { if (rule instanceof FamilyOperandTypeChecker familyOperandTypeChecker) { - allowedSignatures.addAll(PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker)); + parameterTypes.addAll(getExprSignatures(familyOperandTypeChecker)); } else { throw new IllegalArgumentException( "Currently only compositions of FamilyOperandTypeChecker are supported"); } } - return String.join(",", allowedSignatures); + return parameterTypes; } } @@ -210,28 +259,60 @@ public boolean checkOperandTypes(List types) { for (int i = 0; i < types.size() - 1; i++) { // TODO: Binary, Array UDT? // DATETIME, NUMERIC, BOOLEAN will be regarded as comparable - // with strings in SqlTypeUtil.isComparable + // with strings in isComparable RelDataType type_l = types.get(i); RelDataType type_r = types.get(i + 1); - if (!SqlTypeUtil.isComparable(type_l, type_r)) { + // Rule out IP types from built-in comparable functions + if (type_l instanceof ExprIPType || type_r instanceof ExprIPType) { return false; } - // Disallow coercing between strings and numeric, boolean - if ((type_l.getFamily() == SqlTypeFamily.CHARACTER - && cannotConvertStringInCompare((SqlTypeFamily) type_r.getFamily())) - || (type_r.getFamily() == SqlTypeFamily.CHARACTER - && cannotConvertStringInCompare((SqlTypeFamily) type_l.getFamily()))) { + if (!isComparable(type_l, type_r)) { return false; } } return true; } - private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) { - return switch (typeFamily) { - case BOOLEAN, INTEGER, NUMERIC, EXACT_NUMERIC, APPROXIMATE_NUMERIC -> true; - default -> false; - }; + /** + * Modified from {@link SqlTypeUtil#isComparable(RelDataType, RelDataType)} to + * + * @param type1 first type + * @param type2 second type + * @return true if the two types are comparable, false otherwise + */ + private static boolean isComparable(RelDataType type1, RelDataType type2) { + if (type1.isStruct() != type2.isStruct()) { + return false; + } + + if (type1.isStruct()) { + int n = type1.getFieldCount(); + if (n != type2.getFieldCount()) { + return false; + } + for (Pair pair : + Pair.zip(type1.getFieldList(), type2.getFieldList())) { + if (!isComparable(pair.left.getType(), pair.right.getType())) { + return false; + } + } + return true; + } + + // Numeric types are comparable without the need to cast + if (SqlTypeUtil.isNumeric(type1) && SqlTypeUtil.isNumeric(type2)) { + return true; + } + + ExprType exprType1 = OpenSearchTypeFactory.convertRelDataTypeToExprType(type1); + ExprType exprType2 = OpenSearchTypeFactory.convertRelDataTypeToExprType(type2); + + if (!exprType1.shouldCast(exprType2)) { + return true; + } + + // If one of the arguments is of type 'ANY', return true. + return type1.getFamily() == SqlTypeFamily.ANY || type2.getFamily() == SqlTypeFamily.ANY; } @Override @@ -254,6 +335,12 @@ public String getAllowedSignatures() { return String.join(",", signatures); } } + + @Override + public List> getParameterTypes() { + // Should not be used + return List.of(List.of(ExprCoreType.UNKNOWN, ExprCoreType.UNKNOWN)); + } } /** @@ -362,7 +449,12 @@ public boolean checkOperandTypes(List types) { @Override public String getAllowedSignatures() { - return PPLTypeChecker.getExprFamilySignature(allowedSignatures); + return PPLTypeChecker.formatExprSignatures(allowedSignatures); + } + + @Override + public List> getParameterTypes() { + return allowedSignatures; } }; } @@ -379,24 +471,27 @@ public String getAllowedSignatures() { * @param typeChecker the {@link FamilyOperandTypeChecker} to use for generating signatures * @return a list of allowed function signatures */ - private static List getFamilySignatures(FamilyOperandTypeChecker typeChecker) { + private static String getFamilySignatures(FamilyOperandTypeChecker typeChecker) { + var allowedExprSignatures = getExprSignatures(typeChecker); + return formatExprSignatures(allowedExprSignatures); + } + + private static List> getExprSignatures(FamilyOperandTypeChecker typeChecker) { var operandCountRange = typeChecker.getOperandCountRange(); int min = operandCountRange.getMin(); int max = operandCountRange.getMax(); - List allowedSignatures = new ArrayList<>(); List families = new ArrayList<>(); for (int i = 0; i < min; i++) { families.add(typeChecker.getOperandSqlTypeFamily(i)); } - allowedSignatures.add(getFamilySignature(families)); + List> allowedSignatures = new ArrayList<>(getExprSignatures(families)); // Avoid enumerating signatures for infinite args final int MAX_ARGS = 10; max = Math.min(max, MAX_ARGS); - for (int i = min; i < max; i++) { families.add(typeChecker.getOperandSqlTypeFamily(i)); - allowedSignatures.add(getFamilySignature(families)); + allowedSignatures.addAll(getExprSignatures(families)); } return allowedSignatures; } @@ -412,9 +507,9 @@ private static List getExprTypes(SqlTypeFamily family) { List concreteTypes = switch (family) { case DATETIME -> List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP), OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DATE), - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIME), - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP)); + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIME)); case NUMERIC -> List.of( OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER), OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE)); @@ -423,6 +518,13 @@ private static List getExprTypes(SqlTypeFamily family) { OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); case ANY, IGNORE -> List.of( OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY)); + case DATETIME_INTERVAL -> SqlTypeName.INTERVAL_TYPES.stream() + .map( + type -> + OpenSearchTypeFactory.TYPE_FACTORY.createSqlIntervalType( + new SqlIntervalQualifier( + type.getStartUnit(), type.getEndUnit(), SqlParserPos.ZERO))) + .collect(Collectors.toList()); default -> { RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY); if (type == null) { @@ -433,9 +535,25 @@ private static List getExprTypes(SqlTypeFamily family) { }; return concreteTypes.stream() .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) + .distinct() .collect(Collectors.toList()); } + /** + * Generates a list of all possible {@link ExprType} signatures based on the provided {@link + * SqlTypeFamily} list. + * + * @param families the list of {@link SqlTypeFamily} to generate signatures for + * @return a list of lists, where each inner list contains {@link ExprType} signatures + */ + private static List> getExprSignatures(List families) { + List> exprTypes = + families.stream().map(PPLTypeChecker::getExprTypes).collect(Collectors.toList()); + + // Do a cartesian product of all ExprTypes in the family + return Lists.cartesianProduct(exprTypes); + } + /** * Generates a string representation of the function signature based on the provided type * families. The format is a list of type families enclosed in square brackets, e.g.: "[INTEGER, @@ -445,27 +563,9 @@ private static List getExprTypes(SqlTypeFamily family) { * @return a string representation of the function signature */ private static String getFamilySignature(List families) { - List> exprTypes = - families.stream().map(PPLTypeChecker::getExprTypes).collect(Collectors.toList()); - - // Do a cartesian product of all ExprTypes in the family - List> signatures = Lists.cartesianProduct(exprTypes); - + List> signatures = getExprSignatures(families); // Convert each signature to a string representation and then concatenate them - return getExprFamilySignature(signatures); - } - - private static String getExprFamilySignature(List> signatures) { - return signatures.stream() - .map( - types -> - "[" - + types.stream() - // Display ExprCoreType.UNDEFINED as "ANY" for better interpretability - .map(t -> t == ExprCoreType.UNDEFINED ? "ANY" : t.toString()) - .collect(Collectors.joining(",")) - + "]") - .collect(Collectors.joining(",")); + return formatExprSignatures(signatures); } /** @@ -488,4 +588,17 @@ private static boolean isCompositionOr(CompositeOperandTypeChecker typeChecker) (CompositeOperandTypeChecker.Composition) compositionField.get(typeChecker); return composition == CompositeOperandTypeChecker.Composition.OR; } + + private static String formatExprSignatures(List> signatures) { + return signatures.stream() + .map( + types -> + "[" + + types.stream() + // Display ExprCoreType.UNDEFINED as "ANY" for better interpretability + .map(t -> t == ExprCoreType.UNDEFINED ? "ANY" : t.toString()) + .collect(Collectors.joining(",")) + + "]") + .collect(Collectors.joining(",")); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java index 263ffb1e1b2..242cd723591 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java @@ -50,9 +50,9 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.ANY) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.ARRAY))); + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.ARRAY))); } public static class PatternParserImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java index af6a45a1ec7..bb228a1b0e2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java @@ -12,14 +12,13 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.FamilyOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.commons.codec.binary.Hex; import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.codec.digest.MessageDigestAlgorithms; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.expression.function.ImplementorUDF; import org.opensearch.sql.expression.function.UDFOperandMetadata; @@ -32,7 +31,7 @@ public static CryptographicFunction sha2() { return new CryptographicFunction(new Sha2Implementor(), NullPolicy.ANY) { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING_INTEGER); + return PPLOperandTypes.STRING_INTEGER; } }; } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java index 060dc0cc4c8..cb5a0501ebb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java @@ -51,10 +51,11 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER) .or( OperandTypes.family( - SqlTypeFamily.DATETIME, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)) + SqlTypeFamily.DATETIME, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)) // TODO: numeric span should support decimal as its interval .or( OperandTypes.family( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java index 142e108face..9d61e18aefd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java @@ -5,7 +5,6 @@ package org.opensearch.sql.expression.function.udf.condition; -import static org.opensearch.sql.calcite.utils.PPLOperandTypes.STRING_TIMESTAMP; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.prependFunctionProperties; import static org.opensearch.sql.utils.DateTimeUtils.getRelativeZonedDateTime; @@ -23,6 +22,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; @@ -41,7 +41,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return STRING_TIMESTAMP; + return PPLOperandTypes.STRING_TIMESTAMP; } public static class EarliestImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java index 713b9a91ccc..ffa6e9e3cea 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java @@ -5,7 +5,6 @@ package org.opensearch.sql.expression.function.udf.condition; -import static org.opensearch.sql.calcite.utils.PPLOperandTypes.STRING_TIMESTAMP; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.prependFunctionProperties; import static org.opensearch.sql.utils.DateTimeUtils.getRelativeZonedDateTime; @@ -23,6 +22,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; @@ -41,7 +41,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return STRING_TIMESTAMP; + return PPLOperandTypes.STRING_TIMESTAMP; } public static class LatestImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java index 4bbb917bc65..8398a508388 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java @@ -76,10 +76,8 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.DATETIME_INTERVAL - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME_INTERVAL)) - .or(OperandTypes.STRING_INTEGER)); + OperandTypes.DATETIME_INTERVAL.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java index c4a8c0f44b0..de3a2cb65ab 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java @@ -15,11 +15,9 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.calcite.utils.datetime.DateTimeConversionUtils; @@ -50,10 +48,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME_INTERVAL.or( - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME_INTERVAL))); + return PPLOperandTypes.DATETIME_INTERVAL; } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java index 14866a6eadb..d7a876cfeaa 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java @@ -50,9 +50,9 @@ public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) OperandTypes.TIMESTAMP_STRING - .or(OperandTypes.STRING_STRING) + .or(OperandTypes.CHARACTER_CHARACTER) .or(OperandTypes.TIMESTAMP) - .or(OperandTypes.STRING)); + .or(OperandTypes.CHARACTER)); } public static class DatetimeImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java index 76c0c27d825..90b7c8a783f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java @@ -12,12 +12,10 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; @@ -51,10 +49,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME) - .or(OperandTypes.STRING_STRING)); + return PPLOperandTypes.STRING_DATETIME; } public static class ExtractImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java index a9892b1c80b..0a7f214b939 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java @@ -17,11 +17,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; @@ -60,10 +58,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING) - .or(OperandTypes.STRING_STRING)); + return PPLOperandTypes.DATETIME_OR_STRING_STRING; } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java index 09b59fd96e3..00fa9f690da 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java @@ -18,10 +18,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.expression.function.ImplementorUDF; @@ -56,10 +54,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.NUMERIC.or( - OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING))); + return PPLOperandTypes.NUMERIC_OPTIONAL_STRING; } public static class FromUnixTimeImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java index d98579f009a..c1eb8aa776d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java @@ -12,10 +12,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.calcite.utils.datetime.DateTimeConversionUtils; @@ -48,8 +47,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING)); + return PPLOperandTypes.DATETIME_OR_STRING; } public static class LastDayImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java index 30cf660462b..109bad16bf1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java @@ -16,10 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.type.ExprType; @@ -52,9 +51,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATE.or(OperandTypes.TIMESTAMP).or(OperandTypes.STRING)); + return PPLOperandTypes.DATE_OR_TIMESTAMP_OR_STRING; } public static class PeriodNameFunctionImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java index 9119013ac21..29de7881fa6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java @@ -16,11 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprLongValue; @@ -58,12 +56,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME))); + return PPLOperandTypes.STRING_INTEGER_DATETIME_OR_STRING; } public static class TimestampAddImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java index 82746a31ba9..5adf7bf8bf7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java @@ -15,12 +15,10 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; @@ -55,19 +53,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.DATETIME, SqlTypeFamily.STRING)) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + return PPLOperandTypes.INTERVAL_DATETIME_DATETIME; } public static class DiffImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java index 6041c4d5648..a13a1ce1894 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java @@ -51,12 +51,12 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING + OperandTypes.CHARACTER .or(OperandTypes.DATETIME) - .or(OperandTypes.STRING_STRING) + .or(OperandTypes.CHARACTER_CHARACTER) .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))); + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)) + .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))); } public static class TimestampImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java index 6639faa96f7..7a89d5e918b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java @@ -18,10 +18,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprLongValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; @@ -52,9 +51,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME.or(OperandTypes.STRING).or(OperandTypes.INTEGER)); + return PPLOperandTypes.DATETIME_OR_STRING_OR_INTEGER; } public static class ToSecondsImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java index 509fee1cdca..5103173b46f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java @@ -15,10 +15,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.function.ImplementorUDF; @@ -50,9 +49,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family())); + return PPLOperandTypes.OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC; } public static class UnixTimestampImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java index b4b1eecbcbc..11564ce47f0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java @@ -14,10 +14,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprValue; @@ -52,12 +50,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.STRING_INTEGER) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class WeekImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java index 2ba25530f8d..472203f3079 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java @@ -16,11 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; @@ -53,11 +51,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class WeekdayImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java index 0dd9f5e84bd..9bfe30df39c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java @@ -16,10 +16,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -52,12 +50,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.STRING_INTEGER) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class YearweekImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java index edb898f3e0f..a580a5f18f4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java @@ -9,6 +9,7 @@ import org.apache.calcite.adapter.enumerable.NotNullImplementor; import org.apache.calcite.adapter.enumerable.NullPolicy; import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.ConstantExpression; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; @@ -67,11 +68,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrapUDT( - List.of( - List.of(ExprCoreType.IP, ExprCoreType.IP), - List.of(ExprCoreType.IP, ExprCoreType.STRING), - List.of(ExprCoreType.STRING, ExprCoreType.IP))); + return UDFOperandMetadata.wrapUDT(List.of(List.of(ExprCoreType.IP, ExprCoreType.IP))); } public static class CompareImplementor implements NotNullImplementor { @@ -96,14 +93,14 @@ public Expression implement( private static Expression generateComparisonExpression( Expression compareResult, ComparisonType comparisonType) { + final ConstantExpression zero = Expressions.constant(0); return switch (comparisonType) { - case EQUALS -> Expressions.equal(compareResult, Expressions.constant(0)); - case NOT_EQUALS -> Expressions.notEqual(compareResult, Expressions.constant(0)); - case LESS -> Expressions.lessThan(compareResult, Expressions.constant(0)); - case LESS_OR_EQUAL -> Expressions.lessThanOrEqual(compareResult, Expressions.constant(0)); - case GREATER -> Expressions.greaterThan(compareResult, Expressions.constant(0)); - case GREATER_OR_EQUAL -> Expressions.greaterThanOrEqual( - compareResult, Expressions.constant(0)); + case EQUALS -> Expressions.equal(compareResult, zero); + case NOT_EQUALS -> Expressions.notEqual(compareResult, zero); + case LESS -> Expressions.lessThan(compareResult, zero); + case LESS_OR_EQUAL -> Expressions.lessThanOrEqual(compareResult, zero); + case GREATER -> Expressions.greaterThan(compareResult, zero); + case GREATER_OR_EQUAL -> Expressions.greaterThanOrEqual(compareResult, zero); }; } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java index 5681f8f19f9..e29c17dba49 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java @@ -14,10 +14,9 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.expression.function.ImplementorUDF; import org.opensearch.sql.expression.function.UDFOperandMetadata; @@ -40,11 +39,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER_INTEGER.or( - OperandTypes.family( - SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.STRING_OR_INTEGER_INTEGER_INTEGER; } public static class ConvImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java b/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java index 04d8c791529..34c7a198e8c 100644 --- a/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java +++ b/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java @@ -5,10 +5,6 @@ package org.opensearch.sql.utils; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP; - import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; @@ -19,18 +15,11 @@ import java.time.format.DateTimeParseException; import java.time.temporal.ChronoUnit; import java.util.Locale; -import java.util.Objects; import java.util.regex.Pattern; import lombok.experimental.UtilityClass; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.type.SqlTypeName; -import org.opensearch.sql.calcite.CalcitePlanContext; -import org.opensearch.sql.calcite.type.ExprSqlType; -import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.data.model.ExprTimeValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; -import org.opensearch.sql.expression.function.PPLBuiltinOperators; @UtilityClass public class DateTimeUtils { @@ -305,78 +294,4 @@ private static String normalizeUnit(String rawUnit) { } } } - - /** - * The function add cast for date-related target node - * - * @param candidate The candidate node - * @param context calcite context - * @param castTarget the target cast type - * @return the rexnode after casting - */ - public static RexNode transferCompareForDateRelated( - RexNode candidate, CalcitePlanContext context, SqlTypeName castTarget) { - if (!(Objects.isNull(castTarget))) { - switch (castTarget) { - case DATE: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_DATE)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.DATE, candidate); - } - break; - case TIME: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_TIME)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.TIME, candidate); - } - break; - case TIMESTAMP: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_TIMESTAMP)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.TIMESTAMP, candidate); - } - break; - default: - return candidate; - } - } - return candidate; - } - - /** - * The function find the target cast type according to the left and right node. When the two node - * are both related to date with different type, cast to timestamp - * - * @param left - * @param right - * @return - */ - public static SqlTypeName findCastType(RexNode left, RexNode right) { - SqlTypeName leftType = returnCorrespondingSqlType(left); - SqlTypeName rightType = returnCorrespondingSqlType(right); - if (leftType != null && rightType != null && rightType != leftType) { - return SqlTypeName.TIMESTAMP; - } - return leftType == null ? rightType : leftType; - } - - /** - * Find corresponding cast type according to the node's type. If they're not related to the date, - * return null - * - * @param node the candidate node - * @return the sql type name - */ - public static SqlTypeName returnCorrespondingSqlType(RexNode node) { - if (node.getType() instanceof ExprSqlType) { - OpenSearchTypeFactory.ExprUDT udt = ((ExprSqlType) node.getType()).getUdt(); - return switch (udt) { - case EXPR_DATE -> SqlTypeName.DATE; - case EXPR_TIME -> SqlTypeName.TIME; - case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP; - default -> null; - }; - } - return null; - } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 764a95417e0..4467ab06e0e 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -6,10 +6,13 @@ package org.opensearch.sql.ppl; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS; import static org.opensearch.sql.util.MatcherUtils.assertJsonEqualsIgnoreId; import java.io.IOException; +import java.util.Locale; import org.junit.Ignore; import org.junit.jupiter.api.Test; import org.opensearch.client.ResponseException; @@ -23,6 +26,7 @@ public void init() throws Exception { loadIndex(Index.ACCOUNT); loadIndex(Index.BANK); loadIndex(Index.DATE_FORMATS); + loadIndex(Index.WEBLOG); } @Test @@ -87,6 +91,34 @@ public void testFilterByCompareStringTimePushDownExplain() throws IOException { + "| where custom_time < '2018-11-09 19:00:00.123456789' ")); } + @Test + public void testFilterByCompareIPCoercion() throws IOException { + // Should automatically cast the string literal to IP. + // TODO: Push down IP comparison as range query with Calcite + String expected = loadExpectedPlan("explain_filter_compare_ip.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + String.format( + Locale.ROOT, + "source=%s | where host > '1.1.1.1' | fields host", + TEST_INDEX_WEBLOGS))); + } + + @Test + public void testWeekArgumentCoercion() throws IOException { + String expected = loadExpectedPlan("explain_week_argument_coercion.json"); + // Week accepts WEEK(timestamp/date/time, [optional int]), it should cast the string + // argument to timestamp with Calcite. In v2, it accepts string, so there is no cast. + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + String.format( + Locale.ROOT, + "source=%s | eval w = week('2024-12-10') | fields w", + TEST_INDEX_ACCOUNT))); + } + @Test public void testFilterAndAggPushDownExplain() throws IOException { String expected = loadExpectedPlan("explain_filter_agg_push.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json new file mode 100644 index 00000000000..ae49a08b653 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(host=[$0])\n LogicalFilter(condition=[GREATER_IP($0, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", + "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]], PushDownContext=[[PROJECT->[host], SCRIPT->GREATER_IP($0, IP('1.1.1.1':VARCHAR))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"query\":{\"script\":{\"script\":{\"source\":\"{\\\"langType\\\":\\\"calcite\\\",\\\"script\\\":\\\"rO0ABXNyABFqYXZhLnV0aWwuQ29sbFNlcleOq7Y6G6gRAwABSQADdGFneHAAAAADdwQAAAAGdAAHcm93VHlwZXQAensKICAiZmllbGRzIjogWwogICAgewogICAgICAidHlwZSI6ICJPVEhFUiIsCiAgICAgICJudWxsYWJsZSI6IHRydWUsCiAgICAgICJuYW1lIjogImhvc3QiCiAgICB9CiAgXSwKICAibnVsbGFibGUiOiBmYWxzZQp9dAAEZXhwcnQDfXsKICAib3AiOiB7CiAgICAibmFtZSI6ICJHUkVBVEVSX0lQIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJpbnB1dCI6IDAsCiAgICAgICJuYW1lIjogIiQwIgogICAgfSwKICAgIHsKICAgICAgIm9wIjogewogICAgICAgICJuYW1lIjogIklQIiwKICAgICAgICAia2luZCI6ICJPVEhFUl9GVU5DVElPTiIsCiAgICAgICAgInN5bnRheCI6ICJGVU5DVElPTiIKICAgICAgfSwKICAgICAgIm9wZXJhbmRzIjogWwogICAgICAgIHsKICAgICAgICAgICJsaXRlcmFsIjogIjEuMS4xLjEiLAogICAgICAgICAgInR5cGUiOiB7CiAgICAgICAgICAgICJ0eXBlIjogIlZBUkNIQVIiLAogICAgICAgICAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICAgICAgICAgInByZWNpc2lvbiI6IC0xCiAgICAgICAgICB9CiAgICAgICAgfQogICAgICBdLAogICAgICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIk9USEVSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0sCiAgICAgICJkZXRlcm1pbmlzdGljIjogdHJ1ZSwKICAgICAgImR5bmFtaWMiOiBmYWxzZQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQk9PTEVBTiIsCiAgICAibnVsbGFibGUiOiB0cnVlCiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9dAAKZmllbGRUeXBlc3NyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAABdAAEaG9zdH5yAClvcmcub3BlbnNlYXJjaC5zcWwuZGF0YS50eXBlLkV4cHJDb3JlVHlwZQAAAAAAAAAAEgAAeHIADmphdmEubGFuZy5FbnVtAAAAAAAAAAASAAB4cHQAAklQeHg=\\\"}\",\"lang\":\"opensearch_compounded_script\",\"params\":{\"utcTimestamp\":1753756416891521000}},\"boost\":1.0}},\"_source\":{\"includes\":[\"host\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_week_argument_coercion.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_week_argument_coercion.json new file mode 100644 index 00000000000..c65b74d161e --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_week_argument_coercion.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(w=[WEEK(TIMESTAMP('2024-12-10':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableCalc(expr#0..16=[{inputs}], expr#17=['2024-12-10':VARCHAR], expr#18=[TIMESTAMP($t17)], expr#19=[WEEK($t18)], w=[$t19])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json new file mode 100644 index 00000000000..4d275fcd4ab --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(host=[$0])\n LogicalFilter(condition=[GREATER_IP($0, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", + "physical": "EnumerableCalc(expr#0..11=[{inputs}], expr#12=['1.1.1.1':VARCHAR], expr#13=[IP($t12)], expr#14=[GREATER_IP($t0, $t13)], host=[$t0], $condition=[$t14])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_week_argument_coercion.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_week_argument_coercion.json new file mode 100644 index 00000000000..c65b74d161e --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_week_argument_coercion.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(w=[WEEK(TIMESTAMP('2024-12-10':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableCalc(expr#0..16=[{inputs}], expr#17=['2024-12-10':VARCHAR], expr#18=[TIMESTAMP($t17)], expr#19=[WEEK($t18)], w=[$t19])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json new file mode 100644 index 00000000000..8f6706d8fb4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json @@ -0,0 +1,17 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[host]" + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_weblogs, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"host\":{\"from\":\"1.1.1.1\",\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"host\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + } + ] + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json new file mode 100644 index 00000000000..bc88f5eedfa --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json @@ -0,0 +1,27 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[w]" + }, + "children": [ + { + "name": "OpenSearchEvalOperator", + "description": { + "expressions": { + "w": "week(\"2024-12-10\")" + } + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\"}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + } + ] + } + ] + } +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java index ae7eb1a3a4e..9b3c4b0a1f5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java @@ -19,6 +19,7 @@ import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.opensearch.geospatial.action.IpEnrichmentActionClient; import org.opensearch.sql.common.utils.StringUtils; @@ -60,7 +61,9 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING.or(OperandTypes.STRING_STRING_STRING)); + OperandTypes.CHARACTER_CHARACTER.or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); } public static class GeoIPImplementor implements NotNullImplementor { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index b52de9448ed..977805f9d1a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -35,11 +35,7 @@ public void testTimeDiffWithUdtInputType() { getRelNode(timePpl); Throwable t = Assert.assertThrows(Exception.class, () -> getRelNode(wrongPpl)); verifyErrorMessageContains( - t, - "TIMEDIFF function expects" - + " {[STRING,STRING],[DATE,DATE],[DATE,TIME],[DATE,TIMESTAMP],[TIME,DATE],[TIME,TIME],[TIME,TIMESTAMP]," - + "[TIMESTAMP,DATE],[TIMESTAMP,TIME],[TIMESTAMP,TIMESTAMP],[DATE,STRING],[TIME,STRING],[TIMESTAMP,STRING],[STRING,DATE],[STRING,TIME],[STRING,TIMESTAMP]}," - + " but got [INTEGER,STRING]"); + t, "TIMEDIFF function expects {[TIME,TIME]}, but got [INTEGER,STRING]"); } @Test @@ -51,8 +47,8 @@ public void testComparisonWithDifferentType() { verifyErrorMessageContains( t, // Temporary fix for the error message as LESS function has two variants. Will remove - // [IP,IP],[IP,STRING],[STRING,IP] when merging the two variants. - "LESS function expects {[IP,IP],[IP,STRING],[STRING,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + // [IP,IP] when merging the two variants. + "LESS function expects {[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + " but got [STRING,INTEGER]"); } @@ -107,10 +103,10 @@ public void testTimestampWithWrongArg() { Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); verifyErrorMessageContains( t, - "TIMESTAMP function expects" - + " {[STRING],[DATE],[TIME],[TIMESTAMP],[STRING,STRING],[DATE,DATE],[DATE,TIME],[DATE,TIMESTAMP]," - + "[TIME,DATE],[TIME,TIME],[TIME,TIMESTAMP],[TIMESTAMP,DATE],[TIMESTAMP,TIME],[TIMESTAMP,TIMESTAMP]" - + ",[STRING,DATE],[STRING,TIME],[STRING,TIMESTAMP],[DATE,STRING],[TIME,STRING],[TIMESTAMP,STRING]}," + "TIMESTAMP function expects {" + + "[STRING],[TIMESTAMP],[DATE],[TIME],[STRING,STRING],[TIMESTAMP,TIMESTAMP],[TIMESTAMP,DATE]," + + "[TIMESTAMP,TIME],[DATE,TIMESTAMP],[DATE,DATE],[DATE,TIME],[TIME,TIMESTAMP],[TIME,DATE]," + + "[TIME,TIME],[STRING,TIMESTAMP],[STRING,DATE],[STRING,TIME],[TIMESTAMP,STRING],[DATE,STRING],[TIME,STRING]}," + " but got [STRING,INTEGER]"); }