From b39e7b6d12828fed21b553d06455e81b42645bc7 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 8 Jun 2021 14:39:27 -0700 Subject: [PATCH 01/23] Support construct AggregationResponseParser during Aggregator build stage (#108) * Support construct AggregationResponseParser during Aggregator build stage * modify the doc Signed-off-by: penghuo --- .../value/OpenSearchExprValueFactory.java | 16 +- .../OpenSearchAggregationResponseParser.java | 114 ----------- .../response/OpenSearchResponse.java | 2 +- .../agg/CompositeAggregationParser.java | 51 +++++ .../opensearch/response/agg/FilterParser.java | 38 ++++ .../opensearch/response/agg/MetricParser.java | 36 ++++ .../response/agg/MetricParserHelper.java | 56 +++++ .../agg/NoBucketAggregationParser.java | 41 ++++ .../OpenSearchAggregationResponseParser.java | 31 +++ .../response/agg/SingleValueParser.java | 39 ++++ .../opensearch/response/agg/StatsParser.java | 41 ++++ .../sql/opensearch/response/agg/Utils.java | 27 +++ .../opensearch/storage/OpenSearchIndex.java | 4 +- .../storage/OpenSearchIndexScan.java | 10 +- .../aggregation/AggregationQueryBuilder.java | 47 +++-- .../dsl/MetricAggregationBuilder.java | 93 ++++++--- .../response/AggregationResponseUtils.java | 4 + ...enSearchAggregationResponseParserTest.java | 192 ++++++++++++------ .../response/OpenSearchResponseTest.java | 42 ++-- .../AggregationQueryBuilderTest.java | 17 +- .../dsl/MetricAggregationBuilderTest.java | 2 +- 21 files changed, 650 insertions(+), 253 deletions(-) delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index 313347aec1a..001363b4767 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -63,7 +63,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; -import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.Setter; import org.opensearch.common.time.DateFormatters; import org.opensearch.sql.data.model.ExprBooleanValue; @@ -86,11 +86,11 @@ import org.opensearch.sql.opensearch.data.utils.Content; import org.opensearch.sql.opensearch.data.utils.ObjectContent; import org.opensearch.sql.opensearch.data.utils.OpenSearchJsonContent; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; /** * Construct ExprValue from OpenSearch response. */ -@AllArgsConstructor public class OpenSearchExprValueFactory { /** * The Mapping of Field and ExprType. @@ -98,6 +98,10 @@ public class OpenSearchExprValueFactory { @Setter private Map typeMapping; + @Getter + @Setter + private OpenSearchAggregationResponseParser parser; + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() .appendOptional(SQL_LITERAL_DATE_TIME_FORMAT) @@ -131,6 +135,14 @@ public class OpenSearchExprValueFactory { .put(OPENSEARCH_BINARY, c -> new OpenSearchExprBinaryValue(c.stringValue())) .build(); + /** + * Constructor of OpenSearchExprValueFactory. + */ + public OpenSearchExprValueFactory( + Map typeMapping) { + this.typeMapping = typeMapping; + } + /** * The struct construction has the following assumption. 1. The field has OpenSearch Object * data type. https://www.elastic.co/guide/en/elasticsearch/reference/current/object.html 2. The diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java deleted file mode 100644 index bb029cddb03..00000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package org.opensearch.sql.opensearch.response; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.experimental.UtilityClass; -import org.opensearch.search.aggregations.Aggregation; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.filter.Filter; -import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; - -/** - * AggregationResponseParser. - */ -@UtilityClass -public class OpenSearchAggregationResponseParser { - - /** - * Parse Aggregations as a list of field and value map. - * - * @param aggregations aggregations - * @return a list of field and value map - */ - public static List> parse(Aggregations aggregations) { - List aggregationList = aggregations.asList(); - ImmutableList.Builder> builder = new ImmutableList.Builder<>(); - Map noBucketMap = new HashMap<>(); - - for (Aggregation aggregation : aggregationList) { - if (aggregation instanceof CompositeAggregation) { - for (CompositeAggregation.Bucket bucket : - ((CompositeAggregation) aggregation).getBuckets()) { - builder.add(parse(bucket)); - } - } else { - noBucketMap.putAll(parseInternal(aggregation)); - } - - } - // Todo, there is no better way to difference the with/without bucket from aggregations result. - return noBucketMap.isEmpty() ? builder.build() : Collections.singletonList(noBucketMap); - } - - private static Map parse(CompositeAggregation.Bucket bucket) { - Map resultMap = new HashMap<>(); - // The NodeClient return InternalComposite - - // build pair - resultMap.putAll(bucket.getKey()); - - // build pair - for (Aggregation aggregation : bucket.getAggregations()) { - resultMap.putAll(parseInternal(aggregation)); - } - - return resultMap; - } - - private static Map parseInternal(Aggregation aggregation) { - Map resultMap = new HashMap<>(); - if (aggregation instanceof NumericMetricsAggregation.SingleValue) { - resultMap.put( - aggregation.getName(), - handleNanValue(((NumericMetricsAggregation.SingleValue) aggregation).value())); - } else if (aggregation instanceof Filter) { - // parse sub-aggregations for FilterAggregation response - List aggList = ((Filter) aggregation).getAggregations().asList(); - aggList.forEach(internalAgg -> { - Map intermediateMap = parseInternal(internalAgg); - resultMap.put(internalAgg.getName(), intermediateMap.get(internalAgg.getName())); - }); - } else { - throw new IllegalStateException("unsupported aggregation type " + aggregation.getType()); - } - return resultMap; - } - - @VisibleForTesting - protected static Object handleNanValue(double value) { - return Double.isNaN(value) ? null : value; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index fc7421aec36..156490d93a4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -103,7 +103,7 @@ public boolean isAggregationResponse() { */ public Iterator iterator() { if (isAggregationResponse()) { - return OpenSearchAggregationResponseParser.parse(aggregations).stream().map(entry -> { + return exprValueFactory.getParser().parse(aggregations).stream().map(entry -> { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); for (Map.Entry value : entry.entrySet()) { builder.put(value.getKey(), exprValueFactory.construct(value.getKey(), value.getValue())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java new file mode 100644 index 00000000000..00e8a5154c4 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; + +/** + * Composite Aggregation Parser which include composite aggregation and metric parsers. + */ +public class CompositeAggregationParser implements OpenSearchAggregationResponseParser { + + private final MetricParserHelper metricsParser; + + public CompositeAggregationParser(MetricParser... metricParserList) { + metricsParser = new MetricParserHelper(Arrays.asList(metricParserList)); + } + + public CompositeAggregationParser(List metricParserList) { + metricsParser = new MetricParserHelper(metricParserList); + } + + @Override + public List> parse(Aggregations aggregations) { + return ((CompositeAggregation) aggregations.asList().get(0)) + .getBuckets().stream().map(this::parse).collect(Collectors.toList()); + } + + private Map parse(CompositeAggregation.Bucket bucket) { + Map resultMap = new HashMap<>(); + resultMap.putAll(bucket.getKey()); + resultMap.putAll(metricsParser.parse(bucket.getAggregations())); + return resultMap; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java new file mode 100644 index 00000000000..cfcba82c183 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Map; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.bucket.filter.Filter; + +/** + * {@link Filter} Parser. + * The current use case is filter aggregation, e.g. avg(age) filter(balance>0). The filter parser + * do nothing and return the result from metricsParser. + */ +@Builder +public class FilterParser implements MetricParser { + + private final MetricParser metricsParser; + + @Getter private final String name; + + @Override + public Map parse(Aggregation aggregations) { + return metricsParser.parse(((Filter) aggregations).getAggregations().asList().get(0)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java new file mode 100644 index 00000000000..15f05e5b059 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Map; +import org.opensearch.search.aggregations.Aggregation; + +/** + * Metric Aggregation Parser. + */ +public interface MetricParser { + + /** + * Get the name of metric parser. + */ + String getName(); + + /** + * Parse the {@link Aggregation}. + * + * @param aggregation {@link Aggregation} + * @return the map between metric name and metric value. + */ + Map parse(Aggregation aggregation); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java new file mode 100644 index 00000000000..54b9305f492 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.sql.common.utils.StringUtils; + +/** + * Parse multiple metrics in one bucket. + */ +@RequiredArgsConstructor +public class MetricParserHelper { + + private final Map metricParserMap; + + public MetricParserHelper(List metricParserList) { + metricParserMap = + metricParserList.stream().collect(Collectors.toMap(MetricParser::getName, m -> m)); + } + + /** + * Parse {@link Aggregations}. + * + * @param aggregations {@link Aggregations} + * @return the map between metric name and metric value. + */ + public Map parse(Aggregations aggregations) { + Map resultMap = new HashMap<>(); + for (Aggregation aggregation : aggregations) { + if (metricParserMap.containsKey(aggregation.getName())) { + resultMap.putAll(metricParserMap.get(aggregation.getName()).parse(aggregation)); + } else { + throw new RuntimeException(StringUtils.format("couldn't parse field %s in aggregation " + + "response", aggregation.getName())); + } + } + return resultMap; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java new file mode 100644 index 00000000000..57560035232 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.opensearch.search.aggregations.Aggregations; + +/** + * No Bucket Aggregation Parser which include only metric parsers. + */ +public class NoBucketAggregationParser implements OpenSearchAggregationResponseParser { + + private final MetricParserHelper metricsParser; + + public NoBucketAggregationParser(MetricParser... metricParserList) { + metricsParser = new MetricParserHelper(Arrays.asList(metricParserList)); + } + + public NoBucketAggregationParser(List metricParserList) { + metricsParser = new MetricParserHelper(metricParserList); + } + + @Override + public List> parse(Aggregations aggregations) { + return Collections.singletonList(metricsParser.parse(aggregations)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java new file mode 100644 index 00000000000..3a19747ef3f --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.List; +import java.util.Map; +import org.opensearch.search.aggregations.Aggregations; + +/** + * OpenSearch Aggregation Response Parser. + */ +public interface OpenSearchAggregationResponseParser { + + /** + * Parse the OpenSearch Aggregation Response. + * @param aggregations Aggregations. + * @return aggregation result. + */ + List> parse(Aggregations aggregations); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java new file mode 100644 index 00000000000..7536a246617 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; + +import java.util.Collections; +import java.util.Map; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; + +/** + * {@link NumericMetricsAggregation.SingleValue} metric parser. + */ +@RequiredArgsConstructor +public class SingleValueParser implements MetricParser { + + @Getter private final String name; + + @Override + public Map parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), + handleNanValue(((NumericMetricsAggregation.SingleValue) agg).value())); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java new file mode 100644 index 00000000000..6cac2fbdc9a --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.ExtendedStats; + +/** + * {@link ExtendedStats} metric parser. + */ +@RequiredArgsConstructor +public class StatsParser implements MetricParser { + + private final Function valueExtractor; + + @Getter private final String name; + + @Override + public Map parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), handleNanValue(valueExtractor.apply((ExtendedStats) agg))); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java new file mode 100644 index 00000000000..53fd66ceef7 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public class Utils { + /** + * Utils to handle Nan Value. + * @return null if is Nan value. + */ + public static Object handleNanValue(double value) { + return Double.isNaN(value) ? null : value; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 74e966637fd..0198abe7a12 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.common.setting.Settings; @@ -43,6 +44,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; @@ -163,7 +165,7 @@ public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, } AggregationQueryBuilder builder = new AggregationQueryBuilder(new DefaultExpressionSerializer()); - List aggregationBuilder = + Pair, OpenSearchAggregationResponseParser> aggregationBuilder = builder.buildAggregationBuilder(node.getAggregatorList(), node.getGroupByList(), node.getSortList()); context.pushDownAggregation(aggregationBuilder); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java index 99b11c21a43..57980f23b90 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java @@ -40,6 +40,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -55,6 +56,7 @@ import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.storage.TableScanOperator; /** @@ -138,12 +140,14 @@ public void pushDown(QueryBuilder query) { /** * Push down aggregation to DSL request. - * @param aggregationBuilderList aggregation query. + * @param aggregationBuilder pair of aggregation query and aggregation parser. */ - public void pushDownAggregation(List aggregationBuilderList) { + public void pushDownAggregation( + Pair, OpenSearchAggregationResponseParser> aggregationBuilder) { SearchSourceBuilder source = request.getSourceBuilder(); - aggregationBuilderList.forEach(aggregationBuilder -> source.aggregation(aggregationBuilder)); + aggregationBuilder.getLeft().forEach(builder -> source.aggregation(builder)); source.size(0); + request.getExprValueFactory().setParser(aggregationBuilder.getRight()); } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java index a89ba042ee4..403f99e593b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -42,6 +42,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.type.ExprType; @@ -50,6 +51,10 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.BucketAggregationBuilder; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.MetricAggregationBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -82,25 +87,35 @@ public AggregationQueryBuilder( this.metricBuilder = new MetricAggregationBuilder(serializer); } - /** - * Build AggregationBuilder. - */ - public List buildAggregationBuilder( - List namedAggregatorList, - List groupByList, - List> sortList) { + /** Build AggregationBuilder. */ + public Pair, OpenSearchAggregationResponseParser> + buildAggregationBuilder( + List namedAggregatorList, + List groupByList, + List> sortList) { + + final Pair> metrics = + metricBuilder.build(namedAggregatorList); + if (groupByList.isEmpty()) { // no bucket - return ImmutableList - .copyOf(metricBuilder.build(namedAggregatorList).getAggregatorFactories()); + return Pair.of( + ImmutableList.copyOf(metrics.getLeft().getAggregatorFactories()), + new NoBucketAggregationParser(metrics.getRight())); } else { - final GroupSortOrder groupSortOrder = new GroupSortOrder(sortList); - return Collections.singletonList(AggregationBuilders.composite("composite_buckets", - bucketBuilder - .build(groupByList.stream().sorted(groupSortOrder).map(expr -> Pair.of(expr, - groupSortOrder.apply(expr))).collect(Collectors.toList()))) - .subAggregations(metricBuilder.build(namedAggregatorList)) - .size(AGGREGATION_BUCKET_SIZE)); + GroupSortOrder groupSortOrder = new GroupSortOrder(sortList); + return Pair.of( + Collections.singletonList( + AggregationBuilders.composite( + "composite_buckets", + bucketBuilder.build( + groupByList.stream() + .sorted(groupSortOrder) + .map(expr -> Pair.of(expr, groupSortOrder.apply(expr))) + .collect(Collectors.toList()))) + .subAggregations(metrics.getLeft()) + .size(AGGREGATION_BUCKET_SIZE)), + new CompositeAggregationParser(metrics.getRight())); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index f3807ae662f..0dbfec02c1d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -30,7 +30,9 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import java.util.ArrayList; import java.util.List; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; @@ -41,20 +43,22 @@ import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.FilterParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; /** - * Build the Metric Aggregation from {@link NamedAggregator}. + * Build the Metric Aggregation and List of {@link MetricParser} from {@link NamedAggregator}. */ public class MetricAggregationBuilder - extends ExpressionNodeVisitor { + extends ExpressionNodeVisitor, Object> { private final AggregationBuilderHelper> helper; private final FilterQueryBuilder filterBuilder; - public MetricAggregationBuilder( - ExpressionSerializer serializer) { + public MetricAggregationBuilder(ExpressionSerializer serializer) { this.helper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -65,55 +69,89 @@ public MetricAggregationBuilder( * @param aggregatorList aggregator list * @return AggregatorFactories.Builder */ - public AggregatorFactories.Builder build(List aggregatorList) { + public Pair> build( + List aggregatorList) { AggregatorFactories.Builder builder = new AggregatorFactories.Builder(); + List metricParserList = new ArrayList<>(); for (NamedAggregator aggregator : aggregatorList) { - builder.addAggregator(aggregator.accept(this, null)); + Pair pair = aggregator.accept(this, null); + builder.addAggregator(pair.getLeft()); + metricParserList.add(pair.getRight()); } - return builder; + return Pair.of(builder, metricParserList); } @Override - public AggregationBuilder visitNamedAggregator(NamedAggregator node, - Object context) { + public Pair visitNamedAggregator( + NamedAggregator node, Object context) { Expression expression = node.getArguments().get(0); Expression condition = node.getDelegated().condition(); String name = node.getName(); switch (node.getFunctionName().getFunctionName()) { case "avg": - return make(AggregationBuilders.avg(name), expression, condition, name); + return make( + AggregationBuilders.avg(name), + expression, + condition, + name, + new SingleValueParser(name)); case "sum": - return make(AggregationBuilders.sum(name), expression, condition, name); + return make( + AggregationBuilders.sum(name), + expression, + condition, + name, + new SingleValueParser(name)); case "count": return make( - AggregationBuilders.count(name), replaceStarOrLiteral(expression), condition, name); + AggregationBuilders.count(name), + replaceStarOrLiteral(expression), + condition, + name, + new SingleValueParser(name)); case "min": - return make(AggregationBuilders.min(name), expression, condition, name); + return make( + AggregationBuilders.min(name), + expression, + condition, + name, + new SingleValueParser(name)); case "max": - return make(AggregationBuilders.max(name), expression, condition, name); + return make( + AggregationBuilders.max(name), + expression, + condition, + name, + new SingleValueParser(name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); } } - private AggregationBuilder make(ValuesSourceAggregationBuilder builder, - Expression expression, Expression condition, String name) { + private Pair make( + ValuesSourceAggregationBuilder builder, + Expression expression, + Expression condition, + String name, + MetricParser parser) { ValuesSourceAggregationBuilder aggregationBuilder = helper.build(expression, builder::field, builder::script); if (condition != null) { - return makeFilterAggregation(aggregationBuilder, condition, name); + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); } - return aggregationBuilder; + return Pair.of(aggregationBuilder, parser); } /** - * Replace star or literal with OpenSearch metadata field "_index". Because: - * 1) Analyzer already converts * to string literal, literal check here can handle - * both COUNT(*) and COUNT(1). - * 2) Value count aggregation on _index counts all docs (after filter), therefore - * it has same semantics as COUNT(*) or COUNT(1). + * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already + * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) + * Value count aggregation on _index counts all docs (after filter), therefore it has same + * semantics as COUNT(*) or COUNT(1). + * * @param countArg count function argument * @return Reference to _index if literal, otherwise return original argument expression */ @@ -126,16 +164,15 @@ private Expression replaceStarOrLiteral(Expression countArg) { /** * Make builder to build FilterAggregation for aggregations with filter in the bucket. + * * @param subAggBuilder AggregationBuilder instance which the filter is applied to. * @param condition Condition expression in the filter. * @param name Name of the FilterAggregation instance to build. * @return {@link FilterAggregationBuilder}. */ - private FilterAggregationBuilder makeFilterAggregation(AggregationBuilder subAggBuilder, - Expression condition, String name) { - return AggregationBuilders - .filter(name, filterBuilder.build(condition)) + private FilterAggregationBuilder makeFilterAggregation( + AggregationBuilder subAggBuilder, Expression condition, String name) { + return AggregationBuilders.filter(name, filterBuilder.build(condition)) .subAggregation(subAggBuilder); } - } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java index c8ef8306352..173b33575c2 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java @@ -55,9 +55,11 @@ import org.opensearch.search.aggregations.bucket.terms.ParsedStringTerms; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; import org.opensearch.search.aggregations.metrics.ParsedAvg; +import org.opensearch.search.aggregations.metrics.ParsedExtendedStats; import org.opensearch.search.aggregations.metrics.ParsedMax; import org.opensearch.search.aggregations.metrics.ParsedMin; import org.opensearch.search.aggregations.metrics.ParsedSum; @@ -74,6 +76,8 @@ public class AggregationResponseUtils { .put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)) .put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)) .put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)) + .put(ExtendedStatsAggregationBuilder.NAME, + (p, c) -> ParsedExtendedStats.fromXContent(p, (String) c)) .put(StringTerms.NAME, (p, c) -> ParsedStringTerms.fromXContent(p, (String) c)) .put(LongTerms.NAME, (p, c) -> ParsedLongTerms.fromXContent(p, (String) c)) .put(DoubleTerms.NAME, (p, c) -> ParsedDoubleTerms.fromXContent(p, (String) c)) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java index b49bec4d442..120d48b6010 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java @@ -34,6 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.opensearch.response.AggregationResponseUtils.fromJson; +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -41,6 +43,13 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.search.aggregations.metrics.ExtendedStats; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.FilterParser; +import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchAggregationResponseParserTest { @@ -55,7 +64,10 @@ void no_bucket_one_metric_should_pass() { + " \"value\": 40\n" + " }\n" + "}"; - assertThat(parse(response), contains(entry("max", 40d))); + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max") + ); + assertThat(parse(parser, response), contains(entry("max", 40d))); } /** @@ -71,7 +83,11 @@ void no_bucket_two_metric_should_pass() { + " \"value\": 20\n" + " }\n" + "}"; - assertThat(parse(response), + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max"), + new SingleValueParser("min") + ); + assertThat(parse(parser, response), contains(entry("max", 40d,"min", 20d))); } @@ -104,7 +120,10 @@ void one_bucket_one_metric_should_pass() { + " ]\n" + " }\n" + "}"; - assertThat(parse(response), + + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + new SingleValueParser("avg")); + assertThat(parse(parser, response), containsInAnyOrder(ImmutableMap.of("type", "cost", "avg", 20d), ImmutableMap.of("type", "sale", "avg", 105d))); } @@ -139,7 +158,9 @@ void two_bucket_one_metric_should_pass() { + " ]\n" + " }\n" + "}"; - assertThat(parse(response), + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + new SingleValueParser("avg")); + assertThat(parse(parser, response), containsInAnyOrder(ImmutableMap.of("type", "cost", "region", "us", "avg", 20d), ImmutableMap.of("type", "sale", "region", "uk", "avg", 130d))); } @@ -147,81 +168,132 @@ void two_bucket_one_metric_should_pass() { @Test void unsupported_aggregation_should_fail() { String response = "{\n" - + " \"date_histogram#max\": {\n" + + " \"date_histogram#date_histogram\": {\n" + " \"value\": 40\n" + " }\n" + "}"; - IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> parse(response)); - assertEquals("unsupported aggregation type date_histogram", exception.getMessage()); + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max") + ); + RuntimeException exception = + assertThrows(RuntimeException.class, () -> parse(parser, response)); + assertEquals( + "couldn't parse field date_histogram in aggregation response", exception.getMessage()); } @Test void nan_value_should_return_null() { - assertNull(OpenSearchAggregationResponseParser.handleNanValue(Double.NaN)); + assertNull(handleNanValue(Double.NaN)); } - /** - * SELECT AVG(age) FILTER(WHERE age > 37) as filtered FROM accounts. - */ @Test void filter_aggregation_should_pass() { - String response = "{\n" - + " \"filter#filtered\" : {\n" - + " \"doc_count\" : 3,\n" - + " \"avg#filtered\" : {\n" - + " \"value\" : 37.0\n" - + " }\n" - + " }\n" - + " }"; - assertThat(parse(response), contains(entry("filtered", 37.0))); + String response = "{\n" + + " \"filter#filtered\" : {\n" + + " \"doc_count\" : 3,\n" + + " \"avg#filtered\" : {\n" + + " \"value\" : 37.0\n" + + " }\n" + + " }\n" + + " }"; + OpenSearchAggregationResponseParser parser = + new NoBucketAggregationParser( + FilterParser.builder() + .name("filtered") + .metricsParser(new SingleValueParser("filtered")) + .build()); + assertThat(parse(parser, response), contains(entry("filtered", 37.0))); } - /** - * SELECT AVG(age) FILTER(WHERE age > 37) as filtered FROM accounts GROUP BY gender. - */ @Test void filter_aggregation_group_by_should_pass() { - String response = "{\n" - + " \"composite#composite_buckets\":{\n" - + " \"after_key\":{\n" - + " \"gender\":\"m\"\n" - + " },\n" - + " \"buckets\":[\n" - + " {\n" - + " \"key\":{\n" - + " \"gender\":\"f\"\n" - + " },\n" - + " \"doc_count\":3,\n" - + " \"filter#filter\":{\n" - + " \"doc_count\":1,\n" - + " \"avg#avg\":{\n" - + " \"value\":39.0\n" - + " }\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\":{\n" - + " \"gender\":\"m\"\n" - + " },\n" - + " \"doc_count\":4,\n" - + " \"filter#filter\":{\n" - + " \"doc_count\":2,\n" - + " \"avg#avg\":{\n" - + " \"value\":36.0\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - assertThat(parse(response), containsInAnyOrder( + String response = "{\n" + + " \"composite#composite_buckets\":{\n" + + " \"after_key\":{\n" + + " \"gender\":\"m\"\n" + + " },\n" + + " \"buckets\":[\n" + + " {\n" + + " \"key\":{\n" + + " \"gender\":\"f\"\n" + + " },\n" + + " \"doc_count\":3,\n" + + " \"filter#filter\":{\n" + + " \"doc_count\":1,\n" + + " \"avg#avg\":{\n" + + " \"value\":39.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\":{\n" + + " \"gender\":\"m\"\n" + + " },\n" + + " \"doc_count\":4,\n" + + " \"filter#filter\":{\n" + + " \"doc_count\":2,\n" + + " \"avg#avg\":{\n" + + " \"value\":36.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + FilterParser.builder() + .name("filter") + .metricsParser(new SingleValueParser("avg")) + .build() + ); + assertThat(parse(parser, response), containsInAnyOrder( entry("gender", "f", "avg", 39.0), entry("gender", "m", "avg", 36.0))); } - public List> parse(String json) { - return OpenSearchAggregationResponseParser.parse(AggregationResponseUtils.fromJson(json)); + /** + * SELECT MAX(age) as max, STDDEV(age) as min FROM accounts. + */ + @Test + void no_bucket_max_and_extended_stats() { + String response = "{\n" + + " \"extended_stats#esField\": {\n" + + " \"count\": 2033,\n" + + " \"min\": 0,\n" + + " \"max\": 360,\n" + + " \"avg\": 45.47958681751107,\n" + + " \"sum\": 92460,\n" + + " \"sum_of_squares\": 22059450,\n" + + " \"variance\": 8782.295820390027,\n" + + " \"variance_population\": 8782.295820390027,\n" + + " \"variance_sampling\": 8786.61781636463,\n" + + " \"std_deviation\": 93.71390409320287,\n" + + " \"std_deviation_population\": 93.71390409320287,\n" + + " \"std_deviation_sampling\": 93.73696078049805,\n" + + " \"std_deviation_bounds\": {\n" + + " \"upper\": 232.9073950039168,\n" + + " \"lower\": -141.94822136889468,\n" + + " \"upper_population\": 232.9073950039168,\n" + + " \"lower_population\": -141.94822136889468,\n" + + " \"upper_sampling\": 232.95350837850717,\n" + + " \"lower_sampling\": -141.99433474348504\n" + + " }\n" + + " },\n" + + " \"max#maxField\": {\n" + + " \"value\": 360\n" + + " }\n" + + "}"; + + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("maxField"), + new StatsParser(ExtendedStats::getStdDeviation, "esField") + ); + assertThat(parse(parser, response), + contains(entry("esField", 93.71390409320287, "maxField", 360D))); + } + + public List> parse(OpenSearchAggregationResponseParser parser, String json) { + return parser.parse(fromJson(json)); } public Map entry(String name, Object value) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 184312afa10..c9cde4f6349 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -42,8 +42,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchResponse; import org.opensearch.search.SearchHit; @@ -53,6 +51,7 @@ import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @ExtendWith(MockitoExtension.class) class OpenSearchResponseTest { @@ -72,6 +71,9 @@ class OpenSearchResponseTest { @Mock private Aggregations aggregations; + @Mock + private OpenSearchAggregationResponseParser parser; + private ExprTupleValue exprTupleValue1 = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id1", new ExprIntegerValue(1))); @@ -147,26 +149,24 @@ void response_isnot_aggregation_when_aggregation_is_empty() { @Test void aggregation_iterator() { - try ( - MockedStatic mockedStatic = Mockito - .mockStatic(OpenSearchAggregationResponseParser.class)) { - when(OpenSearchAggregationResponseParser.parse(any())) - .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); - when(searchResponse.getAggregations()).thenReturn(aggregations); - when(factory.construct(anyString(), any())).thenReturn(new ExprIntegerValue(1)) - .thenReturn(new ExprIntegerValue(2)); - - int i = 0; - for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { - if (i == 0) { - assertEquals(exprTupleValue1, hit); - } else if (i == 1) { - assertEquals(exprTupleValue2, hit); - } else { - fail("More search hits returned than expected"); - } - i++; + when(parser.parse(any())) + .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); + when(searchResponse.getAggregations()).thenReturn(aggregations); + when(factory.getParser()).thenReturn(parser); + when(factory.construct(anyString(), any())) + .thenReturn(new ExprIntegerValue(1)) + .thenReturn(new ExprIntegerValue(2)); + + int i = 0; + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { + if (i == 0) { + assertEquals(exprTupleValue1, hit); + } else if (i == 1) { + assertEquals(exprTupleValue2, hit); + } else { + fail("More search hits returned than expected"); } + i++; } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java index 2242298bede..62643baad2e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java @@ -423,13 +423,18 @@ private String buildQuery(List namedAggregatorList, } @SneakyThrows - private String buildQuery(List namedAggregatorList, - List groupByList, - List> sortList) { + private String buildQuery( + List namedAggregatorList, + List groupByList, + List> sortList) { ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.readTree( - queryBuilder.buildAggregationBuilder(namedAggregatorList, groupByList, sortList).get(0) - .toString()) + return objectMapper + .readTree( + queryBuilder + .buildAggregationBuilder(namedAggregatorList, groupByList, sortList) + .getLeft() + .get(0) + .toString()) .toPrettyString(); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index b956a2f5a07..85b3bd5a65f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -211,7 +211,7 @@ void should_throw_exception_for_unsupported_exception() { private String buildQuery(List namedAggregatorList) { ObjectMapper objectMapper = new ObjectMapper(); return objectMapper.readTree( - aggregationBuilder.build(namedAggregatorList).toString()) + aggregationBuilder.build(namedAggregatorList).getLeft().toString()) .toPrettyString(); } } From 6f5350d7fe5b4e873a5392e124b3640302321804 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Tue, 8 Jun 2021 17:30:10 -0700 Subject: [PATCH 02/23] support distinct count aggregation Signed-off-by: chloe-zh --- .../sql/analysis/ExpressionAnalyzer.java | 1 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 4 ++++ .../sql/ast/expression/AggregateFunction.java | 14 +++++++++++ .../org/opensearch/sql/expression/DSL.java | 20 ++++++++++++++++ .../aggregation/AggregationState.java | 3 +++ .../expression/aggregation/Aggregator.java | 18 +++++++++++++- .../expression/aggregation/AvgAggregator.java | 6 +++++ .../aggregation/CountAggregator.java | 16 ++++++++++++- .../expression/aggregation/MaxAggregator.java | 6 +++++ .../expression/aggregation/MinAggregator.java | 6 +++++ .../expression/aggregation/SumAggregator.java | 6 +++++ .../sql/analysis/ExpressionAnalyzerTest.java | 8 +++++++ .../aggregation/AggregationTest.java | 7 ++++++ .../aggregation/AvgAggregatorTest.java | 7 ++++++ .../aggregation/CountAggregatorTest.java | 7 ++++++ .../aggregation/MaxAggregatorTest.java | 7 ++++++ .../aggregation/MinAggregatorTest.java | 7 ++++++ .../aggregation/SumAggregatorTest.java | 7 ++++++ .../correctness/queries/aggregation.txt | 3 ++- .../dsl/MetricAggregationBuilder.java | 24 +++++++++++++++++++ .../dsl/MetricAggregationBuilderTest.java | 17 +++++++++++++ sql/src/main/antlr/OpenSearchSQLParser.g4 | 7 ++++-- .../sql/sql/parser/AstExpressionBuilder.java | 9 +++++++ .../sql/parser/AstAggregationBuilderTest.java | 14 +++++++++++ 24 files changed, 219 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0f207c03741..3cc1dc95278 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -160,6 +160,7 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( builtinFunctionName.get().getName(), Collections.singletonList(arg)); + aggregator.distinct(node.getDistinct()); if (node.getCondition() != null) { aggregator.condition(analyze(node.getCondition(), context)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 7400ae20e6f..be8f7095db5 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -214,6 +214,10 @@ public static UnresolvedExpression filteredAggregate( return new AggregateFunction(func, field, condition); } + public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) { + return new AggregateFunction(func, field, true); + } + public static Function function(String funcName, UnresolvedExpression... funcArgs) { return new Function(funcName, Arrays.asList(funcArgs)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 8753e35ed9f..d11fdca3ac0 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -46,6 +46,7 @@ public class AggregateFunction extends UnresolvedExpression { private final UnresolvedExpression field; private final List argList; private UnresolvedExpression condition; + private Boolean distinct = false; /** * Constructor. @@ -72,6 +73,19 @@ public AggregateFunction(String funcName, UnresolvedExpression field, this.condition = condition; } + /** + * Constructor. + * @param funcName function name. + * @param field {@link UnresolvedExpression}. + * @param distinct field is distinct. + */ + public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { + this.funcName = funcName; + this.field = field; + this.argList = Collections.emptyList(); + this.distinct = distinct; + } + @Override public List getChild() { return Collections.singletonList(field); diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 31050afc871..93f86ca1f82 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -492,14 +492,26 @@ public Aggregator avg(Expression... expressions) { return aggregate(BuiltinFunctionName.AVG, expressions); } + public Aggregator distinctAvg(Expression... expressions) { + return avg(expressions).distinct(true); + } + public Aggregator sum(Expression... expressions) { return aggregate(BuiltinFunctionName.SUM, expressions); } + public Aggregator distinctSum(Expression... expressions) { + return sum(expressions).distinct(true); + } + public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator distinctCount(Expression... expressions) { + return count(expressions).distinct(true); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); @@ -519,10 +531,18 @@ public Aggregator min(Expression... expressions) { return aggregate(BuiltinFunctionName.MIN, expressions); } + public Aggregator distinctMin(Expression... expressions) { + return min(expressions).distinct(true); + } + public Aggregator max(Expression... expressions) { return aggregate(BuiltinFunctionName.MAX, expressions); } + public Aggregator distinctMax(Expression... expressions) { + return max(expressions).distinct(true); + } + private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) { return (FunctionExpression) repository.compile( functionName.getName(), Arrays.asList(expressions)); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index b1c29cb4a7a..ed3eca77513 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -26,6 +26,7 @@ package org.opensearch.sql.expression.aggregation; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.storage.bindingtuple.BindingTuple; @@ -37,4 +38,6 @@ public interface AggregationState { * Get {@link ExprValue} result. */ ExprValue result(); + + Set distinctSet(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index 80944172ea1..1e9af97e8b9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -64,6 +64,12 @@ public abstract class Aggregator @Getter @Accessors(fluent = true) protected Expression condition; + @Setter + @Getter + @Accessors(fluent = true) + protected Boolean distinct = false; + + /** * Create an {@link AggregationState} which will be used for aggregation. @@ -89,7 +95,8 @@ public abstract class Aggregator */ public S iterate(BindingTuple tuple, S state) { ExprValue value = getArguments().get(0).valueOf(tuple); - if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { + if (value.isNull() || value.isMissing() || !conditionValue(tuple) + || (distinct && duplicated(value, state))) { return state; } return iterate(value, state); @@ -121,4 +128,13 @@ public boolean conditionValue(BindingTuple tuple) { return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); } + private Boolean duplicated(ExprValue value, S state) { + for (ExprValue exprValue : state.distinctSet()) { + if (exprValue.compareTo(value) == 0) { + return true; + } + } + return false; + } + } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index 0ec0a02a3c1..ca1c32c24f5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -80,5 +81,10 @@ protected static class AvgState implements AggregationState { public ExprValue result() { return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 3195bf39413..36f78c50765 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -28,8 +28,11 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -50,7 +53,7 @@ public CountAggregator.CountState create() { @Override protected CountState iterate(ExprValue value, CountState state) { - state.count++; + state.count(value); return state; } @@ -64,14 +67,25 @@ public String toString() { */ protected static class CountState implements AggregationState { private int count; + private final Set set = new HashSet<>(); CountState() { this.count = 0; } + public void count(ExprValue value) { + set.add(value); + count++; + } + @Override public ExprValue result() { return ExprValueUtils.integerValue(count); } + + @Override + public Set distinctSet() { + return set; + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java index 11ad63093db..9a1d31caad1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java @@ -30,6 +30,7 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -74,5 +75,10 @@ public void max(ExprValue value) { public ExprValue result() { return maxResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index 46f69129ed8..a40315c8c0a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -30,6 +30,7 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -79,5 +80,10 @@ public void min(ExprValue value) { public ExprValue result() { return minResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index e658d21471e..afdf61c5ad0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -116,5 +117,10 @@ public void add(ExprValue value) { public ExprValue result() { return isEmptyCollection ? ExprNullValue.of() : sumResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index aa8d2b12dee..628842b4f09 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -292,6 +292,14 @@ public void aggregation_filter() { ); } + @Test + public void distinct_aggregation() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + AstDSL.distinctAggregate("count", qualifiedName("integer_value")) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index cc2825858a2..634a3a71920 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -116,6 +116,13 @@ public class AggregationTest extends ExpressionTestBase { "timestamp_value", "2040-01-01 07:00:00"))); + protected static List tuples_with_duplicates = + Arrays.asList( + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3))); + protected static List tuples_with_null_and_missing = Arrays.asList( ExprValueUtils.tupleValue( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java index 494d3cfab2e..33ea4c91233 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -61,6 +61,13 @@ public void filtered_avg() { assertEquals(3.0, result.value()); } + @Test + public void distinct_avg() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctAvg(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator avg"); + } + @Test public void avg_with_missing() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 0fdadfc692c..26a53539ace 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -129,6 +129,13 @@ public void filtered_count() { assertEquals(3, result.value()); } + @Test + public void distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)), diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java index 5aa9d3a7473..20cde543141 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java @@ -116,6 +116,13 @@ public void filtered_max() { assertEquals(3, result.value()); } + @Test + public void distinct_max() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctMax(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator max"); + } + @Test public void test_max_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java index 01e72b9cdac..e2927772621 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java @@ -116,6 +116,13 @@ public void filtered_min() { assertEquals(2, result.value()); } + @Test + public void distinct_min() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctMin(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator min"); + } + @Test public void test_min_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java index c0872ed4345..fdd24fb5b16 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java @@ -100,6 +100,13 @@ public void filtered_sum() { assertEquals(9, result.value()); } + @Test + public void distinct_sum() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctSum(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator sum"); + } + @Test public void sum_with_missing() { ExprValue result = diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 6c6e5b73a14..e7cd34451db 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -5,4 +5,5 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index f3807ae662f..fa06f01f518 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -35,6 +35,8 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; @@ -51,11 +53,15 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor { private final AggregationBuilderHelper> helper; + private final AggregationBuilderHelper cardinalityHelper; + private final AggregationBuilderHelper termsHelper; private final FilterQueryBuilder filterBuilder; public MetricAggregationBuilder( ExpressionSerializer serializer) { this.helper = new AggregationBuilderHelper<>(serializer); + this.cardinalityHelper = new AggregationBuilderHelper<>(serializer); + this.termsHelper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -78,8 +84,19 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node, Object context) { Expression expression = node.getArguments().get(0); Expression condition = node.getDelegated().condition(); + Boolean distinct = node.getDelegated().distinct(); String name = node.getName(); + if (distinct) { + switch (node.getFunctionName().getFunctionName()) { + case "count": + return make(AggregationBuilders.cardinality(name), expression); + default: + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + } + } + switch (node.getFunctionName().getFunctionName()) { case "avg": return make(AggregationBuilders.avg(name), expression, condition, name); @@ -108,6 +125,13 @@ private AggregationBuilder make(ValuesSourceAggregationBuilder builder, return aggregationBuilder; } + /** + * Make {@link CardinalityAggregationBuilder} for distinct count aggregations. + */ + private AggregationBuilder make(CardinalityAggregationBuilder builder, Expression expression) { + return cardinalityHelper.build(expression, builder::field, builder::script); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: * 1) Analyzer already converts * to string literal, literal check here can handle diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index b956a2f5a07..c15cb152a38 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -32,12 +32,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; +import java.util.Collections; import java.util.List; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; @@ -185,6 +187,21 @@ void should_build_max_aggregation() { new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Collections.singletonList(named("count(distinct name)", new CountAggregator( + Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0ad08781bfe..51c558a68e9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -336,8 +336,11 @@ caseFuncAlternative ; aggregateFunction - : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall - | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET + #regularAggregateFunctionCall + | functionName=aggregationFunctionName LR_BRACKET DISTINCT functionArg RR_BRACKET + #distinctAggregateFunctionCall + | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index b1630aed509..d267a8df4fa 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -212,6 +212,15 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( visitFunctionArg(ctx.functionArg())); } + @Override + public UnresolvedExpression visitDistinctAggregateFunctionCall( + OpenSearchSQLParser.DistinctAggregateFunctionCallContext ctx) { + return new AggregateFunction( + ctx.functionName.getText(), + visitFunctionArg(ctx.functionArg()), + true); + } + @Override public UnresolvedExpression visitCountStarFunctionCall(CountStarFunctionCallContext ctx) { return new AggregateFunction("COUNT", AllFields.of()); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 1d9516f8162..8e7adaaf039 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -167,6 +168,19 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() { alias("AVG(age)", aggregate("AVG", qualifiedName("age")))))); } + @Test + void can_build_distinct_aggregator() { + assertThat( + buildAggregation("SELECT COUNT(DISTINCT name), AVG(DISTINCT balance) FROM test"), + allOf( + hasGroupByItems(), + hasAggregators( + alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( + "name"))), + alias("AVG(DISTINCT balance)", distinctAggregate("AVG", qualifiedName( + "balance")))))); + } + @Test void should_build_nothing_if_no_group_by_and_no_aggregators_in_select() { assertNull(buildAggregation("SELECT name FROM test")); From e30b685149e555985ddb09c89e521551c3a2c78c Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 9 Jun 2021 12:22:32 -0700 Subject: [PATCH 03/23] fixed tests Signed-off-by: chloe-zh --- .../aggregation/AggregationState.java | 4 ++- .../expression/aggregation/Aggregator.java | 2 +- .../expression/aggregation/AvgAggregator.java | 6 ---- .../aggregation/CountAggregator.java | 3 +- .../expression/aggregation/MaxAggregator.java | 6 ---- .../expression/aggregation/MinAggregator.java | 5 --- .../expression/aggregation/SumAggregator.java | 5 --- .../aggregation/AggregatorStateTest.java | 35 +++++++++++++++++++ .../dsl/MetricAggregationBuilder.java | 3 +- .../dsl/MetricAggregationBuilderTest.java | 9 +++++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 + .../sql/ppl/parser/AstExpressionBuilder.java | 7 ++++ .../ppl/parser/AstExpressionBuilderTest.java | 30 ++++++++++++++++ 13 files changed, 89 insertions(+), 27 deletions(-) create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index ed3eca77513..378490e7663 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -39,5 +39,7 @@ public interface AggregationState { */ ExprValue result(); - Set distinctSet(); + default Set distinctValues() { + return Set.of(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index 1e9af97e8b9..a0a8037751e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -129,7 +129,7 @@ public boolean conditionValue(BindingTuple tuple) { } private Boolean duplicated(ExprValue value, S state) { - for (ExprValue exprValue : state.distinctSet()) { + for (ExprValue exprValue : state.distinctValues()) { if (exprValue.compareTo(value) == 0) { return true; } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index ca1c32c24f5..0ec0a02a3c1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Locale; -import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -81,10 +80,5 @@ protected static class AvgState implements AggregationState { public ExprValue result() { return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 36f78c50765..975a39a8cce 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -28,7 +28,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; - import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -84,7 +83,7 @@ public ExprValue result() { } @Override - public Set distinctSet() { + public Set distinctValues() { return set; } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java index 9a1d31caad1..11ad63093db 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java @@ -30,7 +30,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -75,10 +74,5 @@ public void max(ExprValue value) { public ExprValue result() { return maxResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index a40315c8c0a..e9672475bca 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -80,10 +80,5 @@ public void min(ExprValue value) { public ExprValue result() { return minResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index afdf61c5ad0..8de5ffb7a2d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -117,10 +117,5 @@ public void add(ExprValue value) { public ExprValue result() { return isEmptyCollection ? ExprNullValue.of() : sumResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java new file mode 100644 index 00000000000..338a254148f --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + * + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.data.model.ExprIntegerValue; + +public class AggregatorStateTest extends AggregationTest { + + @Test + void count_distinct_values() { + CountAggregator.CountState state = new CountAggregator.CountState(); + state.count(new ExprIntegerValue(1)); + assertFalse(state.distinctValues().isEmpty()); + } + + @Test + void default_distinct_values() { + AvgAggregator.AvgState state = new AvgAggregator.AvgState(); + assertTrue(state.distinctValues().isEmpty()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index fa06f01f518..e3b5be881ce 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -38,6 +38,7 @@ import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.LiteralExpression; @@ -92,7 +93,7 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node, case "count": return make(AggregationBuilders.cardinality(name), expression); default: - throw new IllegalStateException(String.format( + throw new ExpressionEvaluationException(String.format( "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index c15cb152a38..bacd5413b9d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -49,6 +49,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; @@ -202,6 +203,14 @@ void should_build_cardinality_aggregation() { Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); } + @Test + void should_throw_exception_for_unsupported_distinct_aggregator() { + assertThrows(ExpressionEvaluationException.class, + () -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator( + Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))), + "unsupported distinct aggregator avg"); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 77aecf5a44e..e8b54dab4da 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -135,6 +135,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall | COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression? RT_PRTHS #distinctCountFunctionCall | percentileAggFunction #percentileAggFunctionCall ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 9fdf8d636d5..ef314072760 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -35,6 +35,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; @@ -203,6 +204,12 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex return new AggregateFunction("count", AllFields.of()); } + @Override + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { + return new AggregateFunction("count", + ctx.valueExpression() != null ? visit(ctx.valueExpression()) : AllFields.of(), true); + } + @Override public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) { return new AggregateFunction(ctx.PERCENTILE().getText(), visit(ctx.aggField), diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 07ad97401e7..6bbfda7aef0 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -37,6 +37,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultSortFieldArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultStatsArgs; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.equalTo; import static org.opensearch.sql.ast.dsl.AstDSL.eval; @@ -376,6 +377,35 @@ public void testCountFuncCallExpr() { )); } + @Test + public void testDistinctCount() { + assertEqual("source=t | stats distinct_count(a)", + agg( + relation("t"), + exprList( + alias("distinct_count(a)", + distinctAggregate("count", field("a")))), + emptyList(), + emptyList(), + defaultStatsArgs())); + + assertEqual("source=t | stats dc() by b", + agg( + relation("t"), + exprList( + alias( + "dc()", + distinctAggregate("count", AllFields.of()) + ) + ), + emptyList(), + exprList( + alias("b", field("b")) + ), + defaultStatsArgs() + )); + } + @Test public void testEvalFuncCallExpr() { assertEqual("source=t | eval f=abs(a)", From 866d71d7837478d29c7b465094594234e1d5ed73 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 12:29:50 -0700 Subject: [PATCH 04/23] Merge remote-tracking branch 'upstream/develop' into issue/#100 Signed-off-by: chloe-zh # Conflicts: # opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java --- .../aggregation/dsl/MetricAggregationBuilder.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index aa116877dfa..a065b131967 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -99,7 +99,10 @@ public Pair visitNamedAggregator( if (distinct) { switch (node.getFunctionName().getFunctionName()) { case "count": - return make(AggregationBuilders.cardinality(name), expression); + return make( + AggregationBuilders.cardinality(name), + expression, + new SingleValueParser(name)); default: throw new ExpressionEvaluationException(String.format( "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); @@ -171,6 +174,12 @@ private AggregationBuilder make(CardinalityAggregationBuilder builder, Expressio return cardinalityHelper.build(expression, builder::field, builder::script); } + private Pair make(CardinalityAggregationBuilder builder, + Expression expression, + MetricParser parser) { + return Pair.of(cardinalityHelper.build(expression, builder::field, builder::script), parser); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) From 8a6ca202fbe7a0e1b66fe94619fa899bd99caf7b Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 12:30:34 -0700 Subject: [PATCH 05/23] update Signed-off-by: chloe-zh --- .../script/aggregation/dsl/MetricAggregationBuilder.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index a065b131967..9b6883e4ad4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -170,10 +170,6 @@ private Pair make( /** * Make {@link CardinalityAggregationBuilder} for distinct count aggregations. */ - private AggregationBuilder make(CardinalityAggregationBuilder builder, Expression expression) { - return cardinalityHelper.build(expression, builder::field, builder::script); - } - private Pair make(CardinalityAggregationBuilder builder, Expression expression, MetricParser parser) { From 43cbd17ca065644984648f97401c0e7a3a788758 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 13:12:54 -0700 Subject: [PATCH 06/23] updated user doc Signed-off-by: chloe-zh --- docs/user/dql/aggregations.rst | 13 +++++++++++++ docs/user/ppl/cmd/stats.rst | 15 +++++++++++++++ .../dsl/MetricAggregationBuilder.java | 18 +++++++++--------- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 98b565e1ecd..3c8577fcde7 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,6 +135,19 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +DISTINCT Aggregation +-------------------- + +To get the aggregation of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the aggregation function. Currently the distinct aggregation is only supported in ``COUNT`` aggregation. Example:: + + os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts; + fetched rows / total rows = 1/1 + +--------------------------+-----------------+ + | COUNT(DISTINCT gender) | COUNT(gender) | + |--------------------------+-----------------| + | 2 | 4 | + +--------------------------+-----------------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 3aca304fcd7..8a51811689a 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -134,3 +134,18 @@ PPL query:: | 36 | 32 | M | +------------+------------+----------+ +Example 7: Calculate the distinct count of a field +================================================== + +To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts. + +PPL query:: + + os> source=accounts | stats count(gender), distinct_count(gender); + fetched rows / total rows = 1/1 + +-----------------+--------------------------+ + | count(gender) | distinct_count(gender) | + |-----------------+--------------------------| + | 4 | 2 | + +-----------------+--------------------------+ + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 9b6883e4ad4..84127d9a880 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -37,7 +37,6 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -58,15 +57,16 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor, Object> { - private final AggregationBuilderHelper> helper; - private final AggregationBuilderHelper cardinalityHelper; - private final AggregationBuilderHelper termsHelper; + private final AggregationBuilderHelper> valuesSourceAggHelper; + private final AggregationBuilderHelper cardinalityAggHelper; private final FilterQueryBuilder filterBuilder; + /** + * Constructor. + */ public MetricAggregationBuilder(ExpressionSerializer serializer) { - this.helper = new AggregationBuilderHelper<>(serializer); - this.cardinalityHelper = new AggregationBuilderHelper<>(serializer); - this.termsHelper = new AggregationBuilderHelper<>(serializer); + this.valuesSourceAggHelper = new AggregationBuilderHelper<>(serializer); + this.cardinalityAggHelper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -158,7 +158,7 @@ private Pair make( String name, MetricParser parser) { ValuesSourceAggregationBuilder aggregationBuilder = - helper.build(expression, builder::field, builder::script); + valuesSourceAggHelper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( makeFilterAggregation(aggregationBuilder, condition, name), @@ -173,7 +173,7 @@ private Pair make( private Pair make(CardinalityAggregationBuilder builder, Expression expression, MetricParser parser) { - return Pair.of(cardinalityHelper.build(expression, builder::field, builder::script), parser); + return Pair.of(cardinalityAggHelper.build(expression, builder::field, builder::script), parser); } /** From 392c96ca0d4f1effd0b8b11dd266de7e2039c97a Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 21:09:14 -0700 Subject: [PATCH 07/23] Update: support only count for distinct aggregations Signed-off-by: chloe-zh --- .../sql/analysis/ExpressionAnalyzer.java | 4 +- .../org/opensearch/sql/ast/dsl/AstDSL.java | 7 +++- .../sql/ast/expression/AggregateFunction.java | 19 +++------ .../sql/data/model/ExprValueUtils.java | 8 ++++ .../aggregation/AggregationState.java | 5 --- .../expression/aggregation/Aggregator.java | 14 +------ .../aggregation/CountAggregator.java | 40 ++++++++++++++----- .../expression/aggregation/MinAggregator.java | 1 - .../expression/aggregation/SumAggregator.java | 1 - .../sql/analysis/ExpressionAnalyzerTest.java | 12 +++++- .../sql/data/model/ExprValueUtilsTest.java | 4 +- .../aggregation/AggregationTest.java | 8 ++-- .../aggregation/AggregatorStateTest.java | 35 ---------------- .../aggregation/CountAggregatorTest.java | 8 ++++ .../correctness/queries/aggregation.txt | 4 +- sql/src/main/antlr/OpenSearchSQLParser.g4 | 4 +- .../sql/sql/parser/AstExpressionBuilder.java | 10 ++--- .../sql/parser/AstAggregationBuilderTest.java | 16 +++++--- .../sql/parser/AstExpressionBuilderTest.java | 23 +++++++++++ 19 files changed, 121 insertions(+), 102 deletions(-) delete mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 3cc1dc95278..6de239bef15 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -161,8 +161,8 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext Aggregator aggregator = (Aggregator) repository.compile( builtinFunctionName.get().getName(), Collections.singletonList(arg)); aggregator.distinct(node.getDistinct()); - if (node.getCondition() != null) { - aggregator.condition(analyze(node.getCondition(), context)); + if (node.condition() != null) { + aggregator.condition(analyze(node.condition(), context)); } return aggregator; } else { diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index be8f7095db5..3b78483736c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -211,13 +211,18 @@ public static UnresolvedExpression aggregate( public static UnresolvedExpression filteredAggregate( String func, UnresolvedExpression field, UnresolvedExpression condition) { - return new AggregateFunction(func, field, condition); + return new AggregateFunction(func, field).condition(condition); } public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) { return new AggregateFunction(func, field, true); } + public static UnresolvedExpression filteredDistinctCount( + String func, UnresolvedExpression field, UnresolvedExpression condition) { + return new AggregateFunction(func, field, true).condition(condition); + } + public static Function function(String funcName, UnresolvedExpression... funcArgs) { return new Function(funcName, Arrays.asList(funcArgs)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index d11fdca3ac0..96bd33f1c92 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -28,9 +28,12 @@ import java.util.Collections; import java.util.List; +import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.common.utils.StringUtils; @@ -45,6 +48,8 @@ public class AggregateFunction extends UnresolvedExpression { private final String funcName; private final UnresolvedExpression field; private final List argList; + @Setter + @Accessors(fluent = true) private UnresolvedExpression condition; private Boolean distinct = false; @@ -59,20 +64,6 @@ public AggregateFunction(String funcName, UnresolvedExpression field) { this.argList = Collections.emptyList(); } - /** - * Constructor. - * @param funcName function name. - * @param field {@link UnresolvedExpression}. - * @param condition condition in aggregation filter. - */ - public AggregateFunction(String funcName, UnresolvedExpression field, - UnresolvedExpression condition) { - this.funcName = funcName; - this.field = field; - this.argList = Collections.emptyList(); - this.condition = condition; - } - /** * Constructor. * @param funcName function name. diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java index e2c5fb6a39f..b2172e54f16 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) { } } + public static Byte getByteValue(ExprValue exprValue) { + return exprValue.byteValue(); + } + + public static Short getShortValue(ExprValue exprValue) { + return exprValue.shortValue(); + } + public static Integer getIntegerValue(ExprValue exprValue) { return exprValue.integerValue(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index 378490e7663..b1c29cb4a7a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -26,7 +26,6 @@ package org.opensearch.sql.expression.aggregation; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.storage.bindingtuple.BindingTuple; @@ -38,8 +37,4 @@ public interface AggregationState { * Get {@link ExprValue} result. */ ExprValue result(); - - default Set distinctValues() { - return Set.of(); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index a0a8037751e..5328e11aadd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -69,8 +69,6 @@ public abstract class Aggregator @Accessors(fluent = true) protected Boolean distinct = false; - - /** * Create an {@link AggregationState} which will be used for aggregation. */ @@ -95,8 +93,7 @@ public abstract class Aggregator */ public S iterate(BindingTuple tuple, S state) { ExprValue value = getArguments().get(0).valueOf(tuple); - if (value.isNull() || value.isMissing() || !conditionValue(tuple) - || (distinct && duplicated(value, state))) { + if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { return state; } return iterate(value, state); @@ -128,13 +125,4 @@ public boolean conditionValue(BindingTuple tuple) { return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); } - private Boolean duplicated(ExprValue value, S state) { - for (ExprValue exprValue : state.distinctValues()) { - if (exprValue.compareTo(value) == 0) { - return true; - } - } - return false; - } - } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 975a39a8cce..34d064fe46d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -26,6 +26,16 @@ package org.opensearch.sql.expression.aggregation; +import static org.opensearch.sql.data.model.ExprValueUtils.getBooleanValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getByteValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getCollectionValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getDoubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getFloatValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getIntegerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getLongValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getShortValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getStringValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getTupleValue; import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.HashSet; @@ -52,7 +62,7 @@ public CountAggregator.CountState create() { @Override protected CountState iterate(ExprValue value, CountState state) { - state.count(value); + state.count(value, distinct); return state; } @@ -66,25 +76,35 @@ public String toString() { */ protected static class CountState implements AggregationState { private int count; - private final Set set = new HashSet<>(); + private final Set distinctValues = new HashSet<>(); CountState() { this.count = 0; } - public void count(ExprValue value) { - set.add(value); - count++; + public void count(ExprValue value, Boolean distinct) { + if (distinct) { + if (!duplicated(value)) { + distinctValues.add(value); + count++; + } + } else { + count++; + } } - @Override - public ExprValue result() { - return ExprValueUtils.integerValue(count); + private boolean duplicated(ExprValue value) { + for (ExprValue exprValue : distinctValues) { + if (value.compareTo(exprValue) == 0) { + return true; + } + } + return false; } @Override - public Set distinctValues() { - return set; + public ExprValue result() { + return ExprValueUtils.integerValue(count); } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index e9672475bca..46f69129ed8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -30,7 +30,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index 8de5ffb7a2d..e658d21471e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Locale; -import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 628842b4f09..06233fbc9b6 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -293,13 +293,23 @@ public void aggregation_filter() { } @Test - public void distinct_aggregation() { + public void distinct_count() { assertAnalyzeEqual( dsl.distinctCount(DSL.ref("integer_value", INTEGER)), AstDSL.distinctAggregate("count", qualifiedName("integer_value")) ); } + @Test + public void filtered_distinct_count() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), + AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1))) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java index a27d90f35d0..af2dbf22fc1 100644 --- a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java +++ b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java @@ -96,8 +96,8 @@ public class ExprValueUtilsTest { Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues)); private static List> numberValueExtractor = Arrays.asList( - ExprValue::byteValue, - ExprValue::shortValue, + ExprValueUtils::getByteValue, + ExprValueUtils::getShortValue, ExprValueUtils::getIntegerValue, ExprValueUtils::getLongValue, ExprValueUtils::getFloatValue, diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index 634a3a71920..2cce9018bf8 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -118,10 +118,10 @@ public class AggregationTest extends ExpressionTestBase { protected static List tuples_with_duplicates = Arrays.asList( - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3))); + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 4d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 3d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, "double_value", 2d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, "double_value", 1d))); protected static List tuples_with_null_and_missing = Arrays.asList( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java deleted file mode 100644 index 338a254148f..00000000000 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - * - */ - -package org.opensearch.sql.expression.aggregation; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import org.junit.jupiter.api.Test; -import org.opensearch.sql.data.model.ExprIntegerValue; - -public class AggregatorStateTest extends AggregationTest { - - @Test - void count_distinct_values() { - CountAggregator.CountState state = new CountAggregator.CountState(); - state.count(new ExprIntegerValue(1)); - assertFalse(state.distinctValues().isEmpty()); - } - - @Test - void default_distinct_values() { - AvgAggregator.AvgState state = new AvgAggregator.AvgState(); - assertTrue(state.distinctValues().isEmpty()); - } -} diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 26a53539ace..73bb37a3daf 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -136,6 +136,14 @@ public void distinct_count() { assertEquals(3, result.value()); } + @Test + public void filtered_distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))), + tuples_with_duplicates); + assertEquals(2, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)), diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index e7cd34451db..4fb07e33055 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -6,4 +6,6 @@ SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) +SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 51c558a68e9..ec8ef8bb1a9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -338,9 +338,9 @@ caseFuncAlternative aggregateFunction : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall - | functionName=aggregationFunctionName LR_BRACKET DISTINCT functionArg RR_BRACKET - #distinctAggregateFunctionCall | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + | COUNT LR_BRACKET DISTINCT (functionArg | STAR) RR_BRACKET + #distinctCountFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index d267a8df4fa..62f80eab8e6 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -43,6 +43,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CountStarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; @@ -171,7 +172,7 @@ public UnresolvedExpression visitShowDescribePattern( public UnresolvedExpression visitFilteredAggregationFunctionCall( OpenSearchSQLParser.FilteredAggregationFunctionCallContext ctx) { AggregateFunction agg = (AggregateFunction) visit(ctx.aggregateFunction()); - return new AggregateFunction(agg.getFuncName(), agg.getField(), visit(ctx.filterClause())); + return agg.condition(visit(ctx.filterClause())); } @Override @@ -213,11 +214,10 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( } @Override - public UnresolvedExpression visitDistinctAggregateFunctionCall( - OpenSearchSQLParser.DistinctAggregateFunctionCallContext ctx) { + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { return new AggregateFunction( - ctx.functionName.getText(), - visitFunctionArg(ctx.functionArg()), + ctx.COUNT().getText(), + ctx.functionArg() != null ? visitFunctionArg(ctx.functionArg()) : AllFields.of(), true); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 8e7adaaf039..437e8953fac 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -51,6 +51,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; @@ -171,14 +172,19 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() { @Test void can_build_distinct_aggregator() { assertThat( - buildAggregation("SELECT COUNT(DISTINCT name), AVG(DISTINCT balance) FROM test"), + buildAggregation("SELECT COUNT(DISTINCT name) FROM test group by age"), allOf( - hasGroupByItems(), + hasGroupByItems(alias("age", qualifiedName("age"))), hasAggregators( alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( - "name"))), - alias("AVG(DISTINCT balance)", distinctAggregate("AVG", qualifiedName( - "balance")))))); + "name")))))); + + assertThat( + buildAggregation("SELECT COUNT(DISTINCT *) FROM test"), + allOf( + hasGroupByItems(), + hasAggregators( + alias("COUNT(DISTINCT *)", distinctAggregate("COUNT", AllFields.of()))))); } @Test diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index a3c8494e7a7..8ddbe7feab6 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -57,6 +57,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -410,6 +411,28 @@ public void filteredAggregation() { ); } + @Test + public void distinctCount() { + assertEquals( + AstDSL.distinctAggregate("count", qualifiedName("name")), + buildExprAst("count(distinct name)") + ); + + assertEquals( + AstDSL.distinctAggregate("count", AllFields.of()), + buildExprAst("count(distinct *)") + ); + } + + @Test + public void filteredDistinctCount() { + assertEquals( + AstDSL.filteredDistinctCount("count", qualifiedName("name"), function( + ">", qualifiedName("age"), intLiteral(30))), + buildExprAst("count(distinct name) filter(where age > 30)") + ); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); From 078eae75e6647e3be5d7e5589d8099b954f599d1 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 22:09:09 -0700 Subject: [PATCH 08/23] Update doc; removed distinct start Signed-off-by: chloe-zh --- .../aggregation/AvgAggregatorTest.java | 7 ------- .../aggregation/MaxAggregatorTest.java | 7 ------- .../aggregation/MinAggregatorTest.java | 7 ------- .../aggregation/SumAggregatorTest.java | 7 ------- docs/user/dql/aggregations.rst | 19 ++++++++++++++++--- sql/src/main/antlr/OpenSearchSQLParser.g4 | 3 +-- .../sql/sql/parser/AstExpressionBuilder.java | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 6 +----- 8 files changed, 19 insertions(+), 39 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java index 33ea4c91233..494d3cfab2e 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -61,13 +61,6 @@ public void filtered_avg() { assertEquals(3.0, result.value()); } - @Test - public void distinct_avg() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctAvg(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator avg"); - } - @Test public void avg_with_missing() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java index 20cde543141..5aa9d3a7473 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java @@ -116,13 +116,6 @@ public void filtered_max() { assertEquals(3, result.value()); } - @Test - public void distinct_max() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctMax(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator max"); - } - @Test public void test_max_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java index e2927772621..01e72b9cdac 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java @@ -116,13 +116,6 @@ public void filtered_min() { assertEquals(2, result.value()); } - @Test - public void distinct_min() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctMin(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator min"); - } - @Test public void test_min_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java index fdd24fb5b16..c0872ed4345 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java @@ -100,13 +100,6 @@ public void filtered_sum() { assertEquals(9, result.value()); } - @Test - public void distinct_sum() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctSum(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator sum"); - } - @Test public void sum_with_missing() { ExprValue result = diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 3c8577fcde7..e332da7c144 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,10 +135,10 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. -DISTINCT Aggregation --------------------- +DISTINCT COUNT Aggregation +-------------------------- -To get the aggregation of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the aggregation function. Currently the distinct aggregation is only supported in ``COUNT`` aggregation. Example:: +To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example:: os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts; fetched rows / total rows = 1/1 @@ -247,3 +247,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w | 4 | 1 | +--------------+------------+ +Distinct count aggregate with FILTER +------------------------------------ + +The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example:: + + os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts + fetched rows / total rows = 1/1 + +------------------+ + | distinct_count | + |------------------| + | 3 | + +------------------+ + diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index ec8ef8bb1a9..05c1dffe9c9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -339,8 +339,7 @@ aggregateFunction : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall - | COUNT LR_BRACKET DISTINCT (functionArg | STAR) RR_BRACKET - #distinctCountFunctionCall + | COUNT LR_BRACKET DISTINCT functionArg RR_BRACKET #distinctCountFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 62f80eab8e6..8dda63b7505 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -217,7 +217,7 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { return new AggregateFunction( ctx.COUNT().getText(), - ctx.functionArg() != null ? visitFunctionArg(ctx.functionArg()) : AllFields.of(), + visitFunctionArg(ctx.functionArg()), true); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index 8ddbe7feab6..c7c0c9f6fd2 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -28,6 +28,7 @@ package org.opensearch.sql.sql.parser; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; @@ -417,11 +418,6 @@ public void distinctCount() { AstDSL.distinctAggregate("count", qualifiedName("name")), buildExprAst("count(distinct name)") ); - - assertEquals( - AstDSL.distinctAggregate("count", AllFields.of()), - buildExprAst("count(distinct *)") - ); } @Test From f10b28253b14e2f41474237764cfd752406213df Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 22:22:31 -0700 Subject: [PATCH 09/23] Removed unnecessary methods Signed-off-by: chloe-zh --- .../sql/ast/expression/AggregateFunction.java | 2 +- .../org/opensearch/sql/expression/DSL.java | 16 ---------------- .../dsl/MetricAggregationBuilder.java | 18 +++++++++--------- .../dsl/MetricAggregationBuilderTest.java | 3 +-- 4 files changed, 11 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 96bd33f1c92..e909c46ee7a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -68,7 +68,7 @@ public AggregateFunction(String funcName, UnresolvedExpression field) { * Constructor. * @param funcName function name. * @param field {@link UnresolvedExpression}. - * @param distinct field is distinct. + * @param distinct whether distinct field is specified or not. */ public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { this.funcName = funcName; diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 93f86ca1f82..50b10d55dd5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -492,18 +492,10 @@ public Aggregator avg(Expression... expressions) { return aggregate(BuiltinFunctionName.AVG, expressions); } - public Aggregator distinctAvg(Expression... expressions) { - return avg(expressions).distinct(true); - } - public Aggregator sum(Expression... expressions) { return aggregate(BuiltinFunctionName.SUM, expressions); } - public Aggregator distinctSum(Expression... expressions) { - return sum(expressions).distinct(true); - } - public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } @@ -531,18 +523,10 @@ public Aggregator min(Expression... expressions) { return aggregate(BuiltinFunctionName.MIN, expressions); } - public Aggregator distinctMin(Expression... expressions) { - return min(expressions).distinct(true); - } - public Aggregator max(Expression... expressions) { return aggregate(BuiltinFunctionName.MAX, expressions); } - public Aggregator distinctMax(Expression... expressions) { - return max(expressions).distinct(true); - } - private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) { return (FunctionExpression) repository.compile( functionName.getName(), Arrays.asList(expressions)); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 84127d9a880..84f2b016343 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -32,6 +32,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; @@ -97,15 +98,14 @@ public Pair visitNamedAggregator( String name = node.getName(); if (distinct) { - switch (node.getFunctionName().getFunctionName()) { - case "count": - return make( - AggregationBuilders.cardinality(name), - expression, - new SingleValueParser(name)); - default: - throw new ExpressionEvaluationException(String.format( - "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + if ("count".equals(node.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT))) { + return make( + AggregationBuilders.cardinality(name), + expression, + new SingleValueParser(name)); + } else { + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index e8f7fb79edb..e62a6c37c78 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -49,7 +49,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; @@ -205,7 +204,7 @@ void should_build_cardinality_aggregation() { @Test void should_throw_exception_for_unsupported_distinct_aggregator() { - assertThrows(ExpressionEvaluationException.class, + assertThrows(IllegalStateException.class, () -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator( Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))), "unsupported distinct aggregator avg"); From df81cfa5b8511f5b34e33df978334dc05cede808 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 23:43:03 -0700 Subject: [PATCH 10/23] update Signed-off-by: chloe-zh --- .../expression/aggregation/NamedAggregator.java | 4 ++++ .../dsl/MetricAggregationBuilder.java | 14 ++++++++++++-- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../sql/ppl/parser/AstExpressionBuilder.java | 3 +-- .../sql/ppl/parser/AstExpressionBuilderTest.java | 16 ---------------- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index a1bf2b99613..02e9c1e8296 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -54,6 +54,8 @@ public class NamedAggregator extends Aggregator { /** * NamedAggregator. + * The aggregator properties {@link #condition} and {@link #distinct} + * are inherited by named aggregator to avoid errors introduced by the property inconsistency. * * @param name name * @param delegated delegated @@ -64,6 +66,8 @@ public NamedAggregator( super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); this.name = name; this.delegated = delegated; + this.distinct = delegated.distinct; + this.condition = delegated.condition != null ? delegated.condition : null; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 84f2b016343..9a4b5138ae4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -40,7 +40,6 @@ import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.LiteralExpression; @@ -102,6 +101,8 @@ public Pair visitNamedAggregator( return make( AggregationBuilders.cardinality(name), expression, + condition, + name, new SingleValueParser(name)); } else { throw new IllegalStateException(String.format( @@ -172,8 +173,17 @@ private Pair make( */ private Pair make(CardinalityAggregationBuilder builder, Expression expression, + Expression condition, + String name, MetricParser parser) { - return Pair.of(cardinalityAggHelper.build(expression, builder::field, builder::script), parser); + CardinalityAggregationBuilder aggregationBuilder = + cardinalityAggHelper.build(expression, builder::field, builder::script); + if (condition != null) { + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); + } + return Pair.of(aggregationBuilder, parser); } /** diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index e8b54dab4da..6581e7cdbbc 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -135,7 +135,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall | COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression? RT_PRTHS #distinctCountFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS #distinctCountFunctionCall | percentileAggFunction #percentileAggFunctionCall ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index ef314072760..7da4f90cf0c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -206,8 +206,7 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex @Override public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", - ctx.valueExpression() != null ? visit(ctx.valueExpression()) : AllFields.of(), true); + return new AggregateFunction("count", visit(ctx.valueExpression()), true); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 6bbfda7aef0..b1e25420aa2 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -388,22 +388,6 @@ public void testDistinctCount() { emptyList(), emptyList(), defaultStatsArgs())); - - assertEqual("source=t | stats dc() by b", - agg( - relation("t"), - exprList( - alias( - "dc()", - distinctAggregate("count", AllFields.of()) - ) - ), - emptyList(), - exprList( - alias("b", field("b")) - ), - defaultStatsArgs() - )); } @Test From 8632f80f349260b09273c250f9b82bb1a1afd02f Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 11 Jun 2021 14:35:55 -0700 Subject: [PATCH 11/23] Impl stddev and variance function in SQL and PPL (#115) * impl variance frontend and backend * Support construct AggregationResponseParser during Aggregator build stage * add var and varp for PPL Signed-off-by: penghuo * add UT Signed-off-by: penghuo * fix UT Signed-off-by: penghuo * fix doc format Signed-off-by: penghuo * fix doc format Signed-off-by: penghuo * fix the doc Signed-off-by: penghuo * add stddev_samp and stddev_pop Signed-off-by: penghuo * fix UT coverage * address comments Signed-off-by: penghuo --- core/build.gradle | 1 + .../sql/analysis/ExpressionAnalyzer.java | 3 +- .../org/opensearch/sql/expression/DSL.java | 16 ++ .../aggregation/AggregatorFunction.java | 52 ++++ .../aggregation/StdDevAggregator.java | 110 +++++++++ .../aggregation/VarianceAggregator.java | 109 +++++++++ .../function/BuiltinFunctionName.java | 30 +++ .../sql/analysis/ExpressionAnalyzerTest.java | 8 + .../aggregation/StdDevAggregatorTest.java | 182 ++++++++++++++ .../aggregation/VarianceAggregatorTest.java | 190 +++++++++++++++ docs/user/dql/aggregations.rst | 222 ++++++++++++++++++ docs/user/dql/window.rst | 86 ++++++- docs/user/ppl/cmd/stats.rst | 168 +++++++++++++ .../correctness/queries/aggregation.txt | 6 +- .../resources/correctness/queries/window.txt | 12 + .../sql/opensearch/response/agg/Utils.java | 2 +- .../dsl/MetricAggregationBuilder.java | 30 +++ .../dsl/MetricAggregationBuilderTest.java | 73 ++++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 6 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../ppl/parser/AstExpressionBuilderTest.java | 84 +++++++ sql/src/main/antlr/OpenSearchSQLLexer.g4 | 7 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 21 ++ 24 files changed, 1414 insertions(+), 8 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java diff --git a/core/build.gradle b/core/build.gradle index 69acf5cef35..1c6c0c04817 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -51,6 +51,7 @@ dependencies { compile group: 'org.springframework', name: 'spring-beans', version: '5.2.5.RELEASE' compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' compile group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' compile project(':common') testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0f207c03741..d5c1538b77f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -155,7 +155,8 @@ public Expression visitNot(Not node, AnalysisContext context) { @Override public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { - Optional builtinFunctionName = BuiltinFunctionName.of(node.getFuncName()); + Optional builtinFunctionName = + BuiltinFunctionName.ofAggregation(node.getFuncName()); if (builtinFunctionName.isPresent()) { Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 31050afc871..560414592cd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -500,6 +500,22 @@ public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator varSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARSAMP, expressions); + } + + public Aggregator varPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARPOP, expressions); + } + + public Aggregator stddevSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions); + } + + public Aggregator stddevPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_POP, expressions); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index a6be7378f72..640ae8a9343 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -35,6 +35,10 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIME; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.google.common.collect.ImmutableMap; import java.util.Collections; @@ -68,6 +72,10 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(count()); repository.register(min()); repository.register(max()); + repository.register(varSamp()); + repository.register(varPop()); + repository.register(stddevSamp()); + repository.register(stddevPop()); } private static FunctionResolver avg() { @@ -159,4 +167,48 @@ private static FunctionResolver max() { .build() ); } + + private static FunctionResolver varSamp() { + FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> varianceSample(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver varPop() { + FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> variancePopulation(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver stddevSamp() { + FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevSample(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver stddevPop() { + FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevPopulation(arguments, DOUBLE)) + .build() + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java new file mode 100644 index 00000000000..0cd84944492 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * StandardDeviation Aggregator. + */ +public class StdDevAggregator extends Aggregator { + + private final boolean isSampleStdDev; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevPopulation(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevSample(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleStdDev true for sample standard deviation aggregator, false for population + * standard deviation aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public StdDevAggregator( + Boolean isSampleStdDev, List arguments, ExprCoreType returnType) { + super( + isSampleStdDev + ? BuiltinFunctionName.STDDEV_SAMP.getName() + : BuiltinFunctionName.STDDEV_POP.getName(), + arguments, + returnType); + this.isSampleStdDev = isSampleStdDev; + } + + @Override + public StdDevAggregator.StdDevState create() { + return new StdDevAggregator.StdDevState(isSampleStdDev); + } + + @Override + protected StdDevAggregator.StdDevState iterate(ExprValue value, + StdDevAggregator.StdDevState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments())); + } + + protected static class StdDevState implements AggregationState { + + private final StandardDeviation standardDeviation; + + private final List values = new ArrayList<>(); + + public StdDevState(boolean isSampleStdDev) { + this.standardDeviation = new StandardDeviation(isSampleStdDev); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java new file mode 100644 index 00000000000..bd9f0948f61 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Variance Aggregator. + */ +public class VarianceAggregator extends Aggregator { + + private final boolean isSampleVariance; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator variancePopulation(List arguments, + ExprCoreType returnType) { + return new VarianceAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator varianceSample(List arguments, + ExprCoreType returnType) { + return new VarianceAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleVariance true for sample variance aggregator, false for population variance + * aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public VarianceAggregator( + Boolean isSampleVariance, List arguments, ExprCoreType returnType) { + super( + isSampleVariance + ? BuiltinFunctionName.VARSAMP.getName() + : BuiltinFunctionName.VARPOP.getName(), + arguments, + returnType); + this.isSampleVariance = isSampleVariance; + } + + @Override + public VarianceState create() { + return new VarianceState(isSampleVariance); + } + + @Override + protected VarianceState iterate(ExprValue value, VarianceState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments())); + } + + protected static class VarianceState implements AggregationState { + + private final Variance variance; + + private final List values = new ArrayList<>(); + + public VarianceState(boolean isSampleVariance) { + this.variance = new Variance(isSampleVariance); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 0210161abed..24e65d4b5d5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -12,6 +12,7 @@ package org.opensearch.sql.expression.function; import com.google.common.collect.ImmutableMap; +import java.util.Locale; import java.util.Map; import java.util.Optional; import lombok.Getter; @@ -126,6 +127,14 @@ public enum BuiltinFunctionName { COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), + // sample variance + VARSAMP(FunctionName.of("var_samp")), + // population standard variance + VARPOP(FunctionName.of("var_pop")), + // sample standard deviation. + STDDEV_SAMP(FunctionName.of("stddev_samp")), + // population standard deviation. + STDDEV_POP(FunctionName.of("stddev_pop")), /** * Text Functions. @@ -189,7 +198,28 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } + private static final Map AGGREGATION_FUNC_MAPPING = + new ImmutableMap.Builder() + .put("max", BuiltinFunctionName.MAX) + .put("min", BuiltinFunctionName.MIN) + .put("avg", BuiltinFunctionName.AVG) + .put("count", BuiltinFunctionName.COUNT) + .put("sum", BuiltinFunctionName.SUM) + .put("var_pop", BuiltinFunctionName.VARPOP) + .put("var_samp", BuiltinFunctionName.VARSAMP) + .put("variance", BuiltinFunctionName.VARPOP) + .put("std", BuiltinFunctionName.STDDEV_POP) + .put("stddev", BuiltinFunctionName.STDDEV_POP) + .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) + .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) + .build(); + public static Optional of(String str) { return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); } + + public static Optional ofAggregation(String functionName) { + return Optional.ofNullable( + AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index aa8d2b12dee..8cb7288273c 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -292,6 +292,14 @@ public void aggregation_filter() { ); } + @Test + public void variance_mapto_varPop() { + assertAnalyzeEqual( + dsl.varPop(DSL.ref("integer_value", INTEGER)), + AstDSL.aggregate("variance", qualifiedName("integer_value")) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java new file mode 100644 index 00000000000..ef085a81d32 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class StdDevAggregatorTest extends AggregationTest { + + @Mock + Expression expression; + + @Mock + ExprValue tupleValue; + + @Mock + BindingTuple tuple; + + @Test + public void stddev_sample_field_expression() { + ExprValue result = + stddevSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.2909944487358056, result.value()); + } + + @Test + public void stddev_population_field_expression() { + ExprValue result = + stddevPop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.118033988749895, result.value()); + } + + @Test + public void stddev_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(12.909944487358056, result.value()); + } + + @Test + public void stddev_population_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(11.180339887498949, result.value()); + } + + @Test + public void filtered_stddev_sample() { + ExprValue result = + aggregation( + dsl.stddevSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_stddev_population() { + ExprValue result = + aggregation( + dsl.stddevPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.816496580927726, result.value()); + } + + @Test + public void stddev_sample_with_missing() { + ExprValue result = stddevSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_population_with_missing() { + ExprValue result = stddevPop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_null() { + ExprValue result = stddevSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_pop_with_null() { + ExprValue result = stddevPop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_all_missing_or_null() { + ExprValue result = stddevSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_pop_with_all_missing_or_null() { + ExprValue result = stddevPop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_sample_to_string() { + Aggregator aggregator = dsl.stddevSamp(ref("integer_value", INTEGER)); + assertEquals("stddev_samp(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_pop_to_string() { + Aggregator aggregator = dsl.stddevPop(ref("integer_value", INTEGER)); + assertEquals("stddev_pop(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.stddevSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("stddev_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue stddevSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevSamp(expression), mockTuples(value, values)); + } + + private ExprValue stddevPop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java new file mode 100644 index 00000000000..09fb8b8012f --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class VarianceAggregatorTest extends AggregationTest { + + @Mock Expression expression; + + @Mock ExprValue tupleValue; + + @Mock BindingTuple tuple; + + @Test + public void variance_sample_field_expression() { + ExprValue result = + varianceSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.6666666666666667, result.value()); + } + + @Test + public void variance_population_field_expression() { + ExprValue result = + variancePop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.25, result.value()); + } + + @Test + public void variance_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(166.66666666666666, result.value()); + } + + @Test + public void variance_pop_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(125d, result.value()); + } + + @Test + public void filtered_variance_sample() { + ExprValue result = + aggregation( + dsl.varSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_variance_pop() { + ExprValue result = + aggregation( + dsl.varPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.6666666666666666, result.value()); + } + + @Test + public void variance_sample_with_missing() { + ExprValue result = varianceSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_population_with_missing() { + ExprValue result = variancePop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_null() { + ExprValue result = varianceSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_pop_with_null() { + ExprValue result = variancePop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_all_missing_or_null() { + ExprValue result = varianceSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void variance_pop_with_all_missing_or_null() { + ExprValue result = variancePop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void valueOf() { + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> dsl.avg(ref("double_value", DOUBLE)).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: avg", exception.getMessage()); + } + + @Test + public void variance_sample_to_string() { + Aggregator avgAggregator = dsl.varSamp(ref("integer_value", INTEGER)); + assertEquals("var_samp(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_pop_to_string() { + Aggregator avgAggregator = dsl.varPop(ref("integer_value", INTEGER)); + assertEquals("var_pop(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.varSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("var_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue varianceSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varSamp(expression), mockTuples(value, values)); + } + + private ExprValue variancePop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 98b565e1ecd..1d6d1729815 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,6 +135,228 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example:: + + os> SELECT gender, count(*) as countV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+----------+ + | gender | countV | + |----------+----------| + | F | 1 | + | M | 3 | + +----------+----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example:: + + os> SELECT gender, sum(age) as sumV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------+ + | gender | sumV | + |----------+--------| + | F | 28 | + | M | 101 | + +----------+--------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example:: + + os> SELECT gender, avg(age) as avgV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------------------+ + | gender | avgV | + |----------+--------------------| + | F | 28.0 | + | M | 33.666666666666664 | + +----------+--------------------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example:: + + os> SELECT max(age) as maxV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | maxV | + |--------| + | 36 | + +--------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example:: + + os> SELECT min(age) as minV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | minV | + |--------| + | 28 | + +--------+ + +VAR_POP +------- + +Description +>>>>>>>>>>> + +Usage: VAR_POP(expr). Returns the population standard variance of expr. + +Example:: + + os> SELECT var_pop(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + +VAR_SAMP +-------- + +Description +>>>>>>>>>>> + +Usage: VAR_SAMP(expr). Returns the sample variance of expr. + +Example:: + + os> SELECT var_samp(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | varV | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VARIANCE +-------- + +Description +>>>>>>>>>>> + +Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym VAR_POP() function. + +Example:: + + os> SELECT variance(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Returns the population standard deviation of expr. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Returns the sample standard deviation of expr. + +Example:: + + os> SELECT stddev_samp(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +-------------------+ + | stddevV | + |-------------------| + | 3.304037933599835 | + +-------------------+ + +STD +--- + +Description +>>>>>>>>>>> + +Usage: STD(expr). Returns the population standard deviation of expr. STD() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV +------ + +Description +>>>>>>>>>>> + +Usage: STDDEV(expr). Returns the population standard deviation of expr. STDDEV() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + HAVING Clause ============= diff --git a/docs/user/dql/window.rst b/docs/user/dql/window.rst index 6d71f0637a5..feb2aaa44ee 100644 --- a/docs/user/dql/window.rst +++ b/docs/user/dql/window.rst @@ -20,7 +20,7 @@ A window function consists of 2 pieces: a function and a window definition. A wi There are three categories of common window functions: -1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG() and SUM(). +1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG(), SUM(), STDDEV_POP, STDDEV_SAMP, VAR_POP and VAR_SAMP. 2. **Ranking Functions**: ROW_NUMBER(), RANK(), DENSE_RANK(), PERCENT_RANK() and NTILE(). 3. **Analytic Functions**: CUME_DIST(), LAG() and LEAD(). @@ -146,6 +146,90 @@ Here is an example for ``SUM`` function:: | M | 39225 | 49091 | +----------+-----------+-------+ +STDDEV_POP +---------- + +Here is an example for ``STDDEV_POP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 753.0 | + | M | 39225 | 16177.091422406222 | + +----------+-----------+--------------------+ + +STDDEV_SAMP +----------- + +Here is an example for ``STDDEV_SAMP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1064.9028124669405 | + | M | 39225 | 19812.809753624886 | + +----------+-----------+--------------------+ + +VAR_POP +------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 567009.0 | + | M | 39225 | 261698286.88888893 | + +----------+-----------+--------------------+ + +VAR_SAMP +-------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------------------+ + | gender | balance | val | + |----------+-----------+-------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1134018.0 | + | M | 39225 | 392547430.3333334 | + +----------+-----------+-------------------+ + Ranking Functions ================= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 3aca304fcd7..f6dad255efc 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -38,6 +38,174 @@ stats ... [by-clause]... * aggregation: mandatory. A aggregation function. The argument of aggregation must be field. * by-clause: optional. The one or more fields to group the results by. **Default**: If no is specified, the stats command returns only one row, which is the aggregation over the entire result set. + +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example:: + + os> source=accounts | stats count(); + fetched rows / total rows = 1/1 + +-----------+ + | count() | + |-----------| + | 4 | + +-----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example:: + + os> source=accounts | stats sum(age) by gender; + fetched rows / total rows = 2/2 + +------------+----------+ + | sum(age) | gender | + |------------+----------| + | 28 | F | + | 101 | M | + +------------+----------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example:: + + os> source=accounts | stats avg(age) by gender; + fetched rows / total rows = 2/2 + +--------------------+----------+ + | avg(age) | gender | + |--------------------+----------| + | 28.0 | F | + | 33.666666666666664 | M | + +--------------------+----------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example:: + + os> source=accounts | stats max(age); + fetched rows / total rows = 1/1 + +------------+ + | max(age) | + |------------| + | 36 | + +------------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example:: + + os> source=accounts | stats min(age); + fetched rows / total rows = 1/1 + +------------+ + | min(age) | + |------------| + | 28 | + +------------+ + +VAR_SAMP +-------- + +Description +>>>>>>>>>>> + +Usage: VAR_SAMP(expr). Returns the sample variance of expr. + +Example:: + + os> source=accounts | stats var_samp(age); + fetched rows / total rows = 1/1 + +--------------------+ + | var_samp(age) | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VAR_POP +------- + +Description +>>>>>>>>>>> + +Usage: VAR_POP(expr). Returns the population standard variance of expr. + +Example:: + + os> source=accounts | stats var_pop(age); + fetched rows / total rows = 1/1 + +----------------+ + | var_pop(age) | + |----------------| + | 8.1875 | + +----------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Return the sample standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_samp(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_samp(age) | + |--------------------| + | 3.304037933599835 | + +--------------------+ + +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Return the population standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_pop(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_pop(age) | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + Example 1: Calculate the count of events ======================================== diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 6c6e5b73a14..45aa658783b 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -5,4 +5,8 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index a8d134a2545..c3f27153229 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -9,10 +9,18 @@ SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboar SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MIN(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, SUM(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, AVG(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MAX(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MIN(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -20,6 +28,8 @@ SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dash SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -27,6 +37,8 @@ SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS nu SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT customer_gender, user, ROW_NUMBER() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, DENSE_RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java index 53fd66ceef7..28b9d41e833 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java @@ -19,7 +19,7 @@ public class Utils { /** * Utils to handle Nan Value. - * @return null if is Nan value. + * @return null if is Nan. */ public static Object handleNanValue(double value) { return Double.isNaN(value) ? null : value; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 0dbfec02c1d..3d402582888 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; @@ -46,6 +47,7 @@ import org.opensearch.sql.opensearch.response.agg.FilterParser; import org.opensearch.sql.opensearch.response.agg.MetricParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -124,6 +126,34 @@ public Pair visitNamedAggregator( condition, name, new SingleValueParser(name)); + case "var_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVarianceSampling,name)); + case "var_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVariancePopulation,name)); + case "stddev_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationSampling,name)); + case "stddev_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationPopulation,name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 85b3bd5a65f..95a23834754 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -35,6 +35,10 @@ import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; @@ -53,6 +57,7 @@ import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; +import org.opensearch.sql.expression.aggregation.VarianceAggregator; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -185,6 +190,74 @@ void should_build_max_aggregation() { new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_varPop_aggregation() { + assertEquals( + "{\n" + + " \"var_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_pop(age)", + variancePopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_varSamp_aggregation() { + assertEquals( + "{\n" + + " \"var_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_samp(age)", + varianceSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_stddevPop_aggregation() { + assertEquals( + "{\n" + + " \"stddev_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_pop(age)", + stddevPopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_stddevSamp_aggregation() { + assertEquals( + "{\n" + + " \"stddev_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_samp(age)", + stddevSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3874a0a50ea..cb665f6c887 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -151,8 +151,10 @@ STDEV: 'STDEV'; STDEVP: 'STDEVP'; SUM: 'SUM'; SUMSQ: 'SUMSQ'; -VAR: 'VAR'; -VARP: 'VARP'; +VAR_SAMP: 'VAR_SAMP'; +VAR_POP: 'VAR_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; +STDDEV_POP: 'STDDEV_POP'; PERCENTILE: 'PERCENTILE'; FIRST: 'FIRST'; LAST: 'LAST'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 77aecf5a44e..d552ad0756d 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -139,7 +139,7 @@ statsFunction ; statsFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR_SAMP | VAR_POP | STDDEV_SAMP | STDDEV_POP ; percentileAggFunction diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 07ad97401e7..71ef692abf1 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -335,6 +335,90 @@ public void testAggFuncCallExpr() { )); } + @Test + public void testVarAggregationShouldPass() { + assertEqual("source=t | stats var_samp(a) by b", + agg( + relation("t"), + exprList( + alias( + "var_samp(a)", + aggregate("var_samp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testVarpAggregationShouldPass() { + assertEqual("source=t | stats var_pop(a) by b", + agg( + relation("t"), + exprList( + alias( + "var_pop(a)", + aggregate("var_pop", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevAggregationShouldPass() { + assertEqual("source=t | stats stddev_samp(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_samp(a)", + aggregate("stddev_samp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevPAggregationShouldPass() { + assertEqual("source=t | stats stddev_pop(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_pop(a)", + aggregate("stddev_pop", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + @Test public void testPercentileAggFuncExpr() { assertEqual("source=t | stats percentile<1>(a)", diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 94f8e7c87a4..426c77cf06f 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -126,6 +126,13 @@ COUNT: 'COUNT'; MAX: 'MAX'; MIN: 'MIN'; SUM: 'SUM'; +VAR_POP: 'VAR_POP'; +VAR_SAMP: 'VAR_SAMP'; +VARIANCE: 'VARIANCE'; +STD: 'STD'; +STDDEV: 'STDDEV'; +STDDEV_POP: 'STDDEV_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; // Common function Keywords diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0ad08781bfe..18c75b94ffc 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -345,7 +345,7 @@ filterClause ; aggregationFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR_POP | VAR_SAMP | VARIANCE | STD | STDDEV | STDDEV_POP | STDDEV_SAMP ; mathematicalFunctionName diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index a3c8494e7a7..e4e8028f054 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -410,6 +410,27 @@ public void filteredAggregation() { ); } + @Test + public void canBuildVarSamp() { + assertEquals( + aggregate("var_samp", qualifiedName("age")), + buildExprAst("var_samp(age)")); + } + + @Test + public void canBuildVarPop() { + assertEquals( + aggregate("var_pop", qualifiedName("age")), + buildExprAst("var_pop(age)")); + } + + @Test + public void canBuildVariance() { + assertEquals( + aggregate("variance", qualifiedName("age")), + buildExprAst("variance(age)")); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); From 9ff27939f9c3532263ea3cfe23f2722c70482236 Mon Sep 17 00:00:00 2001 From: Chloe Date: Fri, 11 Jun 2021 15:45:20 -0700 Subject: [PATCH 12/23] Fix the aggregation filter missing in named aggregators (#123) * Take the condition expression as property to the named aggregator when wrapping the delegated aggregator Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * Added test case where filtered agg is not pushed down Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh --- .../aggregation/NamedAggregator.java | 3 ++ .../opensearch/sql/analysis/AnalyzerTest.java | 40 +++++++++++++++++++ .../org/opensearch/sql/sql/AggregationIT.java | 40 +++++++++++++++++++ .../queries/{subquries.txt => subqueries.txt} | 0 4 files changed, 83 insertions(+) create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java rename integ-test/src/test/resources/correctness/queries/{subquries.txt => subqueries.txt} (100%) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index a1bf2b99613..346bd2d28cd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -54,6 +54,8 @@ public class NamedAggregator extends Aggregator { /** * NamedAggregator. + * The aggregator properties {@link #condition} is inherited by named aggregator + * to avoid errors introduced by the property inconsistency. * * @param name name * @param delegated delegated @@ -64,6 +66,7 @@ public NamedAggregator( super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); this.name = name; this.delegated = delegated; + this.condition = delegated.condition; } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 9b42c70e32b..fc45f34ffeb 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -36,6 +36,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.compare; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -624,4 +625,43 @@ public void limit_offset() { ) ); } + + /** + * SELECT COUNT(NAME) FILTER(WHERE age > 1) FROM test. + * This test is to verify that the aggregator properties are taken + * when wrapping it to {@link org.opensearch.sql.expression.aggregation.NamedAggregator} + */ + @Test + public void named_aggregator_with_condition() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.relation("schema"), + ImmutableList.of( + DSL.named("count(string_value) filter(where integer_value > 1)", + dsl.count(DSL.ref("string_value", STRING)).condition(dsl.greater(DSL.ref( + "integer_value", INTEGER), DSL.literal(1)))) + ), + emptyList() + ), + DSL.named("count(string_value) filter(where integer_value > 1)", DSL.ref( + "count(string_value) filter(where integer_value > 1)", INTEGER)) + ), + AstDSL.project( + AstDSL.agg( + AstDSL.relation("schema"), + ImmutableList.of( + alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( + "count", qualifiedName("string_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1))))), + emptyList(), + emptyList(), + emptyList() + ), + AstDSL.alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( + "count", qualifiedName("string_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1)))) + ) + ); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java new file mode 100644 index 00000000000..3cbb222afe1 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + * + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class AggregationIT extends SQLIntegTestCase { + @Override + protected void init() throws Exception { + loadIndex(Index.BANK); + } + + @Test + void filteredAggregateWithSubquery() throws IOException { + JSONObject response = executeQuery( + "SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK + + ") AS a"); + verifySchema(response, schema("COUNT(*)", null, "integer")); + verifyDataRows(response, rows(3)); + } +} diff --git a/integ-test/src/test/resources/correctness/queries/subquries.txt b/integ-test/src/test/resources/correctness/queries/subqueries.txt similarity index 100% rename from integ-test/src/test/resources/correctness/queries/subquries.txt rename to integ-test/src/test/resources/correctness/queries/subqueries.txt From 94a045f742eeec9213f4c194347ff75ab34b75f2 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Fri, 11 Jun 2021 16:47:17 -0700 Subject: [PATCH 13/23] update Signed-off-by: chloe-zh --- .../src/test/resources/correctness/queries/aggregation.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index fa543c9c20b..b3dcc11bace 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,5 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From 11a9758f611b7974ef9166799695bc1fbb295a39 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Mon, 14 Jun 2021 11:21:01 -0700 Subject: [PATCH 14/23] modified comparison test Signed-off-by: chloe-zh --- .../src/test/resources/correctness/queries/aggregation.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index b3dcc11bace..d3bc194e2e0 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,5 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) as distinct_count FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From d5dc9eb93ea84810253ba1ab4a732359c73fff3e Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Tue, 15 Jun 2021 16:06:27 -0700 Subject: [PATCH 15/23] removed a comparison test and added it to aggregationIT Signed-off-by: chloe-zh --- .../java/org/opensearch/sql/sql/AggregationIT.java | 10 +++++++++- .../test/resources/correctness/queries/aggregation.txt | 1 - 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 3cbb222afe1..33cddc6f1f9 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -30,7 +30,15 @@ protected void init() throws Exception { } @Test - void filteredAggregateWithSubquery() throws IOException { + void filteredAggregatePushedDown() throws IOException { + JSONObject response = executeQuery( + "SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK); + verifySchema(response, schema("COUNT(*)", null, "integer")); + verifyDataRows(response, rows(3)); + } + + @Test + void filteredAggregateNotPushedDown() throws IOException { JSONObject response = executeQuery( "SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK + ") AS a"); diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index d3bc194e2e0..0c0648a9371 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,4 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) as distinct_count FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From 684a7421c72b31ccf51e42396fa526878898b48a Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Tue, 15 Jun 2021 16:42:13 -0700 Subject: [PATCH 16/23] added ppl IT test cases; added window function test cases Signed-off-by: chloe-zh --- .../java/org/opensearch/sql/ppl/StatsCommandIT.java | 13 +++++++++++++ .../test/resources/correctness/queries/window.txt | 3 +++ .../aggregation/dsl/MetricAggregationBuilder.java | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java index ff3ad2a6c8b..4a9603fe6bd 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException { verifyDataRows(response, rows(1000)); } + @Test + public void testStatsDistinctCount() throws IOException { + JSONObject response = + executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("distinct_count(gender)", null, "integer")); + verifyDataRows(response, rows(2)); + + response = + executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("dc(age)", null, "integer")); + verifyDataRows(response, rows(21)); + } + @Test public void testStatsMin() throws IOException { JSONObject response = executeQuery(String.format( diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index c3f27153229..07f74742323 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -5,6 +5,7 @@ SELECT DistanceMiles, ROW_NUMBER() OVER (ORDER BY DistanceMiles DESC) AS num FRO SELECT DistanceMiles, RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, COUNT(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT DistanceMiles, COUNT(DISTINCT DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights @@ -24,6 +25,7 @@ SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY F SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, COUNT(DISTINCT day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce @@ -33,6 +35,7 @@ SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_ SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, COUNT(DISTINCT day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 4641fd134ff..7a321b4fce2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -38,8 +38,8 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; -import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; From c750f5979096c730a87c5f46550e75d9854146c9 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 16 Jun 2021 11:09:57 -0700 Subject: [PATCH 17/23] moved distinct window function test cases to WindowsIT Signed-off-by: chloe-zh --- .../opensearch/sql/sql/WindowFunctionIT.java | 48 +++++++++++++++++++ .../resources/correctness/queries/window.txt | 3 -- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java index b92ca17238c..52373a72e32 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java @@ -29,6 +29,8 @@ import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder; + import org.json.JSONObject; import org.junit.Test; @@ -40,6 +42,7 @@ public class WindowFunctionIT extends SQLIntegTestCase { @Override protected void init() throws Exception { loadIndex(Index.BANK_WITH_NULL_VALUES); + loadIndex(Index.BANK); } @Test @@ -74,4 +77,49 @@ public void testOrderByNullLast() { rows(null, 7)); } + @Test + public void testDistinctCountOverNull() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER() " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRows(response, + rows("Duke Willmington", 2), + rows("Bond", 2), + rows("Bates", 2), + rows("Adams", 2), + rows("Ratliff", 2), + rows("Ayala", 2), + rows("Mcpherson", 2)); + } + + @Test + public void testDistinctCountOver() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Adams", 1), + rows("Ayala", 2), + rows("Bates", 2), + rows("Bond", 2), + rows("Duke Willmington", 2), + rows("Mcpherson", 2), + rows("Ratliff", 2)); + } + + @Test + public void testDistinctCountPartition() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(PARTITION BY gender ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Ayala", 1), + rows("Bates", 1), + rows("Mcpherson", 1), + rows("Adams", 1), + rows("Bond", 1), + rows("Duke Willmington", 1), + rows("Ratliff", 1)); + } + } diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index 07f74742323..c3f27153229 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -5,7 +5,6 @@ SELECT DistanceMiles, ROW_NUMBER() OVER (ORDER BY DistanceMiles DESC) AS num FRO SELECT DistanceMiles, RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, COUNT(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights -SELECT DistanceMiles, COUNT(DISTINCT DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights @@ -25,7 +24,6 @@ SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY F SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce -SELECT user, COUNT(DISTINCT day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce @@ -35,7 +33,6 @@ SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_ SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce -SELECT user, COUNT(DISTINCT day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce From 9fa771d3389b87fc432b637e88320c2aeb8dd341 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 16 Jun 2021 14:05:30 -0700 Subject: [PATCH 18/23] added ut Signed-off-by: chloe-zh --- .../aggregation/CountAggregator.java | 10 ----- .../dsl/MetricAggregationBuilderTest.java | 37 ++++++++++++++++++- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 34d064fe46d..b9653796678 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -26,16 +26,6 @@ package org.opensearch.sql.expression.aggregation; -import static org.opensearch.sql.data.model.ExprValueUtils.getBooleanValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getByteValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getCollectionValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getDoubleValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getFloatValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getIntegerValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getLongValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getShortValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getStringValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getTupleValue; import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.HashSet; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 1e157139454..129814d45fb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -53,18 +53,21 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; +import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) class MetricAggregationBuilderTest { + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); @Mock private ExpressionSerializer serializer; @@ -271,7 +274,39 @@ void should_build_cardinality_aggregation() { + "}", buildQuery( Collections.singletonList(named("count(distinct name)", new CountAggregator( - Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); + Collections.singletonList(ref("name", STRING)), INTEGER).distinct(true))))); + } + + @Test + void should_build_filtered_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"filter\" : {\n" + + " \"range\" : {\n" + + " \"age\" : {\n" + + " \"from\" : 30,\n" + + " \"to\" : null,\n" + + " \"include_lower\" : false,\n" + + " \"include_upper\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " },\n" + + " \"aggregations\" : {\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + buildQuery(Collections.singletonList(named( + "count(distinct name) filter(where age > 30)", + new CountAggregator(Collections.singletonList(ref("name", STRING)), INTEGER) + .condition(dsl.greater(ref("age", INTEGER), literal(30))) + .distinct(true))))); } @Test From 5d42554fc70ed755c5d50d53c332cef71b70d897 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 16 Jun 2021 14:32:17 -0700 Subject: [PATCH 19/23] update Signed-off-by: chloe-zh --- .../sql/sql/parser/AstAggregationBuilderTest.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 437e8953fac..44c84495c23 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -178,13 +178,6 @@ void can_build_distinct_aggregator() { hasAggregators( alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( "name")))))); - - assertThat( - buildAggregation("SELECT COUNT(DISTINCT *) FROM test"), - allOf( - hasGroupByItems(), - hasAggregators( - alias("COUNT(DISTINCT *)", distinctAggregate("COUNT", AllFields.of()))))); } @Test From 80a4c611b6931223e39ed2443a44e9338a3a6bba Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Thu, 17 Jun 2021 11:01:42 -0700 Subject: [PATCH 20/23] update Signed-off-by: chloe-zh --- .../sql/expression/aggregation/CountAggregator.java | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index b9653796678..579622b546b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -74,7 +74,7 @@ protected static class CountState implements AggregationState { public void count(ExprValue value, Boolean distinct) { if (distinct) { - if (!duplicated(value)) { + if (!distinctValues.contains(value)) { distinctValues.add(value); count++; } @@ -83,15 +83,6 @@ public void count(ExprValue value, Boolean distinct) { } } - private boolean duplicated(ExprValue value) { - for (ExprValue exprValue : distinctValues) { - if (value.compareTo(exprValue) == 0) { - return true; - } - } - return false; - } - @Override public ExprValue result() { return ExprValueUtils.integerValue(count); From 86db16d961bec757fb7bd1d98ae70385eff19084 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Mon, 26 Jul 2021 15:46:53 -0700 Subject: [PATCH 21/23] addressed comments Signed-off-by: chloe-zh --- .../aggregation/CountAggregator.java | 34 +++++++++++-------- .../dsl/AggregationBuilderHelper.java | 6 ++-- .../dsl/BucketAggregationBuilder.java | 8 ++--- .../dsl/MetricAggregationBuilder.java | 34 +++++++++---------- 4 files changed, 43 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 579622b546b..f1bd088967b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -47,40 +47,34 @@ public CountAggregator(List arguments, ExprCoreType returnType) { @Override public CountAggregator.CountState create() { - return new CountState(); + return distinct ? new DistinctCountState() : new CountState(); } @Override protected CountState iterate(ExprValue value, CountState state) { - state.count(value, distinct); + state.count(value); return state; } @Override public String toString() { - return String.format(Locale.ROOT, "count(%s)", format(getArguments())); + return distinct + ? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments())) + : String.format(Locale.ROOT, "count(%s)", format(getArguments())); } /** * Count State. */ protected static class CountState implements AggregationState { - private int count; - private final Set distinctValues = new HashSet<>(); + protected int count; CountState() { this.count = 0; } - public void count(ExprValue value, Boolean distinct) { - if (distinct) { - if (!distinctValues.contains(value)) { - distinctValues.add(value); - count++; - } - } else { - count++; - } + public void count(ExprValue value) { + count++; } @Override @@ -88,4 +82,16 @@ public ExprValue result() { return ExprValueUtils.integerValue(count); } } + + protected static class DistinctCountState extends CountState { + private final Set distinctValues = new HashSet<>(); + + @Override + public void count(ExprValue value) { + if (!distinctValues.contains(value)) { + distinctValues.add(value); + count++; + } + } + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java index 73d58d793ee..cd793c9046f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java @@ -43,11 +43,9 @@ /** * Abstract Aggregation Builder. - * - * @param type of the actual AggregationBuilder to be built. */ @RequiredArgsConstructor -public class AggregationBuilderHelper { +public class AggregationBuilderHelper { private final ExpressionSerializer serializer; @@ -57,7 +55,7 @@ public class AggregationBuilderHelper { * @param expression Expression * @return AggregationBuilder */ - public T build(Expression expression, Function fieldBuilder, + public T build(Expression expression, Function fieldBuilder, Function scriptBuilder) { if (expression instanceof ReferenceExpression) { String fieldName = ((ReferenceExpression) expression).getAttr(); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java index b1aff2c5b4e..d137cce75d5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java @@ -42,11 +42,11 @@ */ public class BucketAggregationBuilder { - private final AggregationBuilderHelper> helper; + private final AggregationBuilderHelper helper; public BucketAggregationBuilder( ExpressionSerializer serializer) { - this.helper = new AggregationBuilderHelper<>(serializer); + this.helper = new AggregationBuilderHelper(serializer); } /** @@ -62,8 +62,8 @@ public List> build( .missingBucket(true) .order(groupPair.getRight()); resultBuilder - .add(helper.build(groupPair.getLeft().getDelegated(), valuesSourceBuilder::field, - valuesSourceBuilder::script)); + .add((CompositeValuesSourceBuilder) helper.build(groupPair.getLeft().getDelegated(), + valuesSourceBuilder::field, valuesSourceBuilder::script)); } return resultBuilder.build(); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 7a321b4fce2..754da498629 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -59,16 +59,14 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor, Object> { - private final AggregationBuilderHelper> valuesSourceAggHelper; - private final AggregationBuilderHelper cardinalityAggHelper; + private final AggregationBuilderHelper helper; private final FilterQueryBuilder filterBuilder; /** * Constructor. */ public MetricAggregationBuilder(ExpressionSerializer serializer) { - this.valuesSourceAggHelper = new AggregationBuilderHelper<>(serializer); - this.cardinalityAggHelper = new AggregationBuilderHelper<>(serializer); + this.helper = new AggregationBuilderHelper(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -97,22 +95,24 @@ public Pair visitNamedAggregator( Expression condition = node.getDelegated().condition(); Boolean distinct = node.getDelegated().distinct(); String name = node.getName(); + String functionName = node.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT); if (distinct) { - if ("count".equals(node.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT))) { - return make( - AggregationBuilders.cardinality(name), - expression, - condition, - name, - new SingleValueParser(name)); - } else { - throw new IllegalStateException(String.format( - "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + switch (functionName) { + case "count": + return make( + AggregationBuilders.cardinality(name), + expression, + condition, + name, + new SingleValueParser(name)); + default: + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); } } - switch (node.getFunctionName().getFunctionName()) { + switch (functionName) { case "avg": return make( AggregationBuilders.avg(name), @@ -189,7 +189,7 @@ private Pair make( String name, MetricParser parser) { ValuesSourceAggregationBuilder aggregationBuilder = - valuesSourceAggHelper.build(expression, builder::field, builder::script); + helper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( makeFilterAggregation(aggregationBuilder, condition, name), @@ -207,7 +207,7 @@ private Pair make(CardinalityAggregationBuilde String name, MetricParser parser) { CardinalityAggregationBuilder aggregationBuilder = - cardinalityAggHelper.build(expression, builder::field, builder::script); + helper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( makeFilterAggregation(aggregationBuilder, condition, name), From f5cece5e6875b14cb9f72ab54cd11043527767ba Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Mon, 26 Jul 2021 15:59:35 -0700 Subject: [PATCH 22/23] added test cases to meet the coverage requirement Signed-off-by: chloe-zh --- .../sql/expression/aggregation/CountAggregatorTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 73bb37a3daf..ec7c22de011 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -181,6 +181,9 @@ public void valueOf() { public void test_to_string() { Aggregator countAggregator = dsl.count(DSL.ref("integer_value", INTEGER)); assertEquals("count(integer_value)", countAggregator.toString()); + + countAggregator = dsl.distinctCount(DSL.ref("integer_value", INTEGER)); + assertEquals("count(distinct integer_value)", countAggregator.toString()); } @Test From 832e0190f404691937181c2ad474f6a6130695d0 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Tue, 27 Jul 2021 12:59:08 -0700 Subject: [PATCH 23/23] added test cases for distinct count map and array types Signed-off-by: chloe-zh --- .../aggregation/AggregationTest.java | 24 +++++++++++++++---- .../aggregation/CountAggregatorTest.java | 14 +++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index 2cce9018bf8..1db33ac9d5f 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -118,10 +118,26 @@ public class AggregationTest extends ExpressionTestBase { protected static List tuples_with_duplicates = Arrays.asList( - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 4d)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 3d)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, "double_value", 2d)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, "double_value", 1d))); + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 1, + "double_value", 4d, + "struct_value", ImmutableMap.of("str", 1), + "array_value", ImmutableList.of(1))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 1, + "double_value", 3d, + "struct_value", ImmutableMap.of("str", 1), + "array_value", ImmutableList.of(1))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 2, + "double_value", 2d, + "struct_value", ImmutableMap.of("str", 2), + "array_value", ImmutableList.of(2))), + ExprValueUtils.tupleValue(ImmutableMap.of( + "integer_value", 3, + "double_value", 1d, + "struct_value", ImmutableMap.of("str1", 1), + "array_value", ImmutableList.of(1, 2)))); protected static List tuples_with_null_and_missing = Arrays.asList( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index ec7c22de011..ee183dafcef 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -144,6 +144,20 @@ public void filtered_distinct_count() { assertEquals(2, result.value()); } + @Test + public void distinct_count_map() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("struct_value", STRUCT)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + + @Test + public void distinct_count_array() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("array_value", ARRAY)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),