diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java index 8a43a847027..09d3c312a94 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java @@ -24,7 +24,7 @@ public Object result(TakeAccumulator accumulator) { @Override public TakeAccumulator add(TakeAccumulator acc, Object... values) { Object candidateValue = values[0]; - int size = 0; + int size; if (values.length > 1) { size = (int) values[1]; } else { diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index 4afc4f0c13c..e57697038de 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -30,6 +30,7 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.impl.AggregateFunctionImpl; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.parser.SqlParserPos; @@ -77,27 +78,71 @@ public class UserDefinedFunctionUtils { public static Set MULTI_FIELDS_RELEVANCE_FUNCTION_SET = ImmutableSet.of("simple_query_string", "query_string", "multi_match"); - public static RelBuilder.AggCall TransferUserDefinedAggFunction( - Class UDAF, + /** + * Creates a SqlUserDefinedAggFunction that wraps a Java class implementing an aggregate function. + * + * @param udafClass The Java class that implements the UserDefinedAggFunction interface + * @param functionName The name of the function to be used in SQL statements + * @param returnType A SqlReturnTypeInference that determines the return type of the function + * @return A SqlUserDefinedAggFunction that can be used in SQL queries + */ + public static SqlUserDefinedAggFunction createUserDefinedAggFunction( + Class> udafClass, String functionName, - SqlReturnTypeInference returnType, + SqlReturnTypeInference returnType) { + return new SqlUserDefinedAggFunction( + new SqlIdentifier(functionName, SqlParserPos.ZERO), + SqlKind.OTHER_FUNCTION, + returnType, + null, + null, + AggregateFunctionImpl.create(udafClass), + false, + false, + Optionality.FORBIDDEN); + } + + /** + * Creates an aggregate call using the provided SqlAggFunction and arguments. + * + * @param aggFunction The aggregate function to call + * @param fields The primary fields to aggregate + * @param argList Additional arguments for the aggregate function + * @param relBuilder The RelBuilder instance used for building relational expressions + * @return An AggCall object representing the aggregate function call + */ + public static RelBuilder.AggCall makeAggregateCall( + SqlAggFunction aggFunction, List fields, List argList, RelBuilder relBuilder) { - SqlUserDefinedAggFunction sqlUDAF = - new SqlUserDefinedAggFunction( - new SqlIdentifier(functionName, SqlParserPos.ZERO), - SqlKind.OTHER_FUNCTION, - returnType, - null, - null, - AggregateFunctionImpl.create(UDAF), - false, - false, - Optionality.FORBIDDEN); List addArgList = new ArrayList<>(fields); addArgList.addAll(argList); - return relBuilder.aggregateCall(sqlUDAF, addArgList); + return relBuilder.aggregateCall(aggFunction, addArgList); + } + + /** + * Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can + * be used in query plans. + * + * @param udafClass The class implementing the aggregate function behavior + * @param functionName The name of the aggregate function + * @param returnType The return type inference for determining the result type + * @param fields The primary fields to aggregate + * @param argList Additional arguments for the aggregate function + * @param relBuilder The RelBuilder instance used for building relational expressions + * @return An AggCall object representing the aggregate function call + */ + public static RelBuilder.AggCall createAggregateFunction( + Class> udafClass, + String functionName, + SqlReturnTypeInference returnType, + List fields, + List argList, + RelBuilder relBuilder) { + SqlUserDefinedAggFunction udaf = + createUserDefinedAggFunction(udafClass, functionName, returnType); + return makeAggregateCall(udaf, fields, argList, relBuilder); } public static SqlReturnTypeInference getReturnTypeInferenceForArray() { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 9c956dd3682..64f3e5a8f0e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -6,14 +6,13 @@ package org.opensearch.sql.expression.function; import static org.apache.calcite.sql.SqlJsonConstructorNullClause.NULL_ON_NULL; -import static org.apache.calcite.sql.type.SqlTypeFamily.IGNORE; import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_POP_NULLABLE; import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE; import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE; import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName; -import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction; +import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.createAggregateFunction; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ABS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; @@ -232,7 +231,6 @@ import java.util.Optional; import java.util.StringJoiner; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nullable; @@ -240,6 +238,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLambda; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -252,9 +251,9 @@ import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.calcite.tools.RelBuilder; -import org.apache.commons.lang3.function.TriFunction; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -282,21 +281,11 @@ public interface FunctionImp { RelDataType ANY_TYPE = TYPE_FACTORY.createSqlType(SqlTypeName.ANY); RexNode resolve(RexBuilder builder, RexNode... args); - - /** - * @return the PPLTypeChecker. Default return null implies unknown parameters {@link - * CalciteFuncSignature} won't check parameters if it's null - */ - default PPLTypeChecker getTypeChecker() { - return null; - } } public interface FunctionImp1 extends FunctionImp { RexNode resolve(RexBuilder builder, RexNode arg1); - PPLTypeChecker IGNORE_1 = PPLTypeChecker.family(IGNORE); - @Override default RexNode resolve(RexBuilder builder, RexNode... args) { if (args.length != 1) { @@ -304,16 +293,9 @@ default RexNode resolve(RexBuilder builder, RexNode... args) { } return resolve(builder, args[0]); } - - @Override - default PPLTypeChecker getTypeChecker() { - return IGNORE_1; - } } public interface FunctionImp2 extends FunctionImp { - PPLTypeChecker IGNORE_2 = PPLTypeChecker.family(IGNORE, IGNORE); - RexNode resolve(RexBuilder builder, RexNode arg1, RexNode arg2); @Override @@ -323,11 +305,6 @@ default RexNode resolve(RexBuilder builder, RexNode... args) { } return resolve(builder, args[0], args[1]); } - - @Override - default PPLTypeChecker getTypeChecker() { - return IGNORE_2; - } } /** The singleton instance. */ @@ -362,14 +339,16 @@ default PPLTypeChecker getTypeChecker() { * implementations are independent of any specific data storage, should be registered here * internally. */ - private final ImmutableMap aggFunctionRegistry; + private final ImmutableMap> + aggFunctionRegistry; /** * The external agg function registry. Agg Functions whose implementations depend on a specific * data engine should be registered here. This reduces coupling between the core module and * particular storage backends. */ - private final Map aggExternalFunctionRegistry; + private final Map> + aggExternalFunctionRegistry; private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) { final ImmutableMap.Builder>> @@ -378,41 +357,53 @@ private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) { this.functionRegistry = ImmutableMap.copyOf(mapBuilder.build()); this.externalFunctionRegistry = new ConcurrentHashMap<>(); - final ImmutableMap.Builder aggMapBuilder = - ImmutableMap.builder(); + final ImmutableMap.Builder> + aggMapBuilder = ImmutableMap.builder(); aggBuilder.map.forEach(aggMapBuilder::put); this.aggFunctionRegistry = ImmutableMap.copyOf(aggMapBuilder.build()); this.aggExternalFunctionRegistry = new ConcurrentHashMap<>(); } /** - * Register a function implementation from external services dynamically. + * Register an operator from external services dynamically. * * @param functionName the name of the function, has to be defined in BuiltinFunctionName - * @param functionImp the implementation of the function + * @param operator a SqlOperator representing an externally implemented function */ - public void registerExternalFunction(BuiltinFunctionName functionName, FunctionImp functionImp) { - CalciteFuncSignature signature = - new CalciteFuncSignature(functionName.getName(), functionImp.getTypeChecker()); + public void registerExternalOperator(BuiltinFunctionName functionName, SqlOperator operator) { + PPLTypeChecker typeChecker = + wrapSqlOperandTypeChecker( + operator.getOperandTypeChecker(), + functionName.name(), + operator instanceof SqlUserDefinedFunction); + CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), typeChecker); externalFunctionRegistry.compute( functionName, (name, existingList) -> { List> list = existingList == null ? new ArrayList<>() : new ArrayList<>(existingList); - list.add(Pair.of(signature, functionImp)); + list.add(Pair.of(signature, (builder, args) -> builder.makeCall(operator, args))); return list; }); } /** - * Register a function implementation from external services dynamically. + * Register an external aggregate operator dynamically. * * @param functionName the name of the function, has to be defined in BuiltinFunctionName - * @param functionImp the implementation of the agg function + * @param aggFunction a SqlUserDefinedAggFunction representing the aggregate function + * implementation */ - public void registerExternalAggFunction( - BuiltinFunctionName functionName, AggHandler functionImp) { - aggExternalFunctionRegistry.put(functionName, functionImp); + public void registerExternalAggOperator( + BuiltinFunctionName functionName, SqlUserDefinedAggFunction aggFunction) { + PPLTypeChecker typeChecker = + wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true); + CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), typeChecker); + AggHandler handler = + (distinct, field, argList, ctx) -> + UserDefinedFunctionUtils.makeAggregateCall( + aggFunction, List.of(field), argList, ctx.relBuilder); + aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler)); } public RelBuilder.AggCall resolveAgg( @@ -421,13 +412,37 @@ public RelBuilder.AggCall resolveAgg( RexNode field, List argList, CalcitePlanContext context) { - AggHandler handler = aggExternalFunctionRegistry.get(functionName); - if (handler == null) { - handler = aggFunctionRegistry.get(functionName); + var implementation = aggExternalFunctionRegistry.get(functionName); + if (implementation == null) { + implementation = aggFunctionRegistry.get(functionName); } - if (handler == null) { + if (implementation == null) { throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName)); } + CalciteFuncSignature signature = implementation.getKey(); + List argTypes = new ArrayList<>(); + if (field != null) { + argTypes.add(field.getType()); + } + // Currently only PERCENTILE_APPROX and TAKE have additional arguments. + // Their additional arguments will always come as a map of + List additionalArgTypes = + argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).toList(); + argTypes.addAll(additionalArgTypes); + if (!signature.match(functionName.getName(), argTypes)) { + String errorMessagePattern = + argTypes.size() <= 1 + ? "Aggregation function %s expects field type {%s}, but got %s" + : "Aggregation function %s expects field type and additional arguments {%s}, but got" + + " %s"; + throw new ExpressionEvaluationException( + String.format( + errorMessagePattern, + functionName, + signature.typeChecker().getAllowedSignatures(), + getActualSignature(argTypes))); + } + var handler = implementation.getValue(); return handler.apply(distinct, field, argList, context); } @@ -549,6 +564,12 @@ private void compulsoryCast( return null; } + /** + * Get a string representation of the argument types expressed in ExprType for error messages. + * + * @param argTypes the list of argument types as {@link RelDataType} + * @return a string in the format [type1,type2,...] representing the argument types + */ private static String getActualSignature(List argTypes) { return "[" + argTypes.stream() @@ -558,11 +579,64 @@ private static String getActualSignature(List argTypes) { + "]"; } + /** + * Wraps a {@link SqlOperandTypeChecker} into a {@link PPLTypeChecker} for use in function + * signature validation. + * + * @param typeChecker the original SQL operand type checker + * @param functionName the name of the function for error reporting + * @param isUserDefinedFunction true if the function is user-defined, false otherwise + * @return a {@link PPLTypeChecker} that delegates to the provided {@code typeChecker} + */ + private static PPLTypeChecker wrapSqlOperandTypeChecker( + SqlOperandTypeChecker typeChecker, String functionName, boolean isUserDefinedFunction) { + PPLTypeChecker pplTypeChecker; + // Only the composite operand type checker for UDFs are concerned here. + if (isUserDefinedFunction + && typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) { + // UDFs implement their own composite type checkers, which always use OR logic for + // argument + // types. Verifying the composition type would require accessing a protected field in + // CompositeOperandTypeChecker. If access to this field is not allowed, type checking will + // be skipped, so we avoid checking the composition type here. + pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, false); + } else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) { + pplTypeChecker = PPLTypeChecker.wrapFamily(implicitCastTypeChecker); + } else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) { + // If compositeTypeChecker contains operand checkers other than family type checkers or + // other than OR compositions, the function with be registered with a null type checker, + // which means the function will not be type checked. + try { + pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, true); + } catch (IllegalArgumentException | UnsupportedOperationException e) { + logger.debug( + String.format( + "Failed to create composite type checker for operator: %s. Will skip its type" + + " checking", + functionName), + e); + pplTypeChecker = null; + } + } else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) { + // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc. + // SameOperandTypeCheckers like COALESCE, IFNULL, etc. + pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker); + } else if (typeChecker instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) { + pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes()); + } else { + logger.info( + "Cannot create type checker for function: {}. Will skip its type checking", functionName); + pplTypeChecker = null; + } + return pplTypeChecker; + } + @SuppressWarnings({"UnusedReturnValue", "SameParameterValue"}) private abstract static class AbstractBuilder { /** Maps an operator to an implementation. */ - abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp); + abstract void register( + BuiltinFunctionName functionName, FunctionImp functionImp, PPLTypeChecker typeChecker); /** * Register one or multiple operators under a single function name. This allows function @@ -585,40 +659,13 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op typeChecker = operator.getOperandTypeChecker(); } - // Only the composite operand type checker for UDFs are concerned here. - if (operator instanceof SqlUserDefinedFunction - && typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) { - // UDFs implement their own composite type checkers, which always use OR logic for - // argument - // types. Verifying the composition type would require accessing a protected field in - // CompositeOperandTypeChecker. If access to this field is not allowed, type checking will - // be skipped, so we avoid checking the composition type here. - register( - functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false)); - } else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) { - register( - functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker)); - } else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) { - // If compositeTypeChecker contains operand checkers other than family type checkers or - // other than OR compositions, the function with be registered with a null type checker, - // which means the function will not be type checked. - register( - functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true)); - } else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) { - // Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc. - // SameOperandTypeCheckers like COALESCE, IFNULL, etc. - register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker)); - } else if (typeChecker - instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) { - register(functionName, wrapWithUdtTypeChecker(operator, udtOperandMetadata)); - } else { - logger.info( - "Cannot create type checker for function: {}. Will skip its type checking", - functionName); - register( - functionName, - (RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node)); - } + PPLTypeChecker pplTypeChecker = + wrapSqlOperandTypeChecker( + typeChecker, operator.getName(), operator instanceof SqlUserDefinedFunction); + register( + functionName, + (RexBuilder builder, RexNode... args) -> builder.makeCall(operator, args), + pplTypeChecker); } } @@ -629,124 +676,6 @@ private static SqlOperandTypeChecker extractTypeCheckerFromUDF( return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker(); } - // Such wrapWith*TypeChecker methods are useful in that we don't have to create explicit - // overrides of resolve function for different number of operands. - // I.e. we don't have to explicitly call - // (FuncImp1) (builder, arg1) -> builder.makeCall(operator, arg1); - // (FuncImp2) (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2); - // etc. - - /** - * Wrap a SqlOperator into a FunctionImp with a composite type checker. - * - * @param operator the SqlOperator to wrap - * @param typeChecker the CompositeOperandTypeChecker to use for type checking - * @param checkCompositionType if true, the type checker will check whether the composition type - * of the type checker is OR. - * @return a FunctionImp that resolves to the operator and has the specified type checker - */ - private static FunctionImp wrapWithCompositeTypeChecker( - SqlOperator operator, - CompositeOperandTypeChecker typeChecker, - boolean checkCompositionType) { - return new FunctionImp() { - @Override - public RexNode resolve(RexBuilder builder, RexNode... args) { - return builder.makeCall(operator, args); - } - - @Override - public PPLTypeChecker getTypeChecker() { - try { - return PPLTypeChecker.wrapComposite(typeChecker, checkCompositionType); - } catch (IllegalArgumentException | UnsupportedOperationException e) { - logger.debug( - String.format( - "Failed to create composite type checker for operator: %s. Will skip its type" - + " checking", - operator.getName()), - e); - return null; - } - } - }; - } - - private static FunctionImp wrapWithImplicitCastTypeChecker( - SqlOperator operator, ImplicitCastOperandTypeChecker typeChecker) { - return new FunctionImp() { - @Override - public RexNode resolve(RexBuilder builder, RexNode... args) { - return builder.makeCall(operator, args); - } - - @Override - public PPLTypeChecker getTypeChecker() { - return PPLTypeChecker.wrapFamily(typeChecker); - } - }; - } - - private static FunctionImp wrapWithComparableTypeChecker( - SqlOperator operator, SameOperandTypeChecker typeChecker) { - return new FunctionImp() { - @Override - public RexNode resolve(RexBuilder builder, RexNode... args) { - return builder.makeCall(operator, args); - } - - @Override - public PPLTypeChecker getTypeChecker() { - return PPLTypeChecker.wrapComparable(typeChecker); - } - }; - } - - private static FunctionImp wrapWithUdtTypeChecker( - SqlOperator operator, UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) { - return new FunctionImp() { - @Override - public RexNode resolve(RexBuilder builder, RexNode... args) { - return builder.makeCall(operator, args); - } - - @Override - public PPLTypeChecker getTypeChecker() { - return PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes()); - } - }; - } - - private static FunctionImp createFunctionImpWithTypeChecker( - BiFunction resolver, PPLTypeChecker typeChecker) { - return new FunctionImp1() { - @Override - public RexNode resolve(RexBuilder builder, RexNode arg1) { - return resolver.apply(builder, arg1); - } - - @Override - public PPLTypeChecker getTypeChecker() { - return typeChecker; - } - }; - } - - private static FunctionImp createFunctionImpWithTypeChecker( - TriFunction resolver, PPLTypeChecker typeChecker) { - return new FunctionImp2() { - @Override - public RexNode resolve(RexBuilder builder, RexNode arg1, RexNode arg2) { - return resolver.apply(builder, arg1, arg2); - } - - @Override - public PPLTypeChecker getTypeChecker() { - return typeChecker; - } - }; - } - void populate() { // register operators for comparison registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS); @@ -932,14 +861,16 @@ void populate() { builder.makeCall( SqlStdOperatorTable.JSON_ARRAY, Stream.concat(Stream.of(builder.makeFlag(NULL_ON_NULL)), Arrays.stream(args)) - .toArray(RexNode[]::new)))); + .toArray(RexNode[]::new))), + null); register( JSON_OBJECT, ((builder, args) -> builder.makeCall( SqlStdOperatorTable.JSON_OBJECT, Stream.concat(Stream.of(builder.makeFlag(NULL_ON_NULL)), Arrays.stream(args)) - .toArray(RexNode[]::new)))); + .toArray(RexNode[]::new))), + null); registerOperator(JSON, PPLBuiltinOperators.JSON); registerOperator(JSON_ARRAY_LENGTH, PPLBuiltinOperators.JSON_ARRAY_LENGTH); registerOperator(JSON_EXTRACT, PPLBuiltinOperators.JSON_EXTRACT); @@ -954,51 +885,52 @@ void populate() { // Note, make the implementation an individual class if too complex. register( TRIM, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.TRIM, builder.makeFlag(Flag.BOTH), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER)); register( LTRIM, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.TRIM, builder.makeFlag(Flag.LEADING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER)); register( RTRIM, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.TRIM, builder.makeFlag(Flag.TRAILING), builder.makeLiteral(" "), arg), - PPLTypeChecker.family(SqlTypeFamily.CHARACTER))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER)); register( ATAN, - createFunctionImpWithTypeChecker( + (FunctionImp2) (builder, arg1, arg2) -> builder.makeCall(SqlStdOperatorTable.ATAN2, arg1, arg2), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC))); + PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)); register( STRCMP, - createFunctionImpWithTypeChecker( + (FunctionImp2) (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))); + PPLTypeChecker.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)); // SqlStdOperatorTable.SUBSTRING.getOperandTypeChecker is null. We manually create a type // checker for it. register( SUBSTRING, - wrapWithCompositeTypeChecker( - SqlStdOperatorTable.SUBSTRING, + (RexBuilder builder, RexNode... args) -> + builder.makeCall(SqlStdOperatorTable.SUBSTRING, args), + PPLTypeChecker.wrapComposite( (CompositeOperandTypeChecker) OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .or( @@ -1009,8 +941,9 @@ void populate() { false)); register( SUBSTR, - wrapWithCompositeTypeChecker( - SqlStdOperatorTable.SUBSTRING, + (RexBuilder builder, RexNode... args) -> + builder.makeCall(SqlStdOperatorTable.SUBSTRING, args), + PPLTypeChecker.wrapComposite( (CompositeOperandTypeChecker) OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .or( @@ -1024,73 +957,77 @@ void populate() { // operands. register( INTERNAL_ITEM, - wrapWithCompositeTypeChecker( - SqlStdOperatorTable.ITEM, + (RexBuilder builder, RexNode... args) -> builder.makeCall(SqlStdOperatorTable.ITEM, args), + PPLTypeChecker.wrapComposite( (CompositeOperandTypeChecker) OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER) .or(OperandTypes.family(SqlTypeFamily.MAP, SqlTypeFamily.ANY)), false)); register( LOG, - createFunctionImpWithTypeChecker( + (FunctionImp2) (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.LOG, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC))); + PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)); register( LOG, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlLibraryOperators.LOG, arg, builder.makeApproxLiteral(BigDecimal.valueOf(Math.E))), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC))); + PPLTypeChecker.family(SqlTypeFamily.NUMERIC)); // SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is // converted to POWER(x, 0.5). register( SQRT, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.POWER, arg, builder.makeApproxLiteral(BigDecimal.valueOf(0.5))), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC))); + PPLTypeChecker.family(SqlTypeFamily.NUMERIC)); register( TYPEOF, (FunctionImp1) (builder, arg) -> - builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL))); - register(XOR, new XOR_FUNC()); + builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL)), + null); + register( + XOR, + (FunctionImp2) + (builder, arg1, arg2) -> builder.makeCall(SqlStdOperatorTable.NOT_EQUALS, arg1, arg2), + PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.BOOLEAN)); // SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker // for it. The second and third operands are required to be of the same type. If not, // it will throw an IllegalArgumentException with information Can't find leastRestrictive type register( IF, - wrapWithImplicitCastTypeChecker( - SqlStdOperatorTable.CASE, - OperandTypes.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ANY))); + (RexBuilder builder, RexNode... args) -> builder.makeCall(SqlStdOperatorTable.CASE, args), + PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ANY)); register( NULLIF, - createFunctionImpWithTypeChecker( + (FunctionImp2) (builder, arg1, arg2) -> builder.makeCall( SqlStdOperatorTable.CASE, builder.makeCall(SqlStdOperatorTable.EQUALS, arg1, arg2), builder.makeNullLiteral(arg1.getType()), arg1), - PPLTypeChecker.wrapComparable((SameOperandTypeChecker) OperandTypes.SAME_SAME))); + PPLTypeChecker.wrapComparable((SameOperandTypeChecker) OperandTypes.SAME_SAME)); register( IS_EMPTY, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.OR, builder.makeCall(SqlStdOperatorTable.IS_NULL, arg), builder.makeCall(SqlStdOperatorTable.IS_EMPTY, arg)), - PPLTypeChecker.family(SqlTypeFamily.ANY))); + PPLTypeChecker.family(SqlTypeFamily.ANY)); register( IS_BLANK, - createFunctionImpWithTypeChecker( + (FunctionImp1) (builder, arg) -> builder.makeCall( SqlStdOperatorTable.OR, @@ -1102,10 +1039,10 @@ void populate() { builder.makeFlag(Flag.BOTH), builder.makeLiteral(" "), arg))), - PPLTypeChecker.family(SqlTypeFamily.ANY))); + PPLTypeChecker.family(SqlTypeFamily.ANY)); register( LIKE, - createFunctionImpWithTypeChecker( + (FunctionImp2) (builder, arg1, arg2) -> builder.makeCall( SqlLibraryOperators.ILIKE, @@ -1114,7 +1051,7 @@ void populate() { // TODO: Figure out escaping solution. '\\' is used for JSON input but is not // necessary for SQL function input builder.makeLiteral("\\")), - PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING)); } } @@ -1123,9 +1060,10 @@ private static class Builder extends AbstractBuilder { new HashMap<>(); @Override - void register(BuiltinFunctionName functionName, FunctionImp implement) { + void register( + BuiltinFunctionName functionName, FunctionImp implement, PPLTypeChecker typeChecker) { CalciteFuncSignature signature = - new CalciteFuncSignature(functionName.getName(), implement.getTypeChecker()); + new CalciteFuncSignature(functionName.getName(), typeChecker); if (map.containsKey(functionName)) { map.get(functionName).add(Pair.of(signature, implement)); } else { @@ -1134,75 +1072,90 @@ void register(BuiltinFunctionName functionName, FunctionImp implement) { } } - // ------------------------------------------------------------- - // FUNCTIONS - // ------------------------------------------------------------- - /** Implement XOR via NOT_EQUAL, and limit the arguments' type to boolean only */ - private static class XOR_FUNC implements FunctionImp2 { - @Override - public RexNode resolve(RexBuilder builder, RexNode arg1, RexNode arg2) { - return builder.makeCall(SqlStdOperatorTable.NOT_EQUALS, arg1, arg2); - } + private static class AggBuilder { + private final Map> map = + new HashMap<>(); - @Override - public PPLTypeChecker getTypeChecker() { - SqlTypeFamily booleanFamily = SqlTypeName.BOOLEAN.getFamily(); - return PPLTypeChecker.family(booleanFamily, booleanFamily); + void register( + BuiltinFunctionName functionName, AggHandler aggHandler, PPLTypeChecker typeChecker) { + CalciteFuncSignature signature = + new CalciteFuncSignature(functionName.getName(), typeChecker); + map.put(functionName, Pair.of(signature, aggHandler)); } - } - - private static class AggBuilder { - private final Map map = new HashMap<>(); - void register(BuiltinFunctionName functionName, AggHandler aggHandler) { - map.put(functionName, aggHandler); + void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFunction) { + PPLTypeChecker typeChecker = + wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true); + AggHandler handler = + (distinct, field, argList, ctx) -> + UserDefinedFunctionUtils.makeAggregateCall( + aggFunction, List.of(field), argList, ctx.relBuilder); + register(functionName, handler, typeChecker); } void populate() { - register(MAX, (distinct, field, argList, ctx) -> ctx.relBuilder.max(field)); - register(MIN, (distinct, field, argList, ctx) -> ctx.relBuilder.min(field)); + registerOperator(MAX, SqlStdOperatorTable.MAX); + registerOperator(MIN, SqlStdOperatorTable.MIN); + registerOperator(SUM, SqlStdOperatorTable.SUM); - register(AVG, (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field)); + register( + AVG, + (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.AVG.getOperandTypeChecker(), AVG.name(), false)); register( COUNT, (distinct, field, argList, ctx) -> ctx.relBuilder.count( - distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field))); - register(SUM, (distinct, field, argList, ctx) -> ctx.relBuilder.sum(distinct, null, field)); + distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.COUNT.getOperandTypeChecker(), COUNT.name(), false)); register( VARSAMP, - (distinct, field, argList, ctx) -> - ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field)); + (distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.VAR_SAMP.getOperandTypeChecker(), VARSAMP.name(), false)); register( VARPOP, - (distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field)); + (distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.VAR_POP.getOperandTypeChecker(), VARPOP.name(), false)); register( STDDEV_SAMP, (distinct, field, argList, ctx) -> - ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field)); + ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.STDDEV_SAMP.getOperandTypeChecker(), STDDEV_SAMP.name(), false)); register( STDDEV_POP, (distinct, field, argList, ctx) -> - ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field)); + ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field), + wrapSqlOperandTypeChecker( + SqlStdOperatorTable.STDDEV_POP.getOperandTypeChecker(), STDDEV_POP.name(), false)); register( TAKE, (distinct, field, argList, ctx) -> { List newArgList = argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList()); - return TransferUserDefinedAggFunction( + return createAggregateFunction( TakeAggFunction.class, "TAKE", UserDefinedFunctionUtils.getReturnTypeInferenceForArray(), List.of(field), newArgList, ctx.relBuilder); - }); + }, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.ANY.or( + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)), + false)); register( PERCENTILE_APPROX, @@ -1210,25 +1163,32 @@ void populate() { List newArgList = argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList()); newArgList.add(ctx.rexBuilder.makeFlag(field.getType().getSqlTypeName())); - return TransferUserDefinedAggFunction( + return createAggregateFunction( PercentileApproxFunction.class, "percentile_approx", ReturnTypes.ARG0_FORCE_NULLABLE, List.of(field), newArgList, ctx.relBuilder); - }); + }, + PPLTypeChecker.wrapComposite( + (CompositeOperandTypeChecker) + OperandTypes.NUMERIC_NUMERIC.or( + OperandTypes.family( + SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), + false)); register( INTERNAL_PATTERN, (distinct, field, argList, ctx) -> - TransferUserDefinedAggFunction( + createAggregateFunction( LogPatternAggFunction.class, "pattern", ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList), List.of(field), argList, - ctx.relBuilder)); + ctx.relBuilder), + null); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index b0002f9ba52..ff3e05b3d0c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -5,10 +5,6 @@ package org.opensearch.sql.opensearch.executor; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.convertRelDataTypeToExprType; -import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DISTINCT_COUNT_APPROX; - import java.security.AccessController; import java.security.PrivilegedAction; import java.sql.PreparedStatement; @@ -29,9 +25,15 @@ import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.sql.ast.statement.Explain.ExplainFormat; import org.opensearch.sql.calcite.CalcitePlanContext; import org.opensearch.sql.calcite.utils.CalciteToolsHelper.OpenSearchRelRunners; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -45,6 +47,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.PPLFuncImpTable; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchNodeClient; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; import org.opensearch.sql.opensearch.functions.DistinctCountApproxAggFunction; import org.opensearch.sql.opensearch.functions.GeoIpFunction; @@ -54,6 +57,7 @@ /** OpenSearch execution engine implementation. */ public class OpenSearchExecutionEngine implements ExecutionEngine { + private static final Logger logger = LogManager.getLogger(OpenSearchExecutionEngine.class); private final OpenSearchClient client; @@ -250,7 +254,7 @@ private void buildResultSet( exprType = ExprCoreType.UNDEFINED; } } else { - exprType = convertRelDataTypeToExprType(fieldType); + exprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(fieldType); } columns.add(new Column(columnName, null, exprType)); } @@ -261,20 +265,22 @@ private void buildResultSet( /** Registers opensearch-dependent functions */ private void registerOpenSearchFunctions() { - PPLFuncImpTable.FunctionImp geoIpImpl = - (builder, args) -> - builder.makeCall(new GeoIpFunction(client.getNodeClient()).toUDF("GEOIP"), args); - PPLFuncImpTable.INSTANCE.registerExternalFunction(BuiltinFunctionName.GEOIP, geoIpImpl); + if (client instanceof OpenSearchNodeClient) { + SqlUserDefinedFunction geoIpFunction = + new GeoIpFunction(client.getNodeClient()).toUDF("GEOIP"); + PPLFuncImpTable.INSTANCE.registerExternalOperator(BuiltinFunctionName.GEOIP, geoIpFunction); + } else { + logger.info( + "Function [GEOIP] not registered: incompatible client type {}", + client.getClass().getName()); + } - PPLFuncImpTable.INSTANCE.registerExternalAggFunction( - DISTINCT_COUNT_APPROX, - (distinct, field, argList, ctx) -> - TransferUserDefinedAggFunction( - DistinctCountApproxAggFunction.class, - "APPROX_DISTINCT_COUNT", - ReturnTypes.BIGINT_FORCE_NULLABLE, - List.of(field), - argList, - ctx.relBuilder)); + SqlUserDefinedAggFunction approxDistinctCountFunction = + UserDefinedFunctionUtils.createUserDefinedAggFunction( + DistinctCountApproxAggFunction.class, + "APPROX_DISTINCT_COUNT", + ReturnTypes.BIGINT_FORCE_NULLABLE); + PPLFuncImpTable.INSTANCE.registerExternalAggOperator( + BuiltinFunctionName.DISTINCT_COUNT_APPROX, approxDistinctCountFunction); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java index 977805f9d1a..ef619637709 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java @@ -221,4 +221,73 @@ public void testLog2WithWrongArgShouldThrow() { verifyErrorMessageContains( wrongArgException, "LOG2 function expects {[INTEGER],[DOUBLE]}, but got [STRING,STRING]"); } + + @Test + public void testAvgWithWrongArgType() { + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats avg(ENAME) as avg_name")); + verifyErrorMessageContains( + e, "Aggregation function AVG expects field type {[INTEGER],[DOUBLE]}, but got [STRING]"); + } + + @Test + public void testVarsampWithWrongArgType() { + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats var_samp(ENAME) as varsamp_name")); + verifyErrorMessageContains( + e, + "Aggregation function VARSAMP expects field type {[INTEGER],[DOUBLE]}, but got [STRING]"); + } + + @Test + public void testVarpopWithWrongArgType() { + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats var_pop(ENAME) as varpop_name")); + verifyErrorMessageContains( + e, "Aggregation function VARPOP expects field type {[INTEGER],[DOUBLE]}, but got [STRING]"); + } + + @Test + public void testStddevSampWithWrongArgType() { + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats stddev_samp(ENAME) as stddev_name")); + verifyErrorMessageContains( + e, + "Aggregation function STDDEV_SAMP expects field type {[INTEGER],[DOUBLE]}, but got" + + " [STRING]"); + } + + @Test + public void testStddevPopWithWrongArgType() { + Exception e = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats stddev_pop(ENAME) as stddev_name")); + verifyErrorMessageContains( + e, + "Aggregation function STDDEV_POP expects field type {[INTEGER],[DOUBLE]}, but got" + + " [STRING]"); + } + + @Test + public void testPercentileApproxWithWrongArgType() { + // First argument should be numeric + Exception e1 = + Assert.assertThrows( + ExpressionEvaluationException.class, + () -> getRelNode("source=EMP | stats percentile_approx(ENAME, 50) as percentile")); + verifyErrorMessageContains( + e1, + "Aggregation function PERCENTILE_APPROX expects field type and additional arguments" + + " {[INTEGER,INTEGER],[INTEGER,DOUBLE],[DOUBLE,INTEGER],[DOUBLE,DOUBLE],[INTEGER,INTEGER,INTEGER],[INTEGER,INTEGER,DOUBLE],[INTEGER,DOUBLE,INTEGER],[INTEGER,DOUBLE,DOUBLE],[DOUBLE,INTEGER,INTEGER],[DOUBLE,INTEGER,DOUBLE],[DOUBLE,DOUBLE,INTEGER],[DOUBLE,DOUBLE,DOUBLE]}," + + " but got [STRING,INTEGER]"); + } }