Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
aggregator.distinct(node.getDistinct());
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ public static UnresolvedExpression filteredAggregate(
return new AggregateFunction(func, field, condition);
}

public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) {
return new AggregateFunction(func, field, true);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
return new Function(funcName, Arrays.asList(funcArgs));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class AggregateFunction extends UnresolvedExpression {
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
private UnresolvedExpression condition;
private Boolean distinct = false;

/**
* Constructor.
Expand All @@ -72,6 +73,19 @@ public AggregateFunction(String funcName, UnresolvedExpression field,
this.condition = condition;
}

/**
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param distinct field is distinct.
*/
public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.distinct = distinct;
}

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(field);
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,26 @@ public Aggregator avg(Expression... expressions) {
return aggregate(BuiltinFunctionName.AVG, expressions);
}

public Aggregator distinctAvg(Expression... expressions) {
return avg(expressions).distinct(true);
}

public Aggregator sum(Expression... expressions) {
return aggregate(BuiltinFunctionName.SUM, expressions);
}

public Aggregator distinctSum(Expression... expressions) {
return sum(expressions).distinct(true);
}

public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator distinctCount(Expression... expressions) {
return count(expressions).distinct(true);
}

public RankingWindowFunction rowNumber() {
return (RankingWindowFunction) repository.compile(
BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList());
Expand All @@ -519,10 +531,18 @@ public Aggregator min(Expression... expressions) {
return aggregate(BuiltinFunctionName.MIN, expressions);
}

public Aggregator distinctMin(Expression... expressions) {
return min(expressions).distinct(true);
}

public Aggregator max(Expression... expressions) {
return aggregate(BuiltinFunctionName.MAX, expressions);
}

public Aggregator distinctMax(Expression... expressions) {
return max(expressions).distinct(true);
}

private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) {
return (FunctionExpression) repository.compile(
functionName.getName(), Arrays.asList(expressions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

package org.opensearch.sql.expression.aggregation;

import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.storage.bindingtuple.BindingTuple;

Expand All @@ -37,4 +38,8 @@ public interface AggregationState {
* Get {@link ExprValue} result.
*/
ExprValue result();

default Set<ExprValue> distinctValues() {
return Set.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
@Accessors(fluent = true)
protected Expression condition;
@Setter
@Getter
@Accessors(fluent = true)
protected Boolean distinct = false;



/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand All @@ -89,7 +95,8 @@ public abstract class Aggregator<S extends AggregationState>
*/
public S iterate(BindingTuple tuple, S state) {
ExprValue value = getArguments().get(0).valueOf(tuple);
if (value.isNull() || value.isMissing() || !conditionValue(tuple)) {
if (value.isNull() || value.isMissing() || !conditionValue(tuple)
|| (distinct && duplicated(value, state))) {
return state;
}
return iterate(value, state);
Expand Down Expand Up @@ -121,4 +128,13 @@ public boolean conditionValue(BindingTuple tuple) {
return ExprValueUtils.getBooleanValue(condition.valueOf(tuple));
}

private Boolean duplicated(ExprValue value, S state) {
for (ExprValue exprValue : state.distinctValues()) {
if (exprValue.compareTo(value) == 0) {
return true;
}
}
return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand All @@ -50,7 +52,7 @@ public CountAggregator.CountState create() {

@Override
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
state.count(value);
return state;
}

Expand All @@ -64,14 +66,25 @@ public String toString() {
*/
protected static class CountState implements AggregationState {
private int count;
private final Set<ExprValue> set = new HashSet<>();

CountState() {
this.count = 0;
}

public void count(ExprValue value) {
set.add(value);
count++;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}

@Override
public Set<ExprValue> distinctValues() {
return set;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.List;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ public void aggregation_filter() {
);
}

@Test
public void distinct_aggregation() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
AstDSL.distinctAggregate("count", qualifiedName("integer_value"))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ public class AggregationTest extends ExpressionTestBase {
"timestamp_value",
"2040-01-01 07:00:00")));

protected static List<ExprValue> tuples_with_duplicates =
Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3)));

protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*
*/

package org.opensearch.sql.expression.aggregation;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;
import org.opensearch.sql.data.model.ExprIntegerValue;

public class AggregatorStateTest extends AggregationTest {

@Test
void count_distinct_values() {
CountAggregator.CountState state = new CountAggregator.CountState();
state.count(new ExprIntegerValue(1));
assertFalse(state.distinctValues().isEmpty());
}

@Test
void default_distinct_values() {
AvgAggregator.AvgState state = new AvgAggregator.AvgState();
assertTrue(state.distinctValues().isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ public void filtered_avg() {
assertEquals(3.0, result.value());
}

@Test
public void distinct_avg() {
assertThrows(ExpressionEvaluationException.class,
() -> dsl.distinctAvg(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()),
"unsupported distinct aggregator avg");
}

@Test
public void avg_with_missing() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ public void filtered_count() {
assertEquals(3, result.value());
}

@Test
public void distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ public void filtered_max() {
assertEquals(3, result.value());
}

@Test
public void distinct_max() {
assertThrows(ExpressionEvaluationException.class,
() -> dsl.distinctMax(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()),
"unsupported distinct aggregator max");
}

@Test
public void test_max_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ public void filtered_min() {
assertEquals(2, result.value());
}

@Test
public void distinct_min() {
assertThrows(ExpressionEvaluationException.class,
() -> dsl.distinctMin(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()),
"unsupported distinct aggregator min");
}

@Test
public void test_min_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ public void filtered_sum() {
assertEquals(9, result.value());
}

@Test
public void distinct_sum() {
assertThrows(ExpressionEvaluationException.class,
() -> dsl.distinctSum(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()),
"unsupported distinct aggregator sum");
}

@Test
public void sum_with_missing() {
ExprValue result =
Expand Down
13 changes: 13 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments
2. ``COUNT(*)`` will count the number of all its input rows.
3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count.

DISTINCT Aggregation
--------------------

To get the aggregation of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the aggregation function. Currently the distinct aggregation is only supported in ``COUNT`` aggregation. Example::

os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts;
fetched rows / total rows = 1/1
+--------------------------+-----------------+
| COUNT(DISTINCT gender) | COUNT(gender) |
|--------------------------+-----------------|
| 2 | 4 |
+--------------------------+-----------------+

HAVING Clause
=============

Expand Down
15 changes: 15 additions & 0 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,18 @@ PPL query::
| 36 | 32 | M |
+------------+------------+----------+

Example 7: Calculate the distinct count of a field
==================================================

To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts.

PPL query::

os> source=accounts | stats count(gender), distinct_count(gender);
fetched rows / total rows = 1/1
+-----------------+--------------------------+
| count(gender) | distinct_count(gender) |
|-----------------+--------------------------|
| 4 | 2 |
+-----------------+--------------------------+

Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights
SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights
SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights
SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights
SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights
SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights
SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights
Loading