diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1b8d33b94fbd2..394ba3f8bb8c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -140,34 +139,6 @@ object JDBCRDD extends Logging { }) } - def compileAggregates( - aggregates: Seq[AggregateFunc], - dialect: JdbcDialect): Option[Seq[String]] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) - - Some(aggregates.map { - case min: Min => - if (min.column.fieldNames.length != 1) return None - s"MIN(${quote(min.column.fieldNames.head)})" - case max: Max => - if (max.column.fieldNames.length != 1) return None - s"MAX(${quote(max.column.fieldNames.head)})" - case count: Count => - if (count.column.fieldNames.length != 1) return None - val distinct = if (count.isDistinct) "DISTINCT " else "" - val column = quote(count.column.fieldNames.head) - s"COUNT($distinct$column)" - case sum: Sum => - if (sum.column.fieldNames.length != 1) return None - val distinct = if (sum.isDistinct) "DISTINCT " else "" - val column = quote(sum.column.fieldNames.head) - s"SUM($distinct$column)" - case _: CountStar => - s"COUNT(*)" - case _ => return None - }) - } - /** * Build and return JDBCRDD from the given information. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 7605b03f49ea5..d3c141ed53c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -72,8 +72,8 @@ case class JDBCScanBuilder( if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) - val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect) - if (compiledAgg.isEmpty) return false + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_)) + if (compiledAggs.length != aggregation.aggregateExpressions.length) return false val groupByCols = aggregation.groupByColumns.map { col => if (col.fieldNames.length != 1) return false @@ -84,7 +84,7 @@ case class JDBCScanBuilder( // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - val selectList = groupByCols ++ compiledAgg.get + val selectList = groupByCols ++ compiledAggs val groupByClause = if (groupByCols.isEmpty) { "" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c7db771c73a68..fcc2be2d16d88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -193,6 +194,36 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + /** + * Converts aggregate function to String representing a SQL expression. + * @param aggregate The aggregate function to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + aggFunction match { + case min: Min => + if (min.column.fieldNames.length != 1) return None + Some(s"MIN(${quoteIdentifier(min.column.fieldNames.head)})") + case max: Max => + if (max.column.fieldNames.length != 1) return None + Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + case count: Count => + if (count.column.fieldNames.length != 1) return None + val distinct = if (count.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(count.column.fieldNames.head) + Some(s"COUNT($distinct$column)") + case sum: Sum => + if (sum.column.fieldNames.length != 1) return None + val distinct = if (sum.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(sum.column.fieldNames.head) + Some(s"SUM($distinct$column)") + case _: CountStar => + Some(s"COUNT(*)") + case _ => None + } + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading.