Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) {
case EXPR_DATE -> SqlTypeName.DATE;
case EXPR_TIME -> SqlTypeName.TIME;
case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP;
case EXPR_IP -> SqlTypeName.VARCHAR;
// EXPR_IP is mapped to SqlTypeName.OTHER since there is no
// corresponding SqlTypeName in Calcite.
case EXPR_IP -> SqlTypeName.OTHER;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the PR description to align changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding me.

case EXPR_BINARY -> SqlTypeName.VARBINARY;
default -> type.getSqlTypeName();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.sql.expression.function.udf.datetime.WeekdayFunction;
import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction;
import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction;
import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction;
import org.opensearch.sql.expression.function.udf.math.CRC32Function;
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
Expand Down Expand Up @@ -103,6 +104,15 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");

// IP comparing functions
public static final SqlOperator NOT_EQUALS_IP =
CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP");
public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP");
public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP");
public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP");
public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP");
public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP");

// Condition function
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");
public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ functionName, getActualSignature(argTypes), e.getMessage()),
}
StringJoiner allowedSignatures = new StringJoiner(",");
for (var implement : implementList) {
allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures());
String signature = implement.getKey().typeChecker().getAllowedSignatures();
if (!signature.isEmpty()) {
allowedSignatures.add(signature);
}
}
throw new ExpressionEvaluationException(
String.format(
Expand All @@ -481,40 +484,70 @@ private abstract static class AbstractBuilder {
/** Maps an operator to an implementation. */
abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp);

void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
SqlOperandTypeChecker typeChecker;
if (operator instanceof SqlUserDefinedFunction udfOperator) {
typeChecker = extractTypeCheckerFromUDF(udfOperator);
} else {
typeChecker = operator.getOperandTypeChecker();
}
/**
* Register one or multiple operators under a single function name. This allows function
* overloading based on operand types.
*
* <p>When a function is called, the system will try each registered operator in sequence,
* checking if the provided arguments match the operator's type requirements. The first operator
* whose type checker accepts the arguments will be used to execute the function.
*
* @param functionName the built-in function name under which to register the operators
* @param operators the operators to associate with this function name, tried in sequence until
* one matches the argument types during resolution
*/
public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a javadoc for this method to explain what are the operators for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

for (SqlOperator operator : operators) {
SqlOperandTypeChecker typeChecker;
if (operator instanceof SqlUserDefinedFunction udfOperator) {
typeChecker = extractTypeCheckerFromUDF(udfOperator);
} else {
typeChecker = operator.getOperandTypeChecker();
}

// Only the composite operand type checker for UDFs are concerned here.
if (operator instanceof SqlUserDefinedFunction
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
// UDFs implement their own composite type checkers, which always use OR logic for argument
// types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking",
functionName);
register(
functionName,
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
// Only the composite operand type checker for UDFs are concerned here.
if (operator instanceof SqlUserDefinedFunction
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
// UDFs implement their own composite type checkers, which always use OR logic for
// argument
// types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.
register(
functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
register(
functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
register(
functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLIPCompareTypeChecker()));
} else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLCidrTypeChecker()));
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking",
functionName);
register(
functionName,
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
}
}
}

Expand Down Expand Up @@ -622,16 +655,18 @@ public PPLTypeChecker getTypeChecker() {
}

