diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 8b95ee0c4d3c..8ce0e57b6915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -217,9 +217,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable Examples: > SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height); Alice 0 2 165.0 - Bob 0 5 180.0 Alice 1 2 165.0 NULL 3 7 172.5 + Bob 0 5 180.0 Bob 1 5 180.0 NULL 2 2 165.0 NULL 2 5 180.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1dc4a8eaabaa..f7b9f1f466ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1710,6 +1710,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ENABLE_TWOLEVEL_AGG_MAP_PARTIAL_ONLY = + buildConf("spark.sql.codegen.aggregate.map.twolevel.partialOnly") + .internal() + .doc("Enable two-level aggregate hash map for partial aggregate only, " + + "because final aggregate might get more distinct keys compared to partial aggregate. " + + "Overhead of looking up 1st-level map might dominate when having a lot of distinct keys.") + .version("3.2.1") + .booleanConf + .createWithDefault(true) + val ENABLE_VECTORIZED_HASH_MAP = buildConf("spark.sql.codegen.aggregate.map.vectorized.enable") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index da310b6e4be7..854515402860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -667,7 +667,14 @@ case class HashAggregateExec( val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType]) .forall(!DecimalType.isByteArrayDecimalType(_)) - isSupported && isNotByteArrayDecimalType + val isEnabledForAggModes = + if (modes.forall(mode => mode == Partial || mode == PartialMerge)) { + true + } else { + !conf.getConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP_PARTIAL_ONLY) + } + + isSupported && isNotByteArrayDecimalType && isEnabledForAggModes } private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {