diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java index 922f654f8be..0b023777f1c 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java @@ -10,11 +10,6 @@ import static org.apache.calcite.rex.RexWindowBounds.UNBOUNDED_PRECEDING; import static org.apache.calcite.rex.RexWindowBounds.following; import static org.apache.calcite.rex.RexWindowBounds.preceding; -import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_POP_NULLABLE; -import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE; -import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE; -import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE; -import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction; import com.google.common.collect.ImmutableList; import java.util.ArrayList; @@ -25,7 +20,6 @@ import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.rex.RexWindowBound; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -37,9 +31,8 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.calcite.CalcitePlanContext; -import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction; -import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.PPLFuncImpTable; public interface PlanUtils { @@ -232,56 +225,7 @@ static RelBuilder.AggCall makeAggCall( boolean distinct, RexNode field, List argList) { - switch (functionName) { - case MAX: - return context.relBuilder.max(field); - case MIN: - return context.relBuilder.min(field); - case AVG: - return context.relBuilder.avg(distinct, null, field); - case COUNT: - return context.relBuilder.count( - distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)); - case SUM: - return context.relBuilder.sum(distinct, null, field); - // case MEAN: - // throw new UnsupportedOperationException("MEAN is not supported in PPL"); - // case STDDEV: - // return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV, - // field); - case VARSAMP: - return context.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field); - case VARPOP: - return context.relBuilder.aggregateCall(VAR_POP_NULLABLE, field); - case STDDEV_POP: - return context.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field); - case STDDEV_SAMP: - return context.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field); - // case PERCENTILE_APPROX: - // return - // context.relBuilder.aggregateCall(SqlStdOperatorTable.PERCENTILE_CONT, field); - case TAKE: - return TransferUserDefinedAggFunction( - TakeAggFunction.class, - "TAKE", - UserDefinedFunctionUtils.getReturnTypeInferenceForArray(), - List.of(field), - argList, - context.relBuilder); - case PERCENTILE_APPROX: - List newArgList = new ArrayList<>(argList); - newArgList.add(context.rexBuilder.makeFlag(field.getType().getSqlTypeName())); - return TransferUserDefinedAggFunction( - PercentileApproxFunction.class, - "percentile_approx", - ReturnTypes.ARG0_FORCE_NULLABLE, - List.of(field), - newArgList, - context.relBuilder); - default: - throw new UnsupportedOperationException( - "Unexpected aggregation: " + functionName.getName().getFunctionName()); - } + return PPLFuncImpTable.INSTANCE.resolveAgg(functionName, distinct, field, argList, context); } /** Get all uniq input references from a RexNode. */ diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index c85ff86271e..ab9fe651192 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -76,7 +76,7 @@ public static RelBuilder.AggCall TransferUserDefinedAggFunction( return relBuilder.aggregateCall(sqlUDAF, addArgList); } - static SqlReturnTypeInference getReturnTypeInferenceForArray() { + public static SqlReturnTypeInference getReturnTypeInferenceForArray() { return opBinding -> { RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index aa99baee61b..58f63d2419c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -195,6 +195,7 @@ public enum BuiltinFunctionName { TAKE(FunctionName.of("take")), // t-digest percentile which is used in OpenSearch core by default. PERCENTILE_APPROX(FunctionName.of("percentile_approx")), + DISTINCT_COUNT_APPROX(FunctionName.of("distinct_count_approx")), // Not always an aggregation query NESTED(FunctionName.of("nested")), @@ -336,6 +337,7 @@ public enum BuiltinFunctionName { .put("take", BuiltinFunctionName.TAKE) .put("percentile", BuiltinFunctionName.PERCENTILE_APPROX) .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) + .put("distinct_count_approx", BuiltinFunctionName.DISTINCT_COUNT_APPROX) .build(); private static final Map WINDOW_FUNC_MAPPING = diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 5a66bd8d58d..fc830a9e34a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -7,7 +7,12 @@ import static org.apache.calcite.sql.SqlJsonConstructorNullClause.NULL_ON_NULL; import static org.apache.calcite.sql.type.SqlTypeFamily.IGNORE; +import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_POP_NULLABLE; +import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE; +import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE; +import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName; +import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ABS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; @@ -18,6 +23,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.ASIN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ATAN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ATAN2; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.AVG; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CBRT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CEIL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CEILING; @@ -29,6 +35,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.CONVERT_TZ; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COS; import static org.opensearch.sql.expression.function.BuiltinFunctionName.COT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.COUNT; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CRC32; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CURDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.CURRENT_DATE; @@ -102,8 +109,10 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTRIM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAKEDATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAKETIME; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MD5; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MICROSECOND; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_DAY; import static org.opensearch.sql.expression.function.BuiltinFunctionName.MINUTE_OF_HOUR; @@ -119,6 +128,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOW; import static org.opensearch.sql.expression.function.BuiltinFunctionName.NULLIF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.OR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERCENTILE_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_DIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PI; @@ -143,6 +153,8 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SPAN; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SQRT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV_POP; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.STDDEV_SAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.STRCMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.STR_TO_DATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBDATE; @@ -150,7 +162,9 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTRING; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTIME; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.SYSDATE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.TAKE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIME; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMEDIFF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMESTAMP; @@ -168,6 +182,8 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_DATE; import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIME; import static org.opensearch.sql.expression.function.BuiltinFunctionName.UTC_TIMESTAMP; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.VARPOP; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.VARSAMP; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEK; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEKDAY; import static org.opensearch.sql.expression.function.BuiltinFunctionName.WEEKOFYEAR; @@ -176,6 +192,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.YEAR; import static org.opensearch.sql.expression.function.BuiltinFunctionName.YEARWEEK; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; import java.util.ArrayList; @@ -186,6 +203,7 @@ import java.util.Objects; import java.util.Optional; import java.util.StringJoiner; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -199,22 +217,35 @@ import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SameOperandTypeChecker; import org.apache.calcite.sql.type.SqlOperandTypeChecker; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; +import org.apache.calcite.tools.RelBuilder; import org.apache.commons.lang3.function.TriFunction; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.calcite.CalcitePlanContext; +import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction; +import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction; import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; +import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.executor.QueryType; public class PPLFuncImpTable { private static final Logger logger = LogManager.getLogger(PPLFuncImpTable.class); + /** A lambda function interface which could apply parameters to get AggCall. */ + @FunctionalInterface + public interface AggHandler { + RelBuilder.AggCall apply( + boolean distinct, RexNode field, List argList, CalcitePlanContext context); + } + public interface FunctionImp { RexNode resolve(RexBuilder builder, RexNode... args); @@ -271,7 +302,9 @@ default PPLTypeChecker getTypeChecker() { static { final Builder builder = new Builder(); builder.populate(); - INSTANCE = new PPLFuncImpTable(builder); + final AggBuilder aggBuilder = new AggBuilder(); + aggBuilder.populate(); + INSTANCE = new PPLFuncImpTable(builder, aggBuilder); } /** @@ -290,12 +323,32 @@ default PPLTypeChecker getTypeChecker() { private final Map>> externalFunctionRegistry; - private PPLFuncImpTable(Builder builder) { + /** + * The registry for built-in agg functions. Agg Functions defined by the PPL specification, whose + * implementations are independent of any specific data storage, should be registered here + * internally. + */ + private final ImmutableMap aggFunctionRegistry; + + /** + * The external agg function registry. Agg Functions whose implementations depend on a specific + * data engine should be registered here. This reduces coupling between the core module and + * particular storage backends. + */ + private final Map aggExternalFunctionRegistry; + + private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) { final ImmutableMap.Builder>> mapBuilder = ImmutableMap.builder(); builder.map.forEach((k, v) -> mapBuilder.put(k, List.copyOf(v))); this.functionRegistry = ImmutableMap.copyOf(mapBuilder.build()); - this.externalFunctionRegistry = new HashMap<>(); + this.externalFunctionRegistry = new ConcurrentHashMap<>(); + + final ImmutableMap.Builder aggMapBuilder = + ImmutableMap.builder(); + aggBuilder.map.forEach(aggMapBuilder::put); + this.aggFunctionRegistry = ImmutableMap.copyOf(aggMapBuilder.build()); + this.aggExternalFunctionRegistry = new ConcurrentHashMap<>(); } /** @@ -307,12 +360,41 @@ private PPLFuncImpTable(Builder builder) { public void registerExternalFunction(BuiltinFunctionName functionName, FunctionImp functionImp) { CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), functionImp.getTypeChecker()); - if (externalFunctionRegistry.containsKey(functionName)) { - externalFunctionRegistry.get(functionName).add(Pair.of(signature, functionImp)); - } else { - externalFunctionRegistry.put( - functionName, new ArrayList<>(List.of(Pair.of(signature, functionImp)))); + externalFunctionRegistry.compute( + functionName, + (name, existingList) -> { + List> list = + existingList == null ? new ArrayList<>() : new ArrayList<>(existingList); + list.add(Pair.of(signature, functionImp)); + return list; + }); + } + + /** + * Register a function implementation from external services dynamically. + * + * @param functionName the name of the function, has to be defined in BuiltinFunctionName + * @param functionImp the implementation of the agg function + */ + public void registerExternalAggFunction( + BuiltinFunctionName functionName, AggHandler functionImp) { + aggExternalFunctionRegistry.put(functionName, functionImp); + } + + public RelBuilder.AggCall resolveAgg( + BuiltinFunctionName functionName, + boolean distinct, + RexNode field, + List argList, + CalcitePlanContext context) { + AggHandler handler = aggExternalFunctionRegistry.get(functionName); + if (handler == null) { + handler = aggFunctionRegistry.get(functionName); + } + if (handler == null) { + throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName)); } + return handler.apply(distinct, field, argList, context); } public RexNode resolve(final RexBuilder builder, final String functionName, RexNode... args) { @@ -864,4 +946,70 @@ public PPLTypeChecker getTypeChecker() { return PPLTypeChecker.family(booleanFamily, booleanFamily); } } + + private static class AggBuilder { + private final Map map = new HashMap<>(); + + void register(BuiltinFunctionName functionName, AggHandler aggHandler) { + map.put(functionName, aggHandler); + } + + void populate() { + register(MAX, (distinct, field, argList, ctx) -> ctx.relBuilder.max(field)); + register(MIN, (distinct, field, argList, ctx) -> ctx.relBuilder.min(field)); + + register(AVG, (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field)); + + register( + COUNT, + (distinct, field, argList, ctx) -> + ctx.relBuilder.count( + distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field))); + register(SUM, (distinct, field, argList, ctx) -> ctx.relBuilder.sum(distinct, null, field)); + + register( + VARSAMP, + (distinct, field, argList, ctx) -> + ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field)); + + register( + VARPOP, + (distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field)); + + register( + STDDEV_SAMP, + (distinct, field, argList, ctx) -> + ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field)); + + register( + STDDEV_POP, + (distinct, field, argList, ctx) -> + ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field)); + + register( + TAKE, + (distinct, field, argList, ctx) -> + TransferUserDefinedAggFunction( + TakeAggFunction.class, + "TAKE", + UserDefinedFunctionUtils.getReturnTypeInferenceForArray(), + List.of(field), + argList, + ctx.relBuilder)); + + register( + PERCENTILE_APPROX, + (distinct, field, argList, ctx) -> { + List newArgList = new ArrayList<>(argList); + newArgList.add(ctx.rexBuilder.makeFlag(field.getType().getSqlTypeName())); + return TransferUserDefinedAggFunction( + PercentileApproxFunction.class, + "percentile_approx", + ReturnTypes.ARG0_FORCE_NULLABLE, + List.of(field), + newArgList, + ctx.relBuilder); + }); + } + } } diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index c617adc2ebc..c92b15a9ace 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -238,6 +238,26 @@ Example:: | 2.8613807855648994 | +--------------------+ +DISTINCT_COUNT_APPROX +---------- + +Description +>>>>>>>>>>> + +Version: 3.1.0 + +Usage: DISTINCT_COUNT_APPROX(expr). Return the approximate distinct count value of the expr, using the hyperloglog++ algorithm. + +Example:: + + PPL> source=accounts | stats distinct_count_approx(gender); + fetched rows / total rows = 1/1 + +-------------------------------+ + | distinct_count_approx(gender) | + |-------------------------------| + | 2 | + +-------------------------------+ + TAKE ---------- diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java index 4a73d912f39..5bf69dda9bb 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; import org.json.JSONObject; -import org.junit.Ignore; import org.junit.jupiter.api.Test; import org.opensearch.client.Request; @@ -515,21 +514,35 @@ public void testCountDistinct() { } @Test - public void testCountDistinctWithAlias() { + public void testCountDistinctApprox() { JSONObject actual = executeQuery( String.format( - "source=%s | stats distinct_count(state) as dc by gender", TEST_INDEX_BANK)); - verifySchema(actual, schema("gender", "string"), schema("dc", "long")); + "source=%s | stats distinct_count_approx(state) by gender", TEST_INDEX_BANK)); + verifySchema( + actual, schema("gender", "string"), schema("distinct_count_approx(state)", "long")); verifyDataRows(actual, rows(3, "F"), rows(4, "M")); } - @Ignore("https://github.com/opensearch-project/sql/issues/3353") - public void testApproxCountDistinct() { + @Test + public void testCountDistinctApproxWithAlias() { JSONObject actual = executeQuery( String.format( - "source=%s | stats distinct_count_approx(state) by gender", TEST_INDEX_BANK)); + "source=%s | stats distinct_count_approx(state) as dca by gender", + TEST_INDEX_BANK)); + verifySchema(actual, schema("gender", "string"), schema("dca", "long")); + verifyDataRows(actual, rows(3, "F"), rows(4, "M")); + } + + @Test + public void testCountDistinctWithAlias() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats distinct_count(state) as dc by gender", TEST_INDEX_BANK)); + verifySchema(actual, schema("gender", "string"), schema("dc", "long")); + verifyDataRows(actual, rows(3, "F"), rows(4, "M")); } @Test diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index 6183aa109bf..9eaa0d777e2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -6,6 +6,8 @@ package org.opensearch.sql.opensearch.executor; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.convertRelDataTypeToExprType; +import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DISTINCT_COUNT_APPROX; import java.security.AccessController; import java.security.PrivilegedAction; @@ -25,6 +27,7 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.runtime.Hook; import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.sql.type.ReturnTypes; import org.opensearch.sql.ast.statement.Explain.ExplainFormat; import org.opensearch.sql.calcite.CalcitePlanContext; import org.opensearch.sql.calcite.utils.CalciteToolsHelper.OpenSearchRelRunners; @@ -41,6 +44,7 @@ import org.opensearch.sql.expression.function.PPLFuncImpTable; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; +import org.opensearch.sql.opensearch.functions.DistinctCountApproxAggFunction; import org.opensearch.sql.opensearch.functions.GeoIpFunction; import org.opensearch.sql.opensearch.util.JdbcOpenSearchDataTypeConvertor; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -246,5 +250,16 @@ private void registerOpenSearchFunctions() { (builder, args) -> builder.makeCall(new GeoIpFunction(client.getNodeClient()).toUDF("GEOIP"), args); PPLFuncImpTable.INSTANCE.registerExternalFunction(BuiltinFunctionName.GEOIP, geoIpImpl); + + PPLFuncImpTable.INSTANCE.registerExternalAggFunction( + DISTINCT_COUNT_APPROX, + (distinct, field, argList, ctx) -> + TransferUserDefinedAggFunction( + DistinctCountApproxAggFunction.class, + "APPROX_DISTINCT_COUNT", + ReturnTypes.BIGINT_FORCE_NULLABLE, + List.of(field), + argList, + ctx.relBuilder)); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/DistinctCountApproxAggFunction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/DistinctCountApproxAggFunction.java new file mode 100644 index 00000000000..0f822062a9e --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/DistinctCountApproxAggFunction.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.functions; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import org.opensearch.common.hash.MurmurHash3; +import org.opensearch.common.util.BigArrays; +import org.opensearch.search.aggregations.metrics.HyperLogLogPlusPlus; +import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; + +/** + * The Function depends on Opensearch Core's HLL++ algorithm to implement an approximate distinct + * count + */ +public class DistinctCountApproxAggFunction + implements UserDefinedAggFunction { + + @Override + public DistinctCountApproxAggFunction.HLLAccumulator init() { + return new DistinctCountApproxAggFunction.HLLAccumulator(); + } + + @Override + public Object result(DistinctCountApproxAggFunction.HLLAccumulator accumulator) { + return accumulator.value(); + } + + @Override + public DistinctCountApproxAggFunction.HLLAccumulator add( + DistinctCountApproxAggFunction.HLLAccumulator acc, Object... values) { + for (Object value : values) { + if (value != null) { + acc.add(value); + } + } + return acc; + } + + public static class HLLAccumulator implements UserDefinedAggFunction.Accumulator { + private final HyperLogLogPlusPlus hll; + + public HLLAccumulator() { + this.hll = + new HyperLogLogPlusPlus( + HyperLogLogPlusPlus.DEFAULT_PRECISION, BigArrays.NON_RECYCLING_INSTANCE, 1); + } + + public void add(Object value) { + hll.collect(0, hash(value)); + } + + @Override + public Object value(Object... args) { + return hll.cardinality(0); + } + } + + private static long hash(Object data) { + MurmurHash3.Hash128 hash = new MurmurHash3.Hash128(); + if (data == null) { + return 0L; + } + + byte[] bytes; + + if (data instanceof byte[]) { + bytes = (byte[]) data; + } else if (data instanceof String) { + bytes = ((String) data).getBytes(StandardCharsets.UTF_8); + } else if (data instanceof Number) { + long value = ((Number) data).longValue(); + bytes = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN).putLong(value).array(); + } else { + bytes = data.toString().getBytes(StandardCharsets.UTF_8); + } + + MurmurHash3.hash128(bytes, 0, bytes.length, 0, hash); + return hash.h1 ^ hash.h2; + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 9bfc33dd869..4b52b70464c 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -207,6 +207,7 @@ BIT_XOR_OP: '^'; AVG: 'AVG'; COUNT: 'COUNT'; DISTINCT_COUNT: 'DISTINCT_COUNT'; +DISTINCT_COUNT_APPROX: 'DISTINCT_COUNT_APPROX'; ESTDC: 'ESTDC'; ESTDC_ERROR: 'ESTDC_ERROR'; MAX: 'MAX'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index e71ba1d8d21..9d355667f24 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -393,7 +393,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | takeAggFunction # takeAggFunctionCall | percentileApproxFunction # percentileApproxFunctionCall ; @@ -1133,6 +1133,7 @@ keywordsCanBeId | statsFunctionName | windowFunctionName | DISTINCT_COUNT + | DISTINCT_COUNT_APPROX | ESTDC | ESTDC_ERROR | MEAN diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index a277cf7004c..6bc2c55c5b8 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -203,7 +203,8 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex @Override public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", visit(ctx.valueExpression()), true); + String funcName = ctx.DISTINCT_COUNT_APPROX() != null ? "distinct_count_approx" : "count"; + return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } @Override