diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index f220c2265b2..ca54a0e0114 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -115,10 +115,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) { case EXPR_DATE -> SqlTypeName.DATE; case EXPR_TIME -> SqlTypeName.TIME; case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP; - // EXPR_IP is mapped to SqlTypeName.NULL since there is no - // corresponding SqlTypeName in Calcite. This is a workaround to allow - // type checking for IP types in UDFs. - case EXPR_IP -> SqlTypeName.NULL; + // EXPR_IP is mapped to SqlTypeName.OTHER since there is no + // corresponding SqlTypeName in Calcite. + case EXPR_IP -> SqlTypeName.OTHER; case EXPR_BINARY -> SqlTypeName.VARBINARY; default -> type.getSqlTypeName(); }; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 981f3be585f..356054b2a12 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -450,7 +450,10 @@ functionName, getActualSignature(argTypes), e.getMessage()), } StringJoiner allowedSignatures = new StringJoiner(","); for (var implement : implementList) { - allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures()); + String signature = implement.getKey().typeChecker().getAllowedSignatures(); + if (!signature.isEmpty()) { + allowedSignatures.add(signature); + } } throw new ExpressionEvaluationException( String.format( @@ -500,6 +503,12 @@ void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) { // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc. // SameOperandTypeCheckers like COALESCE, IFNULL, etc. register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker)); + } else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) { + register( + functionName, + createFunctionImpWithTypeChecker( + (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2), + new PPLTypeChecker.PPLIPCompareTypeChecker())); } else { logger.info( "Cannot create type checker for function: {}. Will skip its type checking", diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java index 17e9e290add..f3b2fba2d99 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java @@ -23,7 +23,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; -import org.opensearch.sql.calcite.type.AbstractExprRelDataType; +import org.opensearch.sql.calcite.type.ExprIPType; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -215,10 +215,6 @@ public boolean checkOperandTypes(List types) { RelDataType type_l = types.get(i); RelDataType type_r = types.get(i + 1); if (!SqlTypeUtil.isComparable(type_l, type_r)) { - if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) { - // Allow IP and string comparison - continue; - } return false; } // Disallow coercing between strings and numeric, boolean @@ -239,14 +235,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) { }; } - private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) { - if (typeIp instanceof AbstractExprRelDataType exprRelDataType) { - return exprRelDataType.getExprType() == ExprCoreType.IP - && typeString.getFamily() == SqlTypeFamily.CHARACTER; - } - return false; - } - @Override public String getAllowedSignatures() { int min = innerTypeChecker.getOperandCountRange().getMin(); @@ -269,6 +257,31 @@ public String getAllowedSignatures() { } } + class PPLIPCompareTypeChecker implements PPLTypeChecker { + @Override + public boolean checkOperandTypes(List types) { + if (types.size() != 2) { + return false; + } + RelDataType type1 = types.get(0); + RelDataType type2 = types.get(1); + return areIpAndStringTypes(type1, type2) + || areIpAndStringTypes(type2, type1) + || (type1 instanceof ExprIPType && type2 instanceof ExprIPType); + } + + @Override + public String getAllowedSignatures() { + // Will be merged with the allowed signatures of comparable type checker, + // shown as [COMPARABLE_TYPE,COMPARABLE_TYPE] + return ""; + } + + private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) { + return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER; + } + } + /** * Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand * belongs to its corresponding {@link SqlTypeFamily}. diff --git a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java index fcd7a6a2be5..76d65cfdc94 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java @@ -14,10 +14,8 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.FamilyOperandTypeChecker; -import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; /** @@ -25,7 +23,7 @@ * creating UDFs, so that a type checker can be passed to the constructor of {@link * SqlUserDefinedFunction} as a {@link SqlOperandMetadata}. */ -public interface UDFOperandMetadata extends SqlOperandMetadata, ImplicitCastOperandTypeChecker { +public interface UDFOperandMetadata extends SqlOperandMetadata { SqlOperandTypeChecker getInnerTypeChecker(); static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) { @@ -35,17 +33,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() { return typeChecker; } - @Override - public boolean checkOperandTypesWithoutTypeCoercion( - SqlCallBinding callBinding, boolean throwOnFailure) { - return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure); - } - - @Override - public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { - return typeChecker.getOperandSqlTypeFamily(iFormalOperand); - } - @Override public List paramTypes(RelDataTypeFactory typeFactory) { // This function is not used in the current context, so we return an empty list. @@ -89,18 +76,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() { return typeChecker; } - @Override - public boolean checkOperandTypesWithoutTypeCoercion( - SqlCallBinding callBinding, boolean throwOnFailure) { - return typeChecker.checkOperandTypes(callBinding, throwOnFailure); - } - - @Override - public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { - throw new IllegalStateException( - "getOperandSqlTypeFamily is not supported for CompositeOperandTypeChecker"); - } - @Override public List paramTypes(RelDataTypeFactory typeFactory) { // This function is not used in the current context, so we return an empty list. @@ -129,4 +104,40 @@ public String getAllowedSignatures(SqlOperator op, String opName) { } }; } + + /** + * A named class that serves as an identifier for IP comparator's operand metadata. It does not + * implement any actual type checking logic. + */ + class IPOperandMetadata implements UDFOperandMetadata { + @Override + public SqlOperandTypeChecker getInnerTypeChecker() { + return this; + } + + @Override + public List paramTypes(RelDataTypeFactory typeFactory) { + return List.of(); + } + + @Override + public List paramNames() { + return List.of(); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + return false; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return null; + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return ""; + } + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/LessIpFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/LessIpFunction.java index e87c44b5260..7b0de6539ad 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/LessIpFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/udf/ip/LessIpFunction.java @@ -44,11 +44,7 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap( - (CompositeOperandTypeChecker) - OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NULL) - .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING)) - .or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.NULL))); + return new UDFOperandMetadata.IPOperandMetadata(); } public static class LessImplementor implements NotNullImplementor { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index 20447be761a..bb6a064c230 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -50,7 +50,7 @@ public void testComparisonWithDifferentType() { Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl)); verifyErrorMessageContains( t, - "LESS function expects {[STRING,IP],[IP,STRING],[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + "LESS function expects {[COMPARABLE_TYPE,COMPARABLE_TYPE]}," + " but got [STRING,INTEGER]"); }