diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index af3a8fe684bb3..aa2610d5f87c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -292,13 +292,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Final aggregate val operators = expressions.map { e => val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af) { x => - val condition = if (e.filter.isDefined) { - e.filter.map(distinctAggFilterAttrLookup.get(_)).get - } else { - None + val condition = e.filter.map(distinctAggFilterAttrLookup.get(_)).flatten + val naf = if (af.children.forall(_.foldable)) { + // If aggregateFunction's children are all foldable, we only put the first child in + // distinctAggGroups. So here we only need to rewrite the first child to + // `if (gid = ...) ...` or `if (gid = ... and condition) ...`. + val firstChild = evalWithinGroup(id, af.children.head, condition) + af.withNewChildren(firstChild +: af.children.drop(1)).asInstanceOf[AggregateFunction] + } else { + patchAggregateFunctionChildren(af) { x => + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } - distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/count.sql b/sql/core/src/test/resources/sql-tests/inputs/count.sql index 9f9ee4a873d4f..203f04c589373 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/count.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/count.sql @@ -25,3 +25,13 @@ SELECT count(a, b), count(b, a), count(testData.*) FROM testData; SELECT count(DISTINCT a, b), count(DISTINCT b, a), count(DISTINCT *), count(DISTINCT testData.*) FROM testData; + +-- distinct count with multiple literals +SELECT count(DISTINCT 3,2); +SELECT count(DISTINCT 2), count(DISTINCT 2,3); +SELECT count(DISTINCT 2), count(DISTINCT 3,2); +SELECT count(DISTINCT a), count(DISTINCT 2,3) FROM testData; +SELECT count(DISTINCT a), count(DISTINCT 3,2) FROM testData; +SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 2,3) FROM testData; +SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 3,2) FROM testData; +SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index 24d303621faea..e4193d845f2e2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -43,6 +43,14 @@ SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredat SELECT COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData; SELECT COUNT(DISTINCT id) FILTER (WHERE true) FROM emp; SELECT COUNT(DISTINCT id) FILTER (WHERE false) FROM emp; +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id = 40) FROM emp; +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id = 40) FROM emp; +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id > 0) FROM emp; +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id > 0) FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id = 40) FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id = 40) FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id > 0) FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id > 0) FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; diff --git a/sql/core/src/test/resources/sql-tests/results/count.sql.out b/sql/core/src/test/resources/sql-tests/results/count.sql.out index 68a5114bb5859..c0cdd0d697538 100644 --- a/sql/core/src/test/resources/sql-tests/results/count.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/count.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 13 -- !query @@ -53,3 +53,67 @@ FROM testData struct -- !query output 3 3 3 3 + + +-- !query +SELECT count(DISTINCT 3,2) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT count(DISTINCT 2), count(DISTINCT 2,3) +-- !query schema +struct +-- !query output +1 1 + + +-- !query +SELECT count(DISTINCT 2), count(DISTINCT 3,2) +-- !query schema +struct +-- !query output +1 1 + + +-- !query +SELECT count(DISTINCT a), count(DISTINCT 2,3) FROM testData +-- !query schema +struct +-- !query output +2 1 + + +-- !query +SELECT count(DISTINCT a), count(DISTINCT 3,2) FROM testData +-- !query schema +struct +-- !query output +2 1 + + +-- !query +SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 2,3) FROM testData +-- !query schema +struct +-- !query output +2 1 1 + + +-- !query +SELECT count(DISTINCT a), count(DISTINCT 2), count(DISTINCT 3,2) FROM testData +-- !query schema +struct +-- !query output +2 1 1 + + +-- !query +SELECT count(distinct 0.8), percentile_approx(distinct a, 0.8) FROM testData +-- !query schema +struct +-- !query output +1 2 \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index c349d9d84c226..89a4da116a6b3 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 68 +-- Number of queries: 76 -- !query @@ -150,6 +150,70 @@ struct 0 +-- !query +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id = 40) FROM emp +-- !query schema +struct +-- !query output +1 0 + + +-- !query +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id = 40) FROM emp +-- !query schema +struct +-- !query output +1 0 + + +-- !query +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id > 0) FROM emp +-- !query schema +struct 0)):bigint> +-- !query output +1 1 + + +-- !query +SELECT COUNT(DISTINCT 2), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id > 0) FROM emp +-- !query schema +struct 0)):bigint> +-- !query output +1 1 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id = 40) FROM emp +-- !query schema +struct +-- !query output +8 0 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id = 40) FROM emp +-- !query schema +struct +-- !query output +8 0 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 2,3) FILTER (WHERE dept_id > 0) FROM emp +-- !query schema +struct 0)):bigint> +-- !query output +8 1 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT 3,2) FILTER (WHERE dept_id > 0) FROM emp +-- !query schema +struct 0)):bigint> +-- !query output +8 1 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema