Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@

package org.opensearch.sql.calcite;

import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.Locale;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
Expand All @@ -24,7 +20,11 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.calcite.type.ExprSqlType;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;

public class ExtendedRexBuilder extends RexBuilder {
Expand Down Expand Up @@ -124,16 +124,31 @@ public RexNode makeCast(
// SqlStdOperatorTable.NOT_EQUALS,
// ImmutableList.of(exp, makeZeroLiteral(exp.getType())));
}
} else if (type instanceof ExprSqlType exprSqlType
&& Set.of(EXPR_DATE, EXPR_TIME, EXPR_TIMESTAMP).contains(exprSqlType.getUdt())) {
switch (exprSqlType.getUdt()) {
case EXPR_DATE:
return makeCall(type, PPLBuiltinOperators.DATE, List.of(exp));
case EXPR_TIME:
return makeCall(type, PPLBuiltinOperators.TIME, List.of(exp));
case EXPR_TIMESTAMP:
return makeCall(type, PPLBuiltinOperators.TIMESTAMP, List.of(exp));
}
} else if (OpenSearchTypeFactory.isUserDefinedType(type)) {
var udt = ((AbstractExprRelDataType<?>) type).getUdt();
var argExprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(exp.getType());
return switch (udt) {
case EXPR_DATE -> makeCall(type, PPLBuiltinOperators.DATE, List.of(exp));
case EXPR_TIME -> makeCall(type, PPLBuiltinOperators.TIME, List.of(exp));
case EXPR_TIMESTAMP -> makeCall(type, PPLBuiltinOperators.TIMESTAMP, List.of(exp));
case EXPR_IP -> {
if (argExprType == ExprCoreType.IP) {
yield exp;
} else if (argExprType == ExprCoreType.STRING) {
yield makeCall(type, PPLBuiltinOperators.IP, List.of(exp));
}
// Throwing error inside implementation will be suppressed by Calcite, thus
// throwing 500 error. Therefore, we throw error here to ensure the error
// information is displayed properly.
throw new ExpressionEvaluationException(
String.format(
Locale.ROOT,
"Cannot convert %s to IP, only STRING and IP types are supported",
argExprType));
}
default -> throw new SemanticCheckException(
String.format(Locale.ROOT, "Cannot cast from %s to %s", argExprType, udt.name()));
};
}
return super.makeCast(pos, type, exp, matchNullability, safe, format);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ public static String getLegacyTypeName(RelDataType relDataType, QueryType queryT

/** Converts a Calcite data type to OpenSearch ExprCoreType. */
public static ExprType convertRelDataTypeToExprType(RelDataType type) {
if (type instanceof AbstractExprRelDataType<?> udt) {
if (isUserDefinedType(type)) {
AbstractExprRelDataType<?> udt = (AbstractExprRelDataType<?>) type;
return udt.getExprType();
}
ExprType exprType = convertSqlTypeNameToExprType(type.getSqlTypeName());
Expand Down Expand Up @@ -326,4 +327,14 @@ public Type getJavaClass(RelDataType type) {
}
return super.getJavaClass(type);
}

/**
* Whether a given RelDataType is a user-defined type (UDT)
*
* @param type the RelDataType to check
* @return true if the type is a user-defined type, false otherwise
*/
public static boolean isUserDefinedType(RelDataType type) {
return type instanceof AbstractExprRelDataType<?>;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public boolean equal(ExprValue other) {

@Override
public String toString() {
return String.format("IP %s", value());
// used for casting to string
return value();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction;
import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction;
import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction;
import org.opensearch.sql.expression.function.udf.ip.IPFunction;
import org.opensearch.sql.expression.function.udf.math.CRC32Function;
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
Expand Down Expand Up @@ -281,6 +282,9 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
NullPolicy.ARG0,
PPLOperandTypes.DATETIME_OR_STRING)
.toUDF("TIME");

// IP cast function
public static final SqlOperator IP = new IPFunction().toUDF("IP");
public static final SqlOperator TIME_TO_SEC =
adaptExprMethodToUDF(
DateTimeFunctions.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,18 +528,9 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLIPCompareTypeChecker()));
} else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
register(
functionName,
createFunctionImpWithTypeChecker(
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
new PPLTypeChecker.PPLCidrTypeChecker()));
} else if (typeChecker
instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
register(functionName, wrapWithUdtTypeChecker(operator, udtOperandMetadata));
} else {
logger.info(
"Cannot create type checker for function: {}. Will skip its type checking",
Expand All @@ -558,6 +549,13 @@ private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
}

// Such wrapWith*TypeChecker methods are useful in that we don't have to create explicit
// overrides of resolve function for different number of operands.
// I.e. we don't have to explicitly call
// (FuncImp1) (builder, arg1) -> builder.makeCall(operator, arg1);
// (FuncImp2) (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2);
// etc.

/**
* Wrap a SqlOperator into a FunctionImp with a composite type checker.
*
Expand Down Expand Up @@ -624,6 +622,21 @@ public PPLTypeChecker getTypeChecker() {
};
}

private static FunctionImp wrapWithUdtTypeChecker(
SqlOperator operator, UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
return new FunctionImp() {
@Override
public RexNode resolve(RexBuilder builder, RexNode... args) {
return builder.makeCall(operator, args);
}

@Override
public PPLTypeChecker getTypeChecker() {
return PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes());
}
};
}

private static FunctionImp createFunctionImpWithTypeChecker(
BiFunction<RexBuilder, RexNode, RexNode> resolver, PPLTypeChecker typeChecker) {
return new FunctionImp1() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.calcite.type.ExprIPType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand Down Expand Up @@ -257,53 +256,6 @@ public String getAllowedSignatures() {
}
}

class PPLIPCompareTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);
return areIpAndStringTypes(type1, type2)
|| areIpAndStringTypes(type2, type1)
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
}

@Override
public String getAllowedSignatures() {
// Will be merged with the allowed signatures of comparable type checker,
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
return "";
}

private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
}
}

class PPLCidrTypeChecker implements PPLTypeChecker {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
if (types.size() != 2) {
return false;
}
RelDataType type1 = types.get(0);
RelDataType type2 = types.get(1);

// accept (STRING, STRING) or (IP, STRING)
if (type2.getFamily() != SqlTypeFamily.CHARACTER) {
return false;
}
return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER;
}

@Override
public String getAllowedSignatures() {
return "[STRING,STRING],[IP,STRING]";
}
}

/**
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
* belongs to its corresponding {@link SqlTypeFamily}.
Expand Down Expand Up @@ -379,6 +331,42 @@ static PPLComparableTypeChecker wrapComparable(SameOperandTypeChecker typeChecke
return new PPLComparableTypeChecker(typeChecker);
}

/**
* Create a {@link PPLTypeChecker} from a list of allowed signatures consisted of {@link
* ExprType}. This is useful to validate arguments against user-defined types (UDT) that does not
* match any Calcite {@link SqlTypeFamily}.
*
* @param allowedSignatures a list of allowed signatures, where each signature is a list of {@link
* ExprType} representing the expected types of the function arguments.
* @return a {@link PPLTypeChecker} that checks if the operand types match any of the allowed
* signatures
*/
static PPLTypeChecker wrapUDT(List<List<ExprType>> allowedSignatures) {
return new PPLTypeChecker() {
@Override
public boolean checkOperandTypes(List<RelDataType> types) {
List<ExprType> argExprTypes =
types.stream().map(OpenSearchTypeFactory::convertRelDataTypeToExprType).toList();
for (var allowedSignature : allowedSignatures) {
if (allowedSignature.size() != types.size()) {
continue; // Skip signatures that do not match the operand count
}
// Check if the argument types match the allowed signature
if (IntStream.range(0, allowedSignature.size())
.allMatch(i -> allowedSignature.get(i).equals(argExprTypes.get(i)))) {
return true;
}
}
return false;
}

@Override
public String getAllowedSignatures() {
return PPLTypeChecker.getExprFamilySignature(allowedSignatures);
}
};
}

// Util Functions
/**
* Generates a list of allowed function signatures based on the provided {@link
Expand Down Expand Up @@ -464,6 +452,10 @@ private static String getFamilySignature(List<SqlTypeFamily> families) {
List<List<ExprType>> signatures = Lists.cartesianProduct(exprTypes);

// Convert each signature to a string representation and then concatenate them
return getExprFamilySignature(signatures);
}

private static String getExprFamilySignature(List<List<ExprType>> signatures) {
return signatures.stream()
.map(
types ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.opensearch.sql.data.type.ExprType;

/**
* This class is created for the compatibility with {@link SqlUserDefinedFunction} constructors when
Expand Down Expand Up @@ -105,47 +106,11 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
};
}

/**
* A named class that serves as an identifier for IP comparator's operand metadata. It does not
* implement any actual type checking logic.
*/
class IPOperandMetadata implements UDFOperandMetadata {
@Override
public SqlOperandTypeChecker getInnerTypeChecker() {
return this;
}

@Override
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
return List.of();
}

@Override
public List<String> paramNames() {
return List.of();
}

@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
return false;
}

@Override
public SqlOperandCountRange getOperandCountRange() {
return null;
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return "";
}
static UDFOperandMetadata wrapUDT(List<List<ExprType>> allowSignatures) {
return new UDTOperandMetadata(allowSignatures);
}

/**
* A named class that serves as an identifier for cidr's operand metadata. It does not implement
* any actual type checking logic.
*/
class CidrOperandMetadata implements UDFOperandMetadata {
record UDTOperandMetadata(List<List<ExprType>> allowedParamTypes) implements UDFOperandMetadata {
@Override
public SqlOperandTypeChecker getInnerTypeChecker() {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.sql.data.model.ExprIpValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;
import org.opensearch.sql.expression.ip.IPFunctions;
Expand Down Expand Up @@ -46,7 +47,10 @@ public UDFOperandMetadata getOperandMetadata() {
// EXPR_IP is mapped to SqlTypeFamily.OTHER in
// UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName
// We use a specific type checker to serve
return new UDFOperandMetadata.CidrOperandMetadata();
return UDFOperandMetadata.wrapUDT(
List.of(
List.of(ExprCoreType.IP, ExprCoreType.STRING),
List.of(ExprCoreType.STRING, ExprCoreType.STRING)));
}

public static class CidrMatchImplementor implements NotNullImplementor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.opensearch.sql.data.model.ExprIpValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

Expand Down Expand Up @@ -66,7 +67,11 @@ public SqlReturnTypeInference getReturnTypeInference() {

@Override
public UDFOperandMetadata getOperandMetadata() {
return new UDFOperandMetadata.IPOperandMetadata();
return UDFOperandMetadata.wrapUDT(
List.of(
List.of(ExprCoreType.IP, ExprCoreType.IP),
List.of(ExprCoreType.IP, ExprCoreType.STRING),
List.of(ExprCoreType.STRING, ExprCoreType.IP)));
}

public static class CompareImplementor implements NotNullImplementor {
Expand Down
Loading
Loading