diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index b40480e8049..ad7547a2d54 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -11,8 +11,6 @@ import static org.opensearch.sql.ast.expression.SpanUnit.NONE; import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; -import static org.opensearch.sql.utils.DateTimeUtils.findCastType; -import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated; import java.math.BigDecimal; import java.util.ArrayList; @@ -23,7 +21,6 @@ import java.util.Locale; import java.util.Map; -import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; import lombok.RequiredArgsConstructor; @@ -32,7 +29,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexLambda; import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlIntervalQualifier; @@ -217,11 +213,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) { @Override public RexNode visitCompare(Compare node, CalcitePlanContext context) { - RexNode leftCandidate = analyze(node.getLeft(), context); - RexNode rightCandidate = analyze(node.getRight(), context); - SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate); - final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget); - final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget); + RexNode left = analyze(node.getLeft(), context); + RexNode right = analyze(node.getRight(), context); return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right); } @@ -470,19 +463,6 @@ private List modifyLambdaTypeByFunction( } } - private List castArgument( - List originalArguments, String functionName, ExtendedRexBuilder rexBuilder) { - switch (functionName.toUpperCase(Locale.ROOT)) { - case "REDUCE": - RexLambda call = (RexLambda) originalArguments.get(2); - originalArguments.set( - 1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true)); - return originalArguments; - default: - return originalArguments; - } - } - @Override public RexNode visitFunction(Function node, CalcitePlanContext context) { List args = node.getFuncArgs(); @@ -509,7 +489,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { } } - arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder); RexNode resolvedNode = PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0])); diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java index 9f818be10ec..8ae08d2c7fd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java @@ -27,40 +27,121 @@ private PPLOperandTypes() {} UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family())); public static final UDFOperandMetadata STRING = - UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING); + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER); public static final UDFOperandMetadata INTEGER = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER); public static final UDFOperandMetadata NUMERIC = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC); + + public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC.or( + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER))); + public static final UDFOperandMetadata INTEGER_INTEGER = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER); public static final UDFOperandMetadata STRING_STRING = - UDFOperandMetadata.wrap(OperandTypes.STRING_STRING); + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER); public static final UDFOperandMetadata NUMERIC_NUMERIC = UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC); + public static final UDFOperandMetadata STRING_INTEGER = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)); + public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC = UDFOperandMetadata.wrap( OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)); + public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER))); + + public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family())); public static final UDFOperandMetadata DATETIME_OR_STRING = UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING)); + (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER)); + public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP)); + public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER)); + public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER)); + + public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.DATETIME.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + public static final UDFOperandMetadata DATETIME_DATETIME = UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)); + public static final UDFOperandMetadata DATETIME_OR_STRING_STRING = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER) + .or(OperandTypes.CHARACTER_CHARACTER)); public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING = UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING + OperandTypes.CHARACTER_CHARACTER .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME))); - public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING = + .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))); + public static final UDFOperandMetadata STRING_TIMESTAMP = + UDFOperandMetadata.wrap( + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)); + public static final UDFOperandMetadata STRING_DATETIME = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)); + public static final UDFOperandMetadata DATETIME_INTERVAL = + UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL); + public static final UDFOperandMetadata TIME_TIME = + UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME)); + + public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING = UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP)); - public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING = + OperandTypes.family( + SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER))); + public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING = UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING)); - public static final UDFOperandMetadata STRING_TIMESTAMP = - UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP)); + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME))); + public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME = + UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER))); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 5834d3f5942..6e278c4192e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -9,6 +9,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.Set; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -380,4 +381,13 @@ public static Optional ofWindowFunction(String functionName return Optional.ofNullable( WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); } + + public static final Set COMPARATORS = + Set.of( + BuiltinFunctionName.EQUAL, + BuiltinFunctionName.NOTEQUAL, + BuiltinFunctionName.LESS, + BuiltinFunctionName.LTE, + BuiltinFunctionName.GREATER, + BuiltinFunctionName.GTE); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java index 1069e3e5425..de2f80e6554 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java @@ -27,12 +27,12 @@ public PPLTypeChecker getTypeChecker() { return typeChecker; } - public boolean match(FunctionName functionName, List paramTypeList) { + public boolean match(FunctionName functionName, List argTypes) { if (!functionName.equals(this.functionName)) return false; // For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED), // the typeChecker will be null because only simple family-based type checks are currently // supported. if (typeChecker == null) return true; - return typeChecker.checkOperandTypes(paramTypeList); + return typeChecker.checkOperandTypes(argTypes); } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java new file mode 100644 index 00000000000..d521df01877 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.data.type.WideningTypeRule; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +public class CoercionUtils { + + /** + * Casts the arguments to the types specified in the typeChecker. Returns null if no combination + * of parameter types matches the arguments or if casting fails. + * + * @param builder RexBuilder to create casts + * @param typeChecker PPLTypeChecker that provides the parameter types + * @param arguments List of RexNode arguments to be cast + * @return List of cast RexNode arguments or null if casting fails + */ + public static @Nullable List castArguments( + RexBuilder builder, PPLTypeChecker typeChecker, List arguments) { + List> paramTypeCombinations = typeChecker.getParameterTypes(); + + // TODO: var args? + + for (List paramTypes : paramTypeCombinations) { + List castedArguments = castArguments(builder, paramTypes, arguments); + if (castedArguments != null) { + return castedArguments; + } + } + return null; + } + + /** + * Widen the arguments to the widest type found among them. If no widest type can be determined, + * returns null. + * + * @param builder RexBuilder to create casts + * @param arguments List of RexNode arguments to be widened + * @return List of widened RexNode arguments or null if no widest type can be determined + */ + public static @Nullable List widenArguments( + RexBuilder builder, List arguments) { + // TODO: Add test on e.g. IP + ExprType widestType = findWidestType(arguments); + if (widestType == null) { + return null; // No widest type found, return null + } + return arguments.stream().map(arg -> cast(builder, widestType, arg)).collect(Collectors.toList()); + } + + /** + * Casts the arguments to the types specified in paramTypes. Returns null if the number of + * parameters does not match or if casting fails. + */ + private static @Nullable List castArguments( + RexBuilder builder, List paramTypes, List arguments) { + if (paramTypes.size() != arguments.size()) { + return null; // Skip if the number of parameters does not match + } + + List castedArguments = new ArrayList<>(); + for (int i = 0; i < paramTypes.size(); i++) { + ExprType toType = paramTypes.get(i); + RexNode arg = arguments.get(i); + + RexNode castedArg = cast(builder, toType, arg); + + if (castedArg == null) { + return null; + } + castedArguments.add(castedArg); + } + return castedArguments; + } + + private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) { + ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType()); + if (!argType.shouldCast(targetType)) { + return arg; + } + + if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) { + return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg); + } + return null; + } + + /** + * Finds the widest type among the given arguments. The widest type is determined by applying the + * widening type rule to each pair of types in the arguments. + * + * @param arguments List of RexNode arguments to find the widest type from + * @return the widest ExprType if found, otherwise null + */ + private static @Nullable ExprType findWidestType(List arguments) { + if (arguments.isEmpty()) { + return null; // No arguments to process + } + ExprType widestType = + OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(0).getType()); + if (arguments.size() == 1) { + return widestType; + } + + // Iterate pairwise through the arguments and find the widest type + for (int i = 1; i < arguments.size(); i++) { + var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType()); + try { + if (areDateAndTime(widestType, type)) { + // If one is date and the other is time, we consider timestamp as the widest type + widestType = ExprCoreType.TIMESTAMP; + } else { + widestType = WideningTypeRule.max(widestType, type); + } + } catch (ExpressionEvaluationException e) { + // the two types are not compatible, return null + return null; + } + } + return widestType; + } + + private static boolean areDateAndTime(ExprType type1, ExprType type2) { + return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME) + || (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 02ddfd78dee..257f0c04668 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -21,10 +21,7 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rex.RexCall; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; @@ -144,7 +141,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprAddTime", PPLReturnTypes.TIME_APPLY_RETURN_TYPE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("ADDTIME"); public static final SqlOperator SUBTIME = adaptExprMethodWithPropertiesToUDF( @@ -152,7 +149,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprSubTime", PPLReturnTypes.TIME_APPLY_RETURN_TYPE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("SUBTIME"); public static final SqlOperator ADDDATE = new AddSubDateFunction(true).toUDF("ADDDATE"); public static final SqlOperator SUBDATE = new AddSubDateFunction(false).toUDF("SUBDATE"); @@ -198,11 +195,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprConvertTZ", PPLReturnTypes.TIMESTAMP_FORCE_NULLABLE, NullPolicy.ANY, - UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING_STRING.or( - OperandTypes.family( - SqlTypeFamily.DATETIME, SqlTypeFamily.STRING, SqlTypeFamily.STRING)))) + PPLOperandTypes.TIMESTAMP_OR_STRING_STRING_STRING) .toUDF("CONVERT_TZ"); public static final SqlOperator DATEDIFF = adaptExprMethodWithPropertiesToUDF( @@ -210,7 +203,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprDateDiff", ReturnTypes.BIGINT_FORCE_NULLABLE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.DATETIME_DATETIME) .toUDF("DATEDIFF"); public static final SqlOperator TIMESTAMPDIFF = new TimestampDiffFunction().toUDF("TIMESTAMPDIFF"); @@ -291,7 +284,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprTimeToSec", ReturnTypes.BIGINT_FORCE_NULLABLE, NullPolicy.ARG0, - PPLOperandTypes.DATETIME_OR_STRING) + PPLOperandTypes.TIME_OR_TIMESTAMP_OR_STRING) .toUDF("TIME_TO_SEC"); public static final SqlOperator TIMEDIFF = UserDefinedFunctionUtils.adaptExprMethodToUDF( @@ -299,7 +292,7 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "exprTimeDiff", PPLReturnTypes.TIME_FORCE_NULLABLE, NullPolicy.ANY, - PPLOperandTypes.DATETIME_OR_STRING_DATETIME_OR_STRING) + PPLOperandTypes.TIME_TIME) .toUDF("TIME_DIFF"); public static final SqlOperator TIMESTAMPADD = new TimestampAddFunction().toUDF("TIMESTAMPADD"); public static final SqlOperator TO_DAYS = diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 98b35930513..a8bb5d0dc42 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -226,10 +226,11 @@ import java.util.StringJoiner; import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; -import java.util.stream.Collectors; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLambda; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; @@ -243,7 +244,6 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; -import org.checkerframework.checker.nullness.qual.Nullable; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.calcite.tools.RelBuilder; import org.apache.commons.lang3.function.TriFunction; @@ -258,7 +258,6 @@ import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.exception.ExpressionEvaluationException; -import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.executor.QueryType; public class PPLFuncImpTable { @@ -274,7 +273,6 @@ RelBuilder.AggCall apply( public interface FunctionImp { RelDataType ANY_TYPE = TYPE_FACTORY.createSqlType(SqlTypeName.ANY); - // TODO: Support argument coercion and casting RexNode resolve(RexBuilder builder, RexNode... args); /** @@ -449,6 +447,12 @@ public RexNode resolve( if (implementList == null || implementList.isEmpty()) { throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName)); } + + // Make compulsory casts for some functions that require specific casting of arguments. + // For example, the REDUCE function requires the second argument to be cast to the + // return type of the lambda function. + compulsoryCast(builder, functionName, args); + List argTypes = Arrays.stream(args).map(RexNode::getType).collect(Collectors.toList()); try { for (Map.Entry implement : implementList) { @@ -456,6 +460,13 @@ public RexNode resolve( return implement.getValue().resolve(builder, args); } } + + // If no implementation found with exact match, try to cast arguments to match the + // signatures. + RexNode coerced = resolveWithCoercion(builder, functionName, implementList, args); + if (coerced != null) { + return coerced; + } } catch (Exception e) { throw new ExpressionEvaluationException( String.format( @@ -476,6 +487,63 @@ functionName, getActualSignature(argTypes), e.getMessage()), functionName, allowedSignatures, getActualSignature(argTypes))); } + /** + * Ad-hoc coercion for some functions that require specific casting of arguments. Now it only + * applies to the REDUCE function. + */ + private void compulsoryCast( + final RexBuilder builder, final BuiltinFunctionName functionName, RexNode... args) { + + //noinspection SwitchStatementWithTooFewBranches + switch (functionName) { + case REDUCE: + // Set the second argument to the return type of the lambda function, so that + // code generated with linq4j can correctly accumulate the result. + RexLambda call = (RexLambda) args[2]; + args[1] = builder.makeCast(call.getType(), args[1], true, true); + break; + default: + break; + } + } + + private @Nullable RexNode resolveWithCoercion( + final RexBuilder builder, + final BuiltinFunctionName functionName, + List> implementList, + RexNode... args) { + if (BuiltinFunctionName.COMPARATORS.contains(functionName)) { + for (Map.Entry implement : implementList) { + var widenedArgs = CoercionUtils.widenArguments(builder, List.of(args)); + if (widenedArgs != null) { + boolean matchSignature = + implement + .getKey() + .getTypeChecker() + .checkOperandTypes(widenedArgs.stream().map(RexNode::getType).collect(Collectors.toList())); + if (matchSignature) { + return implement.getValue().resolve(builder, widenedArgs.toArray(new RexNode[0])); + } + } + } + } else { + for (Map.Entry implement : implementList) { + var signature = implement.getKey(); + var castedArgs = + CoercionUtils.castArguments(builder, signature.getTypeChecker(), List.of(args)); + if (castedArgs != null) { + // If compatible function is found, replace the original RexNode with cast node + // TODO: check - this is a return-once-found implementation, rest possible combinations + // will be skipped. + // Maybe can be improved to return the best match? E.g. convert to timestamp when date, + // time, and timestamp are all possible. + return implement.getValue().resolve(builder, castedArgs.toArray(new RexNode[0])); + } + } + } + return null; + } + private static String getActualSignature(List argTypes) { return "[" + argTypes.stream() @@ -881,7 +949,7 @@ void populate() { builder.makeFlag(Flag.BOTH), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( LTRIM, @@ -892,7 +960,7 @@ void populate() { builder.makeFlag(Flag.LEADING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( RTRIM, createFunctionImpWithTypeChecker( @@ -902,7 +970,7 @@ void populate() { builder.makeFlag(Flag.TRAILING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); register( ATAN, createFunctionImpWithTypeChecker( @@ -912,7 +980,7 @@ void populate() { STRCMP, createFunctionImpWithTypeChecker( (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); // SqlStdOperatorTable.SUBSTRING.getOperandTypeChecker is null. We manually create a type // checker for it. register( @@ -920,14 +988,24 @@ void populate() { wrapWithCompositeTypeChecker( SqlStdOperatorTable.SUBSTRING, (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.INTEGER, + SqlTypeFamily.INTEGER)), false)); register( SUBSTR, wrapWithCompositeTypeChecker( SqlStdOperatorTable.SUBSTRING, (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) + .or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, + SqlTypeFamily.INTEGER, + SqlTypeFamily.INTEGER)), false)); // SqlStdOperatorTable.ITEM.getOperandTypeChecker() checks only the first operand instead of // all operands. Therefore, we wrap it with a custom CompositeOperandTypeChecker to check both diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java index 325c57689b9..363b806a2c2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java @@ -15,6 +15,9 @@ import java.util.stream.IntStream; import lombok.RequiredArgsConstructor; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.FamilyOperandTypeChecker; import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; @@ -23,6 +26,8 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Pair; +import org.opensearch.sql.calcite.type.ExprIPType; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -53,6 +58,16 @@ public interface PPLTypeChecker { */ String getAllowedSignatures(); + /** + * Get a list of all possible parameter type combinations for the function. + * + *

This method is used to generate the allowed signatures for the function based on the + * parameter types. + * + * @return a list of lists, where each inner list represents an allowed parameter type combination + */ + List> getParameterTypes(); + private static boolean validateOperands( List funcTypeFamilies, List operandTypes) { // If the number of actual operands does not match expectation, return false @@ -95,6 +110,11 @@ public String getAllowedSignatures() { return PPLTypeChecker.getFamilySignature(families); } + @Override + public List> getParameterTypes() { + return PPLTypeChecker.getExprSignatures(families); + } + @Override public String toString() { return String.format("PPLFamilyTypeChecker[families=%s]", getAllowedSignatures()); @@ -135,12 +155,24 @@ public boolean checkOperandTypes(List types) { public String getAllowedSignatures() { if (innerTypeChecker instanceof FamilyOperandTypeChecker) { FamilyOperandTypeChecker familyOperandTypeChecker = (FamilyOperandTypeChecker) innerTypeChecker; - var allowedSignatures = PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker); - return String.join(",", allowedSignatures); + var allowedExprSignatures = getExprSignatures(familyOperandTypeChecker); + return PPLTypeChecker.formatExprSignatures(allowedExprSignatures); } else { return ""; } } + + @Override + public List> getParameterTypes() { + if (innerTypeChecker instanceof FamilyOperandTypeChecker) { + FamilyOperandTypeChecker familyOperandTypeChecker = (FamilyOperandTypeChecker) innerTypeChecker; + return getExprSignatures(familyOperandTypeChecker); + } else { + // If the inner type checker is not a FamilyOperandTypeChecker, we cannot provide + // parameter types. + return Collections.emptyList(); + } + } } /** @@ -154,6 +186,7 @@ public String getAllowedSignatures() { * ImplicitCastOperandTypeChecker}. */ class PPLCompositeTypeChecker implements PPLTypeChecker { + private final List allowedRules; public PPLCompositeTypeChecker(CompositeOperandTypeChecker typeChecker) { @@ -190,17 +223,35 @@ public boolean checkOperandTypes(List types) { @Override public String getAllowedSignatures() { - List allowedSignatures = new ArrayList<>(); + StringBuilder builder = new StringBuilder(); for (SqlOperandTypeChecker rule : allowedRules) { if (rule instanceof FamilyOperandTypeChecker) { - FamilyOperandTypeChecker familyOperandTypeChecker = (FamilyOperandTypeChecker) rule; - allowedSignatures.addAll(PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker)); + FamilyOperandTypeChecker familyOperandTypeChecker = (FamilyOperandTypeChecker) rule; + if (builder.length() > 0) { + builder.append(","); + } + builder.append(PPLTypeChecker.getFamilySignatures(familyOperandTypeChecker)); } else { throw new IllegalArgumentException( "Currently only compositions of FamilyOperandTypeChecker are supported"); } } - return String.join(",", allowedSignatures); + return builder.toString(); + } + + @Override + public List> getParameterTypes() { + List> parameterTypes = new ArrayList<>(); + for (SqlOperandTypeChecker rule : allowedRules) { + if (rule instanceof FamilyOperandTypeChecker) { + FamilyOperandTypeChecker familyOperandTypeChecker = (FamilyOperandTypeChecker) rule; + parameterTypes.addAll(getExprSignatures(familyOperandTypeChecker)); + } else { + throw new IllegalArgumentException( + "Currently only compositions of FamilyOperandTypeChecker are supported"); + } + } + return parameterTypes; } } @@ -217,34 +268,60 @@ public boolean checkOperandTypes(List types) { for (int i = 0; i < types.size() - 1; i++) { // TODO: Binary, Array UDT? // DATETIME, NUMERIC, BOOLEAN will be regarded as comparable - // with strings in SqlTypeUtil.isComparable + // with strings in isComparable RelDataType type_l = types.get(i); RelDataType type_r = types.get(i + 1); - if (!SqlTypeUtil.isComparable(type_l, type_r)) { + // Rule out IP types from built-in comparable functions + if (type_l instanceof ExprIPType || type_r instanceof ExprIPType) { return false; } - // Disallow coercing between strings and numeric, boolean - if ((type_l.getFamily() == SqlTypeFamily.CHARACTER - && cannotConvertStringInCompare((SqlTypeFamily) type_r.getFamily())) - || (type_r.getFamily() == SqlTypeFamily.CHARACTER - && cannotConvertStringInCompare((SqlTypeFamily) type_l.getFamily()))) { + if (!isComparable(type_l, type_r)) { return false; } } return true; } - private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) { - switch (typeFamily) { - case BOOLEAN: - case INTEGER: - case NUMERIC: - case EXACT_NUMERIC: - case APPROXIMATE_NUMERIC: - return true; - default: + /** + * Modified from {@link SqlTypeUtil#isComparable(RelDataType, RelDataType)} to + * + * @param type1 first type + * @param type2 second type + * @return true if the two types are comparable, false otherwise + */ + private static boolean isComparable(RelDataType type1, RelDataType type2) { + if (type1.isStruct() != type2.isStruct()) { + return false; + } + + if (type1.isStruct()) { + int n = type1.getFieldCount(); + if (n != type2.getFieldCount()) { return false; + } + for (Pair pair : + Pair.zip(type1.getFieldList(), type2.getFieldList())) { + if (!isComparable(pair.left.getType(), pair.right.getType())) { + return false; + } + } + return true; + } + + // Numeric types are comparable without the need to cast + if (SqlTypeUtil.isNumeric(type1) && SqlTypeUtil.isNumeric(type2)) { + return true; } + + ExprType exprType1 = OpenSearchTypeFactory.convertRelDataTypeToExprType(type1); + ExprType exprType2 = OpenSearchTypeFactory.convertRelDataTypeToExprType(type2); + + if (!exprType1.shouldCast(exprType2)) { + return true; + } + + // If one of the arguments is of type 'ANY', return true. + return type1.getFamily() == SqlTypeFamily.ANY || type2.getFamily() == SqlTypeFamily.ANY; } @Override @@ -267,6 +344,12 @@ public String getAllowedSignatures() { return String.join(",", signatures); } } + + @Override + public List> getParameterTypes() { + // Should not be used + return List.of(List.of(ExprCoreType.UNKNOWN, ExprCoreType.UNKNOWN)); + } } /** @@ -375,7 +458,12 @@ public boolean checkOperandTypes(List types) { @Override public String getAllowedSignatures() { - return PPLTypeChecker.getExprFamilySignature(allowedSignatures); + return PPLTypeChecker.formatExprSignatures(allowedSignatures); + } + + @Override + public List> getParameterTypes() { + return allowedSignatures; } }; } @@ -392,24 +480,27 @@ public String getAllowedSignatures() { * @param typeChecker the {@link FamilyOperandTypeChecker} to use for generating signatures * @return a list of allowed function signatures */ - private static List getFamilySignatures(FamilyOperandTypeChecker typeChecker) { + private static String getFamilySignatures(FamilyOperandTypeChecker typeChecker) { + var allowedExprSignatures = getExprSignatures(typeChecker); + return formatExprSignatures(allowedExprSignatures); + } + + private static List> getExprSignatures(FamilyOperandTypeChecker typeChecker) { var operandCountRange = typeChecker.getOperandCountRange(); int min = operandCountRange.getMin(); int max = operandCountRange.getMax(); - List allowedSignatures = new ArrayList<>(); List families = new ArrayList<>(); for (int i = 0; i < min; i++) { families.add(typeChecker.getOperandSqlTypeFamily(i)); } - allowedSignatures.add(getFamilySignature(families)); + List> allowedSignatures = new ArrayList<>(getExprSignatures(families)); // Avoid enumerating signatures for infinite args final int MAX_ARGS = 10; max = Math.min(max, MAX_ARGS); - for (int i = min; i < max; i++) { families.add(typeChecker.getOperandSqlTypeFamily(i)); - allowedSignatures.add(getFamilySignature(families)); + allowedSignatures.addAll(getExprSignatures(families)); } return allowedSignatures; } @@ -423,42 +514,63 @@ private static List getFamilySignatures(FamilyOperandTypeChecker typeChe */ private static List getExprTypes(SqlTypeFamily family) { List concreteTypes; - switch (family) { - case DATETIME: - concreteTypes = List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DATE), - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIME), - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP)); - break; - case NUMERIC: - concreteTypes = List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER), - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE)); - break; - case INTEGER: - concreteTypes = List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); - break; - case ANY: - case IGNORE: - concreteTypes = List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY)); - break; - default: - RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY); - if (type == null) { - concreteTypes = List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.OTHER)); - } else { - concreteTypes = List.of(type); - } - break; - } + switch (family) { + case DATETIME: + concreteTypes = List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP), + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DATE), + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.TIME)); + break; + case NUMERIC: concreteTypes = List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER), + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE)); + break; + // Integer is mapped to BIGINT in family.getDefaultConcreteType + case INTEGER: concreteTypes = List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); + break; + case ANY: + case IGNORE: concreteTypes = List.of( + OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY)); + break; + case DATETIME_INTERVAL: concreteTypes = SqlTypeName.INTERVAL_TYPES.stream() + .map( + type -> + OpenSearchTypeFactory.TYPE_FACTORY.createSqlIntervalType( + new SqlIntervalQualifier( + type.getStartUnit(), type.getEndUnit(), SqlParserPos.ZERO))) + .collect(Collectors.toList()); + break; + default: + RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY); + if (type == null) { + concreteTypes = List.of(OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.OTHER)); + } else { + concreteTypes = List.of(type); + } + break; + }; return concreteTypes.stream() .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) + .distinct() .collect(Collectors.toList()); } + /** + * Generates a list of all possible {@link ExprType} signatures based on the provided {@link + * SqlTypeFamily} list. + * + * @param families the list of {@link SqlTypeFamily} to generate signatures for + * @return a list of lists, where each inner list contains {@link ExprType} signatures + */ + private static List> getExprSignatures(List families) { + List> exprTypes = + families.stream().map(PPLTypeChecker::getExprTypes).collect(Collectors.toList()); + + // Do a cartesian product of all ExprTypes in the family + return Lists.cartesianProduct(exprTypes); + } + /** * Generates a string representation of the function signature based on the provided type * families. The format is a list of type families enclosed in square brackets, e.g.: "[INTEGER, @@ -468,27 +580,9 @@ private static List getExprTypes(SqlTypeFamily family) { * @return a string representation of the function signature */ private static String getFamilySignature(List families) { - List> exprTypes = - families.stream().map(PPLTypeChecker::getExprTypes).collect(Collectors.toList()); - - // Do a cartesian product of all ExprTypes in the family - List> signatures = Lists.cartesianProduct(exprTypes); - + List> signatures = getExprSignatures(families); // Convert each signature to a string representation and then concatenate them - return getExprFamilySignature(signatures); - } - - private static String getExprFamilySignature(List> signatures) { - return signatures.stream() - .map( - types -> - "[" - + types.stream() - // Display ExprCoreType.UNDEFINED as "ANY" for better interpretability - .map(t -> t == ExprCoreType.UNDEFINED ? "ANY" : t.toString()) - .collect(Collectors.joining(",")) - + "]") - .collect(Collectors.joining(",")); + return formatExprSignatures(signatures); } /** @@ -511,4 +605,17 @@ private static boolean isCompositionOr(CompositeOperandTypeChecker typeChecker) (CompositeOperandTypeChecker.Composition) compositionField.get(typeChecker); return composition == CompositeOperandTypeChecker.Composition.OR; } + + private static String formatExprSignatures(List> signatures) { + return signatures.stream() + .map( + types -> + "[" + + types.stream() + // Display ExprCoreType.UNDEFINED as "ANY" for better interpretability + .map(t -> t == ExprCoreType.UNDEFINED ? "ANY" : t.toString()) + .collect(Collectors.joining(",")) + + "]") + .collect(Collectors.joining(",")); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java index 8a8a5d071f2..b18b85e76b1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java @@ -52,9 +52,9 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.ANY) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.ARRAY))); + OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.ARRAY))); } public static class PatternParserImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java index dc0ae0eb5ae..d4739d5fadd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/CryptographicFunction.java @@ -12,14 +12,13 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.FamilyOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.commons.codec.binary.Hex; import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.codec.digest.MessageDigestAlgorithms; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.expression.function.ImplementorUDF; import org.opensearch.sql.expression.function.UDFOperandMetadata; @@ -32,7 +31,7 @@ public static CryptographicFunction sha2() { return new CryptographicFunction(new Sha2Implementor(), NullPolicy.ANY) { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING_INTEGER); + return PPLOperandTypes.STRING_INTEGER; } }; } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java index b48636615ba..190e3a54f46 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/SpanFunction.java @@ -52,10 +52,11 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING) + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER) .or( OperandTypes.family( - SqlTypeFamily.DATETIME, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)) + SqlTypeFamily.DATETIME, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)) // TODO: numeric span should support decimal as its interval .or( OperandTypes.family( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java index dc6a687825a..b183ca30fbe 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/EarliestFunction.java @@ -5,7 +5,6 @@ package org.opensearch.sql.expression.function.udf.condition; -import static org.opensearch.sql.calcite.utils.PPLOperandTypes.STRING_TIMESTAMP; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.prependFunctionProperties; import static org.opensearch.sql.utils.DateTimeUtils.getRelativeZonedDateTime; @@ -25,6 +24,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; @@ -43,7 +43,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return STRING_TIMESTAMP; + return PPLOperandTypes.STRING_TIMESTAMP; } public static class EarliestImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java index 2f7a968add1..7bc1eb3b316 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/condition/LatestFunction.java @@ -5,7 +5,6 @@ package org.opensearch.sql.expression.function.udf.condition; -import static org.opensearch.sql.calcite.utils.PPLOperandTypes.STRING_TIMESTAMP; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.prependFunctionProperties; import static org.opensearch.sql.utils.DateTimeUtils.getRelativeZonedDateTime; @@ -25,6 +24,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; @@ -43,7 +43,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return STRING_TIMESTAMP; + return PPLOperandTypes.STRING_TIMESTAMP; } public static class LatestImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java index 53563f11971..7d335a68dc4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/AddSubDateFunction.java @@ -76,10 +76,8 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.DATETIME_INTERVAL - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME_INTERVAL)) - .or(OperandTypes.STRING_INTEGER)); + OperandTypes.DATETIME_INTERVAL.or( + OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java index c4a8c0f44b0..de3a2cb65ab 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DateAddSubFunction.java @@ -15,11 +15,9 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.calcite.utils.datetime.DateTimeConversionUtils; @@ -50,10 +48,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME_INTERVAL.or( - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME_INTERVAL))); + return PPLOperandTypes.DATETIME_INTERVAL; } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java index 14866a6eadb..d7a876cfeaa 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/DatetimeFunction.java @@ -50,9 +50,9 @@ public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) OperandTypes.TIMESTAMP_STRING - .or(OperandTypes.STRING_STRING) + .or(OperandTypes.CHARACTER_CHARACTER) .or(OperandTypes.TIMESTAMP) - .or(OperandTypes.STRING)); + .or(OperandTypes.CHARACTER)); } public static class DatetimeImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java index 76c0c27d825..90b7c8a783f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ExtractFunction.java @@ -12,12 +12,10 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; @@ -51,10 +49,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME) - .or(OperandTypes.STRING_STRING)); + return PPLOperandTypes.STRING_DATETIME; } public static class ExtractImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java index 3c38ad95a93..5cf658fb056 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FormatFunction.java @@ -17,11 +17,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; @@ -60,10 +58,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING) - .or(OperandTypes.STRING_STRING)); + return PPLOperandTypes.DATETIME_OR_STRING_STRING; } @RequiredArgsConstructor diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java index 09b59fd96e3..00fa9f690da 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/FromUnixTimeFunction.java @@ -18,10 +18,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.expression.function.ImplementorUDF; @@ -56,10 +54,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.NUMERIC.or( - OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING))); + return PPLOperandTypes.NUMERIC_OPTIONAL_STRING; } public static class FromUnixTimeImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java index 4b8e5d9b0e3..28fdb6b978d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/LastDayFunction.java @@ -12,10 +12,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.calcite.utils.datetime.DateTimeConversionUtils; @@ -48,8 +47,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING)); + return PPLOperandTypes.DATETIME_OR_STRING; } public static class LastDayImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java index 3316c3d82e0..d510c864177 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/PeriodNameFunction.java @@ -16,10 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.type.ExprType; @@ -52,9 +51,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATE.or(OperandTypes.TIMESTAMP).or(OperandTypes.STRING)); + return PPLOperandTypes.DATE_OR_TIMESTAMP_OR_STRING; } public static class PeriodNameFunctionImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java index 9119013ac21..29de7881fa6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampAddFunction.java @@ -16,11 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprLongValue; @@ -58,12 +56,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.STRING) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME))); + return PPLOperandTypes.STRING_INTEGER_DATETIME_OR_STRING; } public static class TimestampAddImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java index c9ca2e1d4f5..295fda44013 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampDiffFunction.java @@ -15,12 +15,10 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; @@ -55,19 +53,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.DATETIME, SqlTypeFamily.STRING)) - .or( - OperandTypes.family( - SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + return PPLOperandTypes.INTERVAL_DATETIME_DATETIME; } public static class DiffImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java index 6041c4d5648..a13a1ce1894 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/TimestampFunction.java @@ -51,12 +51,12 @@ public SqlReturnTypeInference getReturnTypeInference() { public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( (CompositeOperandTypeChecker) - OperandTypes.STRING + OperandTypes.CHARACTER .or(OperandTypes.DATETIME) - .or(OperandTypes.STRING_STRING) + .or(OperandTypes.CHARACTER_CHARACTER) .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))); + .or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)) + .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))); } public static class TimestampImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java index 75e4a0d7228..7cd6d62063f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/ToSecondsFunction.java @@ -19,10 +19,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprLongValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; @@ -54,9 +53,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME.or(OperandTypes.STRING).or(OperandTypes.INTEGER)); + return PPLOperandTypes.DATETIME_OR_STRING_OR_INTEGER; } public static class ToSecondsImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java index 509fee1cdca..5103173b46f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/UnixTimestampFunction.java @@ -15,10 +15,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.function.ImplementorUDF; @@ -50,9 +49,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family())); + return PPLOperandTypes.OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC; } public static class UnixTimestampImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java index b4b1eecbcbc..11564ce47f0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekFunction.java @@ -14,10 +14,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprValue; @@ -52,12 +50,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.STRING_INTEGER) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class WeekImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java index aaa276c8316..7c09645e8e9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/WeekdayFunction.java @@ -16,11 +16,9 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprValue; @@ -53,11 +51,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class WeekdayImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java index 0dd9f5e84bd..9bfe30df39c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/datetime/YearweekFunction.java @@ -16,10 +16,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -52,12 +50,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.DATETIME - .or(OperandTypes.STRING) - .or(OperandTypes.STRING_INTEGER) - .or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.DATETIME_OPTIONAL_INTEGER; } public static class YearweekImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java index 21b4ccfb579..07550829b16 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java @@ -9,6 +9,7 @@ import org.apache.calcite.adapter.enumerable.NotNullImplementor; import org.apache.calcite.adapter.enumerable.NullPolicy; import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.ConstantExpression; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; @@ -67,11 +68,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrapUDT( - List.of( - List.of(ExprCoreType.IP, ExprCoreType.IP), - List.of(ExprCoreType.IP, ExprCoreType.STRING), - List.of(ExprCoreType.STRING, ExprCoreType.IP))); + return UDFOperandMetadata.wrapUDT(List.of(List.of(ExprCoreType.IP, ExprCoreType.IP))); } public static class CompareImplementor implements NotNullImplementor { @@ -96,19 +93,20 @@ public Expression implement( private static Expression generateComparisonExpression( Expression compareResult, ComparisonType comparisonType) { + final ConstantExpression zero = Expressions.constant(0); switch (comparisonType) { case EQUALS: - return Expressions.equal(compareResult, Expressions.constant(0)); + return Expressions.equal(compareResult, zero); case NOT_EQUALS: - return Expressions.notEqual(compareResult, Expressions.constant(0)); + return Expressions.notEqual(compareResult, zero); case LESS: - return Expressions.lessThan(compareResult, Expressions.constant(0)); + return Expressions.lessThan(compareResult, zero); case LESS_OR_EQUAL: - return Expressions.lessThanOrEqual(compareResult, Expressions.constant(0)); + return Expressions.lessThanOrEqual(compareResult, zero); case GREATER: - return Expressions.greaterThan(compareResult, Expressions.constant(0)); + return Expressions.greaterThan(compareResult, zero); case GREATER_OR_EQUAL: - return Expressions.greaterThanOrEqual(compareResult, Expressions.constant(0)); + return Expressions.greaterThanOrEqual(compareResult, zero); default: throw new IllegalArgumentException("Unexpected comparison type: " + comparisonType); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java index 2044e668e7d..fea64b6bc6e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/math/ConvFunction.java @@ -14,10 +14,9 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; +import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.expression.function.ImplementorUDF; import org.opensearch.sql.expression.function.UDFOperandMetadata; @@ -40,11 +39,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER_INTEGER.or( - OperandTypes.family( - SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER))); + return PPLOperandTypes.STRING_OR_INTEGER_INTEGER_INTEGER; } public static class ConvImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java b/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java index 428ba4accab..6b7e64a8865 100644 --- a/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java +++ b/core/src/main/java/org/opensearch/sql/utils/DateTimeUtils.java @@ -5,10 +5,6 @@ package org.opensearch.sql.utils; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP; - import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; @@ -19,18 +15,11 @@ import java.time.format.DateTimeParseException; import java.time.temporal.ChronoUnit; import java.util.Locale; -import java.util.Objects; import java.util.regex.Pattern; import lombok.experimental.UtilityClass; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.type.SqlTypeName; -import org.opensearch.sql.calcite.CalcitePlanContext; -import org.opensearch.sql.calcite.type.ExprSqlType; -import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.data.model.ExprTimeValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.FunctionProperties; -import org.opensearch.sql.expression.function.PPLBuiltinOperators; @UtilityClass public class DateTimeUtils { @@ -368,82 +357,4 @@ private static String normalizeUnit(String rawUnit) { throw new IllegalArgumentException("Unsupported unit alias: " + rawUnit); } } - - /** - * The function add cast for date-related target node - * - * @param candidate The candidate node - * @param context calcite context - * @param castTarget the target cast type - * @return the rexnode after casting - */ - public static RexNode transferCompareForDateRelated( - RexNode candidate, CalcitePlanContext context, SqlTypeName castTarget) { - if (!(Objects.isNull(castTarget))) { - switch (castTarget) { - case DATE: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_DATE)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.DATE, candidate); - } - break; - case TIME: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_TIME)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.TIME, candidate); - } - break; - case TIMESTAMP: - if (!(candidate.getType() instanceof ExprSqlType - && ((ExprSqlType) candidate.getType()).getUdt() == EXPR_TIMESTAMP)) { - return context.rexBuilder.makeCall(PPLBuiltinOperators.TIMESTAMP, candidate); - } - break; - default: - return candidate; - } - } - return candidate; - } - - /** - * The function find the target cast type according to the left and right node. When the two node - * are both related to date with different type, cast to timestamp - * - * @param left - * @param right - * @return - */ - public static SqlTypeName findCastType(RexNode left, RexNode right) { - SqlTypeName leftType = returnCorrespondingSqlType(left); - SqlTypeName rightType = returnCorrespondingSqlType(right); - if (leftType != null && rightType != null && rightType != leftType) { - return SqlTypeName.TIMESTAMP; - } - return leftType == null ? rightType : leftType; - } - - /** - * Find corresponding cast type according to the node's type. If they're not related to the date, - * return null - * - * @param node the candidate node - * @return the sql type name - */ - public static SqlTypeName returnCorrespondingSqlType(RexNode node) { - if (node.getType() instanceof ExprSqlType) { - OpenSearchTypeFactory.ExprUDT udt = ((ExprSqlType) node.getType()).getUdt(); - switch (udt) { - case EXPR_DATE: - return SqlTypeName.DATE; - case EXPR_TIME: - return SqlTypeName.TIME; - case EXPR_TIMESTAMP: - return SqlTypeName.TIMESTAMP; - default: - return null; - } - } - return null; - } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 95a46fc2dc8..cfed5a3384e 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -6,10 +6,16 @@ package org.opensearch.sql.ppl; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS; import static org.opensearch.sql.util.MatcherUtils.assertJsonEqualsIgnoreId; import java.io.IOException; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import org.junit.Ignore; import org.junit.jupiter.api.Test; import org.opensearch.client.ResponseException; @@ -23,6 +29,7 @@ public void init() throws Exception { loadIndex(Index.ACCOUNT); loadIndex(Index.BANK); loadIndex(Index.DATE_FORMATS); + loadIndex(Index.WEBLOG); } @Test @@ -87,6 +94,68 @@ public void testFilterByCompareStringTimePushDownExplain() throws IOException { + "| where custom_time < '2018-11-09 19:00:00.123456789' ")); } + @Test + public void testFilterByCompareIPCoercion() throws IOException { + // Should automatically cast the string literal to IP. + String expected = loadExpectedPlan("explain_filter_compare_ip.json"); + // The index of host is flaky (different from test to test) + assertJsonEqualsIgnoreFieldIndex( + expected, + explainQueryToString( + String.format( + Locale.ROOT, + "source=%s | where host > '1.1.1.1' | fields host", + TEST_INDEX_WEBLOGS))); + } + + private static void assertJsonEqualsIgnoreFieldIndex(String expected, String actual) throws IOException { + String reorderedExpected = maskIndexAndReorderProject(expected); + String reorderedActual = maskIndexAndReorderProject(actual); + assertJsonEqualsIgnoreId(reorderedExpected, reorderedActual); + } + + private static String maskIndexAndReorderProject(String plan) { + // Replace $number or $tnumber with * + Pattern pattern = Pattern.compile("\\$t?(\\d+)"); + Matcher matcher = pattern.matcher(plan); + StringBuilder sb = new StringBuilder(); + while (matcher.find()) { + matcher.appendReplacement(sb, "*"); + } + matcher.appendTail(sb); + String maskedPlan = sb.toString(); + // Reorder logical projects: LogicalProject(b, c, a) -> LogicalProject(a, b, c) + Pattern projectPattern = Pattern.compile("LogicalProject\\(([^)]*)\\)"); + Matcher projectMatcher = projectPattern.matcher(maskedPlan); + StringBuilder result = new StringBuilder(); + while (projectMatcher.find()) { + String fields = projectMatcher.group(1); + String[] fieldArr = fields.split(","); + for (int i = 0; i < fieldArr.length; i++) { + fieldArr[i] = fieldArr[i].trim(); + } + java.util.Arrays.sort(fieldArr, String.CASE_INSENSITIVE_ORDER); + String sortedFields = String.join(", ", fieldArr); + projectMatcher.appendReplacement(result, "LogicalProject(" + sortedFields + ")"); + } + projectMatcher.appendTail(result); + return result.toString(); + } + + @Test + public void testWeekArgumentCoercion() throws IOException { + String expected = loadExpectedPlan("explain_week_argument_coercion.json"); + // Week accepts WEEK(timestamp/date/time, [optional int]), it should cast the string + // argument to timestamp with Calcite. In v2, it accepts string, so there is no cast. + assertJsonEqualsIgnoreId( + expected, + explainQueryToString( + String.format( + Locale.ROOT, + "source=%s | eval w = week('2024-12-10') | fields w", + TEST_INDEX_ACCOUNT))); + } + @Test public void testFilterAndAggPushDownExplain() throws IOException { String expected = loadExpectedPlan("explain_filter_agg_push.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json index 0eecbedb31d..416ab90eb79 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_filter_compare_ip.json @@ -1,6 +1,6 @@ { "calcite": { - "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(host=[$0])\n LogicalFilter(condition=[GREATER_IP($0, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", - "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]], PushDownContext=[[PROJECT->[host], SCRIPT->GREATER_IP($0, IP('1.1.1.1':VARCHAR)), LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"script\":{\"script\":{\"source\":\"{\\\"langType\\\":\\\"calcite\\\",\\\"script\\\":\\\"rO0ABXNyABFqYXZhLnV0aWwuQ29sbFNlcleOq7Y6G6gRAwABSQADdGFneHAAAAADdwQAAAAGdAAHcm93VHlwZXQAensKICAiZmllbGRzIjogWwogICAgewogICAgICAidHlwZSI6ICJPVEhFUiIsCiAgICAgICJudWxsYWJsZSI6IHRydWUsCiAgICAgICJuYW1lIjogImhvc3QiCiAgICB9CiAgXSwKICAibnVsbGFibGUiOiBmYWxzZQp9dAAEZXhwcnQDfXsKICAib3AiOiB7CiAgICAibmFtZSI6ICJHUkVBVEVSX0lQIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJpbnB1dCI6IDAsCiAgICAgICJuYW1lIjogIiQwIgogICAgfSwKICAgIHsKICAgICAgIm9wIjogewogICAgICAgICJuYW1lIjogIklQIiwKICAgICAgICAia2luZCI6ICJPVEhFUl9GVU5DVElPTiIsCiAgICAgICAgInN5bnRheCI6ICJGVU5DVElPTiIKICAgICAgfSwKICAgICAgIm9wZXJhbmRzIjogWwogICAgICAgIHsKICAgICAgICAgICJsaXRlcmFsIjogIjEuMS4xLjEiLAogICAgICAgICAgInR5cGUiOiB7CiAgICAgICAgICAgICJ0eXBlIjogIlZBUkNIQVIiLAogICAgICAgICAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICAgICAgICAgInByZWNpc2lvbiI6IC0xCiAgICAgICAgICB9CiAgICAgICAgfQogICAgICBdLAogICAgICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIk9USEVSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0sCiAgICAgICJkZXRlcm1pbmlzdGljIjogdHJ1ZSwKICAgICAgImR5bmFtaWMiOiBmYWxzZQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQk9PTEVBTiIsCiAgICAibnVsbGFibGUiOiB0cnVlCiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9dAAKZmllbGRUeXBlc3NyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAABdAAEaG9zdH5yAClvcmcub3BlbnNlYXJjaC5zcWwuZGF0YS50eXBlLkV4cHJDb3JlVHlwZQAAAAAAAAAAEgAAeHIADmphdmEubGFuZy5FbnVtAAAAAAAAAAASAAB4cHQAAklQeHg=\\\"}\",\"lang\":\"opensearch_compounded_script\",\"params\":{\"utcTimestamp\":*}},\"boost\":1.0}},\"_source\":{\"includes\":[\"host\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n" + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(host=[$3])\n LogicalFilter(condition=[GREATER_IP($3, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", + "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]], PushDownContext=[[PROJECT->[host], SCRIPT->GREATER_IP($0, IP('1.1.1.1':VARCHAR)), LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"script\":{\"script\":{\"source\":\"{\\\"langType\\\":\\\"calcite\\\",\\\"script\\\":\\\"rO0ABXNyABFqYXZhLnV0aWwuQ29sbFNlcleOq7Y6G6gRAwABSQADdGFneHAAAAADdwQAAAAGdAAHcm93VHlwZXQAensKICAiZmllbGRzIjogWwogICAgewogICAgICAidHlwZSI6ICJPVEhFUiIsCiAgICAgICJudWxsYWJsZSI6IHRydWUsCiAgICAgICJuYW1lIjogImhvc3QiCiAgICB9CiAgXSwKICAibnVsbGFibGUiOiBmYWxzZQp9dAAEZXhwcnQDfXsKICAib3AiOiB7CiAgICAibmFtZSI6ICJHUkVBVEVSX0lQIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJpbnB1dCI6IDAsCiAgICAgICJuYW1lIjogIiQwIgogICAgfSwKICAgIHsKICAgICAgIm9wIjogewogICAgICAgICJuYW1lIjogIklQIiwKICAgICAgICAia2luZCI6ICJPVEhFUl9GVU5DVElPTiIsCiAgICAgICAgInN5bnRheCI6ICJGVU5DVElPTiIKICAgICAgfSwKICAgICAgIm9wZXJhbmRzIjogWwogICAgICAgIHsKICAgICAgICAgICJsaXRlcmFsIjogIjEuMS4xLjEiLAogICAgICAgICAgInR5cGUiOiB7CiAgICAgICAgICAgICJ0eXBlIjogIlZBUkNIQVIiLAogICAgICAgICAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICAgICAgICAgInByZWNpc2lvbiI6IC0xCiAgICAgICAgICB9CiAgICAgICAgfQogICAgICBdLAogICAgICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIk9USEVSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0sCiAgICAgICJkZXRlcm1pbmlzdGljIjogdHJ1ZSwKICAgICAgImR5bmFtaWMiOiBmYWxzZQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQk9PTEVBTiIsCiAgICAibnVsbGFibGUiOiB0cnVlCiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9dAAKZmllbGRUeXBlc3NyABFqYXZhLnV0aWwuSGFzaE1hcAUH2sHDFmDRAwACRgAKbG9hZEZhY3RvckkACXRocmVzaG9sZHhwP0AAAAAAAAx3CAAAABAAAAABdAAEaG9zdH5yAClvcmcub3BlbnNlYXJjaC5zcWwuZGF0YS50eXBlLkV4cHJDb3JlVHlwZQAAAAAAAAAAEgAAeHIADmphdmEubGFuZy5FbnVtAAAAAAAAAAASAAB4cHQAAklQeHg=\\\"}\",\"lang\":\"opensearch_compounded_script\",\"params\":{\"utcTimestamp\":1754322819643187000}},\"boost\":1.0}},\"_source\":{\"includes\":[\"host\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n" } } diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json index 9d963dd5747..cf026690ac7 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_filter_compare_ip.json @@ -1,6 +1,6 @@ { "calcite": { - "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(host=[$0])\n LogicalFilter(condition=[GREATER_IP($0, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", - "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..11=[{inputs}], expr#12=['1.1.1.1':VARCHAR], expr#13=[IP($t12)], expr#14=[GREATER_IP($t0, $t13)], host=[$t0], $condition=[$t14])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n" + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(host=[$3])\n LogicalFilter(condition=[GREATER_IP($3, IP('1.1.1.1':VARCHAR))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableCalc(expr#0..11=[{inputs}], expr#12=['1.1.1.1':VARCHAR], expr#13=[IP($t12)], expr#14=[GREATER_IP($t3, $t13)], host=[$t3], $condition=[$t14])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_weblogs]])\n" } } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json new file mode 100644 index 00000000000..b4bc892285d --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_compare_ip.json @@ -0,0 +1,17 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[host]" + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_weblogs, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"host\":{\"from\":\"1.1.1.1\",\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"host\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + } + ] + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json new file mode 100644 index 00000000000..bc88f5eedfa --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_week_argument_coercion.json @@ -0,0 +1,27 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[w]" + }, + "children": [ + { + "name": "OpenSearchEvalOperator", + "description": { + "expressions": { + "w": "week(\"2024-12-10\")" + } + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\"}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=1m, searchAfter=null, searchResponse=null)" + }, + "children": [] + } + ] + } + ] + } +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java index 8ff8c85a88d..245b5b84693 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java @@ -19,6 +19,7 @@ import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.opensearch.geospatial.action.IpEnrichmentActionClient; import org.opensearch.sql.common.utils.StringUtils; @@ -59,8 +60,10 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.STRING_STRING.or(OperandTypes.STRING_STRING_STRING)); + (CompositeOperandTypeChecker) + OperandTypes.CHARACTER_CHARACTER.or( + OperandTypes.family( + SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); } public static class GeoIPImplementor implements NotNullImplementor { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index b52de9448ed..977805f9d1a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -35,11 +35,7 @@ public void testTimeDiffWithUdtInputType() { getRelNode(timePpl); Throwable t = Assert.assertThrows(Exception.class, () -> getRelNode(wrongPpl)); verifyErrorMessageContains( - t, - "TIMEDIFF function expects" - + " {[STRING,STRING],[DATE,DATE],[DATE,TIME],[DATE,TIMESTAMP],[TIME,DATE],[TIME,TIME],[TIME,TIMESTAMP]," - + "[TIMESTAMP,DATE],[TIMESTAMP,TIME],[TIMESTAMP,TIMESTAMP],[DATE,STRING],[TIME,STRING],[TIMESTAMP,STRING],[STRING,DATE],[STRING,TIME],[STRING,TIMESTAMP]}," - + " but got [INTEGER,STRING]"); + t, "TIMEDIFF function expects {[TIME,TIME]}, but got [INTEGER,STRING]"); } @Test @@ -51,8 +47,8 @@ public void testComparisonWithDifferentType() { verifyErrorMessageContains( t, // Temporary fix for the error message as LESS function has two variants. Will remove - // [IP,IP],[IP,STRING],[STRING,IP] when merging the two variants. - "LESS function expects {[IP,IP],[IP,STRING],[STRING,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + // [IP,IP] when merging the two variants. + "LESS function expects {[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + " but got [STRING,INTEGER]"); } @@ -107,10 +103,10 @@ public void testTimestampWithWrongArg() { Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); verifyErrorMessageContains( t, - "TIMESTAMP function expects" - + " {[STRING],[DATE],[TIME],[TIMESTAMP],[STRING,STRING],[DATE,DATE],[DATE,TIME],[DATE,TIMESTAMP]," - + "[TIME,DATE],[TIME,TIME],[TIME,TIMESTAMP],[TIMESTAMP,DATE],[TIMESTAMP,TIME],[TIMESTAMP,TIMESTAMP]" - + ",[STRING,DATE],[STRING,TIME],[STRING,TIMESTAMP],[DATE,STRING],[TIME,STRING],[TIMESTAMP,STRING]}," + "TIMESTAMP function expects {" + + "[STRING],[TIMESTAMP],[DATE],[TIME],[STRING,STRING],[TIMESTAMP,TIMESTAMP],[TIMESTAMP,DATE]," + + "[TIMESTAMP,TIME],[DATE,TIMESTAMP],[DATE,DATE],[DATE,TIME],[TIME,TIMESTAMP],[TIME,DATE]," + + "[TIME,TIME],[STRING,TIMESTAMP],[STRING,DATE],[STRING,TIME],[TIMESTAMP,STRING],[DATE,STRING],[TIME,STRING]}," + " but got [STRING,INTEGER]"); }