diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 68ddec9fc8d0..946fc7f421ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1811,6 +1811,14 @@ object CodeGenerator extends Logging { def boxedType(dt: DataType): String = boxedType(javaType(dt)) + def typeName(clazz: Class[_]): String = { + if (clazz.isArray) { + typeName(clazz.getComponentType) + "[]" + } else { + clazz.getName + } + } + /** * Returns the representation of default value for a given Java Type. * @param jt the string name of the Java type diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2d187e3c9ebe..5dc5b822919b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -299,7 +299,9 @@ case class HashAggregateExec( if (inputVars.forall(_.isDefined)) { val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") - val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") + val argList = args.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + }.mkString(", ") val doAggFuncName = ctx.addNewFunction(doAggFunc, s""" |private void $doAggFunc($argList) throws java.io.IOException { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 8c7e5bf5ac1d..4a3277f5a7e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1028,6 +1028,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-29140: HashAggregateExec aggregating binary type doesn't break codegen compilation") { + val schema = new StructType().add("id", IntegerType, nullable = false) + .add("c1", BinaryType, nullable = true) + + withSQLConf( + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") { + val emptyRows = spark.sparkContext.parallelize(Seq.empty[Row], 1) + val aggDf = spark.createDataFrame(emptyRows, schema) + .groupBy($"id" % 10 as "group") + .agg(countDistinct($"c1")) + checkAnswer(aggDf, Seq.empty[Row]) + } + } }