void populate() {
// register operators for comparison
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP, SqlStdOperatorTable.EQUALS);
registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP, SqlStdOperatorTable.GREATER_THAN);
registerOperator(GTE, PPLBuiltinOperators.GTE_IP, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
registerOperator(LESS, PPLBuiltinOperators.LESS_IP, SqlStdOperatorTable.LESS_THAN);
registerOperator(LTE, PPLBuiltinOperators.LTE_IP, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);

// Register std operator
registerOperator(AND, SqlStdOperatorTable.AND);
registerOperator(OR, SqlStdOperatorTable.OR);
registerOperator(NOT, SqlStdOperatorTable.NOT);
registerOperator(NOTEQUAL, SqlStdOperatorTable.NOT_EQUALS);
registerOperator(EQUAL, SqlStdOperatorTable.EQUALS);
registerOperator(GREATER, SqlStdOperatorTable.GREATER_THAN);
registerOperator(GTE, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
registerOperator(LESS, SqlStdOperatorTable.LESS_THAN);
registerOperator(LTE, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
registerOperator(ADD, SqlStdOperatorTable.PLUS);
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.type.ExprIPType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand Down Expand Up @@ -215,10 +215,6 @@ public boolean checkOperandTypes(List<RelDataType> types) {
RelDataType type_l = types.get(i);
RelDataType type_r = types.get(i + 1);
if (!SqlTypeUtil.isComparable(type_l, type_r)) {
if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) {
// Allow IP and string comparison
continue;
}
return false;
}
// Disallow coercing between strings and numeric, boolean
Expand All @@ -239,14 +235,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) {
};
}

private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
if (typeIp instanceof AbstractExprRelDataType<?> exprRelDataType) {
return exprRelDataType.getExprType() == ExprCoreType.IP
&& typeString.getFamily() == SqlTypeFamily.CHARACTER;
}
return false;
}

@Override
public String getAllowedSignatures() {
int min = innerTypeChecker.getOperandCountRange().getMin();
Expand All @@ -269,6 +257,53 @@ public String getAllowedSignatures() {
}
}

class PPLIPCompareTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);
return areIpAndStringTypes(type1, type2)
|| areIpAndStringTypes(type2, type1)
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
}

@Override
public String getAllowedSignatures() {
// Will be merged with the allowed signatures of comparable type checker,
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
return "";
}

private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
}
}

class PPLCidrTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);

// accept (STRING, STRING) or (IP, STRING)
if (type2.getFamily() != SqlTypeFamily.CHARACTER) {
return false;
}
return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER;
}

@Override
public String getAllowedSignatures() {
return "[STRING,STRING],[IP,STRING]";
}
}

/**
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
* belongs to its corresponding {@link SqlTypeFamily}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

/**
* This class is created for the compatibility with {@link SqlUserDefinedFunction} constructors when
* creating UDFs, so that a type checker can be passed to the constructor of {@link
* SqlUserDefinedFunction} as a {@link SqlOperandMetadata}.
*/
public interface UDFOperandMetadata extends SqlOperandMetadata, ImplicitCastOperandTypeChecker {
public interface UDFOperandMetadata extends SqlOperandMetadata {
SqlOperandTypeChecker getInnerTypeChecker();

static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) {
Expand All @@ -35,17 +33,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
return typeChecker;
}

@Override
public boolean checkOperandTypesWithoutTypeCoercion(
SqlCallBinding callBinding, boolean throwOnFailure) {
return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure);
}

@Override
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
return typeChecker.getOperandSqlTypeFamily(iFormalOperand);
}

@Override
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
// This function is not used in the current context, so we return an empty list.
Expand Down Expand Up @@ -89,18 +76,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
return typeChecker;
}

@Override
public boolean checkOperandTypesWithoutTypeCoercion(
SqlCallBinding callBinding, boolean throwOnFailure) {
return typeChecker.checkOperandTypes(callBinding, throwOnFailure);
}

@Override
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
throw new IllegalStateException(
"getOperandSqlTypeFamily is not supported for CompositeOperandTypeChecker");
}

@Override
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
// This function is not used in the current context, so we return an empty list.
Expand Down Expand Up @@ -129,4 +104,76 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
}
};
}

/**
* A named class that serves as an identifier for IP comparator's operand metadata. It does not
* implement any actual type checking logic.
*/
class IPOperandMetadata implements UDFOperandMetadata {
@Override
public SqlOperandTypeChecker getInnerTypeChecker() {
return this;
}

@Override
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
return List.of();
}

@Override
public List<String> paramNames() {
return List.of();
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
return false;
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return null;
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return "";
}
}

/**
* A named class that serves as an identifier for cidr's operand metadata. It does not implement
* any actual type checking logic.
*/
class CidrOperandMetadata implements UDFOperandMetadata {
@Override
public SqlOperandTypeChecker getInnerTypeChecker() {
return this;
}

@Override
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
return List.of();
}

@Override
public List<String> paramNames() {
return List.of();
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
return false;
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return null;
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return "";
}
}
}
Loading
Loading