diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index 9f691479f5d..aac002f1943 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -121,7 +121,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) { case EXPR_DATE -> SqlTypeName.DATE; case EXPR_TIME -> SqlTypeName.TIME; case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP; - case EXPR_IP -> SqlTypeName.VARCHAR; + // EXPR_IP is mapped to SqlTypeName.OTHER since there is no + // corresponding SqlTypeName in Calcite. + case EXPR_IP -> SqlTypeName.OTHER; case EXPR_BINARY -> SqlTypeName.VARBINARY; default -> type.getSqlTypeName(); }; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 3da366da827..6c2d0af8eba 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -72,6 +72,7 @@ import org.opensearch.sql.expression.function.udf.datetime.WeekdayFunction; import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction; import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction; +import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction; import org.opensearch.sql.expression.function.udf.math.CRC32Function; import org.opensearch.sql.expression.function.udf.math.ConvFunction; import org.opensearch.sql.expression.function.udf.math.DivideFunction; @@ -103,6 +104,15 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2"); public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH"); + // IP comparing functions + public static final SqlOperator NOT_EQUALS_IP = + CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP"); + public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP"); + public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP"); + public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP"); + public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP"); + public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP"); + // Condition function public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST"); public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 363baff3e4c..dc4f195d6be 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -458,7 +458,10 @@ functionName, getActualSignature(argTypes), e.getMessage()), } StringJoiner allowedSignatures = new StringJoiner(","); for (var implement : implementList) { - allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures()); + String signature = implement.getKey().typeChecker().getAllowedSignatures(); + if (!signature.isEmpty()) { + allowedSignatures.add(signature); + } } throw new ExpressionEvaluationException( String.format( @@ -481,40 +484,70 @@ private abstract static class AbstractBuilder { /** Maps an operator to an implementation. */ abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp); - void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) { - SqlOperandTypeChecker typeChecker; - if (operator instanceof SqlUserDefinedFunction udfOperator) { - typeChecker = extractTypeCheckerFromUDF(udfOperator); - } else { - typeChecker = operator.getOperandTypeChecker(); - } + /** + * Register one or multiple operators under a single function name. This allows function + * overloading based on operand types. + * + *
When a function is called, the system will try each registered operator in sequence,
+ * checking if the provided arguments match the operator's type requirements. The first operator
+ * whose type checker accepts the arguments will be used to execute the function.
+ *
+ * @param functionName the built-in function name under which to register the operators
+ * @param operators the operators to associate with this function name, tried in sequence until
+ * one matches the argument types during resolution
+ */
+ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
+ for (SqlOperator operator : operators) {
+ SqlOperandTypeChecker typeChecker;
+ if (operator instanceof SqlUserDefinedFunction udfOperator) {
+ typeChecker = extractTypeCheckerFromUDF(udfOperator);
+ } else {
+ typeChecker = operator.getOperandTypeChecker();
+ }
- // Only the composite operand type checker for UDFs are concerned here.
- if (operator instanceof SqlUserDefinedFunction
- && typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
- // UDFs implement their own composite type checkers, which always use OR logic for argument
- // types. Verifying the composition type would require accessing a protected field in
- // CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
- // be skipped, so we avoid checking the composition type here.
- register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
- } else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
- register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
- } else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
- // If compositeTypeChecker contains operand checkers other than family type checkers or
- // other than OR compositions, the function with be registered with a null type checker,
- // which means the function will not be type checked.
- register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
- } else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
- // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
- // SameOperandTypeCheckers like COALESCE, IFNULL, etc.
- register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
- } else {
- logger.info(
- "Cannot create type checker for function: {}. Will skip its type checking",
- functionName);
- register(
- functionName,
- (RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
+ // Only the composite operand type checker for UDFs are concerned here.
+ if (operator instanceof SqlUserDefinedFunction
+ && typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
+ // UDFs implement their own composite type checkers, which always use OR logic for
+ // argument
+ // types. Verifying the composition type would require accessing a protected field in
+ // CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
+ // be skipped, so we avoid checking the composition type here.
+ register(
+ functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
+ } else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
+ register(
+ functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
+ } else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
+ // If compositeTypeChecker contains operand checkers other than family type checkers or
+ // other than OR compositions, the function with be registered with a null type checker,
+ // which means the function will not be type checked.
+ register(
+ functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
+ } else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
+ // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
+ // SameOperandTypeCheckers like COALESCE, IFNULL, etc.
+ register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
+ } else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
+ register(
+ functionName,
+ createFunctionImpWithTypeChecker(
+ (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
+ new PPLTypeChecker.PPLIPCompareTypeChecker()));
+ } else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
+ register(
+ functionName,
+ createFunctionImpWithTypeChecker(
+ (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
+ new PPLTypeChecker.PPLCidrTypeChecker()));
+ } else {
+ logger.info(
+ "Cannot create type checker for function: {}. Will skip its type checking",
+ functionName);
+ register(
+ functionName,
+ (RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
+ }
}
}
@@ -622,16 +655,18 @@ public PPLTypeChecker getTypeChecker() {
}
void populate() {
+ // register operators for comparison
+ registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
+ registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP, SqlStdOperatorTable.EQUALS);
+ registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP, SqlStdOperatorTable.GREATER_THAN);
+ registerOperator(GTE, PPLBuiltinOperators.GTE_IP, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
+ registerOperator(LESS, PPLBuiltinOperators.LESS_IP, SqlStdOperatorTable.LESS_THAN);
+ registerOperator(LTE, PPLBuiltinOperators.LTE_IP, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
+
// Register std operator
registerOperator(AND, SqlStdOperatorTable.AND);
registerOperator(OR, SqlStdOperatorTable.OR);
registerOperator(NOT, SqlStdOperatorTable.NOT);
- registerOperator(NOTEQUAL, SqlStdOperatorTable.NOT_EQUALS);
- registerOperator(EQUAL, SqlStdOperatorTable.EQUALS);
- registerOperator(GREATER, SqlStdOperatorTable.GREATER_THAN);
- registerOperator(GTE, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
- registerOperator(LESS, SqlStdOperatorTable.LESS_THAN);
- registerOperator(LTE, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
registerOperator(ADD, SqlStdOperatorTable.PLUS);
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);
diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java
index d3443ab850e..b3de4de4f51 100644
--- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java
+++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java
@@ -23,7 +23,7 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
-import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
+import org.opensearch.sql.calcite.type.ExprIPType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.data.type.ExprCoreType;
@@ -215,10 +215,6 @@ public boolean checkOperandTypes(List Signature:
+ *
+ *
+ *
+ */
+public class CompareIpFunction extends ImplementorUDF {
+
+ private CompareIpFunction(ComparisonType comparisonType) {
+ super(new CompareImplementor(comparisonType), NullPolicy.ANY);
+ }
+
+ public static CompareIpFunction less() {
+ return new CompareIpFunction(ComparisonType.LESS);
+ }
+
+ public static CompareIpFunction greater() {
+ return new CompareIpFunction(ComparisonType.GREATER);
+ }
+
+ public static CompareIpFunction lessOrEquals() {
+ return new CompareIpFunction(ComparisonType.LESS_OR_EQUAL);
+ }
+
+ public static CompareIpFunction greaterOrEquals() {
+ return new CompareIpFunction(ComparisonType.GREATER_OR_EQUAL);
+ }
+
+ public static CompareIpFunction equals() {
+ return new CompareIpFunction(ComparisonType.EQUALS);
+ }
+
+ public static CompareIpFunction notEquals() {
+ return new CompareIpFunction(ComparisonType.NOT_EQUALS);
+ }
+
+ @Override
+ public SqlReturnTypeInference getReturnTypeInference() {
+ return ReturnTypes.BOOLEAN_FORCE_NULLABLE;
+ }
+
+ @Override
+ public UDFOperandMetadata getOperandMetadata() {
+ return new UDFOperandMetadata.IPOperandMetadata();
+ }
+
+ public static class CompareImplementor implements NotNullImplementor {
+ private final ComparisonType comparisonType;
+
+ public CompareImplementor(ComparisonType comparisonType) {
+ this.comparisonType = comparisonType;
+ }
+
+ @Override
+ public Expression implement(
+ RexToLixTranslator translator, RexCall call, List