From 5151350cbd10517d7ef371493268659e2a69af00 Mon Sep 17 00:00:00 2001 From: Xinyu Hao Date: Fri, 4 Jul 2025 16:18:50 +0800 Subject: [PATCH] add type checker for cidr Signed-off-by: Xinyu Hao --- .../utils/UserDefinedFunctionUtils.java | 7 +- .../expression/function/PPLFuncImpTable.java | 17 +++- .../expression/function/PPLTypeChecker.java | 65 ++++++++---- .../function/UDFOperandMetadata.java | 99 ++++++++++++++----- .../function/udf/ip/CidrMatchFunction.java | 11 +-- .../function/udf/ip/CompareIpFunction.java | 9 +- .../calcite/CalcitePPLFunctionTypeTest.java | 3 +- 7 files changed, 145 insertions(+), 66 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index f220c2265b2..163ac108392 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -115,10 +115,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) { case EXPR_DATE -> SqlTypeName.DATE; case EXPR_TIME -> SqlTypeName.TIME; case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP; - // EXPR_IP is mapped to SqlTypeName.NULL since there is no - // corresponding SqlTypeName in Calcite. This is a workaround to allow - // type checking for IP types in UDFs. - case EXPR_IP -> SqlTypeName.NULL; + // EXPR_IP is mapped to SqlTypeName.OTHER since there is no + // corresponding SqlTypeName in Calcite. + case EXPR_IP -> SqlTypeName.OTHER; case EXPR_BINARY -> SqlTypeName.VARBINARY; default -> type.getSqlTypeName(); }; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 1713699022f..994c2a3a8f9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -450,7 +450,10 @@ functionName, getActualSignature(argTypes), e.getMessage()), } StringJoiner allowedSignatures = new StringJoiner(","); for (var implement : implementList) { - allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures()); + String signature = implement.getKey().typeChecker().getAllowedSignatures(); + if (!signature.isEmpty()) { + allowedSignatures.add(signature); + } } throw new ExpressionEvaluationException( String.format( @@ -505,6 +508,18 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc. // SameOperandTypeCheckers like COALESCE, IFNULL, etc. register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker)); + } else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) { + register( + functionName, + createFunctionImpWithTypeChecker( + (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2), + new PPLTypeChecker.PPLIPCompareTypeChecker())); + } else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) { + register( + functionName, + createFunctionImpWithTypeChecker( + (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2), + new PPLTypeChecker.PPLCidrTypeChecker())); } else { logger.info( "Cannot create type checker for function: {}. Will skip its type checking", diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java index 17e9e290add..b3de4de4f51 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java @@ -23,7 +23,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; -import org.opensearch.sql.calcite.type.AbstractExprRelDataType; +import org.opensearch.sql.calcite.type.ExprIPType; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -215,10 +215,6 @@ public boolean checkOperandTypes(List types) { RelDataType type_l = types.get(i); RelDataType type_r = types.get(i + 1); if (!SqlTypeUtil.isComparable(type_l, type_r)) { - if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) { - // Allow IP and string comparison - continue; - } return false; } // Disallow coercing between strings and numeric, boolean @@ -239,14 +235,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) { }; } - private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) { - if (typeIp instanceof AbstractExprRelDataType exprRelDataType) { - return exprRelDataType.getExprType() == ExprCoreType.IP - && typeString.getFamily() == SqlTypeFamily.CHARACTER; - } - return false; - } - @Override public String getAllowedSignatures() { int min = innerTypeChecker.getOperandCountRange().getMin(); @@ -269,6 +257,53 @@ public String getAllowedSignatures() { } } + class PPLIPCompareTypeChecker implements PPLTypeChecker { + @Override + public boolean checkOperandTypes(List types) { + if (types.size() != 2) { + return false; + } + RelDataType type1 = types.get(0); + RelDataType type2 = types.get(1); + return areIpAndStringTypes(type1, type2) + || areIpAndStringTypes(type2, type1) + || (type1 instanceof ExprIPType && type2 instanceof ExprIPType); + } + + @Override + public String getAllowedSignatures() { + // Will be merged with the allowed signatures of comparable type checker, + // shown as [COMPARABLE_TYPE,COMPARABLE_TYPE] + return ""; + } + + private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) { + return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER; + } + } + + class PPLCidrTypeChecker implements PPLTypeChecker { + @Override + public boolean checkOperandTypes(List types) { + if (types.size() != 2) { + return false; + } + RelDataType type1 = types.get(0); + RelDataType type2 = types.get(1); + + // accept (STRING, STRING) or (IP, STRING) + if (type2.getFamily() != SqlTypeFamily.CHARACTER) { + return false; + } + return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER; + } + + @Override + public String getAllowedSignatures() { + return "[STRING,STRING],[IP,STRING]"; + } + } + /** * Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand * belongs to its corresponding {@link SqlTypeFamily}. @@ -400,10 +435,6 @@ private static List getExprTypes(SqlTypeFamily family) { OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); case ANY, IGNORE -> List.of( OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY)); - // We borrow SqlTypeFamily.NULL to represent EXPR_IP. This is a workaround - // since there is no corresponding IP type family in Calcite. - case NULL -> List.of( - OpenSearchTypeFactory.TYPE_FACTORY.createUDT(OpenSearchTypeFactory.ExprUDT.EXPR_IP)); default -> { RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY); if (type == null) { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java index fcd7a6a2be5..a7d12cbbbaf 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java @@ -14,10 +14,8 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.FamilyOperandTypeChecker; -import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; /** @@ -25,7 +23,7 @@ * creating UDFs, so that a type checker can be passed to the constructor of {@link * SqlUserDefinedFunction} as a {@link SqlOperandMetadata}. */ -public interface UDFOperandMetadata extends SqlOperandMetadata, ImplicitCastOperandTypeChecker { +public interface UDFOperandMetadata extends SqlOperandMetadata { SqlOperandTypeChecker getInnerTypeChecker(); static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) { @@ -35,17 +33,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() { return typeChecker; } - @Override - public boolean checkOperandTypesWithoutTypeCoercion( - SqlCallBinding callBinding, boolean throwOnFailure) { - return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure); - } - - @Override - public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { - return typeChecker.getOperandSqlTypeFamily(iFormalOperand); - } - @Override public List paramTypes(RelDataTypeFactory typeFactory) { // This function is not used in the current context, so we return an empty list. @@ -89,18 +76,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() { return typeChecker; } - @Override - public boolean checkOperandTypesWithoutTypeCoercion( - SqlCallBinding callBinding, boolean throwOnFailure) { - return typeChecker.checkOperandTypes(callBinding, throwOnFailure); - } - - @Override - public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { - throw new IllegalStateException( - "getOperandSqlTypeFamily is not supported for CompositeOperandTypeChecker"); - } - @Override public List paramTypes(RelDataTypeFactory typeFactory) { // This function is not used in the current context, so we return an empty list. @@ -129,4 +104,76 @@ public String getAllowedSignatures(SqlOperator op, String opName) { } }; } + + /** + * A named class that serves as an identifier for IP comparator's operand metadata. It does not + * implement any actual type checking logic. + */ + class IPOperandMetadata implements UDFOperandMetadata { + @Override + public SqlOperandTypeChecker getInnerTypeChecker() { + return this; + } + + @Override + public List paramTypes(RelDataTypeFactory typeFactory) { + return List.of(); + } + + @Override + public List paramNames() { + return List.of(); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + return false; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return null; + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return ""; + } + } + + /** + * A named class that serves as an identifier for cidr's operand metadata. It does not implement + * any actual type checking logic. + */ + class CidrOperandMetadata implements UDFOperandMetadata { + @Override + public SqlOperandTypeChecker getInnerTypeChecker() { + return this; + } + + @Override + public List paramTypes(RelDataTypeFactory typeFactory) { + return List.of(); + } + + @Override + public List paramNames() { + return List.of(); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + return false; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return null; + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return ""; + } + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java index 4c8e532cbb0..bba375079a7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java @@ -12,11 +12,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.data.model.ExprIpValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -46,12 +43,10 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - // EXPR_IP is mapped to SqlTypeFamily.NULL in + // EXPR_IP is mapped to SqlTypeFamily.OTHER in // UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING) - .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING))); + // We use a specific type checker to serve + return new UDFOperandMetadata.CidrOperandMetadata(); } public static class CidrMatchImplementor implements NotNullImplementor { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java index 12a6a42516d..9704d0dbd08 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java @@ -13,11 +13,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.rex.RexCall; -import org.apache.calcite.sql.type.CompositeOperandTypeChecker; -import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.opensearch.sql.data.model.ExprIpValue; import org.opensearch.sql.expression.function.ImplementorUDF; import org.opensearch.sql.expression.function.UDFOperandMetadata; @@ -71,11 +68,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NULL) - .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.NULL))); + return new UDFOperandMetadata.IPOperandMetadata(); } public static class CompareImplementor implements NotNullImplementor { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index 20447be761a..ef1237fa2ec 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -50,8 +50,7 @@ public void testComparisonWithDifferentType() { Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); verifyErrorMessageContains( t, - "LESS function expects {[STRING,IP],[IP,STRING],[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," - + " but got [STRING,INTEGER]"); + "LESS function expects {[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + " but got [STRING,INTEGER]"); } @Test