Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) {
return SqlTypeName.TIME;
case EXPR_TIMESTAMP:
return SqlTypeName.TIMESTAMP;
// EXPR_IP is mapped to SqlTypeName.OTHER since there is no
// corresponding SqlTypeName in Calcite.
case EXPR_IP:
return SqlTypeName.VARCHAR;
return SqlTypeName.OTHER;
case EXPR_BINARY:
return SqlTypeName.VARBINARY;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.sql.expression.function.udf.datetime.WeekdayFunction;
import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction;
import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction;
import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction;
import org.opensearch.sql.expression.function.udf.math.CRC32Function;
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
Expand Down Expand Up @@ -103,6 +104,15 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");

// IP comparing functions
public static final SqlOperator NOT_EQUALS_IP =
CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP");
public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP");
public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP");
public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP");
public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP");
public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP");

// Condition function
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");
public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ functionName, getActualSignature(argTypes), e.getMessage()),
}
StringJoiner allowedSignatures = new StringJoiner(",");
for (var implement : implementList) {
allowedSignatures.add(implement.getKey().getTypeChecker().getAllowedSignatures());
String signature = implement.getKey().getTypeChecker().getAllowedSignatures();
if (!signature.isEmpty()) {
allowedSignatures.add(signature);
}
}
throw new ExpressionEvaluationException(
String.format(
Expand All @@ -488,45 +491,71 @@ private abstract static class AbstractBuilder {
/** Maps an operator to an implementation. */
abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp);

void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
SqlOperandTypeChecker typeChecker;
if (operator instanceof SqlUserDefinedFunction) {
SqlUserDefinedFunction udfOperator = (SqlUserDefinedFunction) operator;
typeChecker = extractTypeCheckerFromUDF(udfOperator);
} else {
typeChecker = operator.getOperandTypeChecker();
}
/**
* Register one or multiple operators under a single function name. This allows function
* overloading based on operand types.
*
* <p>When a function is called, the system will try each registered operator in sequence,
* checking if the provided arguments match the operator's type requirements. The first operator
* whose type checker accepts the arguments will be used to execute the function.
*
* @param functionName the built-in function name under which to register the operators
* @param operators the operators to associate with this function name, tried in sequence until
* one matches the argument types during resolution
*/
void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
for (SqlOperator operator : operators) {
SqlOperandTypeChecker typeChecker;
if (operator instanceof SqlUserDefinedFunction) {
SqlUserDefinedFunction udfOperator = (SqlUserDefinedFunction) operator;
typeChecker = extractTypeCheckerFromUDF(udfOperator);
} else {
typeChecker = operator.getOperandTypeChecker();
}

// Only the composite operand type checker for UDFs are concerned here.
if (operator instanceof SqlUserDefinedFunction
&& typeChecker instanceof CompositeOperandTypeChecker) {
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
// UDFs implement their own composite type checkers, which always use OR logic for argument
// types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker) {
ImplicitCastOperandTypeChecker implicitCastTypeChecker = (ImplicitCastOperandTypeChecker) typeChecker;
register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
} else if (typeChecker instanceof CompositeOperandTypeChecker) {
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
} else if (typeChecker instanceof SameOperandTypeChecker) {
SameOperandTypeChecker comparableTypeChecker = (SameOperandTypeChecker) typeChecker;
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking",
functionName);
register(
functionName,
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
// Only the composite operand type checker for UDFs are concerned here.
if (operator instanceof SqlUserDefinedFunction
&& typeChecker instanceof CompositeOperandTypeChecker) {
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
// UDFs implement their own composite type checkers, which always use OR logic for argument
// types. Verifying the composition type would require accessing a protected field in
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
// be skipped, so we avoid checking the composition type here.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker) {
ImplicitCastOperandTypeChecker implicitCastTypeChecker = (ImplicitCastOperandTypeChecker) typeChecker;
register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
} else if (typeChecker instanceof CompositeOperandTypeChecker) {
CompositeOperandTypeChecker compositeTypeChecker = (CompositeOperandTypeChecker) typeChecker;
// If compositeTypeChecker contains operand checkers other than family type checkers or
// other than OR compositions, the function with be registered with a null type checker,
// which means the function will not be type checked.
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
} else if (typeChecker instanceof SameOperandTypeChecker) {
SameOperandTypeChecker comparableTypeChecker = (SameOperandTypeChecker) typeChecker;
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLIPCompareTypeChecker()));
} else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLCidrTypeChecker()));
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking",
functionName);
register(
functionName,
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
}
}
}

Expand Down Expand Up @@ -634,16 +663,18 @@ public PPLTypeChecker getTypeChecker() {
}

void populate() {
// register operators for comparison
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP, SqlStdOperatorTable.EQUALS);
registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP, SqlStdOperatorTable.GREATER_THAN);
registerOperator(GTE, PPLBuiltinOperators.GTE_IP, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
registerOperator(LESS, PPLBuiltinOperators.LESS_IP, SqlStdOperatorTable.LESS_THAN);
registerOperator(LTE, PPLBuiltinOperators.LTE_IP, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);

// Register std operator
registerOperator(AND, SqlStdOperatorTable.AND);
registerOperator(OR, SqlStdOperatorTable.OR);
registerOperator(NOT, SqlStdOperatorTable.NOT);
registerOperator(NOTEQUAL, SqlStdOperatorTable.NOT_EQUALS);
registerOperator(EQUAL, SqlStdOperatorTable.EQUALS);
registerOperator(GREATER, SqlStdOperatorTable.GREATER_THAN);
registerOperator(GTE, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
registerOperator(LESS, SqlStdOperatorTable.LESS_THAN);
registerOperator(LTE, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
registerOperator(ADD, SqlStdOperatorTable.PLUS);
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.type.ExprIPType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand Down Expand Up @@ -222,10 +222,6 @@ public boolean checkOperandTypes(List<RelDataType> types) {
RelDataType type_l = types.get(i);
RelDataType type_r = types.get(i + 1);
if (!SqlTypeUtil.isComparable(type_l, type_r)) {
if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) {
// Allow IP and string comparison
continue;
}
return false;
}
// Disallow coercing between strings and numeric, boolean
Expand All @@ -252,15 +248,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) {
}
}

private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
if (typeIp instanceof AbstractExprRelDataType<?>) {
AbstractExprRelDataType<?> exprRelDataType = (AbstractExprRelDataType<?>) typeIp;
return exprRelDataType.getExprType() == ExprCoreType.IP
&& typeString.getFamily() == SqlTypeFamily.CHARACTER;
}
return false;
}

@Override
public String getAllowedSignatures() {
int min = innerTypeChecker.getOperandCountRange().getMin();
Expand All @@ -283,6 +270,53 @@ public String getAllowedSignatures() {
}
}

class PPLIPCompareTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);
return areIpAndStringTypes(type1, type2)
|| areIpAndStringTypes(type2, type1)
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
}

@Override
public String getAllowedSignatures() {
// Will be merged with the allowed signatures of comparable type checker,
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
return "";
}

private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
}
}

class PPLCidrTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);

// accept (STRING, STRING) or (IP, STRING)
if (type2.getFamily() != SqlTypeFamily.CHARACTER) {
return false;
}
return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER;
}

@Override
public String getAllowedSignatures() {
return "[STRING,STRING],[IP,STRING]";
}
}

/**
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
* belongs to its corresponding {@link SqlTypeFamily}.
Expand Down
Loading
Loading